burn_central_runtime/inference/
streaming.rs1use derive_more::Deref;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::{Arc, Mutex};
4
5#[derive(Debug, thiserror::Error)]
8pub struct EmitError<T> {
9 #[source]
10 pub source: anyhow::Error,
11 pub item: T,
12}
13
14pub trait Emitter<T>: Send + Sync + 'static {
16 fn emit(&self, item: T) -> Result<(), EmitError<T>>;
17}
18
19#[derive(Clone)]
21pub struct CancelToken(Arc<AtomicBool>);
22
23impl Default for CancelToken {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl CancelToken {
30 pub fn new() -> Self {
31 Self(Arc::new(AtomicBool::new(false)))
32 }
33
34 pub fn cancel(&self) {
35 self.0.store(true, Ordering::SeqCst)
36 }
37
38 pub fn is_cancelled(&self) -> bool {
39 self.0.load(Ordering::SeqCst)
40 }
41}
42
43pub struct CollectEmitter<T>(Mutex<Vec<T>>);
45
46impl<T> CollectEmitter<T> {
47 pub fn new() -> Self {
48 Self(Mutex::new(Vec::new()))
49 }
50
51 pub fn into_inner(self) -> Vec<T> {
52 self.0.into_inner().unwrap()
53 }
54}
55
56impl<T> Default for CollectEmitter<T> {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl<T: Send + 'static> Emitter<T> for CollectEmitter<T> {
63 fn emit(&self, item: T) -> Result<(), EmitError<T>> {
64 self.0.lock().unwrap().push(item);
65 Ok(())
66 }
67}
68
69pub struct SyncChannelEmitter<T> {
71 tx: crossbeam::channel::Sender<T>,
72}
73
74impl<T: Send + 'static> SyncChannelEmitter<T> {
75 pub fn new(tx: crossbeam::channel::Sender<T>) -> Self {
76 Self { tx }
77 }
78}
79
80impl<T: Send + 'static> Emitter<T> for SyncChannelEmitter<T> {
81 fn emit(&self, item: T) -> Result<(), EmitError<T>> {
82 match self.tx.try_send(item) {
83 Ok(_) => Ok(()),
84 Err(crossbeam::channel::TrySendError::Full(item)) => Err(EmitError {
85 source: anyhow::anyhow!("Channel is full"),
86 item,
87 }),
88 Err(crossbeam::channel::TrySendError::Disconnected(item)) => Err(EmitError {
89 source: anyhow::anyhow!("Channel is disconnected"),
90 item,
91 }),
92 }
93 }
94}
95
96#[derive(Clone, Deref)]
98pub struct OutStream<T> {
99 emitter: Arc<dyn Emitter<T>>,
100}
101
102impl<T> OutStream<T> {
103 pub fn new(emitter: Arc<dyn Emitter<T>>) -> Self {
105 Self { emitter }
106 }
107}