burn_central_inference/
writer.rs1use std::sync::{
2 Arc,
3 atomic::{AtomicBool, AtomicUsize, Ordering},
4};
5
6use crate::observer::{InferenceWriterObserver, InferenceWriterStats};
7
8#[derive(Debug, thiserror::Error)]
10pub enum InferenceWriterError {
11 #[error("inference was cancelled")]
12 Cancelled,
13 #[error("unknown error: {0}")]
14 Unknown(Box<dyn std::error::Error + Send + Sync>),
15}
16
17pub struct InferenceWriter<O> {
19 channel: Box<dyn InferenceWriterChannel<O>>,
20 instant: std::time::Instant,
21 observer: Option<Arc<dyn InferenceWriterObserver>>,
22 outputs: AtomicUsize,
23 errors: AtomicUsize,
24 cancelled: AtomicBool,
25 finished: AtomicBool,
26}
27
28impl<O> InferenceWriter<O> {
29 pub(crate) fn new(channel: Box<dyn InferenceWriterChannel<O>>) -> Self {
30 Self {
31 channel,
32 instant: std::time::Instant::now(),
33 observer: None,
34 outputs: AtomicUsize::new(0),
35 errors: AtomicUsize::new(0),
36 cancelled: AtomicBool::new(false),
37 finished: AtomicBool::new(false),
38 }
39 }
40
41 pub(crate) fn from_channel<C>(channel: C) -> Self
42 where
43 C: InferenceWriterChannel<O> + 'static,
44 {
45 Self::new(Box::new(channel))
46 }
47
48 pub fn with_observer(mut self, observer: Arc<dyn InferenceWriterObserver>) -> Self {
49 self.observer = Some(observer);
50 self
51 }
52
53 pub fn write(&self, output: O) -> Result<(), InferenceWriterError> {
55 match self.channel.write(output) {
56 Ok(()) => {
57 self.outputs.fetch_add(1, Ordering::Relaxed);
58 if let Some(ref observer) = self.observer {
59 observer.on_write();
60 }
61 Ok(())
62 }
63 Err(err) => {
64 if matches!(&err, InferenceWriterError::Cancelled) {
65 self.cancelled.store(true, Ordering::Release);
66 if let Some(ref observer) = self.observer {
67 observer.on_cancelled();
68 }
69 }
70 Err(err)
71 }
72 }
73 }
74
75 pub fn error<E>(&self, error: E) -> Result<(), InferenceWriterError>
77 where
78 E: Into<Box<dyn std::error::Error + Send + Sync>>,
79 {
80 match self.channel.error(error.into()) {
81 Ok(()) => {
82 self.errors.fetch_add(1, Ordering::Relaxed);
83 if let Some(ref observer) = self.observer {
84 observer.on_error();
85 }
86 Ok(())
87 }
88 Err(err) => {
89 if matches!(&err, InferenceWriterError::Cancelled) {
90 self.cancelled.store(true, Ordering::Release);
91 if let Some(ref observer) = self.observer {
92 observer.on_cancelled();
93 }
94 }
95 Err(err)
96 }
97 }
98 }
99
100 fn finish(&self) {
101 let duration = self.instant.elapsed();
102 self.channel.finish(duration);
103
104 if self.finished.swap(true, Ordering::AcqRel) {
105 return;
106 }
107
108 if let Some(ref observer) = self.observer {
109 observer.on_finish(&InferenceWriterStats {
110 duration,
111 outputs: self.outputs.load(Ordering::Acquire),
112 errors: self.errors.load(Ordering::Acquire),
113 cancelled: self.cancelled.load(Ordering::Acquire),
114 });
115 }
116 }
117}
118
119impl<O> Drop for InferenceWriter<O> {
121 fn drop(&mut self) {
122 self.finish();
123 }
124}
125
126pub trait InferenceWriterChannel<O> {
129 fn write(&self, output: O) -> Result<(), InferenceWriterError>;
131 fn error(
133 &self,
134 error: Box<dyn std::error::Error + Send + Sync>,
135 ) -> Result<(), InferenceWriterError>;
136 fn finish(&self, duration: std::time::Duration);
138}