Skip to main content

bb_runtime/contracts/
model.rs

1//! `bb::Model` — Contract trait for ML models.
2//!
3//! Each method takes the engine's `&mut RuntimeResourceRef<'_>` ctx
4//! plus a [`CompletionHandle`] AND returns [`ContractResponse`]. See
5//! [`crate::contracts::index`] for the sync (Now) vs async (Later)
6//! semantics.
7//!
8//! ## Associated type: `Tensor`
9//!
10//! One associated type covers input tensors, output tensors, parameter
11//! vectors, gradients, and deltas. Mixed-precision (e.g. f32 input +
12//! f16 weights + f32 output) is handled by wiring [`Codec`] nodes
13//! around the model in the Module body — not by multiplying associated
14//! types per port.
15//!
16//! [`Codec`]: crate::contracts::Codec
17
18use crate::completion::{CompletionHandle, ContractResponse};
19use crate::runtime::RuntimeResourceRef;
20
21/// User-facing Contract trait for an ML model.
22pub trait Model: Send + Sync {
23    /// Tensor storage type. One associated type covers
24    /// input/output/params/grad/delta. Implement as `[f32]` for
25    /// flat f32 tensors.
26    type Tensor: ?Sized + bb_ir::types::Storage;
27    /// Library-maker-defined error type.
28    type Error: std::error::Error + std::fmt::Display + Send + Sync + 'static;
29
30    /// Forward pass: `input → output`. `ctx` is the per-dispatch
31    /// runtime surface; impls reach their declared `#[depends(...)]`
32    /// siblings through [`RuntimeResourceRef::dependency`].
33    fn forward(
34        &mut self,
35        ctx: &mut RuntimeResourceRef<'_>,
36        input: &Self::Tensor,
37        completion: CompletionHandle<Box<Self::Tensor>, Self::Error>,
38    ) -> ContractResponse<Box<Self::Tensor>, Self::Error>;
39
40    /// Load parameters wholesale.
41    fn load_parameters(
42        &mut self,
43        ctx: &mut RuntimeResourceRef<'_>,
44        params: &Self::Tensor,
45        completion: CompletionHandle<(), Self::Error>,
46    ) -> ContractResponse<(), Self::Error>;
47
48    /// Backward pass: accumulate gradients given upstream gradient.
49    fn backward(
50        &mut self,
51        ctx: &mut RuntimeResourceRef<'_>,
52        grad: &Self::Tensor,
53        completion: CompletionHandle<(), Self::Error>,
54    ) -> ContractResponse<(), Self::Error>;
55
56    /// Apply a parameter delta in-place.
57    fn apply_delta(
58        &mut self,
59        ctx: &mut RuntimeResourceRef<'_>,
60        delta: &Self::Tensor,
61        completion: CompletionHandle<(), Self::Error>,
62    ) -> ContractResponse<(), Self::Error>;
63
64    /// Compute loss: `(input, target) → scalar score`. Returns `f32`
65    /// regardless of the tensor element type — loss is always a
66    /// framework-fixed scalar.
67    fn compute_loss(
68        &mut self,
69        ctx: &mut RuntimeResourceRef<'_>,
70        input: &Self::Tensor,
71        target: &Self::Tensor,
72        completion: CompletionHandle<f32, Self::Error>,
73    ) -> ContractResponse<f32, Self::Error>;
74
75    /// Snapshot the current parameter tensor (owned — async
76    /// serialization needs owned values).
77    fn params(
78        &self,
79        ctx: &mut RuntimeResourceRef<'_>,
80        completion: CompletionHandle<Box<Self::Tensor>, Self::Error>,
81    ) -> ContractResponse<Box<Self::Tensor>, Self::Error>;
82}