burn_central_runtime/inference/
core.rs1use super::context::InferenceContext;
2use super::context::InferenceOutput;
3use super::error::InferenceError;
4use super::init::Init;
5use super::job::JobHandle;
6use super::streaming::{CancelToken, CollectEmitter, SyncChannelEmitter};
7use crate::inference::model::ModelHost;
8use crate::input::RoutineInput;
9use crate::routine::ExecutorRoutineWrapper;
10use crate::{InferenceJob, InferenceJobBuilder, IntoRoutine, Routine, StrappedInferenceJobBuilder};
11use burn::prelude::Backend;
12use std::marker::PhantomData;
13use std::sync::{Arc, Mutex};
14
15type ArcInferenceHandler<B, M, I, O, S> =
17 Arc<dyn Routine<InferenceContext<B, M, O, S>, In = I, Out = ()>>;
18
19pub struct Inference<B: Backend, M, I, O, S = ()> {
23 pub id: String,
24 model: ModelHost<M>,
25 handler: ArcInferenceHandler<B, M, I, O, S>,
26}
27
28impl<B, M, I, O, S> Inference<B, M, I, O, S>
29where
30 B: Backend,
31 M: Send + 'static,
32 I: RoutineInput + 'static,
33 O: Send + 'static,
34 S: Send + Sync + 'static,
35{
36 pub(crate) fn new(id: String, handler: ArcInferenceHandler<B, M, I, O, S>, model: M) -> Self {
37 Self {
38 id,
39 model: ModelHost::spawn(model),
40 handler,
41 }
42 }
43
44 pub fn infer(
46 &'_ self,
47 input: I::Inner<'static>,
48 ) -> StrappedInferenceJobBuilder<'_, B, M, I, O, S, super::builder::StateMissing> {
49 StrappedInferenceJobBuilder {
50 inference: self,
51 input: InferenceJobBuilder::new(input),
52 }
53 }
54
55 pub fn run(&self, job: InferenceJob<B, I, S>) -> Result<Vec<O>, InferenceError> {
57 let collector = Arc::new(CollectEmitter::new());
58 let input = job.input;
59 let devices = job.devices;
60 let state = job.state;
61 {
62 let mut ctx = InferenceContext {
63 id: self.id.clone(),
64 devices: devices.into_iter().collect(),
65 model: self.model.accessor(),
66 emitter: collector.clone(),
67 cancel_token: CancelToken::new(),
68 state: Mutex::new(Some(state)),
69 };
70 self.handler
71 .run(input, &mut ctx)
72 .map_err(|e| InferenceError::HandlerExecutionFailed(e.into()))?;
73 }
74 let stream = Arc::try_unwrap(collector)
75 .map_err(|_| InferenceError::Unexpected("Failed to unwrap collector".to_string()))?
76 .into_inner();
77 Ok(stream)
78 }
79
80 pub fn spawn(&self, job: super::builder::InferenceJob<B, I, S>) -> JobHandle<O>
82 where
83 <I as RoutineInput>::Inner<'static>: Send,
84 {
85 let id = self.id.clone();
86 let input = job.input;
87 let devices = job.devices;
88 let state = job.state;
89 let (stream_tx, stream_rx) = crossbeam::channel::unbounded();
90 let cancel_token = CancelToken::new();
91 let mut ctx = InferenceContext {
92 id: id.clone(),
93 devices: devices.into_iter().collect(),
94 model: self.model.accessor(),
95 emitter: Arc::new(SyncChannelEmitter::new(stream_tx)),
96 cancel_token: cancel_token.clone(),
97 state: Mutex::new(Some(state)),
98 };
99 let handler = self.handler.clone();
100 let join = std::thread::spawn(move || {
101 handler
102 .run(input, &mut ctx)
103 .map_err(|e| InferenceError::HandlerExecutionFailed(e.into()))
104 });
105 JobHandle::new(id, stream_rx, cancel_token, join)
106 }
107
108 pub fn into_model(self) -> M {
110 self.model.into_model()
111 }
112}
113
114pub struct InferenceBuilder<B> {
116 phantom_data: PhantomData<B>,
117}
118
119impl<B: Backend> Default for InferenceBuilder<B> {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125impl<B: Backend> InferenceBuilder<B> {
126 pub fn new() -> Self {
128 Self {
129 phantom_data: Default::default(),
130 }
131 }
132
133 pub fn init<M, InitArgs>(
135 self,
136 args: &InitArgs,
137 device: &B::Device,
138 ) -> Result<LoadedInferenceBuilder<B, M>, M::Error>
139 where
140 M: Init<B, InitArgs>,
141 InitArgs: Send + 'static,
142 {
143 let model = M::init(args, device)?;
144 Ok(LoadedInferenceBuilder {
145 model,
146 phantom_data: Default::default(),
147 })
148 }
149
150 pub fn with_model<M>(self, model: M) -> LoadedInferenceBuilder<B, M> {
152 LoadedInferenceBuilder {
153 model,
154 phantom_data: Default::default(),
155 }
156 }
157}
158
159pub struct LoadedInferenceBuilder<B: Backend, M> {
161 model: M,
162 phantom_data: PhantomData<B>,
163}
164
165impl<B, M> LoadedInferenceBuilder<B, M>
166where
167 B: Backend,
168 M: Send + 'static,
169{
170 pub fn build<F, I, O, RO, Marker, S>(self, handler: F) -> Inference<B, M, I, O, S>
172 where
173 F: IntoRoutine<InferenceContext<B, M, O, S>, I, RO, Marker>,
174 I: RoutineInput + 'static,
175 O: Send + 'static,
176 S: Send + Sync + 'static,
177 RO: InferenceOutput<B, M, O, S> + Sync + 'static,
178 {
179 Inference::new(
180 crate::type_name::fn_type_name::<F>(),
181 Arc::new(ExecutorRoutineWrapper::new(IntoRoutine::into_routine(
182 handler,
183 ))),
184 self.model,
185 )
186 }
187}