Skip to main content

limen_core/node/
model.rs

1//! Inference node (1×in → 1×out) that runs a generic `ComputeBackend` model.
2//!
3//! # Design
4//! - **No dynamic dispatch**: backend and model are monomorphized by generics.
5//! - **No `unsafe`** in the hot path.
6//! - **Batching**:
7//!   - `no_std` / no-`alloc`: stack-bounded batching up to `MAX_BATCH`.
8//!   - `alloc`: uses `Vec` for flexible batch sizing.
9//! - **Queues/telemetry** are accessed only via `StepContext`.
10//! - **Zero-copy** preferences are expressed through `PlacementAcceptance`.
11//!
12//! This node delegates inference to the model (`infer_one` / `infer_batch`), and
13//! pushes outputs directly to the provided output edge. It never copies unless
14//! required by payload semantics or batch buffering.
15
16use crate::compute::{BackendCapabilities, ComputeBackend, ComputeModel, ModelMetadata};
17use crate::edge::Edge;
18use crate::errors::{InferenceError, NodeError};
19use crate::memory::PlacementAcceptance;
20use crate::message::{payload::Payload, Message};
21use crate::node::{Node, NodeCapabilities, NodeKind, ProcessResult, StepContext, StepResult};
22use crate::policy::NodePolicy;
23use crate::prelude::{MemoryManager, PlatformClock, Telemetry};
24
25// --- local helpers: map backend/queue errors into NodeError (no From impls required)
26#[inline]
27fn map_inference_err(e: InferenceError) -> NodeError {
28    NodeError::execution_failed().with_code(*e.code())
29}
30
31/// Generic 1×1 inference node for any backend (dyn-free).
32///
33/// - `MAX_BATCH` is a compile-time cap used for the no-alloc path.
34/// - When `alloc` is enabled, the batched path uses `Vec` (still no unsafe).
35pub struct InferenceModel<B, InP, OutP, const MAX_BATCH: usize>
36where
37    B: ComputeBackend<InP, OutP>,
38    InP: Payload,
39    OutP: Payload + Default + Copy,
40{
41    /// Backend instance used solely for model creation and capability query.
42    /// Kept to preserve type ownership and avoid dynamic dispatch.
43    #[allow(dead_code)]
44    backend: B,
45    /// The loaded model instance that performs inference.
46    model: B::Model,
47    /// Backend capabilities snapshot (e.g., max batch size, streams).
48    backend_caps: BackendCapabilities,
49    /// Model metadata snapshot (I/O placement, size hints).
50    model_meta: ModelMetadata,
51
52    /// Declared capabilities of this node (streams, degrade tiers).
53    node_caps: NodeCapabilities,
54    /// Node policy bundle (batching, budget, deadlines).
55    node_policy: NodePolicy,
56    /// Zero-copy placement acceptance for the input port.
57    input_acceptance: [PlacementAcceptance; 1],
58    /// Zero-copy placement acceptance for the output port.
59    output_acceptance: [PlacementAcceptance; 1],
60
61    /// Reusable output for the 1× fast path (constructed once).
62    scratch_out: OutP,
63
64    _pd: core::marker::PhantomData<InP>,
65}
66
67impl<B, InP, OutP, const MAX_BATCH: usize> InferenceModel<B, InP, OutP, MAX_BATCH>
68where
69    B: ComputeBackend<InP, OutP>,
70    InP: Payload,
71    OutP: Payload + Default + Copy,
72{
73    /// Construct a new `InferenceModel` node.
74    ///
75    /// - `backend`: concrete compute backend (e.g., Tract, TFLM adapter).
76    /// - `desc`: backend-specific, borrowed model descriptor (e.g., bytes, artifact).
77    /// - `node_policy`: batching/budget/deadline policies for the node.
78    /// - `node_caps`: advertised capabilities (e.g., device streams).
79    /// - `input_acceptance` / `output_acceptance`: zero-copy placement preferences.
80    pub fn new<'desc>(
81        backend: B,
82        desc: B::ModelDescriptor<'desc>,
83        node_policy: NodePolicy,
84        node_caps: NodeCapabilities,
85        input_acceptance: [PlacementAcceptance; 1],
86        output_acceptance: [PlacementAcceptance; 1],
87    ) -> Result<Self, B::Error> {
88        let backend_caps = backend.capabilities();
89        let model = backend.load_model(desc)?;
90        let model_meta = model.metadata();
91
92        Ok(Self {
93            backend,
94            model,
95            backend_caps,
96            model_meta,
97            node_caps,
98            node_policy,
99            input_acceptance,
100            output_acceptance,
101            scratch_out: OutP::default(),
102            _pd: core::marker::PhantomData,
103        })
104    }
105
106    /// Return cached backend capabilities for this node.
107    #[inline]
108    pub fn backend_capabilities(&self) -> BackendCapabilities {
109        self.backend_caps
110    }
111
112    /// Return cached model metadata for this node.
113    #[inline]
114    pub fn model_metadata(&self) -> ModelMetadata {
115        self.model_meta
116    }
117}
118
119impl<B, InP, OutP, const MAX_BATCH: usize> Node<1, 1, InP, OutP>
120    for InferenceModel<B, InP, OutP, MAX_BATCH>
121where
122    B: ComputeBackend<InP, OutP>,
123    InP: Payload + Default + Copy,
124    OutP: Payload + Default + Copy,
125{
126    #[inline]
127    fn describe_capabilities(&self) -> NodeCapabilities {
128        self.node_caps
129    }
130
131    #[inline]
132    fn input_acceptance(&self) -> [PlacementAcceptance; 1] {
133        self.input_acceptance
134    }
135
136    #[inline]
137    fn output_acceptance(&self) -> [PlacementAcceptance; 1] {
138        self.output_acceptance
139    }
140
141    #[inline]
142    fn policy(&self) -> NodePolicy {
143        self.node_policy
144    }
145
146    /// **TEST ONLY** method used to override batching policis for node contract tests.
147    #[cfg(any(test, feature = "bench"))]
148    fn set_policy(&mut self, policy: NodePolicy) {
149        self.node_policy = policy;
150    }
151
152    #[inline]
153    fn node_kind(&self) -> NodeKind {
154        NodeKind::Model
155    }
156
157    #[inline]
158    fn initialize<C, T>(&mut self, _clock: &C, _telemetry: &mut T) -> Result<(), NodeError>
159    where
160        T: Telemetry,
161    {
162        Ok(())
163    }
164
165    #[inline]
166    fn start<C, T>(&mut self, _clock: &C, _telemetry: &mut T) -> Result<(), NodeError>
167    where
168        T: Telemetry,
169    {
170        self.model.init().map_err(map_inference_err)
171    }
172
173    #[inline]
174    fn process_message<C>(
175        &mut self,
176        msg: &Message<InP>,
177        _sys_clock: &C,
178    ) -> Result<ProcessResult<OutP>, NodeError>
179    where
180        C: PlatformClock + Sized,
181    {
182        // Run single-item inference into the reusable scratch output.
183        let inp: &InP = msg.payload();
184        self.model
185            .infer_one(inp, &mut self.scratch_out)
186            .map_err(map_inference_err)?;
187
188        // Build output message reusing header from input.
189        let hdr = *msg.header();
190        let out_msg = Message::new(hdr, core::mem::take(&mut self.scratch_out));
191
192        Ok(ProcessResult::Output(out_msg))
193    }
194
195    #[inline]
196    fn step<'g, 't, 'c, InQ, OutQ, InM, OutM, C, Tel>(
197        &mut self,
198        ctx: &mut StepContext<'g, 't, 'c, 1, 1, InP, OutP, InQ, OutQ, InM, OutM, C, Tel>,
199    ) -> Result<StepResult, NodeError>
200    where
201        InQ: Edge,
202        OutQ: Edge,
203        InM: MemoryManager<InP>,
204        OutM: MemoryManager<OutP>,
205        C: PlatformClock + Sized,
206        Tel: Telemetry + Sized,
207    {
208        ctx.pop_and_process(0, |msg| self.process_message(msg, ctx.clock))
209    }
210
211    #[inline]
212    fn step_batch<'g, 't, 'c, InQ, OutQ, InM, OutM, C, Tel>(
213        &mut self,
214        ctx: &mut StepContext<'g, 't, 'c, 1, 1, InP, OutP, InQ, OutQ, InM, OutM, C, Tel>,
215    ) -> Result<StepResult, NodeError>
216    where
217        InQ: Edge,
218        OutQ: Edge,
219        InM: MemoryManager<InP>,
220        OutM: MemoryManager<OutP>,
221        C: PlatformClock + Sized,
222        Tel: Telemetry + Sized,
223    {
224        let want = self.node_policy.batching().fixed_n().unwrap_or(1);
225        let backend_cap = self.backend_caps.max_batch().unwrap_or(usize::MAX);
226        let nmax = core::cmp::min(core::cmp::min(want, backend_cap), MAX_BATCH);
227
228        if nmax <= 1 {
229            return self.step(ctx);
230        }
231
232        let node_policy = self.node_policy;
233        let clock = ctx.clock;
234
235        ctx.pop_batch_and_process(0, nmax, &node_policy, |msg| {
236            self.process_message(msg, clock)
237        })
238    }
239
240    #[inline]
241    fn on_watchdog_timeout<C, Tel>(
242        &mut self,
243        clock: &C,
244        _telemetry: &mut Tel,
245    ) -> Result<StepResult, NodeError>
246    where
247        C: PlatformClock + Sized,
248        Tel: Telemetry,
249    {
250        if let Some(backoff) = self.node_policy.budget().watchdog_ticks() {
251            let until = clock.now_ticks().saturating_add(*backoff);
252            Ok(StepResult::YieldUntil(until))
253        } else {
254            Ok(StepResult::YieldUntil(clock.now_ticks()))
255        }
256    }
257
258    #[inline]
259    fn stop<C, Tel>(&mut self, _clock: &C, _telemetry: &mut Tel) -> Result<(), NodeError>
260    where
261        Tel: Telemetry,
262    {
263        self.model.drain().map_err(map_inference_err)?;
264        self.model.reset().map_err(map_inference_err)
265    }
266}