use alloc::vec::Vec;
use burn_backend::DType;
use burn_std::{Bytes, bf16, f16};
use crate::layout::StridedBlocks;
use crate::{FlexTensor, Layout};
pub fn unary_op<F32Op, F64Op>(tensor: FlexTensor, f32_op: F32Op, f64_op: F64Op) -> FlexTensor
where
F32Op: Fn(f32) -> f32 + Copy,
F64Op: Fn(f64) -> f64 + Copy,
{
let dtype = tensor.dtype();
match dtype {
DType::F32 => unary_op_typed(tensor, f32_op),
DType::F64 => unary_op_typed(tensor, f64_op),
DType::F16 => unary_op_typed(tensor, |x: f16| f16::from_f32(f32_op(x.to_f32()))),
DType::BF16 => unary_op_typed(tensor, |x: bf16| bf16::from_f32(f32_op(x.to_f32()))),
_ => panic!("unary_op: unsupported dtype {:?}", dtype),
}
}
fn unary_op_typed<E, Op>(mut tensor: FlexTensor, op: Op) -> FlexTensor
where
E: burn_backend::Element + bytemuck::Pod,
Op: Fn(E) -> E,
{
let n = tensor.layout().num_elements();
if tensor.is_unique() && tensor.layout().is_contiguous() && tensor.layout().start_offset() == 0
{
let storage: &mut [E] = tensor.storage_mut();
for x in storage[..n].iter_mut() {
*x = op(*x);
}
return tensor;
}
let layout = tensor.layout().clone();
let src: &[E] = tensor.storage();
let has_negative_strides = layout.strides().iter().any(|&s| s < 0);
if !has_negative_strides && layout.start_offset() == 0 && src.len() == n {
let result: Vec<E> = src.iter().map(|&x| op(x)).collect();
let bytes = Bytes::from_elems(result);
return FlexTensor::new(bytes, layout, E::dtype());
}
if has_negative_strides {
let result: Vec<E> = crate::strided_index::StridedIter::new(&layout)
.map(|idx| op(src[idx]))
.collect();
let bytes = Bytes::from_elems(result);
return FlexTensor::new(
bytes,
Layout::contiguous(layout.shape().clone()),
E::dtype(),
);
}
let result = match layout.strided_blocks() {
StridedBlocks::Single { start, len } => {
src[start..start + len].iter().map(|&x| op(x)).collect()
}
StridedBlocks::Multiple {
block_len,
num_blocks,
..
} => {
let blocks = layout.strided_blocks();
let mut result = Vec::with_capacity(n);
if block_len == 1 {
for block_start in blocks.block_starts() {
result.push(op(src[block_start]));
}
} else {
for block_start in blocks.block_starts() {
for i in 0..block_len {
result.push(op(src[block_start + i]));
}
}
}
debug_assert_eq!(result.len(), num_blocks * block_len);
result
}
};
let bytes = Bytes::from_elems(result);
FlexTensor::new(
bytes,
Layout::contiguous(layout.shape().clone()),
E::dtype(),
)
}
pub fn exp(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::exp, f64::exp)
}
pub fn log(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::ln, f64::ln)
}
pub fn log1p(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::ln_1p, f64::ln_1p)
}
pub fn sqrt(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::sqrt, f64::sqrt)
}
pub fn abs(tensor: FlexTensor) -> FlexTensor {
#[cfg(feature = "simd")]
if tensor.dtype() == DType::F32
&& tensor.is_unique()
&& tensor.layout().is_contiguous()
&& tensor.layout().start_offset() == 0
{
let n = tensor.layout().num_elements();
let mut tensor = tensor;
let storage: &mut [f32] = tensor.storage_mut();
crate::simd::abs_inplace_f32(&mut storage[..n]);
return tensor;
}
unary_op(tensor, f32::abs, f64::abs)
}
pub fn int_abs(tensor: FlexTensor) -> FlexTensor {
let dtype = tensor.dtype();
match dtype {
DType::I64 => unary_op_typed::<i64, _>(tensor, |x| x.wrapping_abs()),
DType::I32 => unary_op_typed::<i32, _>(tensor, |x| x.wrapping_abs()),
DType::I16 => unary_op_typed::<i16, _>(tensor, |x| x.wrapping_abs()),
DType::I8 => unary_op_typed::<i8, _>(tensor, |x| x.wrapping_abs()),
DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor,
_ => panic!("int_abs: unsupported dtype {:?}", dtype),
}
}
pub fn recip(tensor: FlexTensor) -> FlexTensor {
#[cfg(feature = "simd")]
if tensor.dtype() == DType::F32
&& tensor.is_unique()
&& tensor.layout().is_contiguous()
&& tensor.layout().start_offset() == 0
{
let n = tensor.layout().num_elements();
let mut tensor = tensor;
let storage: &mut [f32] = tensor.storage_mut();
crate::simd::recip_inplace_f32(&mut storage[..n]);
return tensor;
}
unary_op(tensor, |x| 1.0 / x, |x| 1.0 / x)
}
pub fn cos(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::cos, f64::cos)
}
pub fn sin(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::sin, f64::sin)
}
pub fn tan(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::tan, f64::tan)
}
pub fn cosh(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::cosh, f64::cosh)
}
pub fn sinh(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::sinh, f64::sinh)
}
pub fn tanh(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::tanh, f64::tanh)
}
pub fn acos(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::acos, f64::acos)
}
pub fn acosh(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::acosh, f64::acosh)
}
pub fn asin(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::asin, f64::asin)
}
pub fn asinh(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::asinh, f64::asinh)
}
pub fn atan(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::atan, f64::atan)
}
pub fn atanh(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::atanh, f64::atanh)
}
pub fn round(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, round_ties_even_f32, round_ties_even_f64)
}
fn round_ties_even_f32(x: f32) -> f32 {
x.round_ties_even()
}
fn round_ties_even_f64(x: f64) -> f64 {
x.round_ties_even()
}
pub fn floor(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::floor, f64::floor)
}
pub fn ceil(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::ceil, f64::ceil)
}
pub fn trunc(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, f32::trunc, f64::trunc)
}
pub fn erf(tensor: FlexTensor) -> FlexTensor {
unary_op(tensor, erf_f32, erf_f64)
}
pub fn erf_f32(x: f32) -> f32 {
let a1 = 0.254_829_6_f32;
let a2 = -0.284_496_72_f32;
let a3 = 1.421_413_8_f32;
let a4 = -1.453_152_1_f32;
let a5 = 1.061_405_4_f32;
let p = 0.3275911_f32;
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs();
let t = 1.0 / (1.0 + p * x);
let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
sign * y
}
pub fn erf_f64(x: f64) -> f64 {
let a1 = 0.254829592_f64;
let a2 = -0.284496736_f64;
let a3 = 1.421413741_f64;
let a4 = -1.453152027_f64;
let a5 = 1.061405429_f64;
let p = 0.3275911_f64;
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs();
let t = 1.0 / (1.0 + p * x);
let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
sign * y
}
#[cfg(test)]
mod tests {
use super::*;
use burn_backend::{TensorData, Tolerance};
fn tensor_from_vec(data: Vec<f32>) -> FlexTensor {
let shape = burn_std::Shape::from(vec![data.len()]);
FlexTensor::from_data(TensorData::new(data, shape.to_vec()))
}
#[test]
fn test_exp() {
let tensor = tensor_from_vec(vec![0.0, 1.0, 2.0]);
let result = exp(tensor);
let e = std::f32::consts::E;
result.into_data().assert_approx_eq::<f32>(
&TensorData::from([1.0, e, e.powi(2)]),
Tolerance::absolute(1e-5),
);
}
#[test]
fn test_log() {
let tensor = tensor_from_vec(vec![1.0, std::f32::consts::E, std::f32::consts::E.powi(2)]);
let result = log(tensor);
result.into_data().assert_approx_eq::<f32>(
&TensorData::from([0.0, 1.0, 2.0]),
Tolerance::absolute(1e-5),
);
}
#[test]
fn test_sqrt() {
let tensor = tensor_from_vec(vec![0.0, 1.0, 4.0, 9.0]);
let result = sqrt(tensor);
result.into_data().assert_approx_eq::<f32>(
&TensorData::from([0.0, 1.0, 2.0, 3.0]),
Tolerance::absolute(1e-5),
);
}
#[test]
fn test_abs() {
let tensor = tensor_from_vec(vec![-3.0, -1.0, 0.0, 1.0, 3.0]);
let result = abs(tensor);
result.into_data().assert_approx_eq::<f32>(
&TensorData::from([3.0, 1.0, 0.0, 1.0, 3.0]),
Tolerance::absolute(1e-5),
);
}
#[test]
fn test_sin_cos() {
let tensor = tensor_from_vec(vec![0.0, std::f32::consts::FRAC_PI_2, std::f32::consts::PI]);
sin(tensor.clone()).into_data().assert_approx_eq::<f32>(
&TensorData::from([0.0, 1.0, 0.0]),
Tolerance::absolute(1e-5),
);
cos(tensor).into_data().assert_approx_eq::<f32>(
&TensorData::from([1.0, 0.0, -1.0]),
Tolerance::absolute(1e-5),
);
}
#[test]
fn test_tanh() {
let tensor = tensor_from_vec(vec![-2.0, 0.0, 2.0]);
let result = tanh(tensor);
result.into_data().assert_approx_eq::<f32>(
&TensorData::from([(-2.0f32).tanh(), 0.0, 2.0f32.tanh()]),
Tolerance::absolute(1e-5),
);
}
#[test]
fn test_round_floor_ceil() {
let tensor = tensor_from_vec(vec![-1.5, -0.5, 0.5, 1.5]);
round(tensor.clone()).into_data().assert_approx_eq::<f32>(
&TensorData::from([-2.0, 0.0, 0.0, 2.0]),
Tolerance::absolute(1e-5),
);
floor(tensor.clone()).into_data().assert_approx_eq::<f32>(
&TensorData::from([-2.0, -1.0, 0.0, 1.0]),
Tolerance::absolute(1e-5),
);
ceil(tensor).into_data().assert_approx_eq::<f32>(
&TensorData::from([-1.0, 0.0, 1.0, 2.0]),
Tolerance::absolute(1e-5),
);
}
#[test]
fn test_erf() {
let tensor = tensor_from_vec(vec![0.0, 0.5, 1.0, 2.0]);
let result = erf(tensor);
result.into_data().assert_approx_eq::<f32>(
&TensorData::from([0.0, 0.5205, 0.8427, 0.9953]),
Tolerance::absolute(1e-3),
);
}
fn tensor_2d(data: Vec<f32>, rows: usize, cols: usize) -> FlexTensor {
FlexTensor::from_data(TensorData::new(data, vec![rows, cols]))
}
#[test]
fn test_exp_transposed() {
let tensor = tensor_2d(vec![0.0, 1.0, 2.0, 3.0], 2, 2);
let transposed = tensor.transpose(0, 1);
assert!(!transposed.is_contiguous());
let e = std::f32::consts::E;
exp(transposed).into_data().assert_approx_eq::<f32>(
&TensorData::new(vec![1.0, e * e, e, e * e * e], vec![2, 2]),
Tolerance::absolute(1e-5),
);
}
#[test]
fn test_sqrt_narrowed() {
let tensor = tensor_from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0, 36.0]);
let narrowed = tensor.narrow(0, 1, 4);
assert!(!narrowed.is_contiguous() || narrowed.layout().start_offset() != 0);
sqrt(narrowed).into_data().assert_approx_eq::<f32>(
&TensorData::from([2.0, 3.0, 4.0, 5.0]),
Tolerance::absolute(1e-5),
);
}
#[test]
fn test_abs_flipped() {
let tensor = tensor_from_vec(vec![1.0, -2.0, 3.0, -4.0]);
let flipped = crate::ops::flip::flip(tensor, &[0]);
assert!(flipped.layout().strides()[0] < 0);
abs(flipped).into_data().assert_approx_eq::<f32>(
&TensorData::from([4.0, 3.0, 2.0, 1.0]),
Tolerance::absolute(1e-5),
);
}
#[test]
fn test_sqrt_flipped_2d() {
let tensor = tensor_2d(vec![1.0, 4.0, 9.0, 16.0], 2, 2);
let flipped = crate::ops::flip::flip(tensor, &[0]);
assert!(flipped.layout().strides()[0] < 0);
sqrt(flipped).into_data().assert_approx_eq::<f32>(
&TensorData::new(vec![3.0, 4.0, 1.0, 2.0], vec![2, 2]),
Tolerance::absolute(1e-5),
);
}
#[test]
fn test_cos_flipped_axis1() {
use std::f32::consts::{FRAC_PI_2, PI};
let tensor = tensor_2d(vec![0.0, PI, FRAC_PI_2, 3.0 * FRAC_PI_2], 2, 2);
let flipped = crate::ops::flip::flip(tensor, &[1]);
assert!(flipped.layout().strides()[1] < 0);
cos(flipped).into_data().assert_approx_eq::<f32>(
&TensorData::new(vec![-1.0, 1.0, 0.0, 0.0], vec![2, 2]),
Tolerance::absolute(1e-5),
);
}
#[test]
fn test_sin_step_sliced() {
let tensor = FlexTensor::from_data(TensorData::new(
vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
vec![1, 8],
));
let sliced = crate::ops::slice::slice(
tensor,
&[
burn_backend::Slice {
start: 0,
end: None,
step: 1,
},
burn_backend::Slice {
start: 0,
end: None,
step: 2,
},
],
);
assert_eq!(sliced.layout().shape().to_vec(), vec![1, 4]);
assert_eq!(sliced.layout().strides()[1], 2);
sliced.clone().into_data().assert_approx_eq::<f32>(
&TensorData::new(vec![0.0, 2.0, 4.0, 6.0], vec![1, 4]),
Tolerance::absolute(1e-6),
);
let expected: Vec<f32> = [0.0f32, 2.0, 4.0, 6.0].iter().map(|x| x.sin()).collect();
sin(sliced).into_data().assert_approx_eq::<f32>(
&TensorData::new(expected, vec![1, 4]),
Tolerance::absolute(1e-6),
);
}
#[test]
fn test_cos_step_sliced_3d() {
let vals: Vec<f32> = (0..12).map(|i| i as f32 * 0.5).collect();
let tensor = FlexTensor::from_data(TensorData::new(vals, vec![1, 2, 6]));
let sliced = crate::ops::slice::slice(
tensor,
&[
burn_backend::Slice {
start: 0,
end: None,
step: 1,
},
burn_backend::Slice {
start: 0,
end: None,
step: 1,
},
burn_backend::Slice {
start: 0,
end: None,
step: 2,
},
],
);
assert_eq!(sliced.layout().shape().to_vec(), vec![1, 2, 3]);
sliced.clone().into_data().assert_approx_eq::<f32>(
&TensorData::new(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 2, 3]),
Tolerance::absolute(1e-6),
);
let expected: Vec<f32> = [0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
.iter()
.map(|x| x.cos())
.collect();
cos(sliced).into_data().assert_approx_eq::<f32>(
&TensorData::new(expected, vec![1, 2, 3]),
Tolerance::absolute(1e-6),
);
}
#[test]
fn test_log_3d_transposed() {
let e = std::f32::consts::E;
let data = vec![1.0, e, e * e, e * e * e, 1.0, e, e * e, e * e * e];
let tensor = FlexTensor::from_data(TensorData::new(data, vec![2, 2, 2]));
let permuted = tensor.permute(&[2, 0, 1]); assert!(!permuted.is_contiguous());
let result = log(permuted);
let out: Vec<f32> = result.into_data().to_vec().unwrap();
for &v in &out {
assert!(v >= -0.01 && v <= 3.01, "unexpected log value: {}", v);
}
}
#[test]
fn test_round_ties_even_large_float() {
let data = vec![2e18_f32, -2e18_f32, f32::MAX, f32::MIN];
let tensor = tensor_from_vec(data.clone());
let result = round(tensor);
let out: Vec<f32> = result.into_data().to_vec().unwrap();
for (a, b) in out.iter().zip(data.iter()) {
assert_eq!(a.to_bits(), b.to_bits());
}
}
}