use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use net::adapter::{Adapter, ShardPollResult};
use net::bus::EventBus;
use net::config::{EventBusConfig, ScalingPolicy};
use net::error::AdapterError;
use net::event::{Batch, Event};
use parking_lot::Mutex;
use serde_json::json;
#[derive(Debug, Clone, PartialEq, Eq)]
struct BatchObservation {
shard_id: u16,
sequence_start: u64,
len: usize,
process_nonce: u64,
}
type BatchHandle = Arc<Mutex<Vec<BatchObservation>>>;
type MsgIdHandle = Arc<Mutex<Vec<(u16, u64, usize)>>>;
#[derive(Clone)]
struct RecordingAdapter {
batches: BatchHandle,
msg_ids: MsgIdHandle,
}
impl RecordingAdapter {
fn new() -> (Self, BatchHandle, MsgIdHandle) {
let batches = Arc::new(Mutex::new(Vec::new()));
let msg_ids = Arc::new(Mutex::new(Vec::new()));
(
Self {
batches: batches.clone(),
msg_ids: msg_ids.clone(),
},
batches,
msg_ids,
)
}
}
#[async_trait]
impl Adapter for RecordingAdapter {
async fn init(&mut self) -> Result<(), AdapterError> {
Ok(())
}
async fn on_batch(&self, batch: std::sync::Arc<Batch>) -> Result<(), AdapterError> {
let shard_id = batch.shard_id;
let sequence_start = batch.sequence_start;
let len = batch.len();
{
let mut ids = self.msg_ids.lock();
for i in 0..len {
ids.push((shard_id, sequence_start, i));
}
}
self.batches.lock().push(BatchObservation {
shard_id,
sequence_start,
len,
process_nonce: batch.process_nonce,
});
Ok(())
}
async fn flush(&self) -> Result<(), AdapterError> {
Ok(())
}
async fn shutdown(&self) -> Result<(), AdapterError> {
Ok(())
}
async fn poll_shard(
&self,
_shard_id: u16,
_from_id: Option<&str>,
_limit: usize,
) -> Result<ShardPollResult, AdapterError> {
Ok(ShardPollResult::empty())
}
fn name(&self) -> &'static str {
"recording"
}
}
fn config(num_shards: u16) -> EventBusConfig {
let policy = ScalingPolicy {
min_shards: 1,
max_shards: 16,
cooldown: Duration::from_nanos(1),
..Default::default()
};
EventBusConfig::builder()
.num_shards(num_shards)
.ring_buffer_capacity(1024)
.scaling(policy)
.build()
.unwrap()
}
#[tokio::test]
async fn stranded_flush_does_not_collide_with_worker_msg_ids() {
let (adapter, batches, msg_ids) = RecordingAdapter::new();
let bus = EventBus::new_with_adapter(config(2), Box::new(adapter))
.await
.unwrap();
let added = bus.manual_scale_up(2).await.unwrap();
assert_eq!(added.len(), 2);
for i in 0..2_000u64 {
let _ = bus.ingest(Event::new(json!({"i": i})));
}
bus.flush().await.unwrap();
let removed = bus.manual_scale_down(2).await.unwrap();
assert_eq!(removed.len(), 2);
bus.shutdown().await.unwrap();
let ids = msg_ids.lock().clone();
let mut sorted = ids.clone();
sorted.sort();
sorted.dedup();
assert_eq!(
sorted.len(),
ids.len(),
"duplicate (shard, sequence_start, i) msg-id tuples observed — \
stranded-flush collided with worker batch. \
Total batches: {}, total msg-ids: {}, unique: {}",
batches.lock().len(),
ids.len(),
sorted.len(),
);
}
#[tokio::test]
async fn at_most_one_batch_per_shard_uses_sequence_start_zero() {
let (adapter, batches, _) = RecordingAdapter::new();
let bus = EventBus::new_with_adapter(config(2), Box::new(adapter))
.await
.unwrap();
let _ = bus.manual_scale_up(2).await.unwrap();
for i in 0..1_000u64 {
let _ = bus.ingest(Event::new(json!({"i": i})));
}
bus.flush().await.unwrap();
let _ = bus.manual_scale_down(2).await.unwrap();
bus.shutdown().await.unwrap();
let observations = batches.lock().clone();
use std::collections::HashMap;
let mut zero_starts: HashMap<u16, usize> = HashMap::new();
for o in &observations {
if o.sequence_start == 0 {
*zero_starts.entry(o.shard_id).or_default() += 1;
}
}
for (shard_id, count) in &zero_starts {
assert!(
*count <= 1,
"shard {} produced {} batches with sequence_start=0 — \
stranded-flush re-used the worker's first-batch sequence, \
colliding under JetStream dedup. \
Recorded batches: {:?}",
shard_id,
count,
observations,
);
}
}
#[tokio::test]
async fn events_in_flight_at_finalize_reach_adapter() {
let (adapter, _batches, msg_ids) = RecordingAdapter::new();
let bus = EventBus::new_with_adapter(config(2), Box::new(adapter))
.await
.unwrap();
let added = bus.manual_scale_up(2).await.unwrap();
assert_eq!(added.len(), 2);
const N: u64 = 100;
for i in 0..N {
let _ = bus.ingest(Event::new(json!({"i": i})));
}
let _ = bus.manual_scale_down(2).await.unwrap();
bus.shutdown().await.unwrap();
let total_seen: usize = msg_ids.lock().len();
assert_eq!(
total_seen, N as usize,
"expected exactly {N} events delivered to adapter; got {total_seen}. \
Events lost between BatchWorker pending state and stranded-flush \
(race window)",
);
let ids = msg_ids.lock().clone();
let mut sorted = ids.clone();
sorted.sort();
sorted.dedup();
assert_eq!(
sorted.len(),
ids.len(),
"duplicate msg-id tuples observed during in-flight finalize — \
BatchWorker's pending batch raced with stranded-flush",
);
}
#[tokio::test]
async fn manual_scale_up_succeeds_with_nonzero_cooldown() {
let (adapter, _batches, _msg_ids) = RecordingAdapter::new();
let policy = ScalingPolicy {
min_shards: 1,
max_shards: 8,
cooldown: Duration::from_secs(30),
..Default::default()
};
let cfg = EventBusConfig::builder()
.num_shards(1)
.ring_buffer_capacity(1024)
.scaling(policy)
.build()
.unwrap();
let bus = EventBus::new_with_adapter(cfg, Box::new(adapter))
.await
.unwrap();
let added = bus
.manual_scale_up(4)
.await
.expect("manual_scale_up under nonzero cooldown must succeed");
assert_eq!(added.len(), 4, "all 4 requested shards must be added");
bus.shutdown().await.unwrap();
}
struct SlowRecordingAdapter {
inner: RecordingAdapter,
delay: Duration,
}
#[async_trait]
impl Adapter for SlowRecordingAdapter {
async fn init(&mut self) -> Result<(), AdapterError> {
Ok(())
}
async fn on_batch(&self, batch: std::sync::Arc<Batch>) -> Result<(), AdapterError> {
tokio::time::sleep(self.delay).await;
self.inner.on_batch(batch).await
}
async fn flush(&self) -> Result<(), AdapterError> {
self.inner.flush().await
}
async fn shutdown(&self) -> Result<(), AdapterError> {
self.inner.shutdown().await
}
async fn poll_shard(
&self,
_shard_id: u16,
_from_id: Option<&str>,
_limit: usize,
) -> Result<ShardPollResult, AdapterError> {
Ok(ShardPollResult::empty())
}
fn name(&self) -> &'static str {
"slow_recording"
}
}
#[tokio::test]
async fn stranded_flush_uses_bus_producer_nonce() {
let (recording, batches, _msg_ids) = RecordingAdapter::new();
let slow = SlowRecordingAdapter {
inner: recording,
delay: Duration::from_millis(5),
};
let mut nonce_path = std::env::temp_dir();
let pid = std::process::id();
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
nonce_path.push(format!("net-test-stranded-nonce-{pid}-{nanos}"));
let policy = ScalingPolicy {
min_shards: 1,
max_shards: 16,
cooldown: Duration::from_nanos(1),
..Default::default()
};
let config = EventBusConfig::builder()
.num_shards(2)
.ring_buffer_capacity(2048)
.scaling(policy)
.producer_nonce_path(&nonce_path)
.build()
.unwrap();
let bus = EventBus::new_with_adapter(config, Box::new(slow))
.await
.unwrap();
let added = bus.manual_scale_up(2).await.unwrap();
assert_eq!(added.len(), 2);
for i in 0..5_000u64 {
let _ = bus.ingest(Event::new(json!({"i": i})));
}
let _ = bus.manual_scale_down(2).await.unwrap();
bus.shutdown().await.unwrap();
let observations = batches.lock().clone();
assert!(
!observations.is_empty(),
"expected the recording adapter to have observed at least one batch",
);
let first_nonce = observations[0].process_nonce;
for (i, obs) in observations.iter().enumerate() {
assert_eq!(
obs.process_nonce, first_nonce,
"batch {i} (shard {}, seq {}, len {}) stamped a different \
nonce ({:#x}) than the first batch ({:#x}) — the \
stranded-flush path must use the bus's producer_nonce",
obs.shard_id, obs.sequence_start, obs.len, obs.process_nonce, first_nonce,
);
}
let _ = std::fs::remove_file(&nonce_path);
}
#[tokio::test]
async fn stranded_flush_with_real_stranded_events_uses_post_worker_sequence() {
let (recording, batches, msg_ids) = RecordingAdapter::new();
let slow = SlowRecordingAdapter {
inner: recording,
delay: Duration::from_millis(5),
};
let policy = ScalingPolicy {
min_shards: 1,
max_shards: 16,
cooldown: Duration::from_nanos(1),
..Default::default()
};
let config = EventBusConfig::builder()
.num_shards(2)
.ring_buffer_capacity(2048)
.scaling(policy)
.build()
.unwrap();
let bus = EventBus::new_with_adapter(config, Box::new(slow))
.await
.unwrap();
let added = bus.manual_scale_up(2).await.unwrap();
assert_eq!(added.len(), 2);
for i in 0..5_000u64 {
let _ = bus.ingest(Event::new(json!({"i": i})));
}
let _ = bus.manual_scale_down(2).await.unwrap();
bus.shutdown().await.unwrap();
let ids = msg_ids.lock().clone();
let mut sorted = ids.clone();
sorted.sort();
sorted.dedup();
assert_eq!(
sorted.len(),
ids.len(),
"duplicate msg-id tuples observed in stranded-flush path. \
Total batches: {}, total msg-ids: {}, unique: {}",
batches.lock().len(),
ids.len(),
sorted.len(),
);
}
#[tokio::test]
async fn repeated_scale_cycles_preserve_every_event_with_unique_msg_ids() {
let (adapter, _batches, msg_ids) = RecordingAdapter::new();
let bus = EventBus::new_with_adapter(config(2), Box::new(adapter))
.await
.unwrap();
let mut total_ingested = 0u64;
for cycle in 0..3 {
let _ = bus.manual_scale_up(1).await.unwrap();
for i in 0..200u64 {
if bus
.ingest(Event::new(json!({"cycle": cycle, "i": i})))
.is_ok()
{
total_ingested += 1;
}
}
bus.flush().await.unwrap();
let _ = bus.manual_scale_down(1).await.unwrap();
}
bus.shutdown().await.unwrap();
let ids = msg_ids.lock().clone();
assert_eq!(
ids.len() as u64,
total_ingested,
"{} ingested events; adapter saw {}",
total_ingested,
ids.len(),
);
let mut sorted = ids.clone();
sorted.sort();
sorted.dedup();
assert_eq!(
sorted.len(),
ids.len(),
"duplicate msg-id tuples observed across scale cycles",
);
}
struct WedgedAdapter;
#[async_trait]
impl Adapter for WedgedAdapter {
async fn init(&mut self) -> Result<(), AdapterError> {
Ok(())
}
async fn on_batch(&self, _batch: std::sync::Arc<Batch>) -> Result<(), AdapterError> {
std::future::pending::<()>().await;
unreachable!()
}
async fn flush(&self) -> Result<(), AdapterError> {
Ok(())
}
async fn shutdown(&self) -> Result<(), AdapterError> {
Ok(())
}
async fn poll_shard(
&self,
_shard_id: u16,
_from_id: Option<&str>,
_limit: usize,
) -> Result<ShardPollResult, AdapterError> {
Ok(ShardPollResult::empty())
}
fn name(&self) -> &'static str {
"wedged"
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn manual_scale_down_returns_within_bounded_time_when_adapter_wedged() {
let policy = ScalingPolicy {
min_shards: 1,
max_shards: 16,
cooldown: Duration::from_nanos(1),
..Default::default()
};
let cfg = EventBusConfig::builder()
.num_shards(2)
.ring_buffer_capacity(1024)
.scaling(policy)
.adapter_timeout(Duration::from_millis(150))
.build()
.unwrap();
let bus = EventBus::new_with_adapter(cfg, Box::new(WedgedAdapter))
.await
.unwrap();
bus.manual_scale_up(2).await.unwrap();
for i in 0..2_000u64 {
let _ = bus.ingest(Event::new(json!({"i": i})));
}
let started = std::time::Instant::now();
let result = tokio::time::timeout(Duration::from_secs(15), bus.manual_scale_down(2))
.await
.expect("manual_scale_down hung past 15s — timeout-bounded teardown regressed");
result.expect("manual_scale_down returned Err");
let elapsed = started.elapsed();
assert!(
elapsed < Duration::from_secs(15),
"manual_scale_down took {:?}; expected bounded by adapter_timeout",
elapsed,
);
let _ = tokio::time::timeout(Duration::from_secs(5), bus.shutdown()).await;
}
#[tokio::test]
async fn total_pending_in_rings_reports_stranded_count() {
let policy = ScalingPolicy {
min_shards: 1,
max_shards: 4,
cooldown: Duration::from_nanos(1),
..Default::default()
};
let batch_cfg = net::config::BatchConfig {
min_size: 10_000,
max_size: 10_000,
max_delay: Duration::from_secs(60),
adaptive: false,
velocity_window: Duration::from_millis(100),
};
let cfg = EventBusConfig::builder()
.num_shards(2)
.ring_buffer_capacity(1024)
.scaling(policy)
.batch(batch_cfg)
.build()
.unwrap();
let (recording, _batches, _msg_ids) = RecordingAdapter::new();
let slow = SlowRecordingAdapter {
inner: recording,
delay: Duration::from_secs(60),
};
let bus = EventBus::new_with_adapter(cfg, Box::new(slow))
.await
.unwrap();
const N: u64 = 5_000;
for i in 0..N {
let _ = bus.ingest(Event::new(json!({"i": i})));
}
let pending_in_rings = bus.pending_in_rings();
assert!(
pending_in_rings > 0,
"expected events still in ring buffers before drop \
(got {}); the Drop impl's stranded-in-rings increment \
would be silently 0 and `shutdown_was_lossy` would not \
be set, masking the data-loss incident",
pending_in_rings,
);
drop(bus);
}