Skip to main content

bb_ir/
tensor.rs

1//! Tensor + Scalar abstractions. The framework crate ships no
2//! concrete tensor type; backends implement these traits over their
3//! own storage of choice.
4//!
5//! - `Scalar` - every scalar projects to `f32` for cross-backend
6//!   interop. Framework ships universal impls for `f32` + `f64`;
7//!   backends add others.
8//! - `Tensor` - the contract every backend's concrete tensor type
9//!   implements: shape + total length + canonical ONNX
10//!   `TensorProto` round-trip.
11
12use crate::proto::onnx::TensorProto;
13
14// ---------------------------------------------------------------
15// Scalar
16// ---------------------------------------------------------------
17
18/// A scalar value usable in tensors. Every scalar projects to
19/// `f32` for cross-backend interop.
20///
21/// Backends pick their own concrete scalar set; the framework
22/// ships universal primitive impls for `f32` + `f64` only.
23/// Concrete backends may add impls for `i32`, `i64`, `u32`, `u64`,
24/// `bool`, `f16` (bf16, fp8 variants), etc.
25pub trait Scalar: Copy + Send + Sync + 'static {
26    /// Projection to `f32`. Lossy for wider types (f64, i64);
27    /// faithful for f32 + smaller integers.
28    fn to_f32(&self) -> f32;
29}
30
31impl Scalar for f32 {
32    fn to_f32(&self) -> f32 {
33        *self
34    }
35}
36
37impl Scalar for f64 {
38    fn to_f32(&self) -> f32 {
39        *self as f32
40    }
41}
42
43// ---------------------------------------------------------------
44// Tensor
45// ---------------------------------------------------------------
46
47/// The contract every backend's concrete tensor type implements.
48///
49/// Backend impls land in integration crates (e.g. `bb-cpu-onnx`).
50/// The framework crate ships no concrete tensor type - the
51/// `Tensor` trait IS the contract.
52///
53/// The serde + `Clone` bounds are what make every concrete tensor
54/// type a [`crate::slot_value::SlotValue`] via the universal
55/// blanket - tensors ride slots, wire envelopes, and snapshots
56/// through the same bincode encoding path as every other value.
57pub trait Tensor:
58    Clone
59    + std::fmt::Debug
60    + std::fmt::Display
61    + Send
62    + Sync
63    + 'static
64    + serde::Serialize
65    + serde::de::DeserializeOwned
66{
67    /// The scalar element type this tensor holds.
68    type Scalar: Scalar;
69
70    /// Tensor shape. ONNX-compatible signed-dim convention; `-1`
71    /// for dynamic dims.
72    fn dims(&self) -> &[i64];
73
74    /// Total element count across all dims. For dynamic-dim
75    /// tensors callers must resolve concrete dims before consulting.
76    fn len(&self) -> usize;
77
78    /// `true` when the tensor holds zero elements.
79    fn is_empty(&self) -> bool {
80        self.len() == 0
81    }
82
83    /// Serialize to canonical ONNX `TensorProto`. The result is
84    /// portable across backends declaring the same scalar type.
85    fn to_proto(&self) -> TensorProto;
86
87    /// Deserialize from canonical ONNX `TensorProto`. Returns an
88    /// error if the proto's `elem_type` / shape doesn't match
89    /// `Self`'s expectations.
90    fn from_proto(proto: TensorProto) -> Result<Self, TensorSerializationError>;
91}
92
93/// Errors surfaced by `Tensor::from_proto`.
94#[derive(Debug)]
95pub enum TensorSerializationError {
96    /// Proto's elem_type didn't match the impl's expected scalar.
97    ElementTypeMismatch {
98        /// What the impl expected (ONNX `DataType` enum value).
99        expected: i32,
100        /// What the proto held.
101        found: i32,
102    },
103    /// Proto's shape couldn't be interpreted as the impl's tensor
104    /// layout (e.g. byte-count mismatch, malformed dim list).
105    ShapeError(String),
106    /// Impl-specific deserialization failure.
107    Custom(String),
108}
109
110impl std::fmt::Display for TensorSerializationError {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        match self {
113            Self::ElementTypeMismatch { expected, found } => {
114                write!(
115                    f,
116                    "tensor elem_type mismatch: expected {expected}, found {found}"
117                )
118            }
119            Self::ShapeError(m) => write!(f, "tensor shape error: {m}"),
120            Self::Custom(m) => write!(f, "tensor serialization failure: {m}"),
121        }
122    }
123}
124
125impl std::error::Error for TensorSerializationError {}
126