pub use crate::backend::f16_convert::{f16_to_f32, f32_to_f16};
pub fn tensor_to_fp16_bits(t: &crate::object::Tensor<f32>) -> Vec<u16> {
t.data.iter().copied().map(f32_to_f16).collect()
}
pub fn fp16_bits_to_tensor(
bits: &[u16],
shape: crate::object::Shape,
domain: crate::domain::DomainId,
) -> crate::object::Tensor<f32> {
let data = bits.iter().copied().map(f16_to_f32).collect();
crate::object::Tensor::dense_cpu(domain, shape, data)
}
pub fn tensor_full(
shape: crate::object::Shape,
value: f32,
domain: crate::domain::DomainId,
) -> crate::object::Tensor<f32> {
let n: usize = shape
.dims
.iter()
.map(|d| match d {
crate::object::Dim::Static(v) => *v,
_ => 0,
})
.product();
crate::object::Tensor::dense_cpu(domain, shape, vec![value; n])
}
pub fn tensor_add(
a: &crate::object::Tensor<f32>,
b: &crate::object::Tensor<f32>,
) -> Result<crate::object::Tensor<f32>, crate::Error> {
if a.data.len() != b.data.len() {
return Err(crate::Error::shape(format!(
"tensor_add length mismatch: {} vs {}",
a.data.len(),
b.data.len()
)));
}
let data: Vec<f32> = a
.data
.iter()
.zip(b.data.iter())
.map(|(x, y)| x + y)
.collect();
Ok(crate::object::Tensor::dense_cpu(
a.meta.domain.clone(),
a.meta.shape.clone(),
data,
))
}