Skip to main content

limen_core/
compute.rs

1//! Compute backend and model traits (dyn-free, explicit; no defaults).
2//!
3//! Backends implement `ComputeBackend<InP, OutP>` and return a concrete `Model`
4//! that implements `ComputeModel<InP, OutP>`. The hot path is `infer_one`,
5//! which performs exactly one synchronous inference step for a single input.
6//!
7//! Backends are monomorphized by generics (no dynamic dispatch). The model API
8//! is intentionally decoupled from graph/queue details so nodes own batching,
9//! backpressure, and telemetry.
10
11use crate::errors::InferenceError;
12use crate::memory::MemoryClass;
13use crate::message::payload::Payload;
14use crate::prelude::Batch;
15
16/// Capability descriptor of a compute backend.
17#[non_exhaustive]
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub struct BackendCapabilities {
20    /// Whether the backend supports device streams (async/event completion).
21    device_streams: bool,
22    /// Maximum supported batch size, if any.
23    max_batch: Option<usize>,
24    /// Bitfield for supported data types (backend-defined; optional use).
25    dtype_mask: u64,
26}
27
28impl BackendCapabilities {
29    /// Create a new `BackendCapabilities`.
30    #[inline]
31    pub fn new(device_streams: bool, max_batch: Option<usize>, dtype_mask: u64) -> Self {
32        Self {
33            device_streams,
34            max_batch,
35            dtype_mask,
36        }
37    }
38
39    /// Whether the backend supports device streams (async/event completion).
40    #[inline]
41    pub fn device_streams(&self) -> &bool {
42        &self.device_streams
43    }
44
45    /// Maximum supported batch size, if any.
46    #[inline]
47    pub fn max_batch(&self) -> &Option<usize> {
48        &self.max_batch
49    }
50
51    /// Bitfield for supported data types (backend-defined; optional use).
52    #[inline]
53    pub fn dtype_mask(&self) -> &u64 {
54        &self.dtype_mask
55    }
56}
57
58/// Model metadata describing input/output shapes and preferences.
59#[non_exhaustive]
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub struct ModelMetadata {
62    /// Preferred input memory class (Host/Pinned/Device).
63    preferred_input: MemoryClass,
64    /// Preferred output memory class.
65    preferred_output: MemoryClass,
66    /// Optional maximum input size in bytes (admission hint).
67    max_input_bytes: Option<usize>,
68    /// Optional maximum output size in bytes.
69    max_output_bytes: Option<usize>,
70}
71
72impl ModelMetadata {
73    /// Create a new `ModelMetadata`.
74    #[inline]
75    pub fn new(
76        preferred_input: MemoryClass,
77        preferred_output: MemoryClass,
78        max_input_bytes: Option<usize>,
79        max_output_bytes: Option<usize>,
80    ) -> Self {
81        Self {
82            preferred_input,
83            preferred_output,
84            max_input_bytes,
85            max_output_bytes,
86        }
87    }
88
89    /// Preferred input memory class (Host/Pinned/Device).
90    #[inline]
91    pub fn preferred_input(&self) -> &MemoryClass {
92        &self.preferred_input
93    }
94
95    /// Preferred output memory class.
96    #[inline]
97    pub fn preferred_output(&self) -> &MemoryClass {
98        &self.preferred_output
99    }
100
101    /// Optional maximum input size in bytes (admission hint).
102    #[inline]
103    pub fn max_input_bytes(&self) -> &Option<usize> {
104        &self.max_input_bytes
105    }
106
107    /// Optional maximum output size in bytes.
108    #[inline]
109    pub fn max_output_bytes(&self) -> &Option<usize> {
110        &self.max_output_bytes
111    }
112}
113
114/// A loaded model that can perform inference.
115pub trait ComputeModel<InP: Payload, OutP: Payload> {
116    /// Prepare internal state (allocate work buffers, compile kernels, etc.).
117    fn init(&mut self) -> Result<(), InferenceError>;
118
119    /// Single-item inference (1×1).
120    fn infer_one(&mut self, inp: &InP, out: &mut OutP) -> Result<(), InferenceError>;
121
122    /// Optional: batched inference. Default loops `infer_one`.
123    #[inline]
124    fn infer_batch(
125        &mut self,
126        inps: Batch<'_, InP>,
127        outs: &mut [OutP],
128    ) -> Result<(), InferenceError> {
129        // Default: call infer_one using references into the Batch's messages.
130        for (m, o) in inps.messages().iter().zip(outs.iter_mut()) {
131            // Message::payload() returns &InP so we pass a reference, no move/clones.
132            self.infer_one(m.payload(), o)?;
133        }
134        Ok(())
135    }
136
137    /// Ensure outstanding device work is complete (if any).
138    fn drain(&mut self) -> Result<(), InferenceError>;
139
140    /// Reset internal state to a known baseline (drop caches, etc.).
141    fn reset(&mut self) -> Result<(), InferenceError>;
142
143    /// Return model metadata (I/O placement preferences, limits).
144    fn metadata(&self) -> ModelMetadata;
145}
146
147/// A dyn-free engine that constructs models and reports capabilities.
148pub trait ComputeBackend<InP: Payload, OutP: Payload> {
149    /// Concrete model type (no trait objects).
150    type Model: ComputeModel<InP, OutP>;
151
152    /// Backend-specific error.
153    type Error;
154
155    /// Backend-chosen borrowed descriptor used to load a model.
156    ///
157    /// Examples:
158    /// - on `std`:    `type ModelDescriptor<'desc> = &'desc ModelArtifact;`
159    /// - on `no_std`: `type ModelDescriptor<'desc> = &'desc [u8];`
160    type ModelDescriptor<'desc>
161    where
162        Self: 'desc;
163
164    /// Capability report.
165    fn capabilities(&self) -> BackendCapabilities;
166
167    /// Load a model from a descriptor.
168    fn load_model<'desc>(
169        &self,
170        desc: Self::ModelDescriptor<'desc>,
171    ) -> Result<Self::Model, Self::Error>;
172}
173
174/// A simple artifact passed to backends for model creation (POC-friendly).
175#[cfg(feature = "std")]
176#[non_exhaustive]
177#[derive(Debug, Clone)]
178pub struct ModelArtifact {
179    /// Raw bytes of a model file or an engine-specific blob.
180    bytes: std::sync::Arc<Vec<u8>>,
181    /// Optional label or path hint.
182    label: Option<String>,
183}
184
185#[cfg(feature = "std")]
186impl ModelArtifact {
187    /// Construct from an Arc of bytes and an optional label.
188    #[inline]
189    pub fn new(bytes: std::sync::Arc<Vec<u8>>, label: Option<String>) -> Self {
190        Self { bytes, label }
191    }
192
193    /// Construct from raw bytes.
194    pub fn from_bytes(bytes: Vec<u8>) -> Self {
195        Self {
196            bytes: std::sync::Arc::new(bytes),
197            label: None,
198        }
199    }
200
201    /// Convenience: load from a file path.
202    pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<Self> {
203        let bytes = std::fs::read(path)?;
204        Ok(Self::from_bytes(bytes))
205    }
206
207    /// Access the bytes (cloned Arc).
208    #[inline]
209    pub fn bytes(&self) -> std::sync::Arc<Vec<u8>> {
210        self.bytes.clone()
211    }
212
213    /// Access the optional label as `Option<&str>`.
214    #[inline]
215    pub fn label(&self) -> Option<&str> {
216        self.label.as_deref()
217    }
218}