1use super::core::Inference;
2use super::job::JobHandle;
3use crate::input::RoutineInput;
4use burn::prelude::Backend;
5use std::marker::PhantomData;
6
7pub struct StrappedInferenceJobBuilder<'a, B: Backend, M, I: RoutineInput, O, S, Flag> {
10 pub(crate) inference: &'a Inference<B, M, I, O, S>,
11 pub(crate) input: InferenceJobBuilder<B, I, S, Flag>,
12}
13
14impl<'a, B, M, I, O, S, Flag> StrappedInferenceJobBuilder<'a, B, M, I, O, S, Flag>
15where
16 B: Backend,
17 M: Send + 'static,
18 I: RoutineInput + 'static,
19 O: Send + 'static,
20 S: Send + Sync + 'static,
21{
22 pub fn with_devices(mut self, devices: impl IntoIterator<Item = B::Device>) -> Self {
24 self.input = self.input.with_devices(devices);
25 self
26 }
27}
28
29impl<'a, B, M, I, O, S> StrappedInferenceJobBuilder<'a, B, M, I, O, S, StateMissing>
30where
31 B: Backend,
32 M: Send + 'static,
33 I: RoutineInput + 'static,
34 O: Send + 'static,
35 S: Send + Sync + 'static,
36{
37 pub fn with_state(
39 self,
40 state: S,
41 ) -> StrappedInferenceJobBuilder<'a, B, M, I, O, S, StateProvided> {
42 StrappedInferenceJobBuilder {
43 inference: self.inference,
44 input: self.input.with_state(state),
45 }
46 }
47}
48
49pub struct InferenceJobBuilder<B: Backend, I: RoutineInput, S, Flag> {
51 pub(crate) input: <I as RoutineInput>::Inner<'static>,
52 pub(crate) devices: Vec<B::Device>,
53 pub(crate) state: Option<S>,
54 _flag: PhantomData<Flag>,
55}
56
57impl<B, I, S, Flag> InferenceJobBuilder<B, I, S, Flag>
58where
59 B: Backend,
60 I: RoutineInput + 'static,
61 S: Send + Sync + 'static,
62{
63 pub fn new(input: <I as RoutineInput>::Inner<'static>) -> Self {
65 Self {
66 input,
67 devices: Vec::new(),
68 state: None,
69 _flag: PhantomData,
70 }
71 }
72
73 pub fn with_devices(mut self, devices: impl IntoIterator<Item = B::Device>) -> Self {
75 self.devices = devices.into_iter().collect();
76 self
77 }
78}
79
80pub struct StateMissing;
82pub struct StateProvided;
84
85impl<B, I, S> InferenceJobBuilder<B, I, S, StateMissing>
86where
87 B: Backend,
88 I: RoutineInput + 'static,
89 S: Send + Sync + 'static,
90{
91 pub fn with_state(self, state: S) -> InferenceJobBuilder<B, I, S, StateProvided> {
93 InferenceJobBuilder {
94 input: self.input,
95 devices: self.devices,
96 state: Some(state),
97 _flag: PhantomData,
98 }
99 }
100}
101
102impl<B, I, S> InferenceJobBuilder<B, I, S, StateProvided>
103where
104 B: Backend,
105 I: RoutineInput + 'static,
106 S: Send + Sync + 'static,
107{
108 pub fn build(self) -> InferenceJob<B, I, S> {
110 InferenceJob {
111 input: self.input,
112 devices: self.devices,
113 state: self.state.expect("state must be set"),
114 }
115 }
116}
117
118impl<'a, B, M, I, O> StrappedInferenceJobBuilder<'a, B, M, I, O, (), StateMissing>
119where
120 B: Backend,
121 M: Send + 'static,
122 I: RoutineInput + 'static,
123 O: Send + 'static,
124{
125 pub fn spawn(self) -> JobHandle<O>
127 where
128 <I as RoutineInput>::Inner<'static>: Send,
129 {
130 let job = InferenceJob {
131 input: self.input.input,
132 devices: self.input.devices,
133 state: (),
134 };
135 self.inference.spawn(job)
136 }
137
138 pub fn run(self) -> Result<Vec<O>, super::error::InferenceError> {
140 let job = InferenceJob {
141 input: self.input.input,
142 devices: self.input.devices,
143 state: (),
144 };
145 self.inference.run(job)
146 }
147}
148
149impl<'a, B, M, I, O, S> StrappedInferenceJobBuilder<'a, B, M, I, O, S, StateProvided>
150where
151 B: Backend,
152 M: Send + 'static,
153 I: RoutineInput + 'static,
154 O: Send + 'static,
155 S: Send + Sync + 'static,
156{
157 pub fn spawn(self) -> JobHandle<O>
159 where
160 <I as RoutineInput>::Inner<'static>: Send,
161 {
162 let job = InferenceJob {
163 input: self.input.input,
164 devices: self.input.devices,
165 state: self.input.state.expect("state must be set"),
166 };
167 self.inference.spawn(job)
168 }
169
170 pub fn run(self) -> Result<Vec<O>, super::error::InferenceError> {
172 let job = InferenceJob {
173 input: self.input.input,
174 devices: self.input.devices,
175 state: self.input.state.expect("state must be set"),
176 };
177 self.inference.run(job)
178 }
179}
180
181pub struct InferenceJob<B: Backend, I: RoutineInput, S> {
183 pub(crate) input: <I as RoutineInput>::Inner<'static>,
184 pub(crate) devices: Vec<B::Device>,
185 pub(crate) state: S,
186}
187
188impl<B, I, S> InferenceJob<B, I, S>
189where
190 B: Backend,
191 I: RoutineInput + 'static,
192 S: Send + Sync + 'static,
193{
194 pub fn builder(
196 input: <I as RoutineInput>::Inner<'static>,
197 ) -> InferenceJobBuilder<B, I, S, StateMissing> {
198 InferenceJobBuilder::new(input)
199 }
200}