use std::future::Future;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use coralstack_cmd_ipc::{
Command, CommandChannel, CommandError, CommandRegistry, Config, DynEvent, InMemoryChannel,
};
use futures::channel::oneshot;
use futures::executor::{block_on, ThreadPool};
use futures::future::join_all;
use futures::lock::Mutex as AsyncMutex;
use futures::task::SpawnExt;
use futures::FutureExt;
use serde::{Deserialize, Serialize};
use serde_json::json;
fn sleep_ms(ms: u64) -> impl Future<Output = ()> {
let (tx, rx) = oneshot::channel();
std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(ms));
let _ = tx.send(());
});
async move {
let _ = rx.await;
}
}
fn config(id: &str, router: Option<&str>) -> Config {
Config {
id: Some(id.into()),
router_channel: router.map(String::from),
request_ttl: Duration::from_secs(10),
event_ttl: Duration::from_secs(2),
max_in_flight_per_channel: 256,
}
}
fn config_with_cap(id: &str, router: Option<&str>, cap: usize) -> Config {
let mut c = config(id, router);
c.max_in_flight_per_channel = cap;
c
}
fn wire(
cfg_a: Config,
cfg_b: Config,
pool: &ThreadPool,
) -> (
CommandRegistry,
CommandRegistry,
Arc<dyn CommandChannel>,
Arc<dyn CommandChannel>,
) {
let a_id = cfg_a.id.clone().unwrap();
let b_id = cfg_b.id.clone().unwrap();
let (ch_for_a, ch_for_b) = InMemoryChannel::pair(b_id.clone(), a_id.clone());
let ch_for_a: Arc<dyn CommandChannel> = ch_for_a;
let ch_for_b: Arc<dyn CommandChannel> = ch_for_b;
let reg_a = CommandRegistry::new(cfg_a);
let reg_b = CommandRegistry::new(cfg_b);
block_on(async {
let drv_a = reg_a.register_channel(ch_for_a.clone()).await.unwrap();
let drv_b = reg_b.register_channel(ch_for_b.clone()).await.unwrap();
pool.spawn(drv_a).unwrap();
pool.spawn(drv_b).unwrap();
});
(reg_a, reg_b, ch_for_a, ch_for_b)
}
struct SlowCmd;
#[derive(Deserialize, Serialize)]
struct SleepReq {
ms: u64,
tag: String,
}
#[derive(Deserialize, Serialize, Debug, PartialEq)]
struct TaggedResp {
tag: String,
finished_at_ms: u64,
}
impl Command for SlowCmd {
const ID: &'static str = "slow";
type Request = SleepReq;
type Response = TaggedResp;
async fn handle(&self, req: SleepReq) -> Result<TaggedResp, CommandError> {
let start = Instant::now();
sleep_ms(req.ms).await;
Ok(TaggedResp {
tag: req.tag,
finished_at_ms: start.elapsed().as_millis() as u64,
})
}
}
struct FastCmd;
impl Command for FastCmd {
const ID: &'static str = "fast";
type Request = String;
type Response = String;
async fn handle(&self, req: String) -> Result<String, CommandError> {
Ok(format!("got:{req}"))
}
}
struct BarrierCmd {
counter: Arc<AtomicUsize>,
high_water: Arc<AtomicUsize>,
release: Arc<AsyncMutex<Option<futures::future::Shared<oneshot::Receiver<()>>>>>,
}
impl Command for BarrierCmd {
const ID: &'static str = "barrier";
type Request = ();
type Response = ();
async fn handle(&self, _req: ()) -> Result<(), CommandError> {
let now = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
let mut prev = self.high_water.load(Ordering::SeqCst);
while now > prev {
match self
.high_water
.compare_exchange(prev, now, Ordering::SeqCst, Ordering::SeqCst)
{
Ok(_) => break,
Err(cur) => prev = cur,
}
}
let rx = self.release.lock().await.as_ref().cloned();
if let Some(rx) = rx {
let _ = rx.await;
}
self.counter.fetch_sub(1, Ordering::SeqCst);
Ok(())
}
}
#[test]
fn fast_command_does_not_wait_for_slow_one() {
let pool = ThreadPool::new().unwrap();
let (reg_a, reg_b, _ca, _cb) = wire(config("a", None), config("b", Some("a")), &pool);
block_on(async {
reg_a.register_command(SlowCmd).await.unwrap();
reg_a.register_command(FastCmd).await.unwrap();
let slow_fut = reg_b.execute::<SlowCmd>(SleepReq {
ms: 300,
tag: "slow".into(),
});
sleep_ms(20).await;
let fast_fut = reg_b.execute::<FastCmd>("hi".to_string());
let started = Instant::now();
let fast = fast_fut.await.unwrap();
let fast_elapsed = started.elapsed();
assert_eq!(fast, "got:hi");
assert!(
fast_elapsed < Duration::from_millis(250),
"fast call took {fast_elapsed:?}, head-of-line blocked by slow handler"
);
let slow = slow_fut.await.unwrap();
assert_eq!(slow.tag, "slow");
});
}
#[test]
fn backpressure_cap_limits_concurrent_handlers() {
let pool = ThreadPool::new().unwrap();
let (reg_a, reg_b, _ca, _cb) =
wire(config_with_cap("a", None, 4), config("b", Some("a")), &pool);
let counter = Arc::new(AtomicUsize::new(0));
let high_water = Arc::new(AtomicUsize::new(0));
let (release_tx, release_rx) = oneshot::channel::<()>();
let release_shared = release_rx.shared();
let release_slot = Arc::new(AsyncMutex::new(Some(release_shared)));
block_on(async {
reg_a
.register_command(BarrierCmd {
counter: counter.clone(),
high_water: high_water.clone(),
release: release_slot,
})
.await
.unwrap();
let mut handles = Vec::new();
for _ in 0..10 {
let reg = reg_b.clone();
let h = pool
.spawn_with_handle(async move { reg.execute_dyn("barrier", json!(null)).await })
.unwrap();
handles.push(h);
}
sleep_ms(300).await;
let hw = high_water.load(Ordering::SeqCst);
assert!(hw > 0, "no handler ever started");
assert!(hw <= 4, "high water {hw} exceeds cap of 4");
let _ = release_tx.send(());
let results = join_all(handles).await;
for r in results {
r.unwrap();
}
assert_eq!(counter.load(Ordering::SeqCst), 0);
});
}
#[test]
fn responses_match_originating_requests() {
let pool = ThreadPool::new().unwrap();
let (reg_a, reg_b, _ca, _cb) = wire(config("a", None), config("b", Some("a")), &pool);
block_on(async {
reg_a.register_command(SlowCmd).await.unwrap();
let pattern: Vec<(u64, &'static str)> = vec![
(200, "A"),
(10, "B"),
(150, "C"),
(5, "D"),
(80, "E"),
(40, "F"),
(20, "G"),
(120, "H"),
(1, "I"),
(60, "J"),
];
let futs: Vec<_> = pattern
.iter()
.map(|(ms, tag)| {
reg_b.execute::<SlowCmd>(SleepReq {
ms: *ms,
tag: (*tag).to_string(),
})
})
.collect();
let results = join_all(futs).await;
for ((_, expected_tag), res) in pattern.iter().zip(results.iter()) {
let r = res.as_ref().unwrap();
assert_eq!(&r.tag, *expected_tag, "thid correlation failed");
}
});
}
#[test]
fn events_flow_while_slow_handler_in_flight() {
let pool = ThreadPool::new().unwrap();
let (reg_a, reg_b, _ca, _cb) = wire(config("a", None), config("b", Some("a")), &pool);
let received = Arc::new(AtomicUsize::new(0));
{
let received = received.clone();
std::mem::forget(reg_a.on_dyn("tick", move |_| {
received.fetch_add(1, Ordering::SeqCst);
}));
}
block_on(async {
reg_a.register_command(SlowCmd).await.unwrap();
let slow_fut = reg_b.execute::<SlowCmd>(SleepReq {
ms: 400,
tag: "slow".into(),
});
sleep_ms(30).await;
for _ in 0..100 {
reg_b.emit(DynEvent::new("tick", json!(null))).unwrap();
}
for _ in 0..20 {
if received.load(Ordering::SeqCst) == 100 {
break;
}
sleep_ms(10).await;
}
let count = received.load(Ordering::SeqCst);
assert_eq!(
count, 100,
"events stalled behind handler: only {count}/100"
);
let _ = slow_fut.await.unwrap();
});
}
#[test]
fn channel_close_during_slow_handler_is_clean() {
let pool = ThreadPool::new().unwrap();
let (reg_a, reg_b, _ca, cb) = wire(config("a", None), config("b", Some("a")), &pool);
block_on(async {
reg_a.register_command(SlowCmd).await.unwrap();
let slow_fut = reg_b.execute::<SlowCmd>(SleepReq {
ms: 300,
tag: "slow".into(),
});
sleep_ms(50).await;
cb.close().await;
let err = slow_fut.await.unwrap_err();
assert!(
matches!(err, CommandError::ChannelDisconnected),
"expected ChannelDisconnected, got {err:?}"
);
sleep_ms(400).await;
});
}
#[test]
fn forward_execute_does_not_serialize_remote_calls() {
let pool = ThreadPool::new().unwrap();
let (ch_b_for_a, ch_a_for_b) = InMemoryChannel::pair("b", "a");
let (ch_c_for_a, ch_a_for_c) = InMemoryChannel::pair("c", "a");
let ch_b_for_a: Arc<dyn CommandChannel> = ch_b_for_a;
let ch_a_for_b: Arc<dyn CommandChannel> = ch_a_for_b;
let ch_c_for_a: Arc<dyn CommandChannel> = ch_c_for_a;
let ch_a_for_c: Arc<dyn CommandChannel> = ch_a_for_c;
let reg_a = CommandRegistry::new(config("a", None));
let reg_b = CommandRegistry::new(config("b", Some("a")));
let reg_c = CommandRegistry::new(config("c", Some("a")));
block_on(async {
let drv = reg_a.register_channel(ch_b_for_a.clone()).await.unwrap();
pool.spawn(drv).unwrap();
let drv = reg_a.register_channel(ch_c_for_a.clone()).await.unwrap();
pool.spawn(drv).unwrap();
let drv = reg_b.register_channel(ch_a_for_b.clone()).await.unwrap();
pool.spawn(drv).unwrap();
let drv = reg_c.register_channel(ch_a_for_c.clone()).await.unwrap();
pool.spawn(drv).unwrap();
reg_a.register_command(SlowCmd).await.unwrap();
sleep_ms(100).await;
let slow = reg_b.execute::<SlowCmd>(SleepReq {
ms: 300,
tag: "slow".into(),
});
sleep_ms(20).await;
let started = Instant::now();
let fast = reg_c
.execute::<SlowCmd>(SleepReq {
ms: 10,
tag: "fast".into(),
})
.await
.unwrap();
let fast_elapsed = started.elapsed();
assert_eq!(fast.tag, "fast");
assert!(
fast_elapsed < Duration::from_millis(250),
"fast forwarded call took {fast_elapsed:?}, blocked by slow forwarded call"
);
let _ = slow.await.unwrap();
});
}