use std::fmt::Debug;
use burn::{
Tensor,
prelude::{
Backend,
Shape,
},
tensor::{
BasicOps,
DType,
},
};
use crate::{
burner::descriptors::{
ParamDesc,
ParamKindBinding,
TensorKindDesc,
},
errors::{
BunsenError,
BunsenResult,
},
};
#[derive(Debug, Clone, PartialEq)]
pub struct TensorDesc {
kind: TensorKindDesc,
dtype: DType,
shape: Shape,
}
impl<B, const R: usize, K> From<&Tensor<B, R, K>> for TensorDesc
where
B: Backend,
K: BasicOps<B>,
K: ParamKindBinding,
{
fn from(param: &Tensor<B, R, K>) -> Self {
Self {
kind: TensorKindDesc::for_kind::<K>(),
dtype: param.dtype(),
shape: param.shape(),
}
}
}
pub fn dtype_from_str(dtype: &str) -> BunsenResult<DType> {
Ok(match dtype {
"F64" => DType::F64,
"F32" => DType::F32,
"Flex32" => DType::Flex32,
"F16" => DType::F16,
"BF16" => DType::BF16,
"I64" => DType::I64,
"I32" => DType::I32,
"I16" => DType::I16,
"I8" => DType::I8,
"U64" => DType::U64,
"U32" => DType::U32,
"U16" => DType::U16,
"U8" => DType::U8,
_ => return Err(BunsenError::External(format!("Invalid dtype: {}", dtype))),
})
}
impl TensorDesc {
pub fn new(
kind: TensorKindDesc,
dtype: DType,
shape: Shape,
) -> Self {
Self { kind, dtype, shape }
}
pub fn kind(&self) -> TensorKindDesc {
self.kind
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn rank(&self) -> usize {
self.shape.rank()
}
pub fn num_elements(&self) -> usize {
self.shape.num_elements()
}
pub fn size_estimate(&self) -> usize {
self.dtype.size() * self.num_elements()
}
}
pub type TensorParamDesc = ParamDesc<TensorDesc>;
#[cfg(test)]
#[allow(unused)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "cuda")]
fn test_tensor_desc() {
type B = burn::backend::Cuda;
let device = Default::default();
{
let tensor: Tensor<B, 2> = Tensor::ones([2, 3], &device);
let desc = TensorDesc::from(&tensor);
assert_eq!(desc.kind, TensorKindDesc::Float);
assert_eq!(desc.dtype, DType::F32);
assert_eq!(desc.shape, Shape::new([2, 3]));
assert_eq!(desc.rank(), 2);
assert_eq!(desc.size_estimate(), DType::F32.size() * 2 * 3);
}
{
let tensor: Tensor<B, 2, burn::tensor::Bool> = Tensor::zeros([2, 3], &device);
let desc = TensorDesc::from(&tensor);
assert_eq!(desc.kind, TensorKindDesc::Bool);
assert_eq!(desc.dtype, tensor.dtype());
assert_eq!(desc.shape, Shape::new([2, 3]));
assert_eq!(desc.rank(), 2);
assert_eq!(desc.size_estimate(), tensor.dtype().size() * 2 * 3);
}
{
let tensor: Tensor<B, 2, burn::tensor::Int> = Tensor::zeros([2, 3], &device);
let desc = TensorDesc::from(&tensor);
assert_eq!(desc.kind, TensorKindDesc::Int);
assert_eq!(desc.dtype, tensor.dtype());
assert_eq!(desc.shape, Shape::new([2, 3]));
assert_eq!(desc.rank(), 2);
assert_eq!(desc.size_estimate(), tensor.dtype().size() * 2 * 3);
}
}
}