use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use nv_core::error::StageError;
use nv_core::id::StageId;
use nv_perception::StageOutput;
use nv_perception::batch::BatchEntry;
use super::config::BatchConfig;
use super::metrics::{BatchMetrics, BatchMetricsInner};
const DEFAULT_RESPONSE_TIMEOUT: Duration = Duration::from_secs(5);
pub(super) struct PendingEntry {
pub(super) entry: BatchEntry,
pub(super) response_tx: std::sync::mpsc::SyncSender<Result<StageOutput, StageError>>,
pub(super) in_flight_guard: Option<Arc<AtomicUsize>>,
}
#[derive(Clone)]
pub struct BatchHandle {
pub(crate) inner: Arc<BatchHandleInner>,
}
pub(crate) struct BatchHandleInner {
pub(super) submit_tx: std::sync::mpsc::SyncSender<PendingEntry>,
pub(super) metrics: Arc<BatchMetricsInner>,
pub(super) processor_id: StageId,
pub(super) config: BatchConfig,
pub(super) capabilities: Option<nv_perception::stage::StageCapabilities>,
}
impl BatchHandle {
#[must_use]
pub fn processor_id(&self) -> StageId {
self.inner.processor_id
}
#[must_use]
pub fn capabilities(&self) -> Option<&nv_perception::stage::StageCapabilities> {
self.inner.capabilities.as_ref()
}
#[must_use]
pub fn metrics(&self) -> BatchMetrics {
self.inner
.metrics
.snapshot(self.inner.config.max_batch_size as u64)
}
pub(crate) fn record_timeout(&self) {
self.inner.metrics.record_timeout();
}
pub(crate) fn submit_and_wait(
&self,
entry: BatchEntry,
in_flight: Option<&Arc<AtomicUsize>>,
) -> Result<StageOutput, BatchSubmitError> {
let (response_tx, response_rx) = std::sync::mpsc::sync_channel(1);
self.inner.metrics.record_submission();
let guard = if let Some(counter) = in_flight {
let prev = counter.fetch_add(1, Ordering::Acquire);
if prev >= self.inner.config.max_in_flight_per_feed {
counter.fetch_sub(1, Ordering::Release);
self.inner.metrics.record_rejection();
return Err(BatchSubmitError::InFlightCapReached);
}
Some(Arc::clone(counter))
} else {
None
};
self.inner
.submit_tx
.try_send(PendingEntry {
entry,
response_tx,
in_flight_guard: guard.clone(),
})
.map_err(|e| {
if let Some(ref g) = guard {
g.fetch_sub(1, Ordering::Release);
}
match e {
std::sync::mpsc::TrySendError::Full(_) => {
self.inner.metrics.record_rejection();
BatchSubmitError::QueueFull
}
std::sync::mpsc::TrySendError::Disconnected(_) => {
self.inner.metrics.record_rejection();
BatchSubmitError::CoordinatorShutdown
}
}
})?;
let safety = self
.inner
.config
.response_timeout
.unwrap_or(DEFAULT_RESPONSE_TIMEOUT);
let timeout = self.inner.config.max_latency + safety;
match response_rx.recv_timeout(timeout) {
Ok(Ok(output)) => Ok(output),
Ok(Err(stage_err)) => Err(BatchSubmitError::ProcessingFailed(stage_err)),
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => Err(BatchSubmitError::Timeout),
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
Err(BatchSubmitError::CoordinatorShutdown)
}
}
}
}
#[derive(Debug)]
pub(crate) enum BatchSubmitError {
QueueFull,
CoordinatorShutdown,
ProcessingFailed(StageError),
Timeout,
InFlightCapReached,
}