use std::sync::Arc;
use serde::de::{self, MapAccess, Visitor};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use bb_ir::proto::onnx::TensorProto;
use bb_ir::tensor::{Tensor, TensorSerializationError};
use bb_ir::types::TYPE_TENSOR_F32;
use bb_ir::{register_charged_bytes, register_type_node};
use bb_runtime::slot_value::SlotValue;
use ndarray::{ArrayD, IxDyn};
register_type_node!(CpuTensor, &TYPE_TENSOR_F32);
register_charged_bytes!(CpuTensor, |t: &CpuTensor| t.0.charged_bytes);
pub const ONNX_FLOAT: i32 = 1;
#[derive(Debug)]
pub struct CpuBackendBuffer {
pub(crate) data: ArrayD<f32>,
pub(crate) dims_i64: Vec<i64>,
pub(crate) charged_bytes: usize,
}
#[derive(Clone, Debug)]
pub struct CpuTensor(pub(crate) Arc<CpuBackendBuffer>);
#[derive(Debug)]
pub enum CpuTensorError {
ShapeMismatch {
expected: usize,
got: usize,
},
}
impl std::fmt::Display for CpuTensorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ShapeMismatch { expected, got } => write!(
f,
"CpuTensor shape mismatch: dims product {expected} ≠ data.len {got}",
),
}
}
}
impl std::error::Error for CpuTensorError {}
impl CpuTensor {
pub fn from_array(data: ArrayD<f32>) -> Self {
let dims_i64 = data.shape().iter().map(|&n| n as i64).collect();
Self(Arc::new(CpuBackendBuffer {
data,
dims_i64,
charged_bytes: 0,
}))
}
pub fn from_vec(shape: Vec<i64>, data: Vec<f32>) -> Self {
Self::new(shape, data)
}
pub fn as_array(&self) -> &ArrayD<f32> {
&self.0.data
}
pub fn into_array(self) -> ArrayD<f32> {
self.0.data.clone()
}
#[doc(hidden)]
pub fn dims_vec(&self) -> &[i64] {
&self.0.dims_i64
}
#[doc(hidden)]
pub fn flat_data(&self) -> Vec<f32> {
self.0.data.iter().copied().collect()
}
pub fn new(dims: Vec<i64>, data: Vec<f32>) -> Self {
let shape: Vec<usize> = dims.iter().map(|&d| d.max(0) as usize).collect();
let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
.expect("CpuTensor::new shape × data mismatch");
Self::from_array(array)
}
pub fn new_checked(dims: Vec<i64>, data: Vec<f32>) -> Result<Self, CpuTensorError> {
let expected = dims_product(&dims);
if expected != data.len() {
return Err(CpuTensorError::ShapeMismatch {
expected,
got: data.len(),
});
}
let shape: Vec<usize> = dims.iter().map(|&d| d.max(0) as usize).collect();
let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
.map_err(|_| CpuTensorError::ShapeMismatch { expected, got: 0 })?;
Ok(Self::from_array(array))
}
pub fn zeros(dims: Vec<i64>) -> Self {
let shape: Vec<usize> = dims.iter().map(|&d| d.max(0) as usize).collect();
Self::from_array(ArrayD::zeros(IxDyn(&shape)))
}
pub fn ones(dims: Vec<i64>) -> Self {
let shape: Vec<usize> = dims.iter().map(|&d| d.max(0) as usize).collect();
Self::from_array(ArrayD::ones(IxDyn(&shape)))
}
pub fn strong_count(&self) -> usize {
Arc::strong_count(&self.0)
}
pub(crate) fn from_wire_buffer(data: ArrayD<f32>, charged_bytes: usize) -> Self {
let dims_i64 = data.shape().iter().map(|&n| n as i64).collect();
Self(Arc::new(CpuBackendBuffer {
data,
dims_i64,
charged_bytes,
}))
}
}
fn dims_product(dims: &[i64]) -> usize {
dims.iter().map(|d| (*d).max(0) as usize).product()
}
impl std::fmt::Display for CpuTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"CpuTensor(dims={:?}, len={})",
self.0.data.shape(),
self.0.data.len(),
)
}
}
impl Tensor for CpuTensor {
type Scalar = f32;
fn dims(&self) -> &[i64] {
&self.0.dims_i64
}
fn len(&self) -> usize {
self.0.data.len()
}
fn to_proto(&self) -> TensorProto {
let dims: Vec<i64> = self.0.data.shape().iter().map(|&n| n as i64).collect();
let float_data: Vec<f32> = self.0.data.iter().copied().collect();
TensorProto {
dims,
data_type: ONNX_FLOAT,
float_data,
..Default::default()
}
}
fn from_proto(proto: TensorProto) -> Result<Self, TensorSerializationError> {
if proto.data_type != ONNX_FLOAT {
return Err(TensorSerializationError::ElementTypeMismatch {
expected: ONNX_FLOAT,
found: proto.data_type,
});
}
let data = if !proto.float_data.is_empty() {
proto.float_data
} else if !proto.raw_data.is_empty() {
if proto.raw_data.len() % 4 != 0 {
return Err(TensorSerializationError::ShapeError(format!(
"raw_data length {} not divisible by 4",
proto.raw_data.len(),
)));
}
let mut out = Vec::with_capacity(proto.raw_data.len() / 4);
for chunk in proto.raw_data.chunks_exact(4) {
out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
out
} else {
Vec::new()
};
let expected = dims_product(&proto.dims);
if expected != data.len() {
return Err(TensorSerializationError::ShapeError(format!(
"dims product {expected} doesn't match data len {len}",
len = data.len()
)));
}
let shape: Vec<usize> = proto.dims.iter().map(|&d| d.max(0) as usize).collect();
let array = ArrayD::from_shape_vec(IxDyn(&shape), data).map_err(|e| {
TensorSerializationError::ShapeError(format!("ndarray::from_shape_vec: {e}"))
})?;
Ok(Self::from_array(array))
}
}
impl Serialize for CpuTensor {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut s = serializer.serialize_struct("CpuTensor", 2)?;
s.serialize_field("data", &self.0.data)?;
s.serialize_field("dims_i64", &self.0.dims_i64)?;
s.end()
}
}
impl<'de> Deserialize<'de> for CpuTensor {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "snake_case")]
enum Field {
Data,
DimsI64,
}
struct CpuTensorVisitor;
impl<'de> Visitor<'de> for CpuTensorVisitor {
type Value = CpuTensor;
fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("struct CpuTensor")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
let data: ArrayD<f32> = seq.next_element()?.ok_or_else(|| {
de::Error::invalid_length(0, &"struct CpuTensor with 2 fields")
})?;
let dims_i64: Vec<i64> = seq.next_element()?.ok_or_else(|| {
de::Error::invalid_length(1, &"struct CpuTensor with 2 fields")
})?;
Ok(CpuTensor(Arc::new(CpuBackendBuffer {
data,
dims_i64,
charged_bytes: 0,
})))
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut data: Option<ArrayD<f32>> = None;
let mut dims_i64: Option<Vec<i64>> = None;
while let Some(key) = map.next_key()? {
match key {
Field::Data => {
if data.is_some() {
return Err(de::Error::duplicate_field("data"));
}
data = Some(map.next_value()?);
}
Field::DimsI64 => {
if dims_i64.is_some() {
return Err(de::Error::duplicate_field("dims_i64"));
}
dims_i64 = Some(map.next_value()?);
}
}
}
let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
let dims_i64 = dims_i64.ok_or_else(|| de::Error::missing_field("dims_i64"))?;
Ok(CpuTensor(Arc::new(CpuBackendBuffer {
data,
dims_i64,
charged_bytes: 0,
})))
}
}
const FIELDS: &[&str] = &["data", "dims_i64"];
deserializer.deserialize_struct("CpuTensor", FIELDS, CpuTensorVisitor)
}
}
const _: fn() = || {
fn _check<T: SlotValue>() {}
_check::<CpuTensor>();
};