use sapient_core::error::{Result, SapientError};
use sapient_core::{DType, Tensor};
fn unary_f32<F: Fn(f32) -> f32>(x: &Tensor, f: F) -> Result<Tensor> {
if x.dtype() != DType::F32 {
return Err(SapientError::TypeMismatch {
expected: "f32".into(),
got: x.dtype().to_string(),
});
}
let data: Vec<f32> = x.to_f32_cow().iter().map(|&v| f(v)).collect();
Tensor::from_f32(&data, x.shape().clone())
}
fn binary_f32<F: Fn(f32, f32) -> f32>(a: &Tensor, b: &Tensor, f: F) -> Result<Tensor> {
let a_cow = a.to_f32_cow();
let a_data = a_cow.as_ref();
let b_cow = b.to_f32_cow();
let b_data = b_cow.as_ref();
let (out, shape) = if a_data.len() == b_data.len() {
let out: Vec<f32> = a_data
.iter()
.zip(b_data.iter())
.map(|(&x, &y)| f(x, y))
.collect();
(out, a.shape().clone())
} else if b_data.len() == 1 {
let scalar = b_data[0];
let out: Vec<f32> = a_data.iter().map(|&x| f(x, scalar)).collect();
(out, a.shape().clone())
} else if a_data.len() == 1 {
let scalar = a_data[0];
let out: Vec<f32> = b_data.iter().map(|&y| f(scalar, y)).collect();
(out, b.shape().clone())
} else {
return Err(SapientError::ShapeMismatch {
expected: a.shape().dims().to_vec(),
got: b.shape().dims().to_vec(),
});
};
Tensor::from_f32(&out, shape)
}
pub fn add(a: &Tensor, b: &Tensor) -> Result<Tensor> {
binary_f32(a, b, |x, y| x + y)
}
pub fn sub(a: &Tensor, b: &Tensor) -> Result<Tensor> {
binary_f32(a, b, |x, y| x - y)
}
pub fn mul(a: &Tensor, b: &Tensor) -> Result<Tensor> {
binary_f32(a, b, |x, y| x * y)
}
pub fn div(a: &Tensor, b: &Tensor) -> Result<Tensor> {
binary_f32(a, b, |x, y| x / y)
}
pub fn pow(a: &Tensor, b: &Tensor) -> Result<Tensor> {
binary_f32(a, b, |x, y| x.powf(y))
}
pub fn neg(x: &Tensor) -> Result<Tensor> {
unary_f32(x, |v| -v)
}
pub fn abs(x: &Tensor) -> Result<Tensor> {
unary_f32(x, |v| v.abs())
}
pub fn sqrt(x: &Tensor) -> Result<Tensor> {
unary_f32(x, |v| v.sqrt())
}
pub fn exp(x: &Tensor) -> Result<Tensor> {
unary_f32(x, |v| v.exp())
}
pub fn log(x: &Tensor) -> Result<Tensor> {
unary_f32(x, |v| v.ln())
}
pub fn erf(x: &Tensor) -> Result<Tensor> {
unary_f32(x, erf_approx)
}
pub fn floor(x: &Tensor) -> Result<Tensor> {
unary_f32(x, |v| v.floor())
}
pub fn ceil(x: &Tensor) -> Result<Tensor> {
unary_f32(x, |v| v.ceil())
}
pub fn round(x: &Tensor) -> Result<Tensor> {
unary_f32(x, |v| v.round())
}
pub fn relu(x: &Tensor) -> Result<Tensor> {
unary_f32(x, |v| v.max(0.0))
}
pub fn sigmoid(x: &Tensor) -> Result<Tensor> {
unary_f32(x, |v| 1.0 / (1.0 + (-v).exp()))
}
pub fn tanh_act(x: &Tensor) -> Result<Tensor> {
unary_f32(x, |v| v.tanh())
}
pub fn gelu(x: &Tensor) -> Result<Tensor> {
const SQRT_2_OVER_PI: f32 = 0.797_884_56;
const COEF: f32 = 0.044_715;
unary_f32(x, |v| {
let inner = SQRT_2_OVER_PI * (v + COEF * v * v * v);
0.5 * v * (1.0 + inner.tanh())
})
}
pub fn silu(x: &Tensor) -> Result<Tensor> {
unary_f32(x, |v| v / (1.0 + (-v).exp()))
}
pub fn hard_swish(x: &Tensor) -> Result<Tensor> {
unary_f32(x, |v| v * (v + 3.0).clamp(0.0, 6.0) / 6.0)
}
pub fn leaky_relu(x: &Tensor, alpha: f32) -> Result<Tensor> {
unary_f32(x, |v| if v >= 0.0 { v } else { alpha * v })
}
pub fn clip(x: &Tensor, min: Option<f32>, max: Option<f32>) -> Result<Tensor> {
unary_f32(x, |v| {
let v = min.map_or(v, |lo| v.max(lo));
max.map_or(v, |hi| v.min(hi))
})
}
fn erf_approx(x: f32) -> f32 {
let sign = x.signum();
let x = x.abs();
let t = 1.0 / (1.0 + 0.327_591_1 * x);
let y = 1.0
- (0.254_829_59
+ (-0.284_496_74 + (1.421_413_74 + (-1.453_152_03 + 1.061_405_43 * t) * t) * t) * t)
* t
* (-x * x).exp();
sign * y
}
#[cfg(test)]
mod tests {
use super::*;
fn t(data: &[f32]) -> Tensor {
Tensor::from_f32(data, vec![data.len()]).unwrap()
}
#[test]
fn test_add() {
assert!(
(add(&t(&[1.0, 2.0]), &t(&[3.0, 4.0]))
.unwrap()
.as_f32_slice()[0]
- 4.0)
.abs()
< 1e-6
);
}
#[test]
fn test_relu() {
let r = relu(&t(&[-1.0, 0.0, 1.0])).unwrap();
let d = r.as_f32_slice();
assert_eq!(d, &[0.0, 0.0, 1.0]);
}
#[test]
fn test_sigmoid() {
let v = sigmoid(&t(&[0.0])).unwrap().as_f32_slice()[0];
assert!((v - 0.5).abs() < 1e-6);
}
#[test]
fn test_gelu() {
let v = gelu(&t(&[0.0])).unwrap().as_f32_slice()[0];
assert!(v.abs() < 1e-5);
}
#[test]
fn test_erf() {
let v = erf_approx(0.0);
assert!(v.abs() < 1e-6, "erf(0) should be ~0, got {v}");
}
#[test]
fn test_scalar_broadcast() {
let a = t(&[1.0, 2.0, 3.0]);
let b = t(&[2.0]);
let r = mul(&a, &b).unwrap();
assert_eq!(r.as_f32_slice(), &[2.0, 4.0, 6.0]);
}
}