use std::sync::Arc;
use std::time::Duration;
use coralstack_cmd_ipc::prelude::*;
use coralstack_cmd_ipc::Config;
use futures::executor::{block_on, ThreadPool};
use futures::task::SpawnExt;
#[payload]
struct AddReq {
a: i64,
b: i64,
}
#[payload]
struct SubReq {
a: i64,
b: i64,
}
struct MathService;
#[command_service]
impl MathService {
#[command("math.add", description = "Add two integers")]
async fn add(&self, req: AddReq) -> Result<i64, CommandError> {
Ok(req.a + req.b)
}
#[command("math.sub")]
async fn sub(&self, req: SubReq) -> Result<i64, CommandError> {
Ok(req.a - req.b)
}
#[command("_internal.ping")]
async fn ping(&self, _: ()) -> Result<String, CommandError> {
Ok("pong".into())
}
}
#[command("greet")]
async fn greet(name: String) -> Result<String, CommandError> {
Ok(format!("hello, {name}"))
}
fn config(id: &str, router: Option<&str>) -> Config {
Config {
id: Some(id.into()),
router_channel: router.map(String::from),
request_ttl: Duration::from_secs(5),
event_ttl: Duration::from_secs(5),
max_in_flight_per_channel: 256,
}
}
fn wire_pair(
a_id: &str,
b_id: &str,
a_router: Option<&str>,
b_router: Option<&str>,
) -> (CommandRegistry, CommandRegistry, ThreadPool) {
let (ch_for_a, ch_for_b) = InMemoryChannel::pair(b_id, a_id);
let ch_for_a: Arc<dyn CommandChannel> = ch_for_a;
let ch_for_b: Arc<dyn CommandChannel> = ch_for_b;
let reg_a = CommandRegistry::new(config(a_id, a_router));
let reg_b = CommandRegistry::new(config(b_id, b_router));
let pool = ThreadPool::new().unwrap();
block_on(async {
let driver_a = reg_a.register_channel(ch_for_a).await.unwrap();
let driver_b = reg_b.register_channel(ch_for_b).await.unwrap();
pool.spawn(driver_a).unwrap();
pool.spawn(driver_b).unwrap();
});
(reg_a, reg_b, pool)
}
#[test]
fn impl_block_macro_registers_and_executes_across_channel() {
let (reg_a, reg_b, _pool) = wire_pair("root", "worker", None, Some("root"));
block_on(async {
MathService.register(®_b).await.unwrap();
let sum: i64 = reg_a
.execute::<math_service::Add>(AddReq { a: 2, b: 3 })
.await
.unwrap();
assert_eq!(sum, 5);
let diff: i64 = reg_a
.execute::<math_service::Sub>(SubReq { a: 10, b: 4 })
.await
.unwrap();
assert_eq!(diff, 6);
});
}
#[test]
fn strict_execute_via_nested_module_path() {
let reg = CommandRegistry::new(config("solo", None));
block_on(async {
MathService.register(®).await.unwrap();
let sum: i64 = reg
.execute::<math_service::Add>(AddReq { a: 7, b: 8 })
.await
.unwrap();
assert_eq!(sum, 15);
let diff: i64 = reg
.execute::<math_service::Sub>(SubReq { a: 20, b: 5 })
.await
.unwrap();
assert_eq!(diff, 15);
});
}
#[event("worker.ready")]
struct WorkerReady {
worker_id: String,
command_count: u32,
}
#[test]
fn typed_event_emit_and_on_round_trips_across_channel() {
let (reg_a, reg_b, _pool) = wire_pair("root", "worker", None, Some("root"));
let hits = Arc::new(std::sync::Mutex::new(Vec::<WorkerReady>::new()));
let h = hits.clone();
let _unsub = reg_a.on::<WorkerReady>(move |event| {
h.lock().unwrap().push(event);
});
reg_b
.emit(WorkerReady {
worker_id: "w1".into(),
command_count: 7,
})
.unwrap();
block_on(async {
for _ in 0..50 {
if !hits.lock().unwrap().is_empty() {
break;
}
let (tx, rx) = futures::channel::oneshot::channel();
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(20));
let _ = tx.send(());
});
let _ = rx.await;
}
});
let seen = hits.lock().unwrap();
assert_eq!(seen.len(), 1);
assert_eq!(seen[0].worker_id, "w1");
assert_eq!(seen[0].command_count, 7);
}
#[test]
fn free_fn_macro_registers_via_factory() {
let (reg_a, _reg_b, _pool) = wire_pair("root", "worker", None, Some("root"));
block_on(async {
register_greet(®_a).await.unwrap();
let hello: String = reg_a
.execute::<GreetCommand>("world".to_string())
.await
.unwrap();
assert_eq!(hello, "hello, world");
});
}
#[test]
fn private_command_stays_local() {
let (reg_a, reg_b, _pool) = wire_pair("root", "worker", None, Some("root"));
block_on(async {
MathService.register(®_b).await.unwrap();
let err = reg_a
.execute_dyn("_internal.ping", serde_json::Value::Null)
.await
.unwrap_err();
assert!(matches!(err, CommandError::NotFound(_)));
let got: String = reg_b.execute::<math_service::Ping>(()).await.unwrap();
assert_eq!(got, "pong");
});
}
#[test]
fn free_fn_macro_exposes_schema_after_registration() {
let reg = CommandRegistry::new(config("solo", None));
block_on(async {
register_greet(®).await.unwrap();
});
let def = reg
.list_commands()
.into_iter()
.find(|d| d.id == "greet")
.expect("greet should be registered");
let schema = def.schema.expect("macro should populate schema");
assert!(schema.request.is_some());
assert!(schema.response.is_some());
let resp = schema.response.unwrap().to_string();
assert!(
resp.contains("string"),
"unexpected response schema: {resp}"
);
}
struct UnnormalizedCommand;
impl Command for UnnormalizedCommand {
const ID: &'static str = "hand.rolled";
type Request = serde_json::Value;
type Response = serde_json::Value;
async fn handle(&self, _req: serde_json::Value) -> Result<serde_json::Value, CommandError> {
Ok(serde_json::Value::Null)
}
fn schema(&self) -> Option<coralstack_cmd_ipc::CommandSchema> {
Some(coralstack_cmd_ipc::CommandSchema {
request: Some(serde_json::json!({
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "HandRolled",
"type": "object",
"properties": {
"n": { "title": "int64", "type": "integer", "format": "int64" }
},
"required": ["n"]
})),
response: Some(serde_json::json!({
"title": "HandRolledOut",
"type": "integer",
"format": "int64"
})),
})
}
}
#[test]
fn registry_normalizes_hand_written_schema() {
use futures::executor::block_on;
let reg = CommandRegistry::new(config("solo", None));
block_on(async {
reg.register_command(UnnormalizedCommand).await.unwrap();
});
let defs = reg.list_commands();
let def = defs.iter().find(|d| d.id == "hand.rolled").unwrap();
let req = def.schema.as_ref().unwrap().request.as_ref().unwrap();
let resp = def.schema.as_ref().unwrap().response.as_ref().unwrap();
assert!(req.get("title").is_none(), "title leaked: {req}");
assert!(req.get("$schema").is_none(), "$schema leaked: {req}");
assert!(resp.get("title").is_none(), "title leaked: {resp}");
assert_eq!(req["additionalProperties"], serde_json::Value::Bool(false));
assert!(
req["properties"]["n"].get("format").is_none(),
"format leaked on property: {req}"
);
assert!(
resp.get("format").is_none(),
"format leaked on response: {resp}"
);
}