burn_central_runtime/inference/
builder.rs

1use super::core::Inference;
2use super::job::JobHandle;
3use crate::input::RoutineInput;
4use burn::prelude::Backend;
5use std::marker::PhantomData;
6
7/// Builder returned by [`Inference::infer`] allowing configuration (devices, state) before
8/// executing the job via [`StrappedInferenceJobBuilder::run`] or spawning via [`StrappedInferenceJobBuilder::spawn`].
9pub struct StrappedInferenceJobBuilder<'a, B: Backend, M, I: RoutineInput, O, S, Flag> {
10    pub(crate) inference: &'a Inference<B, M, I, O, S>,
11    pub(crate) input: InferenceJobBuilder<B, I, S, Flag>,
12}
13
14impl<'a, B, M, I, O, S, Flag> StrappedInferenceJobBuilder<'a, B, M, I, O, S, Flag>
15where
16    B: Backend,
17    M: Send + 'static,
18    I: RoutineInput + 'static,
19    O: Send + 'static,
20    S: Send + Sync + 'static,
21{
22    /// Specify the devices to be exposed to the handler (order preserved; first often primary).
23    pub fn with_devices(mut self, devices: impl IntoIterator<Item = B::Device>) -> Self {
24        self.input = self.input.with_devices(devices);
25        self
26    }
27}
28
29impl<'a, B, M, I, O, S> StrappedInferenceJobBuilder<'a, B, M, I, O, S, StateMissing>
30where
31    B: Backend,
32    M: Send + 'static,
33    I: RoutineInput + 'static,
34    O: Send + 'static,
35    S: Send + Sync + 'static,
36{
37    /// Provide state to the handler. Consumed exactly once during handler execution.
38    pub fn with_state(
39        self,
40        state: S,
41    ) -> StrappedInferenceJobBuilder<'a, B, M, I, O, S, StateProvided> {
42        StrappedInferenceJobBuilder {
43            inference: self.inference,
44            input: self.input.with_state(state),
45        }
46    }
47}
48
49/// Internal job builder accumulating input + devices + optional state before conversion to an executable job.
50pub struct InferenceJobBuilder<B: Backend, I: RoutineInput, S, Flag> {
51    pub(crate) input: <I as RoutineInput>::Inner<'static>,
52    pub(crate) devices: Vec<B::Device>,
53    pub(crate) state: Option<S>,
54    _flag: PhantomData<Flag>,
55}
56
57impl<B, I, S, Flag> InferenceJobBuilder<B, I, S, Flag>
58where
59    B: Backend,
60    I: RoutineInput + 'static,
61    S: Send + Sync + 'static,
62{
63    /// Create a new job builder with the provided routine input payload.
64    pub fn new(input: <I as RoutineInput>::Inner<'static>) -> Self {
65        Self {
66            input,
67            devices: Vec::new(),
68            state: None,
69            _flag: PhantomData,
70        }
71    }
72
73    /// Set the devices collection for this job.
74    pub fn with_devices(mut self, devices: impl IntoIterator<Item = B::Device>) -> Self {
75        self.devices = devices.into_iter().collect();
76        self
77    }
78}
79
80/// Marker type indicating the job state has not been supplied.
81pub struct StateMissing;
82/// Marker type indicating the job state has been supplied.
83pub struct StateProvided;
84
85impl<B, I, S> InferenceJobBuilder<B, I, S, StateMissing>
86where
87    B: Backend,
88    I: RoutineInput + 'static,
89    S: Send + Sync + 'static,
90{
91    /// Attach state to the job; transitions the builder into the `StateProvided` phase.
92    pub fn with_state(self, state: S) -> InferenceJobBuilder<B, I, S, StateProvided> {
93        InferenceJobBuilder {
94            input: self.input,
95            devices: self.devices,
96            state: Some(state),
97            _flag: PhantomData,
98        }
99    }
100}
101
102impl<B, I, S> InferenceJobBuilder<B, I, S, StateProvided>
103where
104    B: Backend,
105    I: RoutineInput + 'static,
106    S: Send + Sync + 'static,
107{
108    /// Finalize the builder into an [`InferenceJob`]. Panics if state missing (by design of type-state pattern).
109    pub fn build(self) -> InferenceJob<B, I, S> {
110        InferenceJob {
111            input: self.input,
112            devices: self.devices,
113            state: self.state.expect("state must be set"),
114        }
115    }
116}
117
118impl<'a, B, M, I, O> StrappedInferenceJobBuilder<'a, B, M, I, O, (), StateMissing>
119where
120    B: Backend,
121    M: Send + 'static,
122    I: RoutineInput + 'static,
123    O: Send + 'static,
124{
125    /// Spawn the inference job on a background thread returning a [`JobHandle`].
126    pub fn spawn(self) -> JobHandle<O>
127    where
128        <I as RoutineInput>::Inner<'static>: Send,
129    {
130        let job = InferenceJob {
131            input: self.input.input,
132            devices: self.input.devices,
133            state: (),
134        };
135        self.inference.spawn(job)
136    }
137
138    /// Run the inference job to completion collecting all outputs eagerly.
139    pub fn run(self) -> Result<Vec<O>, super::error::InferenceError> {
140        let job = InferenceJob {
141            input: self.input.input,
142            devices: self.input.devices,
143            state: (),
144        };
145        self.inference.run(job)
146    }
147}
148
149impl<'a, B, M, I, O, S> StrappedInferenceJobBuilder<'a, B, M, I, O, S, StateProvided>
150where
151    B: Backend,
152    M: Send + 'static,
153    I: RoutineInput + 'static,
154    O: Send + 'static,
155    S: Send + Sync + 'static,
156{
157    /// Spawn the inference job with provided user state.
158    pub fn spawn(self) -> JobHandle<O>
159    where
160        <I as RoutineInput>::Inner<'static>: Send,
161    {
162        let job = InferenceJob {
163            input: self.input.input,
164            devices: self.input.devices,
165            state: self.input.state.expect("state must be set"),
166        };
167        self.inference.spawn(job)
168    }
169
170    /// Run the inference job to completion (stateful variant) collecting all outputs.
171    pub fn run(self) -> Result<Vec<O>, super::error::InferenceError> {
172        let job = InferenceJob {
173            input: self.input.input,
174            devices: self.input.devices,
175            state: self.input.state.expect("state must be set"),
176        };
177        self.inference.run(job)
178    }
179}
180
181/// Concrete job containing fully specified execution parameters passed to the runtime.
182pub struct InferenceJob<B: Backend, I: RoutineInput, S> {
183    pub(crate) input: <I as RoutineInput>::Inner<'static>,
184    pub(crate) devices: Vec<B::Device>,
185    pub(crate) state: S,
186}
187
188impl<B, I, S> InferenceJob<B, I, S>
189where
190    B: Backend,
191    I: RoutineInput + 'static,
192    S: Send + Sync + 'static,
193{
194    /// Create a new builder for an inference job for the given input payload.
195    pub fn builder(
196        input: <I as RoutineInput>::Inner<'static>,
197    ) -> InferenceJobBuilder<B, I, S, StateMissing> {
198        InferenceJobBuilder::new(input)
199    }
200}