bb_ops/backends/cpu/
mod.rs1pub mod ops;
6pub mod opset;
7pub mod tensor;
8
9use bb_derive::Backend;
10use serde::{Deserialize, Serialize};
11
12use bb_dsl::concrete::{ComponentPackage, ConcreteComponent, RestoreError};
13use bb_runtime::component::AnyComponent;
14
15pub use opset::{ONNX_DOMAIN, ONNX_V1_OPSET, ONNX_VERSION};
16pub use tensor::{CpuTensor, CpuTensorError};
17
18pub mod graph_walker;
19pub use graph_walker::{execute_graph, BackendError};
20
21#[derive(Clone, Debug, Default, Serialize, Deserialize, Backend)]
24pub struct CpuBackend;
25
26impl CpuBackend {
27 pub fn new() -> Self {
29 Self
30 }
31
32 pub fn alloc_tensor(&self, shape: Vec<i64>) -> CpuTensor {
35 CpuTensor::zeros(shape)
36 }
37
38 pub fn wrap_array(&self, array: ndarray::ArrayD<f32>) -> CpuTensor {
41 CpuTensor::from_array(array)
42 }
43}
44
45impl AnyComponent for CpuBackend {
46 fn as_any(&self) -> &dyn std::any::Any {
47 self
48 }
49 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
50 self
51 }
52}
53
54impl ConcreteComponent for CpuBackend {
55 const TYPE_NAME: &'static str = "bytesandbrains::backends::cpu::CpuBackend";
56 const PACKAGE: ComponentPackage = ComponentPackage::Framework;
57
58 type Config = ();
59 type Error = std::convert::Infallible;
60
61 fn new(_config: &Self::Config) -> Result<Self, Self::Error> {
62 Ok(Self)
63 }
64
65 fn serialize(&self) -> Vec<u8> {
66 bincode::serialize(self).expect("CpuBackend serde is infallible")
67 }
68
69 fn restore(bytes: &[u8]) -> Result<Self, RestoreError> {
70 bincode::deserialize(bytes).map_err(RestoreError::Malformed)
71 }
72}
73
74#[doc(hidden)]
78fn __cpu_backend_serialize(erased: &dyn bb_ir::component::ErasedComponent) -> Vec<u8> {
79 let any: &dyn std::any::Any = erased;
80 let concrete: &CpuBackend = any
81 .downcast_ref::<CpuBackend>()
82 .expect("inventory downcast: CpuBackend by TYPE_NAME");
83 <CpuBackend as ConcreteComponent>::serialize(concrete)
84}
85#[doc(hidden)]
86fn __cpu_backend_restore(
87 bytes: &[u8],
88) -> Result<Box<dyn bb_ir::component::ErasedComponent>, RestoreError> {
89 <CpuBackend as ConcreteComponent>::restore(bytes)
90 .map(|v| Box::new(v) as Box<dyn bb_ir::component::ErasedComponent>)
91}
92#[doc(hidden)]
93fn __cpu_backend_construct(
94 cfg: &dyn std::any::Any,
95) -> Result<Box<dyn bb_ir::component::ErasedComponent>, bb_ir::component::ConstructError> {
96 let _typed: &() = cfg
97 .downcast_ref::<()>()
98 .ok_or_else(|| bb_ir::component::ConstructError {
99 type_name: "bytesandbrains::backends::cpu::CpuBackend",
100 detail: format!(
101 "config type mismatch: expected `()`, got `{:?}`",
102 cfg.type_id(),
103 ),
104 })?;
105 <CpuBackend as ConcreteComponent>::new(_typed)
106 .map(|v| Box::new(v) as Box<dyn bb_ir::component::ErasedComponent>)
107 .map_err(|e| bb_ir::component::ConstructError {
108 type_name: "bytesandbrains::backends::cpu::CpuBackend",
109 detail: format!("{e}"),
110 })
111}
112bb_ir::registry::inventory::submit! {
113 bb_ir::registry::ConcreteComponentRegistration {
114 type_name: "bytesandbrains::backends::cpu::CpuBackend",
115 package: ComponentPackage::Framework,
116 serialize_fn: __cpu_backend_serialize,
117 restore_fn: __cpu_backend_restore,
118 construct_fn: __cpu_backend_construct,
119 dependencies: &[],
120 }
121}
122
123impl bb_ir::types::Storage for CpuTensor {
125 const TYPE: &'static bb_ir::types::TypeNode = &bb_ir::types::TYPE_TENSOR_F32;
126}
127
128#[cfg(any(test, feature = "test-components"))]
131mod dispatch_counter {
132 use std::cell::Cell;
133 thread_local! {
134 static COUNT: Cell<usize> = const { Cell::new(0) };
135 }
136 pub fn bump() {
137 COUNT.with(|c| c.set(c.get().wrapping_add(1)));
138 }
139 pub fn read() -> usize {
141 COUNT.with(|c| c.get())
142 }
143 pub fn reset() {
145 COUNT.with(|c| c.set(0));
146 }
147}
148
149#[cfg(any(test, feature = "test-components"))]
150pub use dispatch_counter::{read as dispatch_count, reset as reset_dispatch_count};
151
152impl bb_runtime::contracts::Backend for CpuBackend {
156 type Error = graph_walker::BackendError;
157 type Tensor = CpuTensor;
158
159 fn execute(
160 &self,
161 graph: &bb_ir::proto::onnx::GraphProto,
162 inputs: std::collections::HashMap<String, Self::Tensor>,
163 _attrs: bb_runtime::contracts::backend::BackendAttrs<'_>,
164 ) -> Result<std::collections::HashMap<String, Self::Tensor>, Self::Error> {
165 #[cfg(any(test, feature = "test-components"))]
166 dispatch_counter::bump();
167 graph_walker::execute_graph(self, graph, inputs)
168 }
169
170 fn materialize_from_wire(
174 &self,
175 type_hash: u64,
176 bytes: Vec<u8>,
177 ) -> Result<Self::Tensor, Self::Error> {
178 use bb_runtime::contracts::backend_default_walk::BackendWalkError;
179 let expected_hash = bb_ir::slot_value::type_hash_of::<CpuTensor>();
180 if type_hash != expected_hash {
181 return Err(graph_walker::BackendError::DefaultWalker(
182 BackendWalkError::WireMaterializeFailed {
183 type_hash,
184 reason: format!(
185 "expected CpuTensor type_hash {expected_hash:#018x}, got {type_hash:#018x}",
186 ),
187 },
188 ));
189 }
190 let charged_bytes = bytes.len();
191 let wire: CpuTensor = bincode::deserialize(&bytes).map_err(|e| {
194 graph_walker::BackendError::DefaultWalker(BackendWalkError::WireMaterializeFailed {
195 type_hash,
196 reason: format!("bincode decode: {e}"),
197 })
198 })?;
199 Ok(CpuTensor::from_wire_buffer(
202 wire.0.data.clone(),
203 charged_bytes,
204 ))
205 }
206}
207