burn-central-inference 0.6.0

Inference contracts and adapters for Burn Central SDK.
Documentation
use crate::{InferenceWrapper, InferenceWriterChannel, writer::InferenceWriterError};

use crossbeam::channel as cb;

use std::{
    sync::{
        Arc,
        atomic::{AtomicBool, Ordering},
    },
    thread,
    time::Duration,
};

pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;

pub enum StreamEvent<O> {
    Output(O),
    Error(BoxError),
    Done(Duration),
}

pub struct InferenceStream<O> {
    rx: cb::Receiver<StreamEvent<O>>,
    cancel: Arc<AtomicBool>,
    worker: Option<thread::JoinHandle<()>>,
}

impl<O> InferenceStream<O> {
    pub fn cancel(&self) {
        self.cancel.store(true, Ordering::Release);
    }

    fn join(&mut self) {
        if let Some(worker) = self.worker.take() {
            worker.join().unwrap();
        }
    }
}

impl<O> Drop for InferenceStream<O> {
    fn drop(&mut self) {
        self.cancel();
        self.join();
    }
}

impl<O> Iterator for InferenceStream<O> {
    type Item = Result<O, BoxError>;

    fn next(&mut self) -> Option<Self::Item> {
        match self.rx.recv().ok()? {
            StreamEvent::Output(o) => Some(Ok(o)),
            StreamEvent::Error(e) => Some(Err(e)),
            StreamEvent::Done(_) => None,
        }
    }
}

struct StreamingChannel<O> {
    tx: cb::Sender<StreamEvent<O>>,
    cancel: Arc<AtomicBool>,
}

impl<O> InferenceWriterChannel<O> for StreamingChannel<O>
where
    O: Send + Sync + 'static,
{
    fn write(&self, output: O) -> Result<(), InferenceWriterError> {
        if self.cancel.load(Ordering::Acquire) {
            return Err(InferenceWriterError::Cancelled);
        }

        self.tx
            .send(StreamEvent::Output(output))
            .map_err(|e| InferenceWriterError::Unknown(Box::new(e)))
    }

    fn error(&self, error: BoxError) -> Result<(), InferenceWriterError> {
        self.tx
            .send(StreamEvent::Error(error))
            .map_err(|e| InferenceWriterError::Unknown(Box::new(e)))
    }

    fn finish(&self, duration: Duration) {
        let _ = self.tx.send(StreamEvent::Done(duration));
    }
}

impl<O> Drop for StreamingChannel<O> {
    fn drop(&mut self) {}
}

pub struct DirectInference<I, O> {
    inner: InferenceWrapper<I, O>,
}

impl<I, O> DirectInference<I, O>
where
    I: Send + 'static,
    O: Send + Sync + 'static,
{
    pub fn new(inference: InferenceWrapper<I, O>) -> Self {
        Self { inner: inference }
    }

    pub fn stream(&self, input: I) -> InferenceStream<O> {
        let (tx, rx) = cb::unbounded();
        let cancel = Arc::new(AtomicBool::new(false));

        let channel = StreamingChannel {
            tx: tx.clone(),
            cancel: cancel.clone(),
        };

        let inference = self.inner.clone();

        let worker = thread::spawn(move || {
            inference.infer(input, channel);
        });

        InferenceStream {
            rx,
            cancel,
            worker: Some(worker),
        }
    }
}