burn_central_runtime/inference/
core.rs

1use super::context::InferenceContext;
2use super::context::InferenceOutput;
3use super::error::InferenceError;
4use super::init::Init;
5use super::job::JobHandle;
6use super::streaming::{CancelToken, CollectEmitter, SyncChannelEmitter};
7use crate::inference::model::ModelHost;
8use crate::input::RoutineInput;
9use crate::routine::ExecutorRoutineWrapper;
10use crate::{InferenceJob, InferenceJobBuilder, IntoRoutine, Routine, StrappedInferenceJobBuilder};
11use burn::prelude::Backend;
12use std::marker::PhantomData;
13use std::sync::{Arc, Mutex};
14
15/// Internal type alias for a routine trait object representing the user supplied inference handler.
16type ArcInferenceHandler<B, M, I, O, S> =
17    Arc<dyn Routine<InferenceContext<B, M, O, S>, In = I, Out = ()>>;
18
19/// Inference instance wrapping a single model and a handler routine.
20///
21/// An `Inference` can create multiple jobs (sequentially or concurrently) without re-loading the model.
22pub struct Inference<B: Backend, M, I, O, S = ()> {
23    pub id: String,
24    model: ModelHost<M>,
25    handler: ArcInferenceHandler<B, M, I, O, S>,
26}
27
28impl<B, M, I, O, S> Inference<B, M, I, O, S>
29where
30    B: Backend,
31    M: Send + 'static,
32    I: RoutineInput + 'static,
33    O: Send + 'static,
34    S: Send + Sync + 'static,
35{
36    pub(crate) fn new(id: String, handler: ArcInferenceHandler<B, M, I, O, S>, model: M) -> Self {
37        Self {
38            id,
39            model: ModelHost::spawn(model),
40            handler,
41        }
42    }
43
44    /// Start building an inference job for the given input payload.
45    pub fn infer(
46        &'_ self,
47        input: I::Inner<'static>,
48    ) -> StrappedInferenceJobBuilder<'_, B, M, I, O, S, super::builder::StateMissing> {
49        StrappedInferenceJobBuilder {
50            inference: self,
51            input: InferenceJobBuilder::new(input),
52        }
53    }
54
55    /// Execute the provided job synchronously and collect all emitted outputs.
56    pub fn run(&self, job: InferenceJob<B, I, S>) -> Result<Vec<O>, InferenceError> {
57        let collector = Arc::new(CollectEmitter::new());
58        let input = job.input;
59        let devices = job.devices;
60        let state = job.state;
61        {
62            let mut ctx = InferenceContext {
63                id: self.id.clone(),
64                devices: devices.into_iter().collect(),
65                model: self.model.accessor(),
66                emitter: collector.clone(),
67                cancel_token: CancelToken::new(),
68                state: Mutex::new(Some(state)),
69            };
70            self.handler
71                .run(input, &mut ctx)
72                .map_err(|e| InferenceError::HandlerExecutionFailed(e.into()))?;
73        }
74        let stream = Arc::try_unwrap(collector)
75            .map_err(|_| InferenceError::Unexpected("Failed to unwrap collector".to_string()))?
76            .into_inner();
77        Ok(stream)
78    }
79
80    /// Spawn the job on a background thread returning a [`JobHandle`]. Outputs can be read from the handle's stream.
81    pub fn spawn(&self, job: super::builder::InferenceJob<B, I, S>) -> JobHandle<O>
82    where
83        <I as RoutineInput>::Inner<'static>: Send,
84    {
85        let id = self.id.clone();
86        let input = job.input;
87        let devices = job.devices;
88        let state = job.state;
89        let (stream_tx, stream_rx) = crossbeam::channel::unbounded();
90        let cancel_token = CancelToken::new();
91        let mut ctx = InferenceContext {
92            id: id.clone(),
93            devices: devices.into_iter().collect(),
94            model: self.model.accessor(),
95            emitter: Arc::new(SyncChannelEmitter::new(stream_tx)),
96            cancel_token: cancel_token.clone(),
97            state: Mutex::new(Some(state)),
98        };
99        let handler = self.handler.clone();
100        let join = std::thread::spawn(move || {
101            handler
102                .run(input, &mut ctx)
103                .map_err(|e| InferenceError::HandlerExecutionFailed(e.into()))
104        });
105        JobHandle::new(id, stream_rx, cancel_token, join)
106    }
107
108    /// Consume the inference instance and retrieve ownership of the underlying model.
109    pub fn into_model(self) -> M {
110        self.model.into_model()
111    }
112}
113
114/// Entry point builder for an [`Inference`] instance.
115pub struct InferenceBuilder<B> {
116    phantom_data: PhantomData<B>,
117}
118
119impl<B: Backend> Default for InferenceBuilder<B> {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125impl<B: Backend> InferenceBuilder<B> {
126    /// Create a new inference builder.
127    pub fn new() -> Self {
128        Self {
129            phantom_data: Default::default(),
130        }
131    }
132
133    /// Initialize a model implementing [`Init`] from user artifacts / arguments + target device.
134    pub fn init<M, InitArgs>(
135        self,
136        args: &InitArgs,
137        device: &B::Device,
138    ) -> Result<LoadedInferenceBuilder<B, M>, M::Error>
139    where
140        M: Init<B, InitArgs>,
141        InitArgs: Send + 'static,
142    {
143        let model = M::init(args, device)?;
144        Ok(LoadedInferenceBuilder {
145            model,
146            phantom_data: Default::default(),
147        })
148    }
149
150    /// Provide an already constructed model instance (skips the [`Init`] flow).
151    pub fn with_model<M>(self, model: M) -> LoadedInferenceBuilder<B, M> {
152        LoadedInferenceBuilder {
153            model,
154            phantom_data: Default::default(),
155        }
156    }
157}
158
159/// Builder returned after a model has been loaded or supplied ready for registering a handler.
160pub struct LoadedInferenceBuilder<B: Backend, M> {
161    model: M,
162    phantom_data: PhantomData<B>,
163}
164
165impl<B, M> LoadedInferenceBuilder<B, M>
166where
167    B: Backend,
168    M: Send + 'static,
169{
170    /// Finalize the construction of an [`Inference`] by supplying a handler routine implementation.
171    pub fn build<F, I, O, RO, Marker, S>(self, handler: F) -> Inference<B, M, I, O, S>
172    where
173        F: IntoRoutine<InferenceContext<B, M, O, S>, I, RO, Marker>,
174        I: RoutineInput + 'static,
175        O: Send + 'static,
176        S: Send + Sync + 'static,
177        RO: InferenceOutput<B, M, O, S> + Sync + 'static,
178    {
179        Inference::new(
180            crate::type_name::fn_type_name::<F>(),
181            Arc::new(ExecutorRoutineWrapper::new(IntoRoutine::into_routine(
182                handler,
183            ))),
184            self.model,
185        )
186    }
187}