Skip to main content

bb_ops/backends/cpu/
tensor.rs

1//! `CpuTensor` — `Arc`-shared handle to a `CpuBackendBuffer` holding
2//! an `ndarray::ArrayD<f32>`. Cloning the handle is an `Arc::clone`
3//! (O(1) refcount bump); the underlying buffer is owned by the
4//! backend and may be pooled or fresh-allocated by `CpuBackend`.
5//!
6//! Storage is `ndarray::ArrayD<f32>`: heap-dynamic rank, row-major,
7//! with broadcasting + axis-walking primitives ndarray already
8//! provides. Hot kernels downcast to `Ix2` / `Ix3` via
9//! `.into_dimensionality::<...>()` for typed-dim performance and
10//! return to `IxDyn` at the boundary.
11//!
12//! Phase C scope: f32 only. Backend-side extensions for f64 / i32 /
13//! i64 / bool land via the optional `extension_opsets()`
14//! declaration when those types are exercised.
15
16use std::sync::Arc;
17
18use serde::de::{self, MapAccess, Visitor};
19use serde::ser::SerializeStruct;
20use serde::{Deserialize, Deserializer, Serialize, Serializer};
21
22use bb_ir::proto::onnx::TensorProto;
23use bb_ir::tensor::{Tensor, TensorSerializationError};
24use bb_ir::types::TYPE_TENSOR_F32;
25use bb_ir::{register_charged_bytes, register_type_node};
26use bb_runtime::slot_value::SlotValue;
27use ndarray::{ArrayD, IxDyn};
28
29register_type_node!(CpuTensor, &TYPE_TENSOR_F32);
30// Backend-mediated wire receive stamps the byte charge into the
31// buffer; the slot-table writer reads it back through the default
32// SlotValue::charged_bytes body so the engine can release the
33// admission charge on overwrite / eviction.
34register_charged_bytes!(CpuTensor, |t: &CpuTensor| t.0.charged_bytes);
35
36/// ONNX `DataType::FLOAT` numeric tag.
37pub const ONNX_FLOAT: i32 = 1;
38
39/// Backend-owned buffer behind a [`CpuTensor`] handle. Holds the
40/// `ndarray` storage plus the byte count charged against the
41/// `NodeConfig::ingress_byte_budget` at materialization time so the
42/// slot-table writer can release the charge on overwrite / eviction.
43#[derive(Debug)]
44pub struct CpuBackendBuffer {
45    /// f32 storage, heap-dynamic rank, row-major.
46    pub(crate) data: ArrayD<f32>,
47    /// i64-typed shape cache (ONNX convention). Always equals
48    /// `data.shape().iter().map(|&n| n as i64).collect()`.
49    pub(crate) dims_i64: Vec<i64>,
50    /// Wire-byte count charged at the ingress boundary; carriers
51    /// holding this buffer surface this through
52    /// `SlotValue::charged_bytes` for budget release on slot
53    /// overwrite. Zero for tensors that did not arrive via the wire
54    /// (kernel outputs, test fixtures).
55    pub(crate) charged_bytes: usize,
56}
57
58/// f32-dense CPU-resident tensor handle. `Arc`-shared so intra-Node
59/// clones (FedAvg's per-peer buffer insert, slot-table writes, etc.)
60/// are refcount bumps rather than `Vec<f32>` deep copies. The
61/// underlying [`CpuBackendBuffer`] is owned by the backend, which is
62/// free to pool / reuse / free the storage at a later milestone
63/// without API churn (the handle shape stays identical).
64#[derive(Clone, Debug)]
65pub struct CpuTensor(pub(crate) Arc<CpuBackendBuffer>);
66
67/// Errors `CpuTensor::new_checked` and `from_proto` may return.
68#[derive(Debug)]
69pub enum CpuTensorError {
70    /// The product of dims doesn't match `data.len()`. Surfaces
71    /// from `new_checked` and the proto deserialization boundary.
72    ShapeMismatch {
73        /// Expected element count (product of dims).
74        expected: usize,
75        /// Observed element count.
76        got: usize,
77    },
78}
79
80impl std::fmt::Display for CpuTensorError {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        match self {
83            Self::ShapeMismatch { expected, got } => write!(
84                f,
85                "CpuTensor shape mismatch: dims product {expected} ≠ data.len {got}",
86            ),
87        }
88    }
89}
90
91impl std::error::Error for CpuTensorError {}
92
93impl CpuTensor {
94    /// Wrap an existing `ArrayD<f32>` in a fresh backend buffer.
95    /// `charged_bytes = 0` — kernel outputs and test fixtures don't
96    /// arrive via the wire and therefore don't hold an ingress
97    /// charge. The wire path uses the crate-private
98    /// `from_wire_buffer` helper from `CpuBackend::materialize_from_wire`.
99    pub fn from_array(data: ArrayD<f32>) -> Self {
100        let dims_i64 = data.shape().iter().map(|&n| n as i64).collect();
101        Self(Arc::new(CpuBackendBuffer {
102            data,
103            dims_i64,
104            charged_bytes: 0,
105        }))
106    }
107
108    /// Construct from ONNX-signed shape + flat row-major data via
109    /// `ndarray::ArrayD::from_shape_vec`. Equivalent to
110    /// [`Self::new`] but spelled out for callers preferring the
111    /// builder-style name.
112    pub fn from_vec(shape: Vec<i64>, data: Vec<f32>) -> Self {
113        Self::new(shape, data)
114    }
115
116    /// Borrow the underlying ndarray.
117    pub fn as_array(&self) -> &ArrayD<f32> {
118        &self.0.data
119    }
120
121    /// Clone the backend buffer into an owned `ArrayD<f32>`. The
122    /// `Arc` shape means the buffer cannot be unwrapped in-place
123    /// (other handles may share it), so this always pays the
124    /// `ndarray` deep copy. Test-only callers needing flat data
125    /// should prefer [`Self::flat_data`].
126    pub fn into_array(self) -> ArrayD<f32> {
127        self.0.data.clone()
128    }
129
130    /// Test-helper that returns the cached i64 shape. Real callers
131    /// use the [`Tensor::dims`] trait method or
132    /// [`Self::as_array`]`.shape()`.
133    #[doc(hidden)]
134    pub fn dims_vec(&self) -> &[i64] {
135        &self.0.dims_i64
136    }
137
138    /// Test-helper that materializes a flat row-major copy. Real
139    /// callers iterate `self.as_array()` directly to avoid the
140    /// allocation.
141    #[doc(hidden)]
142    pub fn flat_data(&self) -> Vec<f32> {
143        self.0.data.iter().copied().collect()
144    }
145
146    /// Construct from ONNX-signed shape + flat row-major data.
147    /// Panics if the dims product doesn't match the data length —
148    /// callers needing strict checking use [`Self::new_checked`].
149    pub fn new(dims: Vec<i64>, data: Vec<f32>) -> Self {
150        let shape: Vec<usize> = dims.iter().map(|&d| d.max(0) as usize).collect();
151        let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
152            .expect("CpuTensor::new shape × data mismatch");
153        Self::from_array(array)
154    }
155
156    /// Construct + validate `dims_product(&dims) == data.len()`.
157    pub fn new_checked(dims: Vec<i64>, data: Vec<f32>) -> Result<Self, CpuTensorError> {
158        let expected = dims_product(&dims);
159        if expected != data.len() {
160            return Err(CpuTensorError::ShapeMismatch {
161                expected,
162                got: data.len(),
163            });
164        }
165        let shape: Vec<usize> = dims.iter().map(|&d| d.max(0) as usize).collect();
166        let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
167            .map_err(|_| CpuTensorError::ShapeMismatch { expected, got: 0 })?;
168        Ok(Self::from_array(array))
169    }
170
171    /// Construct a zero-filled tensor with the given shape.
172    pub fn zeros(dims: Vec<i64>) -> Self {
173        let shape: Vec<usize> = dims.iter().map(|&d| d.max(0) as usize).collect();
174        Self::from_array(ArrayD::zeros(IxDyn(&shape)))
175    }
176
177    /// Construct a ones-filled tensor with the given shape.
178    pub fn ones(dims: Vec<i64>) -> Self {
179        let shape: Vec<usize> = dims.iter().map(|&d| d.max(0) as usize).collect();
180        Self::from_array(ArrayD::ones(IxDyn(&shape)))
181    }
182
183    /// Observe the underlying `Arc<CpuBackendBuffer>` strong-refcount.
184    /// One strong holder means the caller holds the only handle to
185    /// the buffer; future pooling implementations (`v2`) read this to
186    /// decide whether to return the buffer to the pool on drop. Tests
187    /// use this to prove the wire-decode path lands a single carrier
188    /// in the slot table with no spurious clones.
189    pub fn strong_count(&self) -> usize {
190        Arc::strong_count(&self.0)
191    }
192
193    /// Wrap a kernel-supplied `ArrayD<f32>` plus a wire-byte charge
194    /// in a fresh backend buffer. Used by
195    /// `CpuBackend::materialize_from_wire` so the resulting tensor
196    /// carries the charge that the slot-table writer releases on
197    /// eviction.
198    pub(crate) fn from_wire_buffer(data: ArrayD<f32>, charged_bytes: usize) -> Self {
199        let dims_i64 = data.shape().iter().map(|&n| n as i64).collect();
200        Self(Arc::new(CpuBackendBuffer {
201            data,
202            dims_i64,
203            charged_bytes,
204        }))
205    }
206}
207
208fn dims_product(dims: &[i64]) -> usize {
209    dims.iter().map(|d| (*d).max(0) as usize).product()
210}
211
212impl std::fmt::Display for CpuTensor {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        write!(
215            f,
216            "CpuTensor(dims={:?}, len={})",
217            self.0.data.shape(),
218            self.0.data.len(),
219        )
220    }
221}
222
223impl Tensor for CpuTensor {
224    type Scalar = f32;
225
226    fn dims(&self) -> &[i64] {
227        &self.0.dims_i64
228    }
229
230    fn len(&self) -> usize {
231        self.0.data.len()
232    }
233
234    fn to_proto(&self) -> TensorProto {
235        let dims: Vec<i64> = self.0.data.shape().iter().map(|&n| n as i64).collect();
236        let float_data: Vec<f32> = self.0.data.iter().copied().collect();
237        TensorProto {
238            dims,
239            data_type: ONNX_FLOAT,
240            float_data,
241            ..Default::default()
242        }
243    }
244
245    fn from_proto(proto: TensorProto) -> Result<Self, TensorSerializationError> {
246        if proto.data_type != ONNX_FLOAT {
247            return Err(TensorSerializationError::ElementTypeMismatch {
248                expected: ONNX_FLOAT,
249                found: proto.data_type,
250            });
251        }
252        // Prefer `float_data`; fall back to raw bytes if shipping
253        // tools encoded via `raw_data` (4-byte little-endian floats).
254        let data = if !proto.float_data.is_empty() {
255            proto.float_data
256        } else if !proto.raw_data.is_empty() {
257            if proto.raw_data.len() % 4 != 0 {
258                return Err(TensorSerializationError::ShapeError(format!(
259                    "raw_data length {} not divisible by 4",
260                    proto.raw_data.len(),
261                )));
262            }
263            let mut out = Vec::with_capacity(proto.raw_data.len() / 4);
264            for chunk in proto.raw_data.chunks_exact(4) {
265                out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
266            }
267            out
268        } else {
269            Vec::new()
270        };
271        let expected = dims_product(&proto.dims);
272        if expected != data.len() {
273            return Err(TensorSerializationError::ShapeError(format!(
274                "dims product {expected} doesn't match data len {len}",
275                len = data.len()
276            )));
277        }
278        let shape: Vec<usize> = proto.dims.iter().map(|&d| d.max(0) as usize).collect();
279        let array = ArrayD::from_shape_vec(IxDyn(&shape), data).map_err(|e| {
280            TensorSerializationError::ShapeError(format!("ndarray::from_shape_vec: {e}"))
281        })?;
282        Ok(Self::from_array(array))
283    }
284}
285
286// Hand-written Serialize / Deserialize: skip the `Arc` indirection
287// on the wire (a remote receiver wants the buffer's contents, not
288// the local refcount cell) and skip `charged_bytes` (a snapshot
289// replay restarts ingress accounting from zero). Deserialization
290// fresh-allocates a `CpuBackendBuffer` wrapped in a brand-new `Arc`.
291
292impl Serialize for CpuTensor {
293    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
294    where
295        S: Serializer,
296    {
297        let mut s = serializer.serialize_struct("CpuTensor", 2)?;
298        s.serialize_field("data", &self.0.data)?;
299        s.serialize_field("dims_i64", &self.0.dims_i64)?;
300        s.end()
301    }
302}
303
304impl<'de> Deserialize<'de> for CpuTensor {
305    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
306    where
307        D: Deserializer<'de>,
308    {
309        #[derive(Deserialize)]
310        #[serde(field_identifier, rename_all = "snake_case")]
311        enum Field {
312            Data,
313            DimsI64,
314        }
315
316        struct CpuTensorVisitor;
317
318        impl<'de> Visitor<'de> for CpuTensorVisitor {
319            type Value = CpuTensor;
320
321            fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322                f.write_str("struct CpuTensor")
323            }
324
325            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
326            where
327                A: de::SeqAccess<'de>,
328            {
329                let data: ArrayD<f32> = seq.next_element()?.ok_or_else(|| {
330                    de::Error::invalid_length(0, &"struct CpuTensor with 2 fields")
331                })?;
332                let dims_i64: Vec<i64> = seq.next_element()?.ok_or_else(|| {
333                    de::Error::invalid_length(1, &"struct CpuTensor with 2 fields")
334                })?;
335                Ok(CpuTensor(Arc::new(CpuBackendBuffer {
336                    data,
337                    dims_i64,
338                    charged_bytes: 0,
339                })))
340            }
341
342            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
343            where
344                A: MapAccess<'de>,
345            {
346                let mut data: Option<ArrayD<f32>> = None;
347                let mut dims_i64: Option<Vec<i64>> = None;
348                while let Some(key) = map.next_key()? {
349                    match key {
350                        Field::Data => {
351                            if data.is_some() {
352                                return Err(de::Error::duplicate_field("data"));
353                            }
354                            data = Some(map.next_value()?);
355                        }
356                        Field::DimsI64 => {
357                            if dims_i64.is_some() {
358                                return Err(de::Error::duplicate_field("dims_i64"));
359                            }
360                            dims_i64 = Some(map.next_value()?);
361                        }
362                    }
363                }
364                let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
365                let dims_i64 = dims_i64.ok_or_else(|| de::Error::missing_field("dims_i64"))?;
366                Ok(CpuTensor(Arc::new(CpuBackendBuffer {
367                    data,
368                    dims_i64,
369                    charged_bytes: 0,
370                })))
371            }
372        }
373
374        const FIELDS: &[&str] = &["data", "dims_i64"];
375        deserializer.deserialize_struct("CpuTensor", FIELDS, CpuTensorVisitor)
376    }
377}
378
379// `Tensor` implies the framework's blanket `SlotValue` impl so
380// `CpuTensor` can be passed by-ref into `dispatch_atomic` inputs.
381// The blanket lives in `bytesandbrains::tensor`; this re-export
382// reminds readers that the contract is satisfied.
383const _: fn() = || {
384    fn _check<T: SlotValue>() {}
385    _check::<CpuTensor>();
386};
387