burn_central_inference/
inference.rs1use std::sync::Arc;
2
3use crate::{InferenceWriter, InferenceWriterChannel};
4
5pub trait Inference {
7 type Input;
8 type Output;
9
10 fn infer(&self, input: Self::Input, writer: InferenceWriter<Self::Output>);
11}
12
13pub struct InferenceWrapper<I, O> {
14 inner: Arc<dyn Inference<Input = I, Output = O> + Send + Sync>,
15}
16
17impl<I, O> Clone for InferenceWrapper<I, O> {
18 fn clone(&self) -> Self {
19 Self {
20 inner: Arc::clone(&self.inner),
21 }
22 }
23}
24
25impl<I, O> InferenceWrapper<I, O> {
26 fn new<T>(inference: T) -> Self
27 where
28 T: Inference<Input = I, Output = O> + Send + Sync + 'static,
29 {
30 Self {
31 inner: Arc::new(inference),
32 }
33 }
34}
35
36impl<T, I, O> From<T> for InferenceWrapper<I, O>
37where
38 T: Inference<Input = I, Output = O> + Send + Sync + 'static,
39{
40 fn from(inference: T) -> Self {
41 Self::new(inference)
42 }
43}
44
45impl<I, O> InferenceWrapper<I, O> {
46 pub fn infer<T: InferenceWriterChannel<O> + 'static>(&self, input: I, writer: T) {
47 self.inner
48 .infer(input, InferenceWriter::from_channel(writer));
49 }
50}