limen_core/compute.rs
1//! Compute backend and model traits (dyn-free, explicit; no defaults).
2//!
3//! Backends implement `ComputeBackend<InP, OutP>` and return a concrete `Model`
4//! that implements `ComputeModel<InP, OutP>`. The hot path is `infer_one`,
5//! which performs exactly one synchronous inference step for a single input.
6//!
7//! Backends are monomorphized by generics (no dynamic dispatch). The model API
8//! is intentionally decoupled from graph/queue details so nodes own batching,
9//! backpressure, and telemetry.
10
11use crate::errors::InferenceError;
12use crate::memory::MemoryClass;
13use crate::message::payload::Payload;
14use crate::prelude::Batch;
15
16/// Capability descriptor of a compute backend.
17#[non_exhaustive]
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub struct BackendCapabilities {
20 /// Whether the backend supports device streams (async/event completion).
21 device_streams: bool,
22 /// Maximum supported batch size, if any.
23 max_batch: Option<usize>,
24 /// Bitfield for supported data types (backend-defined; optional use).
25 dtype_mask: u64,
26}
27
28impl BackendCapabilities {
29 /// Create a new `BackendCapabilities`.
30 #[inline]
31 pub fn new(device_streams: bool, max_batch: Option<usize>, dtype_mask: u64) -> Self {
32 Self {
33 device_streams,
34 max_batch,
35 dtype_mask,
36 }
37 }
38
39 /// Whether the backend supports device streams (async/event completion).
40 #[inline]
41 pub fn device_streams(&self) -> &bool {
42 &self.device_streams
43 }
44
45 /// Maximum supported batch size, if any.
46 #[inline]
47 pub fn max_batch(&self) -> &Option<usize> {
48 &self.max_batch
49 }
50
51 /// Bitfield for supported data types (backend-defined; optional use).
52 #[inline]
53 pub fn dtype_mask(&self) -> &u64 {
54 &self.dtype_mask
55 }
56}
57
58/// Model metadata describing input/output shapes and preferences.
59#[non_exhaustive]
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub struct ModelMetadata {
62 /// Preferred input memory class (Host/Pinned/Device).
63 preferred_input: MemoryClass,
64 /// Preferred output memory class.
65 preferred_output: MemoryClass,
66 /// Optional maximum input size in bytes (admission hint).
67 max_input_bytes: Option<usize>,
68 /// Optional maximum output size in bytes.
69 max_output_bytes: Option<usize>,
70}
71
72impl ModelMetadata {
73 /// Create a new `ModelMetadata`.
74 #[inline]
75 pub fn new(
76 preferred_input: MemoryClass,
77 preferred_output: MemoryClass,
78 max_input_bytes: Option<usize>,
79 max_output_bytes: Option<usize>,
80 ) -> Self {
81 Self {
82 preferred_input,
83 preferred_output,
84 max_input_bytes,
85 max_output_bytes,
86 }
87 }
88
89 /// Preferred input memory class (Host/Pinned/Device).
90 #[inline]
91 pub fn preferred_input(&self) -> &MemoryClass {
92 &self.preferred_input
93 }
94
95 /// Preferred output memory class.
96 #[inline]
97 pub fn preferred_output(&self) -> &MemoryClass {
98 &self.preferred_output
99 }
100
101 /// Optional maximum input size in bytes (admission hint).
102 #[inline]
103 pub fn max_input_bytes(&self) -> &Option<usize> {
104 &self.max_input_bytes
105 }
106
107 /// Optional maximum output size in bytes.
108 #[inline]
109 pub fn max_output_bytes(&self) -> &Option<usize> {
110 &self.max_output_bytes
111 }
112}
113
114/// A loaded model that can perform inference.
115pub trait ComputeModel<InP: Payload, OutP: Payload> {
116 /// Prepare internal state (allocate work buffers, compile kernels, etc.).
117 fn init(&mut self) -> Result<(), InferenceError>;
118
119 /// Single-item inference (1×1).
120 fn infer_one(&mut self, inp: &InP, out: &mut OutP) -> Result<(), InferenceError>;
121
122 /// Optional: batched inference. Default loops `infer_one`.
123 #[inline]
124 fn infer_batch(
125 &mut self,
126 inps: Batch<'_, InP>,
127 outs: &mut [OutP],
128 ) -> Result<(), InferenceError> {
129 // Default: call infer_one using references into the Batch's messages.
130 for (m, o) in inps.messages().iter().zip(outs.iter_mut()) {
131 // Message::payload() returns &InP so we pass a reference, no move/clones.
132 self.infer_one(m.payload(), o)?;
133 }
134 Ok(())
135 }
136
137 /// Ensure outstanding device work is complete (if any).
138 fn drain(&mut self) -> Result<(), InferenceError>;
139
140 /// Reset internal state to a known baseline (drop caches, etc.).
141 fn reset(&mut self) -> Result<(), InferenceError>;
142
143 /// Return model metadata (I/O placement preferences, limits).
144 fn metadata(&self) -> ModelMetadata;
145}
146
147/// A dyn-free engine that constructs models and reports capabilities.
148pub trait ComputeBackend<InP: Payload, OutP: Payload> {
149 /// Concrete model type (no trait objects).
150 type Model: ComputeModel<InP, OutP>;
151
152 /// Backend-specific error.
153 type Error;
154
155 /// Backend-chosen borrowed descriptor used to load a model.
156 ///
157 /// Examples:
158 /// - on `std`: `type ModelDescriptor<'desc> = &'desc ModelArtifact;`
159 /// - on `no_std`: `type ModelDescriptor<'desc> = &'desc [u8];`
160 type ModelDescriptor<'desc>
161 where
162 Self: 'desc;
163
164 /// Capability report.
165 fn capabilities(&self) -> BackendCapabilities;
166
167 /// Load a model from a descriptor.
168 fn load_model<'desc>(
169 &self,
170 desc: Self::ModelDescriptor<'desc>,
171 ) -> Result<Self::Model, Self::Error>;
172}
173
174/// A simple artifact passed to backends for model creation (POC-friendly).
175#[cfg(feature = "std")]
176#[non_exhaustive]
177#[derive(Debug, Clone)]
178pub struct ModelArtifact {
179 /// Raw bytes of a model file or an engine-specific blob.
180 bytes: std::sync::Arc<Vec<u8>>,
181 /// Optional label or path hint.
182 label: Option<String>,
183}
184
185#[cfg(feature = "std")]
186impl ModelArtifact {
187 /// Construct from an Arc of bytes and an optional label.
188 #[inline]
189 pub fn new(bytes: std::sync::Arc<Vec<u8>>, label: Option<String>) -> Self {
190 Self { bytes, label }
191 }
192
193 /// Construct from raw bytes.
194 pub fn from_bytes(bytes: Vec<u8>) -> Self {
195 Self {
196 bytes: std::sync::Arc::new(bytes),
197 label: None,
198 }
199 }
200
201 /// Convenience: load from a file path.
202 pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<Self> {
203 let bytes = std::fs::read(path)?;
204 Ok(Self::from_bytes(bytes))
205 }
206
207 /// Access the bytes (cloned Arc).
208 #[inline]
209 pub fn bytes(&self) -> std::sync::Arc<Vec<u8>> {
210 self.bytes.clone()
211 }
212
213 /// Access the optional label as `Option<&str>`.
214 #[inline]
215 pub fn label(&self) -> Option<&str> {
216 self.label.as_deref()
217 }
218}