Skip to main content

burn_central_inference/
writer.rs

1use std::sync::{
2    Arc,
3    atomic::{AtomicBool, AtomicUsize, Ordering},
4};
5
6use crate::observer::{InferenceWriterObserver, InferenceWriterStats};
7
8/// Errors that can occur when writing to an inference channel.
9#[derive(Debug, thiserror::Error)]
10pub enum InferenceWriterError {
11    #[error("inference was cancelled")]
12    Cancelled,
13    #[error("unknown error: {0}")]
14    Unknown(Box<dyn std::error::Error + Send + Sync>),
15}
16
17/// Communication channel for an inference task, allowing the app to send outputs and errors back to the session.
18pub struct InferenceWriter<O> {
19    channel: Box<dyn InferenceWriterChannel<O>>,
20    instant: std::time::Instant,
21    observer: Option<Arc<dyn InferenceWriterObserver>>,
22    outputs: AtomicUsize,
23    errors: AtomicUsize,
24    cancelled: AtomicBool,
25    finished: AtomicBool,
26}
27
28impl<O> InferenceWriter<O> {
29    pub(crate) fn new(channel: Box<dyn InferenceWriterChannel<O>>) -> Self {
30        Self {
31            channel,
32            instant: std::time::Instant::now(),
33            observer: None,
34            outputs: AtomicUsize::new(0),
35            errors: AtomicUsize::new(0),
36            cancelled: AtomicBool::new(false),
37            finished: AtomicBool::new(false),
38        }
39    }
40
41    pub(crate) fn from_channel<C>(channel: C) -> Self
42    where
43        C: InferenceWriterChannel<O> + 'static,
44    {
45        Self::new(Box::new(channel))
46    }
47
48    pub fn with_observer(mut self, observer: Arc<dyn InferenceWriterObserver>) -> Self {
49        self.observer = Some(observer);
50        self
51    }
52
53    /// Respond with an output item. This can be called multiple times to emit multiple items.
54    pub fn write(&self, output: O) -> Result<(), InferenceWriterError> {
55        match self.channel.write(output) {
56            Ok(()) => {
57                self.outputs.fetch_add(1, Ordering::Relaxed);
58                if let Some(ref observer) = self.observer {
59                    observer.on_write();
60                }
61                Ok(())
62            }
63            Err(err) => {
64                if matches!(&err, InferenceWriterError::Cancelled) {
65                    self.cancelled.store(true, Ordering::Release);
66                    if let Some(ref observer) = self.observer {
67                        observer.on_cancelled();
68                    }
69                }
70                Err(err)
71            }
72        }
73    }
74
75    /// Signal an error on the inference.
76    pub fn error<E>(&self, error: E) -> Result<(), InferenceWriterError>
77    where
78        E: Into<Box<dyn std::error::Error + Send + Sync>>,
79    {
80        match self.channel.error(error.into()) {
81            Ok(()) => {
82                self.errors.fetch_add(1, Ordering::Relaxed);
83                if let Some(ref observer) = self.observer {
84                    observer.on_error();
85                }
86                Ok(())
87            }
88            Err(err) => {
89                if matches!(&err, InferenceWriterError::Cancelled) {
90                    self.cancelled.store(true, Ordering::Release);
91                    if let Some(ref observer) = self.observer {
92                        observer.on_cancelled();
93                    }
94                }
95                Err(err)
96            }
97        }
98    }
99
100    fn finish(&self) {
101        let duration = self.instant.elapsed();
102        self.channel.finish(duration);
103
104        if self.finished.swap(true, Ordering::AcqRel) {
105            return;
106        }
107
108        if let Some(ref observer) = self.observer {
109            observer.on_finish(&InferenceWriterStats {
110                duration,
111                outputs: self.outputs.load(Ordering::Acquire),
112                errors: self.errors.load(Ordering::Acquire),
113                cancelled: self.cancelled.load(Ordering::Acquire),
114            });
115        }
116    }
117}
118
119/// When the `InferenceWriter` is dropped, it signals that the inference has finished, allowing the channel to perform any necessary cleanup or finalization.
120impl<O> Drop for InferenceWriter<O> {
121    fn drop(&mut self) {
122        self.finish();
123    }
124}
125
126/// Trait representing an inference task that can be executed with a given input and a writer for outputs.
127/// The inference implementation is responsible for writing outputs and errors to the provided writer, which will be sent back to the session.
128pub trait InferenceWriterChannel<O> {
129    /// Write an output item to the channel. This can be called multiple times to emit multiple items.
130    fn write(&self, output: O) -> Result<(), InferenceWriterError>;
131    /// Signal an error on the inference, which will be sent back to the session.
132    fn error(
133        &self,
134        error: Box<dyn std::error::Error + Send + Sync>,
135    ) -> Result<(), InferenceWriterError>;
136    /// Called when the `InferenceWriter` is dropped, allowing the channel to perform any necessary cleanup or finalization.
137    fn finish(&self, duration: std::time::Duration);
138}