burn_central_runtime/inference/
streaming.rs

1use derive_more::Deref;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::{Arc, Mutex};
4
5/// Error returned when emitting an item fails.
6/// The item of type [T](EmitError::item) is returned to allow for potential retries.
7#[derive(Debug, thiserror::Error)]
8pub struct EmitError<T> {
9    #[source]
10    pub source: anyhow::Error,
11    pub item: T,
12}
13
14/// The sending side of an output stream for inference outputs.
15pub trait Emitter<T>: Send + Sync + 'static {
16    fn emit(&self, item: T) -> Result<(), EmitError<T>>;
17}
18
19/// A token that can be used to cancel an ongoing inference job.
20#[derive(Clone)]
21pub struct CancelToken(Arc<AtomicBool>);
22
23impl Default for CancelToken {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl CancelToken {
30    pub fn new() -> Self {
31        Self(Arc::new(AtomicBool::new(false)))
32    }
33
34    pub fn cancel(&self) {
35        self.0.store(true, Ordering::SeqCst)
36    }
37
38    pub fn is_cancelled(&self) -> bool {
39        self.0.load(Ordering::SeqCst)
40    }
41}
42
43/// An emitter that collects all emitted items into a vector.
44pub struct CollectEmitter<T>(Mutex<Vec<T>>);
45
46impl<T> CollectEmitter<T> {
47    pub fn new() -> Self {
48        Self(Mutex::new(Vec::new()))
49    }
50
51    pub fn into_inner(self) -> Vec<T> {
52        self.0.into_inner().unwrap()
53    }
54}
55
56impl<T> Default for CollectEmitter<T> {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl<T: Send + 'static> Emitter<T> for CollectEmitter<T> {
63    fn emit(&self, item: T) -> Result<(), EmitError<T>> {
64        self.0.lock().unwrap().push(item);
65        Ok(())
66    }
67}
68
69/// Emitter implementation backed by a bounded (try_send) crossbeam channel allowing non-blocking emission.
70pub struct SyncChannelEmitter<T> {
71    tx: crossbeam::channel::Sender<T>,
72}
73
74impl<T: Send + 'static> SyncChannelEmitter<T> {
75    pub fn new(tx: crossbeam::channel::Sender<T>) -> Self {
76        Self { tx }
77    }
78}
79
80impl<T: Send + 'static> Emitter<T> for SyncChannelEmitter<T> {
81    fn emit(&self, item: T) -> Result<(), EmitError<T>> {
82        match self.tx.try_send(item) {
83            Ok(_) => Ok(()),
84            Err(crossbeam::channel::TrySendError::Full(item)) => Err(EmitError {
85                source: anyhow::anyhow!("Channel is full"),
86                item,
87            }),
88            Err(crossbeam::channel::TrySendError::Disconnected(item)) => Err(EmitError {
89                source: anyhow::anyhow!("Channel is disconnected"),
90                item,
91            }),
92        }
93    }
94}
95
96/// Lightweight cloneable wrapper exposing an [`Emitter`] implementation to user handlers.
97#[derive(Clone, Deref)]
98pub struct OutStream<T> {
99    emitter: Arc<dyn Emitter<T>>,
100}
101
102impl<T> OutStream<T> {
103    /// Create a new [`OutStream`] from a raw emitter trait object.
104    pub fn new(emitter: Arc<dyn Emitter<T>>) -> Self {
105        Self { emitter }
106    }
107}