1use crate::compute::{BackendCapabilities, ComputeBackend, ComputeModel, ModelMetadata};
17use crate::edge::Edge;
18use crate::errors::{InferenceError, NodeError};
19use crate::memory::PlacementAcceptance;
20use crate::message::{payload::Payload, Message};
21use crate::node::{Node, NodeCapabilities, NodeKind, ProcessResult, StepContext, StepResult};
22use crate::policy::NodePolicy;
23use crate::prelude::{MemoryManager, PlatformClock, Telemetry};
24
25#[inline]
27fn map_inference_err(e: InferenceError) -> NodeError {
28 NodeError::execution_failed().with_code(*e.code())
29}
30
31pub struct InferenceModel<B, InP, OutP, const MAX_BATCH: usize>
36where
37 B: ComputeBackend<InP, OutP>,
38 InP: Payload,
39 OutP: Payload + Default + Copy,
40{
41 #[allow(dead_code)]
44 backend: B,
45 model: B::Model,
47 backend_caps: BackendCapabilities,
49 model_meta: ModelMetadata,
51
52 node_caps: NodeCapabilities,
54 node_policy: NodePolicy,
56 input_acceptance: [PlacementAcceptance; 1],
58 output_acceptance: [PlacementAcceptance; 1],
60
61 scratch_out: OutP,
63
64 _pd: core::marker::PhantomData<InP>,
65}
66
67impl<B, InP, OutP, const MAX_BATCH: usize> InferenceModel<B, InP, OutP, MAX_BATCH>
68where
69 B: ComputeBackend<InP, OutP>,
70 InP: Payload,
71 OutP: Payload + Default + Copy,
72{
73 pub fn new<'desc>(
81 backend: B,
82 desc: B::ModelDescriptor<'desc>,
83 node_policy: NodePolicy,
84 node_caps: NodeCapabilities,
85 input_acceptance: [PlacementAcceptance; 1],
86 output_acceptance: [PlacementAcceptance; 1],
87 ) -> Result<Self, B::Error> {
88 let backend_caps = backend.capabilities();
89 let model = backend.load_model(desc)?;
90 let model_meta = model.metadata();
91
92 Ok(Self {
93 backend,
94 model,
95 backend_caps,
96 model_meta,
97 node_caps,
98 node_policy,
99 input_acceptance,
100 output_acceptance,
101 scratch_out: OutP::default(),
102 _pd: core::marker::PhantomData,
103 })
104 }
105
106 #[inline]
108 pub fn backend_capabilities(&self) -> BackendCapabilities {
109 self.backend_caps
110 }
111
112 #[inline]
114 pub fn model_metadata(&self) -> ModelMetadata {
115 self.model_meta
116 }
117}
118
119impl<B, InP, OutP, const MAX_BATCH: usize> Node<1, 1, InP, OutP>
120 for InferenceModel<B, InP, OutP, MAX_BATCH>
121where
122 B: ComputeBackend<InP, OutP>,
123 InP: Payload + Default + Copy,
124 OutP: Payload + Default + Copy,
125{
126 #[inline]
127 fn describe_capabilities(&self) -> NodeCapabilities {
128 self.node_caps
129 }
130
131 #[inline]
132 fn input_acceptance(&self) -> [PlacementAcceptance; 1] {
133 self.input_acceptance
134 }
135
136 #[inline]
137 fn output_acceptance(&self) -> [PlacementAcceptance; 1] {
138 self.output_acceptance
139 }
140
141 #[inline]
142 fn policy(&self) -> NodePolicy {
143 self.node_policy
144 }
145
146 #[cfg(any(test, feature = "bench"))]
148 fn set_policy(&mut self, policy: NodePolicy) {
149 self.node_policy = policy;
150 }
151
152 #[inline]
153 fn node_kind(&self) -> NodeKind {
154 NodeKind::Model
155 }
156
157 #[inline]
158 fn initialize<C, T>(&mut self, _clock: &C, _telemetry: &mut T) -> Result<(), NodeError>
159 where
160 T: Telemetry,
161 {
162 Ok(())
163 }
164
165 #[inline]
166 fn start<C, T>(&mut self, _clock: &C, _telemetry: &mut T) -> Result<(), NodeError>
167 where
168 T: Telemetry,
169 {
170 self.model.init().map_err(map_inference_err)
171 }
172
173 #[inline]
174 fn process_message<C>(
175 &mut self,
176 msg: &Message<InP>,
177 _sys_clock: &C,
178 ) -> Result<ProcessResult<OutP>, NodeError>
179 where
180 C: PlatformClock + Sized,
181 {
182 let inp: &InP = msg.payload();
184 self.model
185 .infer_one(inp, &mut self.scratch_out)
186 .map_err(map_inference_err)?;
187
188 let hdr = *msg.header();
190 let out_msg = Message::new(hdr, core::mem::take(&mut self.scratch_out));
191
192 Ok(ProcessResult::Output(out_msg))
193 }
194
195 #[inline]
196 fn step<'g, 't, 'c, InQ, OutQ, InM, OutM, C, Tel>(
197 &mut self,
198 ctx: &mut StepContext<'g, 't, 'c, 1, 1, InP, OutP, InQ, OutQ, InM, OutM, C, Tel>,
199 ) -> Result<StepResult, NodeError>
200 where
201 InQ: Edge,
202 OutQ: Edge,
203 InM: MemoryManager<InP>,
204 OutM: MemoryManager<OutP>,
205 C: PlatformClock + Sized,
206 Tel: Telemetry + Sized,
207 {
208 ctx.pop_and_process(0, |msg| self.process_message(msg, ctx.clock))
209 }
210
211 #[inline]
212 fn step_batch<'g, 't, 'c, InQ, OutQ, InM, OutM, C, Tel>(
213 &mut self,
214 ctx: &mut StepContext<'g, 't, 'c, 1, 1, InP, OutP, InQ, OutQ, InM, OutM, C, Tel>,
215 ) -> Result<StepResult, NodeError>
216 where
217 InQ: Edge,
218 OutQ: Edge,
219 InM: MemoryManager<InP>,
220 OutM: MemoryManager<OutP>,
221 C: PlatformClock + Sized,
222 Tel: Telemetry + Sized,
223 {
224 let want = self.node_policy.batching().fixed_n().unwrap_or(1);
225 let backend_cap = self.backend_caps.max_batch().unwrap_or(usize::MAX);
226 let nmax = core::cmp::min(core::cmp::min(want, backend_cap), MAX_BATCH);
227
228 if nmax <= 1 {
229 return self.step(ctx);
230 }
231
232 let node_policy = self.node_policy;
233 let clock = ctx.clock;
234
235 ctx.pop_batch_and_process(0, nmax, &node_policy, |msg| {
236 self.process_message(msg, clock)
237 })
238 }
239
240 #[inline]
241 fn on_watchdog_timeout<C, Tel>(
242 &mut self,
243 clock: &C,
244 _telemetry: &mut Tel,
245 ) -> Result<StepResult, NodeError>
246 where
247 C: PlatformClock + Sized,
248 Tel: Telemetry,
249 {
250 if let Some(backoff) = self.node_policy.budget().watchdog_ticks() {
251 let until = clock.now_ticks().saturating_add(*backoff);
252 Ok(StepResult::YieldUntil(until))
253 } else {
254 Ok(StepResult::YieldUntil(clock.now_ticks()))
255 }
256 }
257
258 #[inline]
259 fn stop<C, Tel>(&mut self, _clock: &C, _telemetry: &mut Tel) -> Result<(), NodeError>
260 where
261 Tel: Telemetry,
262 {
263 self.model.drain().map_err(map_inference_err)?;
264 self.model.reset().map_err(map_inference_err)
265 }
266}