Skip to main content

bb_ops/backends/cpu/
mod.rs

1//! Pure-Rust reference CPU backend. Implements `bb::Backend` for
2//! the `ai.onnx v1` 51-op subset over ndarray. Gated by the
3//! `cpu-backend` feature.
4
5pub mod ops;
6pub mod opset;
7pub mod tensor;
8
9use bb_derive::Backend;
10use serde::{Deserialize, Serialize};
11
12use bb_dsl::concrete::{ComponentPackage, ConcreteComponent, RestoreError};
13use bb_runtime::component::AnyComponent;
14
15pub use opset::{ONNX_DOMAIN, ONNX_V1_OPSET, ONNX_VERSION};
16pub use tensor::{CpuTensor, CpuTensorError};
17
18pub mod graph_walker;
19pub use graph_walker::{execute_graph, BackendError};
20
21/// Reference CPU backend. Dispatches its opset onto ndarray
22/// kernels; storage is `ArrayD<f32>` end-to-end.
23#[derive(Clone, Debug, Default, Serialize, Deserialize, Backend)]
24pub struct CpuBackend;
25
26impl CpuBackend {
27    /// Construct a fresh backend.
28    pub fn new() -> Self {
29        Self
30    }
31
32    /// Allocate a zero-initialised tensor with the given shape.
33    /// Single allocation seam for future pool / arena strategies.
34    pub fn alloc_tensor(&self, shape: Vec<i64>) -> CpuTensor {
35        CpuTensor::zeros(shape)
36    }
37
38    /// Wrap a kernel-produced `ArrayD<f32>` as a `CpuTensor`.
39    /// Single wrapping seam for future pool / arena strategies.
40    pub fn wrap_array(&self, array: ndarray::ArrayD<f32>) -> CpuTensor {
41        CpuTensor::from_array(array)
42    }
43}
44
45impl AnyComponent for CpuBackend {
46    fn as_any(&self) -> &dyn std::any::Any {
47        self
48    }
49    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
50        self
51    }
52}
53
54impl ConcreteComponent for CpuBackend {
55    const TYPE_NAME: &'static str = "bytesandbrains::backends::cpu::CpuBackend";
56    const PACKAGE: ComponentPackage = ComponentPackage::Framework;
57
58    type Config = ();
59    type Error = std::convert::Infallible;
60
61    fn new(_config: &Self::Config) -> Result<Self, Self::Error> {
62        Ok(Self)
63    }
64
65    fn serialize(&self) -> Vec<u8> {
66        bincode::serialize(self).expect("CpuBackend serde is infallible")
67    }
68
69    fn restore(bytes: &[u8]) -> Result<Self, RestoreError> {
70        bincode::deserialize(bytes).map_err(RestoreError::Malformed)
71    }
72}
73
74// Hand-rolled equivalent of `#[derive(bb::Concrete)]` so the
75// manual `ConcreteComponent` impl above continues to drive inventory
76// registration.
77#[doc(hidden)]
78fn __cpu_backend_serialize(erased: &dyn bb_ir::component::ErasedComponent) -> Vec<u8> {
79    let any: &dyn std::any::Any = erased;
80    let concrete: &CpuBackend = any
81        .downcast_ref::<CpuBackend>()
82        .expect("inventory downcast: CpuBackend by TYPE_NAME");
83    <CpuBackend as ConcreteComponent>::serialize(concrete)
84}
85#[doc(hidden)]
86fn __cpu_backend_restore(
87    bytes: &[u8],
88) -> Result<Box<dyn bb_ir::component::ErasedComponent>, RestoreError> {
89    <CpuBackend as ConcreteComponent>::restore(bytes)
90        .map(|v| Box::new(v) as Box<dyn bb_ir::component::ErasedComponent>)
91}
92#[doc(hidden)]
93fn __cpu_backend_construct(
94    cfg: &dyn std::any::Any,
95) -> Result<Box<dyn bb_ir::component::ErasedComponent>, bb_ir::component::ConstructError> {
96    let _typed: &() = cfg
97        .downcast_ref::<()>()
98        .ok_or_else(|| bb_ir::component::ConstructError {
99            type_name: "bytesandbrains::backends::cpu::CpuBackend",
100            detail: format!(
101                "config type mismatch: expected `()`, got `{:?}`",
102                cfg.type_id(),
103            ),
104        })?;
105    <CpuBackend as ConcreteComponent>::new(_typed)
106        .map(|v| Box::new(v) as Box<dyn bb_ir::component::ErasedComponent>)
107        .map_err(|e| bb_ir::component::ConstructError {
108            type_name: "bytesandbrains::backends::cpu::CpuBackend",
109            detail: format!("{e}"),
110        })
111}
112bb_ir::registry::inventory::submit! {
113    bb_ir::registry::ConcreteComponentRegistration {
114        type_name: "bytesandbrains::backends::cpu::CpuBackend",
115        package: ComponentPackage::Framework,
116        serialize_fn: __cpu_backend_serialize,
117        restore_fn: __cpu_backend_restore,
118        construct_fn: __cpu_backend_construct,
119        dependencies: &[],
120    }
121}
122
123/// `CpuTensor` → `TYPE_TENSOR_F32` in the polymorphism tree.
124impl bb_ir::types::Storage for CpuTensor {
125    const TYPE: &'static bb_ir::types::TypeNode = &bb_ir::types::TYPE_TENSOR_F32;
126}
127
128/// Thread-local invocation counter for `CpuBackend::execute`, gated
129/// under `test-components`.
130#[cfg(any(test, feature = "test-components"))]
131mod dispatch_counter {
132    use std::cell::Cell;
133    thread_local! {
134        static COUNT: Cell<usize> = const { Cell::new(0) };
135    }
136    pub fn bump() {
137        COUNT.with(|c| c.set(c.get().wrapping_add(1)));
138    }
139    /// Read the current per-thread invocation count.
140    pub fn read() -> usize {
141        COUNT.with(|c| c.get())
142    }
143    /// Reset the per-thread invocation count to zero.
144    pub fn reset() {
145        COUNT.with(|c| c.set(0));
146    }
147}
148
149#[cfg(any(test, feature = "test-components"))]
150pub use dispatch_counter::{read as dispatch_count, reset as reset_dispatch_count};
151
152/// `bb::Backend` Contract impl. Overrides `execute` to run through
153/// `graph_walker::execute_graph` rather than the default per-op
154/// walker.
155impl bb_runtime::contracts::Backend for CpuBackend {
156    type Error = graph_walker::BackendError;
157    type Tensor = CpuTensor;
158
159    fn execute(
160        &self,
161        graph: &bb_ir::proto::onnx::GraphProto,
162        inputs: std::collections::HashMap<String, Self::Tensor>,
163        _attrs: bb_runtime::contracts::backend::BackendAttrs<'_>,
164    ) -> Result<std::collections::HashMap<String, Self::Tensor>, Self::Error> {
165        #[cfg(any(test, feature = "test-components"))]
166        dispatch_counter::bump();
167        graph_walker::execute_graph(self, graph, inputs)
168    }
169
170    /// Decode wire bytes inside the backend so the `CpuTensor`
171    /// carries the ingress byte charge for slot-table release on
172    /// overwrite. v1 uses bincode.
173    fn materialize_from_wire(
174        &self,
175        type_hash: u64,
176        bytes: Vec<u8>,
177    ) -> Result<Self::Tensor, Self::Error> {
178        use bb_runtime::contracts::backend_default_walk::BackendWalkError;
179        let expected_hash = bb_ir::slot_value::type_hash_of::<CpuTensor>();
180        if type_hash != expected_hash {
181            return Err(graph_walker::BackendError::DefaultWalker(
182                BackendWalkError::WireMaterializeFailed {
183                    type_hash,
184                    reason: format!(
185                        "expected CpuTensor type_hash {expected_hash:#018x}, got {type_hash:#018x}",
186                    ),
187                },
188            ));
189        }
190        let charged_bytes = bytes.len();
191        // Re-wrap with the carried charge so the slot-table writer
192        // can release the budget on overwrite.
193        let wire: CpuTensor = bincode::deserialize(&bytes).map_err(|e| {
194            graph_walker::BackendError::DefaultWalker(BackendWalkError::WireMaterializeFailed {
195                type_hash,
196                reason: format!("bincode decode: {e}"),
197            })
198        })?;
199        // One copy out of the discarded handle; the wire path pays
200        // this until pooling lands.
201        Ok(CpuTensor::from_wire_buffer(
202            wire.0.data.clone(),
203            charged_bytes,
204        ))
205    }
206}
207