use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use nv_core::error::StageError;
use nv_core::health::HealthEvent;
use nv_core::id::StageId;
use nv_perception::StageOutput;
use nv_perception::batch::BatchProcessor;
use tokio::sync::broadcast;
use super::config::BatchConfig;
use super::handle::{BatchHandle, BatchHandleInner, PendingEntry};
use super::metrics::BatchMetricsInner;
const DEFAULT_ON_START_TIMEOUT: Duration = Duration::from_secs(30);
const SHUTDOWN_POLL_INTERVAL: Duration = Duration::from_millis(100);
const BATCH_ERROR_THROTTLE: Duration = Duration::from_secs(1);
pub(crate) struct BatchCoordinator {
shutdown: Arc<AtomicBool>,
thread: Option<std::thread::JoinHandle<()>>,
handle: BatchHandle,
}
impl std::fmt::Debug for BatchCoordinator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BatchCoordinator").finish_non_exhaustive()
}
}
impl BatchCoordinator {
pub fn start(
mut processor: Box<dyn BatchProcessor>,
config: BatchConfig,
health_tx: broadcast::Sender<HealthEvent>,
) -> Result<Self, nv_core::error::NvError> {
use nv_core::error::{NvError, RuntimeError};
config.validate().map_err(NvError::Config)?;
let queue_depth = match config.queue_capacity {
Some(cap) => cap,
None => config.max_batch_size.saturating_mul(4).max(4),
};
let (submit_tx, submit_rx) = std::sync::mpsc::sync_channel(queue_depth);
let metrics = Arc::new(BatchMetricsInner::new());
let shutdown = Arc::new(AtomicBool::new(false));
let processor_id = processor.id();
let capabilities = processor.capabilities();
let handle = BatchHandle {
inner: Arc::new(BatchHandleInner {
submit_tx,
metrics: Arc::clone(&metrics),
processor_id,
config: config.clone(),
capabilities,
}),
};
let (startup_tx, startup_rx) = std::sync::mpsc::sync_channel::<Result<(), String>>(1);
let shutdown_clone = Arc::clone(&shutdown);
let on_start_timeout = config.startup_timeout.unwrap_or(DEFAULT_ON_START_TIMEOUT);
let thread = std::thread::Builder::new()
.name(format!("nv-batch-{}", processor_id))
.spawn(move || {
let start_result =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| processor.on_start()));
match start_result {
Ok(Ok(())) => {
let _ = startup_tx.send(Ok(()));
}
Ok(Err(e)) => {
let _ =
startup_tx.send(Err(format!("batch processor on_start failed: {e}")));
return;
}
Err(_) => {
let _ = startup_tx.send(Err(format!(
"batch processor '{}' panicked in on_start",
processor_id
)));
return;
}
}
coordinator_loop(
submit_rx,
processor,
config,
shutdown_clone,
metrics,
health_tx,
);
})
.map_err(|e| {
NvError::Runtime(RuntimeError::ThreadSpawnFailed {
detail: format!("batch coordinator thread: {e}"),
})
})?;
match startup_rx.recv_timeout(on_start_timeout) {
Ok(Ok(())) => {}
Ok(Err(detail)) => {
let _ = thread.join();
return Err(NvError::Runtime(RuntimeError::ThreadSpawnFailed { detail }));
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
shutdown.store(true, Ordering::Relaxed);
const STARTUP_JOIN_GRACE: Duration = Duration::from_secs(2);
let detached = {
let (done_tx, done_rx) = std::sync::mpsc::channel();
let _ = std::thread::Builder::new()
.name(format!("nv-join-startup-{processor_id}"))
.spawn(move || {
let _ = thread.join();
let _ = done_tx.send(());
});
done_rx.recv_timeout(STARTUP_JOIN_GRACE).is_err()
};
let detail = if detached {
tracing::warn!(
processor = %processor_id,
timeout_secs = on_start_timeout.as_secs(),
"batch processor on_start timed out — coordinator thread detached \
(safe Rust cannot force-stop a blocked thread)"
);
format!(
"batch processor '{}' on_start did not complete within {}s; \
coordinator thread detached (cannot force-stop blocked on_start)",
processor_id,
on_start_timeout.as_secs(),
)
} else {
format!(
"batch processor '{}' on_start did not complete within {}s",
processor_id,
on_start_timeout.as_secs(),
)
};
return Err(NvError::Runtime(RuntimeError::ThreadSpawnFailed { detail }));
}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
return Err(NvError::Runtime(RuntimeError::ThreadSpawnFailed {
detail: "batch coordinator thread exited during startup".into(),
}));
}
}
Ok(Self {
shutdown,
thread: Some(thread),
handle,
})
}
pub fn handle(&self) -> BatchHandle {
self.handle.clone()
}
pub fn signal_shutdown(&self) {
self.shutdown.store(true, Ordering::Relaxed);
}
pub fn shutdown(mut self, timeout: Duration) -> Option<crate::runtime::DetachedJoin> {
self.shutdown.store(true, Ordering::Relaxed);
if let Some(thread) = self.thread.take() {
bounded_coordinator_join(thread, self.handle.processor_id(), timeout)
} else {
None
}
}
}
impl Drop for BatchCoordinator {
fn drop(&mut self) {
self.shutdown.store(true, Ordering::Relaxed);
}
}
fn coordinator_loop(
rx: std::sync::mpsc::Receiver<PendingEntry>,
mut processor: Box<dyn BatchProcessor>,
config: BatchConfig,
shutdown: Arc<AtomicBool>,
metrics: Arc<BatchMetricsInner>,
health_tx: broadcast::Sender<HealthEvent>,
) {
let mut batch: Vec<PendingEntry> = Vec::with_capacity(config.max_batch_size);
let mut entries: Vec<nv_perception::batch::BatchEntry> =
Vec::with_capacity(config.max_batch_size);
let mut responses: Vec<std::sync::mpsc::SyncSender<Result<StageOutput, StageError>>> =
Vec::with_capacity(config.max_batch_size);
let mut in_flight_guards: Vec<Option<Arc<AtomicUsize>>> =
Vec::with_capacity(config.max_batch_size);
let mut last_batch_error_event: Option<Instant> = None;
'outer: loop {
if shutdown.load(Ordering::Relaxed) {
break;
}
batch.clear();
let first = loop {
if shutdown.load(Ordering::Relaxed) {
break 'outer;
}
match rx.recv_timeout(SHUTDOWN_POLL_INTERVAL) {
Ok(item) => break item,
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => continue,
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
break 'outer;
}
}
};
let formation_start = Instant::now();
batch.push(first);
let deadline = Instant::now() + config.max_latency;
while batch.len() < config.max_batch_size {
if shutdown.load(Ordering::Relaxed) {
break;
}
let now = Instant::now();
if now >= deadline {
break;
}
let wait = (deadline - now).min(SHUTDOWN_POLL_INTERVAL);
match rx.recv_timeout(wait) {
Ok(item) => batch.push(item),
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => continue,
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => break,
}
}
let formation_ns = formation_start.elapsed().as_nanos() as u64;
let batch_size = batch.len();
entries.clear();
responses.clear();
in_flight_guards.clear();
for pending in batch.drain(..) {
entries.push(pending.entry);
responses.push(pending.response_tx);
in_flight_guards.push(pending.in_flight_guard);
}
let process_start = Instant::now();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
processor.process(&mut entries)
}));
let processing_ns = process_start.elapsed().as_nanos() as u64;
metrics.record_dispatch(batch_size, formation_ns, processing_ns);
tracing::debug!(
processor = %processor.id(),
batch_size,
formation_ms = formation_ns / 1_000_000,
processing_ms = processing_ns / 1_000_000,
"batch dispatched"
);
match result {
Ok(Ok(())) => {
metrics.record_batch_success();
for ((entry, tx), guard) in entries
.drain(..)
.zip(responses.drain(..))
.zip(in_flight_guards.drain(..))
{
if let Some(g) = guard {
g.fetch_sub(1, Ordering::Release);
}
let output = entry.output.unwrap_or_default();
let _ = tx.send(Ok(output));
}
}
Ok(Err(stage_err)) => {
metrics.record_batch_error();
tracing::error!(
processor = %processor.id(),
error = %stage_err,
batch_size,
"batch processor error"
);
let now = Instant::now();
let should_emit = last_batch_error_event
.is_none_or(|t| now.duration_since(t) >= BATCH_ERROR_THROTTLE);
if should_emit {
last_batch_error_event = Some(now);
let _ = health_tx.send(HealthEvent::BatchError {
processor_id: processor.id(),
batch_size: batch_size as u32,
error: stage_err.clone(),
});
}
for (tx, guard) in responses.drain(..).zip(in_flight_guards.drain(..)) {
if let Some(g) = guard {
g.fetch_sub(1, Ordering::Release);
}
let _ = tx.send(Err(stage_err.clone()));
}
}
Err(_panic) => {
metrics.record_batch_error();
tracing::error!(
processor = %processor.id(),
batch_size,
"batch processor panicked"
);
let err = StageError::ProcessingFailed {
stage_id: processor.id(),
detail: "batch processor panicked".into(),
};
let now = Instant::now();
let should_emit = last_batch_error_event
.is_none_or(|t| now.duration_since(t) >= BATCH_ERROR_THROTTLE);
if should_emit {
last_batch_error_event = Some(now);
let _ = health_tx.send(HealthEvent::BatchError {
processor_id: processor.id(),
batch_size: batch_size as u32,
error: err.clone(),
});
}
for (tx, guard) in responses.drain(..).zip(in_flight_guards.drain(..)) {
if let Some(g) = guard {
g.fetch_sub(1, Ordering::Release);
}
let _ = tx.send(Err(err.clone()));
}
}
}
}
let drained = drain_pending(&rx);
if drained > 0 {
tracing::debug!(
processor = %processor.id(),
drained,
"drained pending items on coordinator shutdown"
);
}
drop(rx);
call_on_stop(&mut *processor);
}
fn drain_pending(rx: &std::sync::mpsc::Receiver<PendingEntry>) -> u64 {
let mut count = 0u64;
while let Ok(pe) = rx.try_recv() {
if let Some(ref g) = pe.in_flight_guard {
g.fetch_sub(1, Ordering::Release);
}
count += 1;
}
count
}
fn call_on_stop(processor: &mut dyn BatchProcessor) {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| processor.on_stop()));
match result {
Ok(Ok(())) => {}
Ok(Err(e)) => {
tracing::warn!(
processor = %processor.id(),
error = %e,
"batch processor on_stop error (ignored)"
);
}
Err(_) => {
tracing::error!(
processor = %processor.id(),
"batch processor on_stop panicked (ignored)"
);
}
}
}
fn bounded_coordinator_join(
thread: std::thread::JoinHandle<()>,
processor_id: StageId,
timeout: Duration,
) -> Option<crate::runtime::DetachedJoin> {
let (done_tx, done_rx) = std::sync::mpsc::channel();
let label = format!("nv-join-batch-{processor_id}");
let joiner = std::thread::Builder::new()
.name(label.clone())
.spawn(move || {
let result = thread.join();
let _ = done_tx.send(result);
});
match done_rx.recv_timeout(timeout) {
Ok(Ok(())) => None,
Ok(Err(_)) => {
tracing::error!(
processor = %processor_id,
"batch coordinator thread panicked during join"
);
None
}
Err(_) => {
tracing::warn!(
processor = %processor_id,
timeout_secs = timeout.as_secs(),
"batch coordinator thread did not finish within timeout — detaching"
);
joiner.ok().map(|j| crate::runtime::DetachedJoin {
label,
done_rx,
joiner: j,
})
}
}
}