1use std::sync::Arc;
17
18use serde::de::{self, MapAccess, Visitor};
19use serde::ser::SerializeStruct;
20use serde::{Deserialize, Deserializer, Serialize, Serializer};
21
22use bb_ir::proto::onnx::TensorProto;
23use bb_ir::tensor::{Tensor, TensorSerializationError};
24use bb_ir::types::TYPE_TENSOR_F32;
25use bb_ir::{register_charged_bytes, register_type_node};
26use bb_runtime::slot_value::SlotValue;
27use ndarray::{ArrayD, IxDyn};
28
29register_type_node!(CpuTensor, &TYPE_TENSOR_F32);
30register_charged_bytes!(CpuTensor, |t: &CpuTensor| t.0.charged_bytes);
35
36pub const ONNX_FLOAT: i32 = 1;
38
39#[derive(Debug)]
44pub struct CpuBackendBuffer {
45 pub(crate) data: ArrayD<f32>,
47 pub(crate) dims_i64: Vec<i64>,
50 pub(crate) charged_bytes: usize,
56}
57
58#[derive(Clone, Debug)]
65pub struct CpuTensor(pub(crate) Arc<CpuBackendBuffer>);
66
67#[derive(Debug)]
69pub enum CpuTensorError {
70 ShapeMismatch {
73 expected: usize,
75 got: usize,
77 },
78}
79
80impl std::fmt::Display for CpuTensorError {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 match self {
83 Self::ShapeMismatch { expected, got } => write!(
84 f,
85 "CpuTensor shape mismatch: dims product {expected} ≠ data.len {got}",
86 ),
87 }
88 }
89}
90
91impl std::error::Error for CpuTensorError {}
92
93impl CpuTensor {
94 pub fn from_array(data: ArrayD<f32>) -> Self {
100 let dims_i64 = data.shape().iter().map(|&n| n as i64).collect();
101 Self(Arc::new(CpuBackendBuffer {
102 data,
103 dims_i64,
104 charged_bytes: 0,
105 }))
106 }
107
108 pub fn from_vec(shape: Vec<i64>, data: Vec<f32>) -> Self {
113 Self::new(shape, data)
114 }
115
116 pub fn as_array(&self) -> &ArrayD<f32> {
118 &self.0.data
119 }
120
121 pub fn into_array(self) -> ArrayD<f32> {
127 self.0.data.clone()
128 }
129
130 #[doc(hidden)]
134 pub fn dims_vec(&self) -> &[i64] {
135 &self.0.dims_i64
136 }
137
138 #[doc(hidden)]
142 pub fn flat_data(&self) -> Vec<f32> {
143 self.0.data.iter().copied().collect()
144 }
145
146 pub fn new(dims: Vec<i64>, data: Vec<f32>) -> Self {
150 let shape: Vec<usize> = dims.iter().map(|&d| d.max(0) as usize).collect();
151 let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
152 .expect("CpuTensor::new shape × data mismatch");
153 Self::from_array(array)
154 }
155
156 pub fn new_checked(dims: Vec<i64>, data: Vec<f32>) -> Result<Self, CpuTensorError> {
158 let expected = dims_product(&dims);
159 if expected != data.len() {
160 return Err(CpuTensorError::ShapeMismatch {
161 expected,
162 got: data.len(),
163 });
164 }
165 let shape: Vec<usize> = dims.iter().map(|&d| d.max(0) as usize).collect();
166 let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
167 .map_err(|_| CpuTensorError::ShapeMismatch { expected, got: 0 })?;
168 Ok(Self::from_array(array))
169 }
170
171 pub fn zeros(dims: Vec<i64>) -> Self {
173 let shape: Vec<usize> = dims.iter().map(|&d| d.max(0) as usize).collect();
174 Self::from_array(ArrayD::zeros(IxDyn(&shape)))
175 }
176
177 pub fn ones(dims: Vec<i64>) -> Self {
179 let shape: Vec<usize> = dims.iter().map(|&d| d.max(0) as usize).collect();
180 Self::from_array(ArrayD::ones(IxDyn(&shape)))
181 }
182
183 pub fn strong_count(&self) -> usize {
190 Arc::strong_count(&self.0)
191 }
192
193 pub(crate) fn from_wire_buffer(data: ArrayD<f32>, charged_bytes: usize) -> Self {
199 let dims_i64 = data.shape().iter().map(|&n| n as i64).collect();
200 Self(Arc::new(CpuBackendBuffer {
201 data,
202 dims_i64,
203 charged_bytes,
204 }))
205 }
206}
207
208fn dims_product(dims: &[i64]) -> usize {
209 dims.iter().map(|d| (*d).max(0) as usize).product()
210}
211
212impl std::fmt::Display for CpuTensor {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 write!(
215 f,
216 "CpuTensor(dims={:?}, len={})",
217 self.0.data.shape(),
218 self.0.data.len(),
219 )
220 }
221}
222
223impl Tensor for CpuTensor {
224 type Scalar = f32;
225
226 fn dims(&self) -> &[i64] {
227 &self.0.dims_i64
228 }
229
230 fn len(&self) -> usize {
231 self.0.data.len()
232 }
233
234 fn to_proto(&self) -> TensorProto {
235 let dims: Vec<i64> = self.0.data.shape().iter().map(|&n| n as i64).collect();
236 let float_data: Vec<f32> = self.0.data.iter().copied().collect();
237 TensorProto {
238 dims,
239 data_type: ONNX_FLOAT,
240 float_data,
241 ..Default::default()
242 }
243 }
244
245 fn from_proto(proto: TensorProto) -> Result<Self, TensorSerializationError> {
246 if proto.data_type != ONNX_FLOAT {
247 return Err(TensorSerializationError::ElementTypeMismatch {
248 expected: ONNX_FLOAT,
249 found: proto.data_type,
250 });
251 }
252 let data = if !proto.float_data.is_empty() {
255 proto.float_data
256 } else if !proto.raw_data.is_empty() {
257 if proto.raw_data.len() % 4 != 0 {
258 return Err(TensorSerializationError::ShapeError(format!(
259 "raw_data length {} not divisible by 4",
260 proto.raw_data.len(),
261 )));
262 }
263 let mut out = Vec::with_capacity(proto.raw_data.len() / 4);
264 for chunk in proto.raw_data.chunks_exact(4) {
265 out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
266 }
267 out
268 } else {
269 Vec::new()
270 };
271 let expected = dims_product(&proto.dims);
272 if expected != data.len() {
273 return Err(TensorSerializationError::ShapeError(format!(
274 "dims product {expected} doesn't match data len {len}",
275 len = data.len()
276 )));
277 }
278 let shape: Vec<usize> = proto.dims.iter().map(|&d| d.max(0) as usize).collect();
279 let array = ArrayD::from_shape_vec(IxDyn(&shape), data).map_err(|e| {
280 TensorSerializationError::ShapeError(format!("ndarray::from_shape_vec: {e}"))
281 })?;
282 Ok(Self::from_array(array))
283 }
284}
285
286impl Serialize for CpuTensor {
293 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
294 where
295 S: Serializer,
296 {
297 let mut s = serializer.serialize_struct("CpuTensor", 2)?;
298 s.serialize_field("data", &self.0.data)?;
299 s.serialize_field("dims_i64", &self.0.dims_i64)?;
300 s.end()
301 }
302}
303
304impl<'de> Deserialize<'de> for CpuTensor {
305 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
306 where
307 D: Deserializer<'de>,
308 {
309 #[derive(Deserialize)]
310 #[serde(field_identifier, rename_all = "snake_case")]
311 enum Field {
312 Data,
313 DimsI64,
314 }
315
316 struct CpuTensorVisitor;
317
318 impl<'de> Visitor<'de> for CpuTensorVisitor {
319 type Value = CpuTensor;
320
321 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322 f.write_str("struct CpuTensor")
323 }
324
325 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
326 where
327 A: de::SeqAccess<'de>,
328 {
329 let data: ArrayD<f32> = seq.next_element()?.ok_or_else(|| {
330 de::Error::invalid_length(0, &"struct CpuTensor with 2 fields")
331 })?;
332 let dims_i64: Vec<i64> = seq.next_element()?.ok_or_else(|| {
333 de::Error::invalid_length(1, &"struct CpuTensor with 2 fields")
334 })?;
335 Ok(CpuTensor(Arc::new(CpuBackendBuffer {
336 data,
337 dims_i64,
338 charged_bytes: 0,
339 })))
340 }
341
342 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
343 where
344 A: MapAccess<'de>,
345 {
346 let mut data: Option<ArrayD<f32>> = None;
347 let mut dims_i64: Option<Vec<i64>> = None;
348 while let Some(key) = map.next_key()? {
349 match key {
350 Field::Data => {
351 if data.is_some() {
352 return Err(de::Error::duplicate_field("data"));
353 }
354 data = Some(map.next_value()?);
355 }
356 Field::DimsI64 => {
357 if dims_i64.is_some() {
358 return Err(de::Error::duplicate_field("dims_i64"));
359 }
360 dims_i64 = Some(map.next_value()?);
361 }
362 }
363 }
364 let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
365 let dims_i64 = dims_i64.ok_or_else(|| de::Error::missing_field("dims_i64"))?;
366 Ok(CpuTensor(Arc::new(CpuBackendBuffer {
367 data,
368 dims_i64,
369 charged_bytes: 0,
370 })))
371 }
372 }
373
374 const FIELDS: &[&str] = &["data", "dims_i64"];
375 deserializer.deserialize_struct("CpuTensor", FIELDS, CpuTensorVisitor)
376 }
377}
378
379const _: fn() = || {
384 fn _check<T: SlotValue>() {}
385 _check::<CpuTensor>();
386};
387