bb_ir/tensor.rs
1//! Tensor + Scalar abstractions. The framework crate ships no
2//! concrete tensor type; backends implement these traits over their
3//! own storage of choice.
4//!
5//! - `Scalar` - every scalar projects to `f32` for cross-backend
6//! interop. Framework ships universal impls for `f32` + `f64`;
7//! backends add others.
8//! - `Tensor` - the contract every backend's concrete tensor type
9//! implements: shape + total length + canonical ONNX
10//! `TensorProto` round-trip.
11
12use crate::proto::onnx::TensorProto;
13
14// ---------------------------------------------------------------
15// Scalar
16// ---------------------------------------------------------------
17
18/// A scalar value usable in tensors. Every scalar projects to
19/// `f32` for cross-backend interop.
20///
21/// Backends pick their own concrete scalar set; the framework
22/// ships universal primitive impls for `f32` + `f64` only.
23/// Concrete backends may add impls for `i32`, `i64`, `u32`, `u64`,
24/// `bool`, `f16` (bf16, fp8 variants), etc.
25pub trait Scalar: Copy + Send + Sync + 'static {
26 /// Projection to `f32`. Lossy for wider types (f64, i64);
27 /// faithful for f32 + smaller integers.
28 fn to_f32(&self) -> f32;
29}
30
31impl Scalar for f32 {
32 fn to_f32(&self) -> f32 {
33 *self
34 }
35}
36
37impl Scalar for f64 {
38 fn to_f32(&self) -> f32 {
39 *self as f32
40 }
41}
42
43// ---------------------------------------------------------------
44// Tensor
45// ---------------------------------------------------------------
46
47/// The contract every backend's concrete tensor type implements.
48///
49/// Backend impls land in integration crates (e.g. `bb-cpu-onnx`).
50/// The framework crate ships no concrete tensor type - the
51/// `Tensor` trait IS the contract.
52///
53/// The serde + `Clone` bounds are what make every concrete tensor
54/// type a [`crate::slot_value::SlotValue`] via the universal
55/// blanket - tensors ride slots, wire envelopes, and snapshots
56/// through the same bincode encoding path as every other value.
57pub trait Tensor:
58 Clone
59 + std::fmt::Debug
60 + std::fmt::Display
61 + Send
62 + Sync
63 + 'static
64 + serde::Serialize
65 + serde::de::DeserializeOwned
66{
67 /// The scalar element type this tensor holds.
68 type Scalar: Scalar;
69
70 /// Tensor shape. ONNX-compatible signed-dim convention; `-1`
71 /// for dynamic dims.
72 fn dims(&self) -> &[i64];
73
74 /// Total element count across all dims. For dynamic-dim
75 /// tensors callers must resolve concrete dims before consulting.
76 fn len(&self) -> usize;
77
78 /// `true` when the tensor holds zero elements.
79 fn is_empty(&self) -> bool {
80 self.len() == 0
81 }
82
83 /// Serialize to canonical ONNX `TensorProto`. The result is
84 /// portable across backends declaring the same scalar type.
85 fn to_proto(&self) -> TensorProto;
86
87 /// Deserialize from canonical ONNX `TensorProto`. Returns an
88 /// error if the proto's `elem_type` / shape doesn't match
89 /// `Self`'s expectations.
90 fn from_proto(proto: TensorProto) -> Result<Self, TensorSerializationError>;
91}
92
93/// Errors surfaced by `Tensor::from_proto`.
94#[derive(Debug)]
95pub enum TensorSerializationError {
96 /// Proto's elem_type didn't match the impl's expected scalar.
97 ElementTypeMismatch {
98 /// What the impl expected (ONNX `DataType` enum value).
99 expected: i32,
100 /// What the proto held.
101 found: i32,
102 },
103 /// Proto's shape couldn't be interpreted as the impl's tensor
104 /// layout (e.g. byte-count mismatch, malformed dim list).
105 ShapeError(String),
106 /// Impl-specific deserialization failure.
107 Custom(String),
108}
109
110impl std::fmt::Display for TensorSerializationError {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 match self {
113 Self::ElementTypeMismatch { expected, found } => {
114 write!(
115 f,
116 "tensor elem_type mismatch: expected {expected}, found {found}"
117 )
118 }
119 Self::ShapeError(m) => write!(f, "tensor shape error: {m}"),
120 Self::Custom(m) => write!(f, "tensor serialization failure: {m}"),
121 }
122 }
123}
124
125impl std::error::Error for TensorSerializationError {}
126