use std::sync::{
Arc,
atomic::{AtomicBool, AtomicUsize, Ordering},
};
use crate::observer::{InferenceWriterObserver, InferenceWriterStats};
#[derive(Debug, thiserror::Error)]
pub enum InferenceWriterError {
#[error("inference was cancelled")]
Cancelled,
#[error("unknown error: {0}")]
Unknown(Box<dyn std::error::Error + Send + Sync>),
}
pub struct InferenceWriter<O> {
channel: Box<dyn InferenceWriterChannel<O>>,
instant: std::time::Instant,
observer: Option<Arc<dyn InferenceWriterObserver>>,
outputs: AtomicUsize,
errors: AtomicUsize,
cancelled: AtomicBool,
finished: AtomicBool,
}
impl<O> InferenceWriter<O> {
pub(crate) fn new(channel: Box<dyn InferenceWriterChannel<O>>) -> Self {
Self {
channel,
instant: std::time::Instant::now(),
observer: None,
outputs: AtomicUsize::new(0),
errors: AtomicUsize::new(0),
cancelled: AtomicBool::new(false),
finished: AtomicBool::new(false),
}
}
pub(crate) fn from_channel<C>(channel: C) -> Self
where
C: InferenceWriterChannel<O> + 'static,
{
Self::new(Box::new(channel))
}
pub fn with_observer(mut self, observer: Arc<dyn InferenceWriterObserver>) -> Self {
self.observer = Some(observer);
self
}
pub fn write(&self, output: O) -> Result<(), InferenceWriterError> {
match self.channel.write(output) {
Ok(()) => {
self.outputs.fetch_add(1, Ordering::Relaxed);
if let Some(ref observer) = self.observer {
observer.on_write();
}
Ok(())
}
Err(err) => {
if matches!(&err, InferenceWriterError::Cancelled) {
self.cancelled.store(true, Ordering::Release);
if let Some(ref observer) = self.observer {
observer.on_cancelled();
}
}
Err(err)
}
}
}
pub fn error<E>(&self, error: E) -> Result<(), InferenceWriterError>
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
match self.channel.error(error.into()) {
Ok(()) => {
self.errors.fetch_add(1, Ordering::Relaxed);
if let Some(ref observer) = self.observer {
observer.on_error();
}
Ok(())
}
Err(err) => {
if matches!(&err, InferenceWriterError::Cancelled) {
self.cancelled.store(true, Ordering::Release);
if let Some(ref observer) = self.observer {
observer.on_cancelled();
}
}
Err(err)
}
}
}
fn finish(&self) {
let duration = self.instant.elapsed();
self.channel.finish(duration);
if self.finished.swap(true, Ordering::AcqRel) {
return;
}
if let Some(ref observer) = self.observer {
observer.on_finish(&InferenceWriterStats {
duration,
outputs: self.outputs.load(Ordering::Acquire),
errors: self.errors.load(Ordering::Acquire),
cancelled: self.cancelled.load(Ordering::Acquire),
});
}
}
}
impl<O> Drop for InferenceWriter<O> {
fn drop(&mut self) {
self.finish();
}
}
pub trait InferenceWriterChannel<O> {
fn write(&self, output: O) -> Result<(), InferenceWriterError>;
fn error(
&self,
error: Box<dyn std::error::Error + Send + Sync>,
) -> Result<(), InferenceWriterError>;
fn finish(&self, duration: std::time::Duration);
}