use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
use crate::tensor::Tensor;
fn apply_unary_via_f64<F>(
x: &Tensor<CpuRuntime>,
device: &CpuDevice,
f: F,
) -> Result<Tensor<CpuRuntime>>
where
F: Fn(f64) -> f64,
{
match x.dtype() {
DType::F32 => {
let data: Vec<f32> = x.to_vec();
let result: Vec<f32> = data.iter().map(|&v| f(v as f64) as f32).collect();
Ok(Tensor::from_slice(&result, x.shape(), device))
}
DType::F64 => {
let data: Vec<f64> = x.to_vec();
let result: Vec<f64> = data.iter().map(|&v| f(v)).collect();
Ok(Tensor::from_slice(&result, x.shape(), device))
}
#[cfg(feature = "f16")]
DType::F16 => {
let data: Vec<half::f16> = x.to_vec();
let result: Vec<half::f16> = data
.iter()
.map(|&v| half::f16::from_f64(f(v.to_f64())))
.collect();
Ok(Tensor::from_slice(&result, x.shape(), device))
}
#[cfg(feature = "f16")]
DType::BF16 => {
let data: Vec<half::bf16> = x.to_vec();
let result: Vec<half::bf16> = data
.iter()
.map(|&v| half::bf16::from_f64(f(v.to_f64())))
.collect();
Ok(Tensor::from_slice(&result, x.shape(), device))
}
#[cfg(feature = "fp8")]
DType::FP8E4M3 => {
let data: Vec<crate::dtype::FP8E4M3> = x.to_vec();
let result: Vec<crate::dtype::FP8E4M3> = data
.iter()
.map(|&v| crate::dtype::FP8E4M3::from_f32(f(v.to_f32() as f64) as f32))
.collect();
Ok(Tensor::from_slice(&result, x.shape(), device))
}
#[cfg(feature = "fp8")]
DType::FP8E5M2 => {
let data: Vec<crate::dtype::FP8E5M2> = x.to_vec();
let result: Vec<crate::dtype::FP8E5M2> = data
.iter()
.map(|&v| crate::dtype::FP8E5M2::from_f32(f(v.to_f32() as f64) as f32))
.collect();
Ok(Tensor::from_slice(&result, x.shape(), device))
}
_ => unreachable!("dtype validated by caller"),
}
}
fn apply_binary_via_f64<F>(
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
device: &CpuDevice,
f: F,
) -> Result<Tensor<CpuRuntime>>
where
F: Fn(f64, f64) -> f64,
{
if a.shape() != b.shape() {
return Err(Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: b.shape().to_vec(),
});
}
match a.dtype() {
DType::F32 => {
let a_data: Vec<f32> = a.to_vec();
let b_data: Vec<f32> = b.to_vec();
let result: Vec<f32> = a_data
.iter()
.zip(b_data.iter())
.map(|(&av, &bv)| f(av as f64, bv as f64) as f32)
.collect();
Ok(Tensor::from_slice(&result, a.shape(), device))
}
DType::F64 => {
let a_data: Vec<f64> = a.to_vec();
let b_data: Vec<f64> = b.to_vec();
let result: Vec<f64> = a_data
.iter()
.zip(b_data.iter())
.map(|(&av, &bv)| f(av, bv))
.collect();
Ok(Tensor::from_slice(&result, a.shape(), device))
}
#[cfg(feature = "f16")]
DType::F16 => {
let a_data: Vec<half::f16> = a.to_vec();
let b_data: Vec<half::f16> = b.to_vec();
let result: Vec<half::f16> = a_data
.iter()
.zip(b_data.iter())
.map(|(&av, &bv)| half::f16::from_f64(f(av.to_f64(), bv.to_f64())))
.collect();
Ok(Tensor::from_slice(&result, a.shape(), device))
}
#[cfg(feature = "f16")]
DType::BF16 => {
let a_data: Vec<half::bf16> = a.to_vec();
let b_data: Vec<half::bf16> = b.to_vec();
let result: Vec<half::bf16> = a_data
.iter()
.zip(b_data.iter())
.map(|(&av, &bv)| half::bf16::from_f64(f(av.to_f64(), bv.to_f64())))
.collect();
Ok(Tensor::from_slice(&result, a.shape(), device))
}
#[cfg(feature = "fp8")]
DType::FP8E4M3 => {
let a_data: Vec<crate::dtype::FP8E4M3> = a.to_vec();
let b_data: Vec<crate::dtype::FP8E4M3> = b.to_vec();
let result: Vec<crate::dtype::FP8E4M3> =
a_data
.iter()
.zip(b_data.iter())
.map(|(&av, &bv)| {
crate::dtype::FP8E4M3::from_f32(
f(av.to_f32() as f64, bv.to_f32() as f64) as f32
)
})
.collect();
Ok(Tensor::from_slice(&result, a.shape(), device))
}
#[cfg(feature = "fp8")]
DType::FP8E5M2 => {
let a_data: Vec<crate::dtype::FP8E5M2> = a.to_vec();
let b_data: Vec<crate::dtype::FP8E5M2> = b.to_vec();
let result: Vec<crate::dtype::FP8E5M2> =
a_data
.iter()
.zip(b_data.iter())
.map(|(&av, &bv)| {
crate::dtype::FP8E5M2::from_f32(
f(av.to_f32() as f64, bv.to_f32() as f64) as f32
)
})
.collect();
Ok(Tensor::from_slice(&result, a.shape(), device))
}
_ => unreachable!("dtype validated by caller"),
}
}
pub fn apply_unary<F>(
x: &Tensor<CpuRuntime>,
device: &CpuDevice,
f: F,
) -> Result<Tensor<CpuRuntime>>
where
F: Fn(f64) -> f64,
{
apply_unary_via_f64(x, device, f)
}
pub fn apply_binary<F>(
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
device: &CpuDevice,
f: F,
) -> Result<Tensor<CpuRuntime>>
where
F: Fn(f64, f64) -> f64,
{
apply_binary_via_f64(a, b, device, f)
}
pub fn apply_ternary<F>(
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
c: &Tensor<CpuRuntime>,
device: &CpuDevice,
f: F,
) -> Result<Tensor<CpuRuntime>>
where
F: Fn(f64, f64, f64) -> f64,
{
if a.shape() != b.shape() {
return Err(Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: b.shape().to_vec(),
});
}
if a.shape() != c.shape() {
return Err(Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: c.shape().to_vec(),
});
}
match a.dtype() {
DType::F32 => {
let a_data: Vec<f32> = a.to_vec();
let b_data: Vec<f32> = b.to_vec();
let c_data: Vec<f32> = c.to_vec();
let result: Vec<f32> = a_data
.iter()
.zip(b_data.iter())
.zip(c_data.iter())
.map(|((&av, &bv), &cv)| f(av as f64, bv as f64, cv as f64) as f32)
.collect();
Ok(Tensor::from_slice(&result, a.shape(), device))
}
DType::F64 => {
let a_data: Vec<f64> = a.to_vec();
let b_data: Vec<f64> = b.to_vec();
let c_data: Vec<f64> = c.to_vec();
let result: Vec<f64> = a_data
.iter()
.zip(b_data.iter())
.zip(c_data.iter())
.map(|((&av, &bv), &cv)| f(av, bv, cv))
.collect();
Ok(Tensor::from_slice(&result, a.shape(), device))
}
_ => unreachable!("dtype validated by caller"),
}
}
pub fn apply_unary_with_int<F>(
x: &Tensor<CpuRuntime>,
device: &CpuDevice,
n: i32,
f: F,
) -> Result<Tensor<CpuRuntime>>
where
F: Fn(i32, f64) -> f64,
{
apply_unary_via_f64(x, device, |v| f(n, v))
}
pub fn apply_unary_with_two_ints<F>(
x: &Tensor<CpuRuntime>,
device: &CpuDevice,
n: i32,
m: i32,
f: F,
) -> Result<Tensor<CpuRuntime>>
where
F: Fn(i32, i32, f64) -> f64,
{
apply_unary_via_f64(x, device, |v| f(n, m, v))
}
pub fn apply_binary_with_two_ints<F>(
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
device: &CpuDevice,
n: i32,
m: i32,
f: F,
) -> Result<Tensor<CpuRuntime>>
where
F: Fn(i32, i32, f64, f64) -> f64,
{
apply_binary_via_f64(a, b, device, |av, bv| f(n, m, av, bv))
}
pub fn apply_unary_with_three_f64s<F>(
z: &Tensor<CpuRuntime>,
device: &CpuDevice,
a: f64,
b: f64,
c: f64,
f: F,
) -> Result<Tensor<CpuRuntime>>
where
F: Fn(f64, f64, f64, f64) -> f64,
{
apply_unary_via_f64(z, device, |v| f(a, b, c, v))
}
pub fn apply_unary_with_two_f64s<F>(
z: &Tensor<CpuRuntime>,
device: &CpuDevice,
a: f64,
b: f64,
f: F,
) -> Result<Tensor<CpuRuntime>>
where
F: Fn(f64, f64, f64) -> f64,
{
apply_unary_via_f64(z, device, |v| f(a, b, v))
}