Skip to main content

burn_central_inference/
inference.rs

1use std::sync::Arc;
2
3use crate::{InferenceWriter, InferenceWriterChannel};
4
5// TODO: maybe this should require send + sync
6pub trait Inference {
7    type Input;
8    type Output;
9
10    fn infer(&self, input: Self::Input, writer: InferenceWriter<Self::Output>);
11}
12
13pub struct InferenceWrapper<I, O> {
14    inner: Arc<dyn Inference<Input = I, Output = O> + Send + Sync>,
15}
16
17impl<I, O> Clone for InferenceWrapper<I, O> {
18    fn clone(&self) -> Self {
19        Self {
20            inner: Arc::clone(&self.inner),
21        }
22    }
23}
24
25impl<I, O> InferenceWrapper<I, O> {
26    fn new<T>(inference: T) -> Self
27    where
28        T: Inference<Input = I, Output = O> + Send + Sync + 'static,
29    {
30        Self {
31            inner: Arc::new(inference),
32        }
33    }
34}
35
36impl<T, I, O> From<T> for InferenceWrapper<I, O>
37where
38    T: Inference<Input = I, Output = O> + Send + Sync + 'static,
39{
40    fn from(inference: T) -> Self {
41        Self::new(inference)
42    }
43}
44
45impl<I, O> InferenceWrapper<I, O> {
46    pub fn infer<T: InferenceWriterChannel<O> + 'static>(&self, input: I, writer: T) {
47        self.inner
48            .infer(input, InferenceWriter::from_channel(writer));
49    }
50}