use std::path::PathBuf;
use std::sync::Arc;
use serde::de::{self, Deserializer, SeqAccess, Visitor};
use ferrotorch_core::grad_fns::arithmetic::{abs, add, div, mul, neg, pow, sqrt, sub};
use ferrotorch_core::grad_fns::comparison::{where_, where_bt};
use ferrotorch_core::grad_fns::reduction::sum as grad_sum;
use ferrotorch_core::ops::elementwise::{
binary_map, fast_add, fast_cos, fast_div, fast_exp, fast_log, fast_mul, fast_sigmoid, fast_sin,
fast_sub, fast_tanh, logsumexp, logsumexp_dim, mean, nanmean, nansum, scalar_map, simd_add_f32,
simd_add_f64, simd_exp_f32, simd_exp_f64, simd_log_f32, simd_mul_f32, simd_mul_f64,
simd_sqrt_f32, sum, sum_axis, unary_map,
};
use ferrotorch_core::{BoolTensor, Device, Tensor, TensorStorage};
use serde::Deserialize;
mod tolerance {
pub const F32_ELEMENTWISE: f32 = 1e-6;
pub const F64_ELEMENTWISE: f64 = 1e-12;
pub const F32_TRANSCENDENTAL: f32 = 1e-5;
pub const F64_TRANSCENDENTAL: f64 = 1e-10;
pub const F32_FAST: f32 = 1e-5;
pub const F64_FAST: f64 = 1e-10;
pub const F32_SIMD: f32 = 1e-6;
pub const F64_SIMD: f64 = 1e-12;
pub const F32_REDUCTION: f32 = 1e-5;
pub const F64_REDUCTION: f64 = 1e-12;
#[allow(dead_code, reason = "used by `gpu` cfg-gated module")]
pub const F32_GPU: f32 = 1e-5;
#[allow(dead_code, reason = "used by `gpu` cfg-gated module")]
pub const F64_GPU: f64 = 1e-10;
pub fn assert_close_f32(actual: &[f32], expected: &[f32], tol: f32, label: &str) {
assert_eq!(
actual.len(),
expected.len(),
"{label}: length mismatch (actual={}, expected={})",
actual.len(),
expected.len()
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
if a.is_nan() && e.is_nan() {
continue;
}
if !a.is_finite() || !e.is_finite() {
if a.to_bits() == e.to_bits() {
continue;
}
if a.is_infinite() && e.is_infinite() && a.signum() == e.signum() {
continue;
}
panic!("{label}: index {i} non-finite mismatch (actual={a}, expected={e})");
}
let diff = (a - e).abs();
let scale = e.abs().max(1.0);
let allowed = tol * scale;
assert!(
diff <= allowed,
"{label}: index {i} delta {diff:.3e} exceeds tol {tol:.3e} \
(actual={a}, expected={e})"
);
}
}
pub fn assert_close_f64(actual: &[f64], expected: &[f64], tol: f64, label: &str) {
assert_eq!(
actual.len(),
expected.len(),
"{label}: length mismatch (actual={}, expected={})",
actual.len(),
expected.len()
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
if a.is_nan() && e.is_nan() {
continue;
}
if !a.is_finite() || !e.is_finite() {
if a.to_bits() == e.to_bits() {
continue;
}
if a.is_infinite() && e.is_infinite() && a.signum() == e.signum() {
continue;
}
panic!("{label}: index {i} non-finite mismatch (actual={a}, expected={e})");
}
let diff = (a - e).abs();
let scale = e.abs().max(1.0);
let allowed = tol * scale;
assert!(
diff <= allowed,
"{label}: index {i} delta {diff:.3e} exceeds tol {tol:.3e} \
(actual={a}, expected={e})"
);
}
}
}
#[derive(Debug)]
struct F64ListSentinel(Vec<f64>);
impl F64ListSentinel {
fn as_slice(&self) -> &[f64] {
&self.0
}
}
struct FloatOrSentinel(f64);
struct FloatOrSentinelVisitor;
impl<'de> Visitor<'de> for FloatOrSentinelVisitor {
type Value = FloatOrSentinel;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str("an f64 or one of \"Infinity\"/\"-Infinity\"/\"NaN\"")
}
fn visit_f64<E: de::Error>(self, v: f64) -> Result<Self::Value, E> {
Ok(FloatOrSentinel(v))
}
fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
Ok(FloatOrSentinel(v as f64))
}
fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
Ok(FloatOrSentinel(v as f64))
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
match v {
"Infinity" => Ok(FloatOrSentinel(f64::INFINITY)),
"-Infinity" => Ok(FloatOrSentinel(f64::NEG_INFINITY)),
"NaN" => Ok(FloatOrSentinel(f64::NAN)),
other => Err(E::custom(format!("unexpected float sentinel {other:?}"))),
}
}
}
impl<'de> serde::Deserialize<'de> for FloatOrSentinel {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
deserializer.deserialize_any(FloatOrSentinelVisitor)
}
}
struct F64ListSentinelVisitor;
impl<'de> Visitor<'de> for F64ListSentinelVisitor {
type Value = F64ListSentinel;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str("a list of floats with optional Infinity/-Infinity/NaN sentinels")
}
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
let mut out = Vec::new();
while let Some(FloatOrSentinel(v)) = seq.next_element()? {
out.push(v);
}
Ok(F64ListSentinel(out))
}
}
impl<'de> serde::Deserialize<'de> for F64ListSentinel {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
deserializer.deserialize_seq(F64ListSentinelVisitor)
}
}
#[derive(Debug, Deserialize)]
struct FixtureFile {
#[allow(dead_code, reason = "metadata used for diagnostics")]
metadata: FixtureMetadata,
fixtures: Vec<Fixture>,
}
#[derive(Debug, Deserialize)]
struct FixtureMetadata {
#[allow(dead_code, reason = "diagnostics only")]
torch_version: String,
#[allow(dead_code, reason = "diagnostics only")]
cuda_version: Option<String>,
#[allow(dead_code, reason = "consumed by `gpu` cfg-gated module")]
cuda_available: bool,
#[allow(dead_code, reason = "diagnostics only")]
python_executable: String,
#[allow(dead_code, reason = "diagnostics only")]
python_platform: String,
#[allow(dead_code, reason = "diagnostics only")]
generated_at: String,
#[allow(dead_code, reason = "diagnostics only")]
rng_seed: u64,
}
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
struct Fixture {
op: String,
#[serde(default)]
tag: Option<String>,
dtype: String,
device: String,
#[serde(default)]
a_shape: Option<Vec<usize>>,
#[serde(default)]
b_shape: Option<Vec<usize>>,
#[serde(default)]
#[allow(
dead_code,
reason = "deserialized for fixture-shape stability and future shape-checks"
)]
out_shape: Option<Vec<usize>>,
#[serde(default)]
a_data: Option<F64ListSentinel>,
#[serde(default)]
b_data: Option<F64ListSentinel>,
#[serde(default)]
out_values: Option<F64ListSentinel>,
#[serde(default)]
grad_a: Option<F64ListSentinel>,
#[serde(default)]
grad_b: Option<F64ListSentinel>,
#[serde(default)]
grad_x: Option<F64ListSentinel>,
#[serde(default)]
grad_y: Option<F64ListSentinel>,
#[serde(default)]
cond: Option<Vec<bool>>,
#[serde(default)]
x_shape: Option<Vec<usize>>,
#[serde(default)]
y_shape: Option<Vec<usize>>,
#[serde(default)]
x_data: Option<F64ListSentinel>,
#[serde(default)]
y_data: Option<F64ListSentinel>,
#[serde(default)]
exp: Option<f64>,
#[serde(default)]
scalar: Option<f64>,
#[serde(default)]
axis: Option<usize>,
#[serde(default)]
keepdim: Option<bool>,
#[serde(default)]
min: Option<f64>,
#[serde(default)]
max: Option<f64>,
}
fn load_fixtures() -> FixtureFile {
let p = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("conformance")
.join("fixtures")
.join("elementwise.json");
let bytes = std::fs::read(&p).unwrap_or_else(|e| {
panic!(
"read {} failed: {e}. Regenerate via \
`python3 scripts/regenerate_elementwise_fixtures.py`",
p.display()
)
});
serde_json::from_slice(&bytes).unwrap_or_else(|e| panic!("parse {}: {e}", p.display()))
}
fn cases_for<'a>(file: &'a FixtureFile, op: &str, device: &str) -> Vec<&'a Fixture> {
file.fixtures
.iter()
.filter(|f| f.op == op && f.device == device)
.collect()
}
fn read_back_f32(t: &Tensor<f32>) -> Vec<f32> {
if t.is_cpu() {
t.data().expect("read CPU data").to_vec()
} else {
let cpu = t.cpu().expect("D2H readback");
cpu.data().expect("read CPU data after readback").to_vec()
}
}
fn read_back_f64(t: &Tensor<f64>) -> Vec<f64> {
if t.is_cpu() {
t.data().expect("read CPU data").to_vec()
} else {
let cpu = t.cpu().expect("D2H readback");
cpu.data().expect("read CPU data after readback").to_vec()
}
}
fn make_cpu_f32(data: &[f64], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
let v: Vec<f32> = data.iter().map(|&x| x as f32).collect();
Tensor::from_storage(TensorStorage::cpu(v), shape.to_vec(), requires_grad)
.expect("make_cpu_f32")
}
fn make_cpu_f64(data: &[f64], shape: &[usize], requires_grad: bool) -> Tensor<f64> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
shape.to_vec(),
requires_grad,
)
.expect("make_cpu_f64")
}
fn upload_f32(t: Tensor<f32>, device: Device) -> Tensor<f32> {
if matches!(device, Device::Cuda(_)) {
t.to(device).expect("upload to cuda")
} else {
t
}
}
fn upload_f64(t: Tensor<f64>, device: Device) -> Tensor<f64> {
if matches!(device, Device::Cuda(_)) {
t.to(device).expect("upload to cuda")
} else {
t
}
}
fn check_f32(label: &str, actual: &[f32], expected: &[f64], tol: f32) {
let exp_f32: Vec<f32> = expected.iter().map(|&x| x as f32).collect();
tolerance::assert_close_f32(actual, &exp_f32, tol, label);
}
fn check_f64(label: &str, actual: &[f64], expected: &[f64], tol: f64) {
tolerance::assert_close_f64(actual, expected, tol, label);
}
fn run_binary_cpu(op_name: &str, op: BinaryOp) {
let file = load_fixtures();
let cases = cases_for(&file, op_name, "cpu");
assert!(
!cases.is_empty(),
"no CPU fixtures for op {op_name:?} — regenerate elementwise.json"
);
for f in cases {
let label = format!("{op_name} cpu tag={:?} dtype={}", f.tag, f.dtype);
let a_shape = f.a_shape.as_ref().expect("a_shape");
let b_shape = f.b_shape.as_ref().expect("b_shape");
let a_data = f
.a_data
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("a_data");
let b_data = f
.b_data
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("b_data");
let expected = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("out_values");
let grad_a_exp = f
.grad_a
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("grad_a");
let grad_b_exp = f
.grad_b
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("grad_b");
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, a_shape, false);
let b = make_cpu_f32(b_data, b_shape, false);
let c = op.apply_f32(&a, &b);
let actual = read_back_f32(&c);
check_f32(
&format!("{label} fwd"),
&actual,
expected,
tolerance::F32_TRANSCENDENTAL,
);
let a_g = make_cpu_f32(a_data, a_shape, true);
let b_g = make_cpu_f32(b_data, b_shape, true);
let out = op.apply_f32(&a_g, &b_g);
let loss = grad_sum(&out).expect("sum");
loss.backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
let gb = b_g.grad().unwrap().expect("grad_b");
check_f32(
&format!("{label} grad_a"),
&read_back_f32(&ga),
grad_a_exp,
tolerance::F32_TRANSCENDENTAL,
);
check_f32(
&format!("{label} grad_b"),
&read_back_f32(&gb),
grad_b_exp,
tolerance::F32_TRANSCENDENTAL,
);
}
"float64" => {
let a = make_cpu_f64(a_data, a_shape, false);
let b = make_cpu_f64(b_data, b_shape, false);
let c = op.apply_f64(&a, &b);
check_f64(
&format!("{label} fwd"),
&read_back_f64(&c),
expected,
tolerance::F64_TRANSCENDENTAL,
);
let a_g = make_cpu_f64(a_data, a_shape, true);
let b_g = make_cpu_f64(b_data, b_shape, true);
let out = op.apply_f64(&a_g, &b_g);
let loss = grad_sum(&out).expect("sum");
loss.backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
let gb = b_g.grad().unwrap().expect("grad_b");
check_f64(
&format!("{label} grad_a"),
&read_back_f64(&ga),
grad_a_exp,
tolerance::F64_TRANSCENDENTAL,
);
check_f64(
&format!("{label} grad_b"),
&read_back_f64(&gb),
grad_b_exp,
tolerance::F64_TRANSCENDENTAL,
);
}
other => panic!("unhandled dtype {other:?}"),
}
}
}
#[derive(Clone, Copy)]
enum BinaryOp {
Add,
Sub,
Mul,
Div,
}
impl BinaryOp {
fn apply_f32(self, a: &Tensor<f32>, b: &Tensor<f32>) -> Tensor<f32> {
match self {
BinaryOp::Add => add(a, b).expect("add"),
BinaryOp::Sub => sub(a, b).expect("sub"),
BinaryOp::Mul => mul(a, b).expect("mul"),
BinaryOp::Div => div(a, b).expect("div"),
}
}
fn apply_f64(self, a: &Tensor<f64>, b: &Tensor<f64>) -> Tensor<f64> {
match self {
BinaryOp::Add => add(a, b).expect("add"),
BinaryOp::Sub => sub(a, b).expect("sub"),
BinaryOp::Mul => mul(a, b).expect("mul"),
BinaryOp::Div => div(a, b).expect("div"),
}
}
}
#[test]
fn cpu_add() {
run_binary_cpu("add", BinaryOp::Add);
}
#[test]
fn cpu_sub() {
run_binary_cpu("sub", BinaryOp::Sub);
}
#[test]
fn cpu_mul() {
run_binary_cpu("mul", BinaryOp::Mul);
}
#[test]
fn cpu_div() {
run_binary_cpu("div", BinaryOp::Div);
}
#[derive(Clone, Copy)]
enum UnaryOp {
Neg,
Abs,
Sqrt,
}
impl UnaryOp {
fn apply_f32(self, a: &Tensor<f32>) -> Tensor<f32> {
match self {
UnaryOp::Neg => neg(a).expect("neg"),
UnaryOp::Abs => abs(a).expect("abs"),
UnaryOp::Sqrt => sqrt(a).expect("sqrt"),
}
}
fn apply_f64(self, a: &Tensor<f64>) -> Tensor<f64> {
match self {
UnaryOp::Neg => neg(a).expect("neg"),
UnaryOp::Abs => abs(a).expect("abs"),
UnaryOp::Sqrt => sqrt(a).expect("sqrt"),
}
}
}
fn run_unary_cpu(op_name: &str, op: UnaryOp) {
let file = load_fixtures();
let cases = cases_for(&file, op_name, "cpu");
assert!(
!cases.is_empty(),
"no CPU fixtures for op {op_name:?} — regenerate elementwise.json"
);
for f in cases {
let label = format!("{op_name} cpu tag={:?} dtype={}", f.tag, f.dtype);
let shape = f.a_shape.as_ref().expect("a_shape");
let a_data = f
.a_data
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("a_data");
let expected = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("out_values");
let grad_a_exp = f
.grad_a
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("grad_a");
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, shape, false);
let c = op.apply_f32(&a);
check_f32(
&format!("{label} fwd"),
&read_back_f32(&c),
expected,
tolerance::F32_TRANSCENDENTAL,
);
let a_g = make_cpu_f32(a_data, shape, true);
let out = op.apply_f32(&a_g);
let loss = grad_sum(&out).expect("sum");
loss.backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
check_f32(
&format!("{label} grad_a"),
&read_back_f32(&ga),
grad_a_exp,
tolerance::F32_TRANSCENDENTAL,
);
}
"float64" => {
let a = make_cpu_f64(a_data, shape, false);
let c = op.apply_f64(&a);
check_f64(
&format!("{label} fwd"),
&read_back_f64(&c),
expected,
tolerance::F64_TRANSCENDENTAL,
);
let a_g = make_cpu_f64(a_data, shape, true);
let out = op.apply_f64(&a_g);
let loss = grad_sum(&out).expect("sum");
loss.backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
check_f64(
&format!("{label} grad_a"),
&read_back_f64(&ga),
grad_a_exp,
tolerance::F64_TRANSCENDENTAL,
);
}
_ => unreachable!(),
}
}
}
#[test]
fn cpu_neg() {
run_unary_cpu("neg", UnaryOp::Neg);
}
#[test]
fn cpu_abs() {
run_unary_cpu("abs", UnaryOp::Abs);
}
#[test]
fn cpu_sqrt() {
run_unary_cpu("sqrt", UnaryOp::Sqrt);
}
#[test]
fn cpu_pow() {
let file = load_fixtures();
let cases = cases_for(&file, "pow", "cpu");
assert!(!cases.is_empty(), "no CPU fixtures for pow");
for f in cases {
let label = format!("pow cpu tag={:?} dtype={} exp={:?}", f.tag, f.dtype, f.exp);
let shape = f.a_shape.as_ref().expect("a_shape");
let a_data = f
.a_data
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("a_data");
let exp = f.exp.expect("exp");
let expected = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("out_values");
let grad_a_exp = f
.grad_a
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("grad_a");
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, shape, false);
let c = pow(&a, exp).expect("pow");
check_f32(
&format!("{label} fwd"),
&read_back_f32(&c),
expected,
tolerance::F32_TRANSCENDENTAL,
);
let a_g = make_cpu_f32(a_data, shape, true);
let out = pow(&a_g, exp).expect("pow");
grad_sum(&out).expect("sum").backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
check_f32(
&format!("{label} grad_a"),
&read_back_f32(&ga),
grad_a_exp,
tolerance::F32_TRANSCENDENTAL,
);
}
"float64" => {
let a = make_cpu_f64(a_data, shape, false);
let c = pow(&a, exp).expect("pow");
check_f64(
&format!("{label} fwd"),
&read_back_f64(&c),
expected,
tolerance::F64_TRANSCENDENTAL,
);
let a_g = make_cpu_f64(a_data, shape, true);
let out = pow(&a_g, exp).expect("pow");
grad_sum(&out).expect("sum").backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
check_f64(
&format!("{label} grad_a"),
&read_back_f64(&ga),
grad_a_exp,
tolerance::F64_TRANSCENDENTAL,
);
}
_ => unreachable!(),
}
}
}
#[test]
fn cpu_edge_cases() {
let file = load_fixtures();
for f in cases_for(&file, "pow_zero_exp", "cpu") {
let label = format!("pow_zero_exp cpu dtype={}", f.dtype);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, shape, false);
let c = pow(&a, 0.0).expect("pow");
check_f32(&label, &read_back_f32(&c), exp, tolerance::F32_ELEMENTWISE);
}
"float64" => {
let a = make_cpu_f64(a_data, shape, false);
let c = pow(&a, 0.0).expect("pow");
check_f64(&label, &read_back_f64(&c), exp, tolerance::F64_ELEMENTWISE);
}
_ => unreachable!(),
}
}
for f in cases_for(&file, "sqrt_zero", "cpu") {
let label = format!("sqrt_zero cpu dtype={}", f.dtype);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, shape, false);
let c = sqrt(&a).expect("sqrt");
check_f32(&label, &read_back_f32(&c), exp, tolerance::F32_ELEMENTWISE);
}
"float64" => {
let a = make_cpu_f64(a_data, shape, false);
let c = sqrt(&a).expect("sqrt");
check_f64(&label, &read_back_f64(&c), exp, tolerance::F64_ELEMENTWISE);
}
_ => unreachable!(),
}
}
for f in cases_for(&file, "div_zero", "cpu") {
let label = format!("div_zero cpu dtype={}", f.dtype);
let a_shape = f.a_shape.as_ref().unwrap();
let b_shape = f.b_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let b_data = f.b_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, a_shape, false);
let b = make_cpu_f32(b_data, b_shape, false);
let c = div(&a, &b).expect("div");
check_f32(&label, &read_back_f32(&c), exp, tolerance::F32_ELEMENTWISE);
}
"float64" => {
let a = make_cpu_f64(a_data, a_shape, false);
let b = make_cpu_f64(b_data, b_shape, false);
let c = div(&a, &b).expect("div");
check_f64(&label, &read_back_f64(&c), exp, tolerance::F64_ELEMENTWISE);
}
_ => unreachable!(),
}
}
}
fn run_where_for_device(device_label: &str, device: Device) {
let file = load_fixtures();
let cases = cases_for(&file, "where", device_label);
assert!(!cases.is_empty(), "no fixtures for where on {device_label}");
for f in cases {
let label = format!("where {device_label} dtype={}", f.dtype);
let cond = f.cond.as_ref().expect("cond").clone();
let x_shape = f.x_shape.as_ref().expect("x_shape");
let y_shape = f.y_shape.as_ref().expect("y_shape");
let x_data = f
.x_data
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("x_data");
let y_data = f
.y_data
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("y_data");
let expected = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("out_values");
let grad_x_exp = f
.grad_x
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("grad_x");
let grad_y_exp = f
.grad_y
.as_ref()
.map(F64ListSentinel::as_slice)
.expect("grad_y");
match f.dtype.as_str() {
"float32" => {
let x = upload_f32(make_cpu_f32(x_data, x_shape, false), device);
let y = upload_f32(make_cpu_f32(y_data, y_shape, false), device);
let out = where_(&cond, &x, &y).expect("where_");
check_f32(
&format!("{label} where_ fwd"),
&read_back_f32(&out),
expected,
tolerance::F32_ELEMENTWISE,
);
let cond_bt = BoolTensor::from_vec(cond.clone(), x_shape.clone()).expect("bt");
let out_bt = where_bt(&cond_bt, &x, &y).expect("where_bt");
check_f32(
&format!("{label} where_bt fwd"),
&read_back_f32(&out_bt),
expected,
tolerance::F32_ELEMENTWISE,
);
let x_g = upload_f32(make_cpu_f32(x_data, x_shape, true), device);
let y_g = upload_f32(make_cpu_f32(y_data, y_shape, true), device);
let out_g = where_(&cond, &x_g, &y_g).expect("where_ grad");
grad_sum(&out_g).expect("sum").backward().expect("backward");
let gx = x_g.grad().unwrap().expect("grad_x");
let gy = y_g.grad().unwrap().expect("grad_y");
check_f32(
&format!("{label} grad_x"),
&read_back_f32(&gx),
grad_x_exp,
tolerance::F32_ELEMENTWISE,
);
check_f32(
&format!("{label} grad_y"),
&read_back_f32(&gy),
grad_y_exp,
tolerance::F32_ELEMENTWISE,
);
}
"float64" => {
let x = upload_f64(make_cpu_f64(x_data, x_shape, false), device);
let y = upload_f64(make_cpu_f64(y_data, y_shape, false), device);
let out = where_(&cond, &x, &y).expect("where_");
check_f64(
&format!("{label} where_ fwd"),
&read_back_f64(&out),
expected,
tolerance::F64_ELEMENTWISE,
);
let cond_bt = BoolTensor::from_vec(cond.clone(), x_shape.clone()).expect("bt");
let out_bt = where_bt(&cond_bt, &x, &y).expect("where_bt");
check_f64(
&format!("{label} where_bt fwd"),
&read_back_f64(&out_bt),
expected,
tolerance::F64_ELEMENTWISE,
);
let x_g = upload_f64(make_cpu_f64(x_data, x_shape, true), device);
let y_g = upload_f64(make_cpu_f64(y_data, y_shape, true), device);
let out_g = where_(&cond, &x_g, &y_g).expect("where_ grad");
grad_sum(&out_g).expect("sum").backward().expect("backward");
let gx = x_g.grad().unwrap().expect("grad_x");
let gy = y_g.grad().unwrap().expect("grad_y");
check_f64(
&format!("{label} grad_x"),
&read_back_f64(&gx),
grad_x_exp,
tolerance::F64_ELEMENTWISE,
);
check_f64(
&format!("{label} grad_y"),
&read_back_f64(&gy),
grad_y_exp,
tolerance::F64_ELEMENTWISE,
);
}
_ => unreachable!(),
}
}
}
#[test]
fn cpu_where() {
run_where_for_device("cpu", Device::Cpu);
}
#[test]
fn cpu_binary_map_higher_order() {
let file = load_fixtures();
let cases = cases_for(&file, "binary_map_maxmin", "cpu");
assert!(!cases.is_empty(), "no fixtures for binary_map_maxmin");
for f in cases {
let label = format!("binary_map_maxmin cpu dtype={}", f.dtype);
let a_shape = f.a_shape.as_ref().unwrap();
let b_shape = f.b_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let b_data = f.b_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, a_shape, false);
let b = make_cpu_f32(b_data, b_shape, false);
let out = binary_map(&a, &b, |x, y| x.max(y) - x.min(y)).expect("binary_map");
check_f32(
&label,
&read_back_f32(&out),
exp,
tolerance::F32_ELEMENTWISE,
);
}
"float64" => {
let a = make_cpu_f64(a_data, a_shape, false);
let b = make_cpu_f64(b_data, b_shape, false);
let out = binary_map(&a, &b, |x, y| x.max(y) - x.min(y)).expect("binary_map");
check_f64(
&label,
&read_back_f64(&out),
exp,
tolerance::F64_ELEMENTWISE,
);
}
_ => unreachable!(),
}
}
}
#[test]
fn cpu_scalar_map_higher_order() {
let file = load_fixtures();
let cases = cases_for(&file, "scalar_map_sqplus", "cpu");
assert!(!cases.is_empty(), "no fixtures for scalar_map_sqplus");
for f in cases {
let label = format!("scalar_map_sqplus cpu dtype={}", f.dtype);
let a_shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let scalar = f.scalar.expect("scalar");
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, a_shape, false);
let s = scalar as f32;
let out = scalar_map(&a, s, |x, s| x * x + s).expect("scalar_map");
check_f32(
&label,
&read_back_f32(&out),
exp,
tolerance::F32_ELEMENTWISE,
);
}
"float64" => {
let a = make_cpu_f64(a_data, a_shape, false);
let out = scalar_map(&a, scalar, |x, s| x * x + s).expect("scalar_map");
check_f64(
&label,
&read_back_f64(&out),
exp,
tolerance::F64_ELEMENTWISE,
);
}
_ => unreachable!(),
}
}
}
#[test]
fn cpu_unary_map_higher_order() {
let file = load_fixtures();
let cases = cases_for(&file, "unary_map_tan", "cpu");
assert!(!cases.is_empty(), "no fixtures for unary_map_tan");
for f in cases {
let label = format!("unary_map_tan cpu dtype={}", f.dtype);
let a_shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, a_shape, false);
let out = unary_map(&a, |x| x.tan()).expect("unary_map");
check_f32(
&label,
&read_back_f32(&out),
exp,
tolerance::F32_TRANSCENDENTAL,
);
}
"float64" => {
let a = make_cpu_f64(a_data, a_shape, false);
let out = unary_map(&a, |x| x.tan()).expect("unary_map");
check_f64(
&label,
&read_back_f64(&out),
exp,
tolerance::F64_TRANSCENDENTAL,
);
}
_ => unreachable!(),
}
}
}
fn check_canon_consistency_f32(label: &str, fast_actual: &[f32], canon_actual: &[f32], tol: f32) {
tolerance::assert_close_f32(fast_actual, canon_actual, tol, label);
}
fn check_canon_consistency_f64(label: &str, fast_actual: &[f64], canon_actual: &[f64], tol: f64) {
tolerance::assert_close_f64(fast_actual, canon_actual, tol, label);
}
#[test]
fn cpu_fast_binary_ops() {
let file = load_fixtures();
for op_name in ["fast_add", "fast_sub", "fast_mul", "fast_div"] {
let cases = cases_for(&file, op_name, "cpu");
assert!(!cases.is_empty(), "no fixtures for {op_name}");
for f in cases {
let label = format!("{op_name} cpu dtype={}", f.dtype);
let a_shape = f.a_shape.as_ref().unwrap();
let b_shape = f.b_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let b_data = f.b_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, a_shape, false);
let b = make_cpu_f32(b_data, b_shape, false);
let fast_t = match op_name {
"fast_add" => fast_add(&a, &b).unwrap(),
"fast_sub" => fast_sub(&a, &b).unwrap(),
"fast_mul" => fast_mul(&a, &b).unwrap(),
"fast_div" => fast_div(&a, &b).unwrap(),
_ => unreachable!(),
};
let canon_t = match op_name {
"fast_add" => add(&a, &b).unwrap(),
"fast_sub" => sub(&a, &b).unwrap(),
"fast_mul" => mul(&a, &b).unwrap(),
"fast_div" => div(&a, &b).unwrap(),
_ => unreachable!(),
};
let fa = read_back_f32(&fast_t);
let ca = read_back_f32(&canon_t);
check_f32(
&format!("{label} parity-vs-torch"),
&fa,
exp,
tolerance::F32_FAST,
);
check_canon_consistency_f32(
&format!("{label} parity-vs-canonical"),
&fa,
&ca,
tolerance::F32_FAST,
);
}
"float64" => {
let a = make_cpu_f64(a_data, a_shape, false);
let b = make_cpu_f64(b_data, b_shape, false);
let fast_t = match op_name {
"fast_add" => fast_add(&a, &b).unwrap(),
"fast_sub" => fast_sub(&a, &b).unwrap(),
"fast_mul" => fast_mul(&a, &b).unwrap(),
"fast_div" => fast_div(&a, &b).unwrap(),
_ => unreachable!(),
};
let canon_t = match op_name {
"fast_add" => add(&a, &b).unwrap(),
"fast_sub" => sub(&a, &b).unwrap(),
"fast_mul" => mul(&a, &b).unwrap(),
"fast_div" => div(&a, &b).unwrap(),
_ => unreachable!(),
};
let fa = read_back_f64(&fast_t);
let ca = read_back_f64(&canon_t);
check_f64(
&format!("{label} parity-vs-torch"),
&fa,
exp,
tolerance::F64_FAST,
);
check_canon_consistency_f64(
&format!("{label} parity-vs-canonical"),
&fa,
&ca,
tolerance::F64_FAST,
);
}
_ => unreachable!(),
}
}
}
}
#[test]
fn cpu_fast_unary_ops() {
let file = load_fixtures();
for op_name in [
"fast_exp",
"fast_log",
"fast_sigmoid",
"fast_tanh",
"fast_sin",
"fast_cos",
] {
let cases = cases_for(&file, op_name, "cpu");
assert!(!cases.is_empty(), "no fixtures for {op_name}");
for f in cases {
let label = format!("{op_name} cpu dtype={}", f.dtype);
let a_shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, a_shape, false);
let (fast_t, canon_t): (Tensor<f32>, Tensor<f32>) = match op_name {
"fast_exp" => (
fast_exp(&a).unwrap(),
unary_map(&a, |x: f32| x.exp()).unwrap(),
),
"fast_log" => (
fast_log(&a).unwrap(),
unary_map(&a, |x: f32| x.ln()).unwrap(),
),
"fast_sigmoid" => (
fast_sigmoid(&a).unwrap(),
unary_map(&a, |x: f32| 1.0 / (1.0 + (-x).exp())).unwrap(),
),
"fast_tanh" => (
fast_tanh(&a).unwrap(),
unary_map(&a, |x: f32| x.tanh()).unwrap(),
),
"fast_sin" => (
fast_sin(&a).unwrap(),
unary_map(&a, |x: f32| x.sin()).unwrap(),
),
"fast_cos" => (
fast_cos(&a).unwrap(),
unary_map(&a, |x: f32| x.cos()).unwrap(),
),
_ => unreachable!(),
};
let fa = read_back_f32(&fast_t);
let ca = read_back_f32(&canon_t);
check_f32(
&format!("{label} parity-vs-torch"),
&fa,
exp,
tolerance::F32_FAST,
);
check_canon_consistency_f32(
&format!("{label} parity-vs-canonical"),
&fa,
&ca,
tolerance::F32_FAST,
);
}
"float64" => {
let a = make_cpu_f64(a_data, a_shape, false);
let (fast_t, canon_t): (Tensor<f64>, Tensor<f64>) = match op_name {
"fast_exp" => (
fast_exp(&a).unwrap(),
unary_map(&a, |x: f64| x.exp()).unwrap(),
),
"fast_log" => (
fast_log(&a).unwrap(),
unary_map(&a, |x: f64| x.ln()).unwrap(),
),
"fast_sigmoid" => (
fast_sigmoid(&a).unwrap(),
unary_map(&a, |x: f64| 1.0 / (1.0 + (-x).exp())).unwrap(),
),
"fast_tanh" => (
fast_tanh(&a).unwrap(),
unary_map(&a, |x: f64| x.tanh()).unwrap(),
),
"fast_sin" => (
fast_sin(&a).unwrap(),
unary_map(&a, |x: f64| x.sin()).unwrap(),
),
"fast_cos" => (
fast_cos(&a).unwrap(),
unary_map(&a, |x: f64| x.cos()).unwrap(),
),
_ => unreachable!(),
};
let fa = read_back_f64(&fast_t);
let ca = read_back_f64(&canon_t);
check_f64(
&format!("{label} parity-vs-torch"),
&fa,
exp,
tolerance::F64_FAST,
);
check_canon_consistency_f64(
&format!("{label} parity-vs-canonical"),
&fa,
&ca,
tolerance::F64_FAST,
);
}
_ => unreachable!(),
}
}
}
}
#[test]
fn cpu_simd_ops() {
let file = load_fixtures();
for op_name in [
"simd_add_f32",
"simd_mul_f32",
"simd_exp_f32",
"simd_log_f32",
"simd_sqrt_f32",
] {
let cases = cases_for(&file, op_name, "cpu");
assert!(!cases.is_empty(), "no fixtures for {op_name}");
for f in cases {
let label = format!("{op_name} cpu");
let a_shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
let a = make_cpu_f32(a_data, a_shape, false);
let (simd_t, canon_t): (Tensor<f32>, Tensor<f32>) = match op_name {
"simd_add_f32" => {
let b_shape = f.b_shape.as_ref().unwrap();
let b_data = f.b_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let b = make_cpu_f32(b_data, b_shape, false);
(simd_add_f32(&a, &b).unwrap(), add(&a, &b).unwrap())
}
"simd_mul_f32" => {
let b_shape = f.b_shape.as_ref().unwrap();
let b_data = f.b_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let b = make_cpu_f32(b_data, b_shape, false);
(simd_mul_f32(&a, &b).unwrap(), mul(&a, &b).unwrap())
}
"simd_exp_f32" => (
simd_exp_f32(&a).unwrap(),
unary_map(&a, |x: f32| x.exp()).unwrap(),
),
"simd_log_f32" => (
simd_log_f32(&a).unwrap(),
unary_map(&a, |x: f32| x.ln()).unwrap(),
),
"simd_sqrt_f32" => (simd_sqrt_f32(&a).unwrap(), sqrt(&a).unwrap()),
_ => unreachable!(),
};
let sa = read_back_f32(&simd_t);
let ca = read_back_f32(&canon_t);
check_f32(
&format!("{label} parity-vs-torch"),
&sa,
exp,
tolerance::F32_SIMD,
);
check_canon_consistency_f32(
&format!("{label} parity-vs-canonical"),
&sa,
&ca,
tolerance::F32_SIMD,
);
}
}
for op_name in ["simd_add_f64", "simd_mul_f64", "simd_exp_f64"] {
let cases = cases_for(&file, op_name, "cpu");
assert!(!cases.is_empty(), "no fixtures for {op_name}");
for f in cases {
let label = format!("{op_name} cpu");
let a_shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
let a = make_cpu_f64(a_data, a_shape, false);
let (simd_t, canon_t): (Tensor<f64>, Tensor<f64>) = match op_name {
"simd_add_f64" => {
let b_shape = f.b_shape.as_ref().unwrap();
let b_data = f.b_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let b = make_cpu_f64(b_data, b_shape, false);
(simd_add_f64(&a, &b).unwrap(), add(&a, &b).unwrap())
}
"simd_mul_f64" => {
let b_shape = f.b_shape.as_ref().unwrap();
let b_data = f.b_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let b = make_cpu_f64(b_data, b_shape, false);
(simd_mul_f64(&a, &b).unwrap(), mul(&a, &b).unwrap())
}
"simd_exp_f64" => (
simd_exp_f64(&a).unwrap(),
unary_map(&a, |x: f64| x.exp()).unwrap(),
),
_ => unreachable!(),
};
let sa = read_back_f64(&simd_t);
let ca = read_back_f64(&canon_t);
check_f64(
&format!("{label} parity-vs-torch"),
&sa,
exp,
tolerance::F64_SIMD,
);
check_canon_consistency_f64(
&format!("{label} parity-vs-canonical"),
&sa,
&ca,
tolerance::F64_SIMD,
);
}
}
}
#[test]
fn cpu_sum() {
let file = load_fixtures();
let cases = cases_for(&file, "sum", "cpu");
assert!(!cases.is_empty(), "no fixtures for sum");
for f in cases {
let label = format!("sum cpu dtype={}", f.dtype);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
let grad_a_exp = f.grad_a.as_ref().map(F64ListSentinel::as_slice).unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, shape, false);
let s = sum(&a).expect("sum");
check_f32(
&format!("{label} fwd"),
&read_back_f32(&s),
exp,
tolerance::F32_REDUCTION,
);
let a_g = make_cpu_f32(a_data, shape, true);
let s_g = ferrotorch_core::grad_fns::reduction::sum(&a_g).expect("grad sum");
s_g.backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
check_f32(
&format!("{label} grad_a"),
&read_back_f32(&ga),
grad_a_exp,
tolerance::F32_ELEMENTWISE,
);
}
"float64" => {
let a = make_cpu_f64(a_data, shape, false);
let s = sum(&a).expect("sum");
check_f64(
&format!("{label} fwd"),
&read_back_f64(&s),
exp,
tolerance::F64_REDUCTION,
);
let a_g = make_cpu_f64(a_data, shape, true);
let s_g = ferrotorch_core::grad_fns::reduction::sum(&a_g).expect("grad sum");
s_g.backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
check_f64(
&format!("{label} grad_a"),
&read_back_f64(&ga),
grad_a_exp,
tolerance::F64_ELEMENTWISE,
);
}
_ => unreachable!(),
}
}
}
#[test]
fn cpu_sum_axis() {
let file = load_fixtures();
let cases = cases_for(&file, "sum_axis", "cpu");
assert!(!cases.is_empty(), "no fixtures for sum_axis");
for f in cases {
let label = format!("sum_axis cpu axis={:?} dtype={}", f.axis, f.dtype);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let axis = f.axis.expect("axis");
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, shape, false);
let s = sum_axis(&a, axis).expect("sum_axis");
check_f32(&label, &read_back_f32(&s), exp, tolerance::F32_REDUCTION);
}
"float64" => {
let a = make_cpu_f64(a_data, shape, false);
let s = sum_axis(&a, axis).expect("sum_axis");
check_f64(&label, &read_back_f64(&s), exp, tolerance::F64_REDUCTION);
}
_ => unreachable!(),
}
}
}
#[test]
fn cpu_mean() {
let file = load_fixtures();
let cases = cases_for(&file, "mean", "cpu");
assert!(!cases.is_empty(), "no fixtures for mean");
for f in cases {
let label = format!("mean cpu dtype={}", f.dtype);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
let grad_a_exp = f.grad_a.as_ref().map(F64ListSentinel::as_slice).unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, shape, false);
let m = mean(&a).expect("mean");
check_f32(
&format!("{label} fwd"),
&read_back_f32(&m),
exp,
tolerance::F32_REDUCTION,
);
let a_g = make_cpu_f32(a_data, shape, true);
let s_g = ferrotorch_core::grad_fns::reduction::sum(&a_g).expect("sum");
s_g.backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
let n = a_g.numel() as f32;
let scaled: Vec<f32> = read_back_f32(&ga).iter().map(|&x| x / n).collect();
check_f32(
&format!("{label} grad_a"),
&scaled,
grad_a_exp,
tolerance::F32_ELEMENTWISE,
);
}
"float64" => {
let a = make_cpu_f64(a_data, shape, false);
let m = mean(&a).expect("mean");
check_f64(
&format!("{label} fwd"),
&read_back_f64(&m),
exp,
tolerance::F64_REDUCTION,
);
let a_g = make_cpu_f64(a_data, shape, true);
let s_g = ferrotorch_core::grad_fns::reduction::sum(&a_g).expect("sum");
s_g.backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
let n = a_g.numel() as f64;
let scaled: Vec<f64> = read_back_f64(&ga).iter().map(|&x| x / n).collect();
check_f64(
&format!("{label} grad_a"),
&scaled,
grad_a_exp,
tolerance::F64_ELEMENTWISE,
);
}
_ => unreachable!(),
}
}
}
#[test]
fn cpu_nansum() {
let file = load_fixtures();
let cases = cases_for(&file, "nansum", "cpu");
assert!(!cases.is_empty(), "no fixtures for nansum");
for f in cases {
let label = format!("nansum cpu dtype={}", f.dtype);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, shape, false);
let s = nansum(&a).expect("nansum");
check_f32(&label, &read_back_f32(&s), exp, tolerance::F32_REDUCTION);
}
"float64" => {
let a = make_cpu_f64(a_data, shape, false);
let s = nansum(&a).expect("nansum");
check_f64(&label, &read_back_f64(&s), exp, tolerance::F64_REDUCTION);
}
_ => unreachable!(),
}
}
}
#[test]
fn cpu_nanmean() {
let file = load_fixtures();
let cases = cases_for(&file, "nanmean", "cpu");
assert!(!cases.is_empty(), "no fixtures for nanmean");
for f in cases {
let label = format!("nanmean cpu dtype={}", f.dtype);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, shape, false);
let m = nanmean(&a).expect("nanmean");
check_f32(&label, &read_back_f32(&m), exp, tolerance::F32_REDUCTION);
}
"float64" => {
let a = make_cpu_f64(a_data, shape, false);
let m = nanmean(&a).expect("nanmean");
check_f64(&label, &read_back_f64(&m), exp, tolerance::F64_REDUCTION);
}
_ => unreachable!(),
}
}
}
#[test]
fn cpu_logsumexp() {
let file = load_fixtures();
let cases = cases_for(&file, "logsumexp", "cpu");
assert!(!cases.is_empty(), "no fixtures for logsumexp");
for f in cases {
let label = format!("logsumexp cpu tag={:?} dtype={}", f.tag, f.dtype);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, shape, false);
let l = logsumexp(&a).expect("logsumexp");
check_f32(
&label,
&read_back_f32(&l),
exp,
tolerance::F32_TRANSCENDENTAL,
);
}
"float64" => {
let a = make_cpu_f64(a_data, shape, false);
let l = logsumexp(&a).expect("logsumexp");
check_f64(
&label,
&read_back_f64(&l),
exp,
tolerance::F64_TRANSCENDENTAL,
);
}
_ => unreachable!(),
}
}
}
#[test]
fn cpu_logsumexp_dim() {
let file = load_fixtures();
let cases = cases_for(&file, "logsumexp_dim", "cpu");
assert!(!cases.is_empty(), "no fixtures for logsumexp_dim");
for f in cases {
let label = format!(
"logsumexp_dim cpu tag={:?} dtype={} axis={:?} keepdim={:?}",
f.tag, f.dtype, f.axis, f.keepdim
);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let axis = f.axis.expect("axis");
let keepdim = f.keepdim.expect("keepdim");
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = make_cpu_f32(a_data, shape, false);
let l = logsumexp_dim(&a, axis, keepdim).expect("logsumexp_dim");
check_f32(
&label,
&read_back_f32(&l),
exp,
tolerance::F32_TRANSCENDENTAL,
);
}
"float64" => {
let a = make_cpu_f64(a_data, shape, false);
let l = logsumexp_dim(&a, axis, keepdim).expect("logsumexp_dim");
check_f64(
&label,
&read_back_f64(&l),
exp,
tolerance::F64_TRANSCENDENTAL,
);
}
_ => unreachable!(),
}
}
}
fn storage_arc_id<T: ferrotorch_core::Float>(t: &Tensor<T>) -> *const TensorStorage<T> {
Arc::as_ptr(t.inner_storage_arc())
}
#[test]
fn cpu_inplace_add_sub_mul_div() {
let file = load_fixtures();
for op_name in ["add_", "sub_", "mul_", "div_"] {
let cases = cases_for(&file, op_name, "cpu");
assert!(!cases.is_empty(), "no fixtures for {op_name}");
for f in cases {
let label = format!("{op_name} cpu dtype={}", f.dtype);
let a_shape = f.a_shape.as_ref().unwrap();
let b_shape = f.b_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let b_data = f.b_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let t = make_cpu_f32(a_data, a_shape, false);
let other = make_cpu_f32(b_data, b_shape, false);
let before_id = storage_arc_id(&t);
match op_name {
"add_" => {
t.add_(&other).expect("add_");
}
"sub_" => {
t.sub_(&other).expect("sub_");
}
"mul_" => {
t.mul_(&other).expect("mul_");
}
"div_" => {
t.div_(&other).expect("div_");
}
_ => unreachable!(),
}
let after_id = storage_arc_id(&t);
assert_eq!(
before_id, after_id,
"{label}: in-place op replaced the storage Arc — \
mutation must happen through the existing Arc"
);
check_f32(
&format!("{label} value"),
&read_back_f32(&t),
exp,
tolerance::F32_ELEMENTWISE,
);
}
"float64" => {
let t = make_cpu_f64(a_data, a_shape, false);
let other = make_cpu_f64(b_data, b_shape, false);
let before_id = storage_arc_id(&t);
match op_name {
"add_" => {
t.add_(&other).expect("add_");
}
"sub_" => {
t.sub_(&other).expect("sub_");
}
"mul_" => {
t.mul_(&other).expect("mul_");
}
"div_" => {
t.div_(&other).expect("div_");
}
_ => unreachable!(),
}
let after_id = storage_arc_id(&t);
assert_eq!(before_id, after_id, "{label}: storage Arc replaced");
check_f64(
&format!("{label} value"),
&read_back_f64(&t),
exp,
tolerance::F64_ELEMENTWISE,
);
}
_ => unreachable!(),
}
}
}
}
#[test]
fn cpu_inplace_scalar_ops() {
let file = load_fixtures();
for (op_name, _is_add) in [("add_scalar_", true), ("mul_scalar_", false)] {
let cases = cases_for(&file, op_name, "cpu");
assert!(!cases.is_empty(), "no fixtures for {op_name}");
for f in cases {
let label = format!("{op_name} cpu dtype={}", f.dtype);
let a_shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let scalar = f.scalar.expect("scalar");
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let t = make_cpu_f32(a_data, a_shape, false);
let before_id = storage_arc_id(&t);
match op_name {
"add_scalar_" => {
t.add_scalar_(scalar as f32).expect("add_scalar_");
}
"mul_scalar_" => {
t.mul_scalar_(scalar as f32).expect("mul_scalar_");
}
_ => unreachable!(),
}
let after_id = storage_arc_id(&t);
assert_eq!(before_id, after_id, "{label}: storage Arc replaced");
check_f32(
&format!("{label} value"),
&read_back_f32(&t),
exp,
tolerance::F32_ELEMENTWISE,
);
}
"float64" => {
let t = make_cpu_f64(a_data, a_shape, false);
let before_id = storage_arc_id(&t);
match op_name {
"add_scalar_" => {
t.add_scalar_(scalar).expect("add_scalar_");
}
"mul_scalar_" => {
t.mul_scalar_(scalar).expect("mul_scalar_");
}
_ => unreachable!(),
}
let after_id = storage_arc_id(&t);
assert_eq!(before_id, after_id, "{label}: storage Arc replaced");
check_f64(
&format!("{label} value"),
&read_back_f64(&t),
exp,
tolerance::F64_ELEMENTWISE,
);
}
_ => unreachable!(),
}
}
}
}
#[test]
fn cpu_inplace_fill_zero_clamp() {
let file = load_fixtures();
for op_name in ["fill_", "zero_", "clamp_"] {
let cases = cases_for(&file, op_name, "cpu");
assert!(!cases.is_empty(), "no fixtures for {op_name}");
for f in cases {
let label = format!("{op_name} cpu dtype={}", f.dtype);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let t = make_cpu_f32(a_data, shape, false);
let before_id = storage_arc_id(&t);
match op_name {
"fill_" => {
t.fill_(f.scalar.expect("scalar") as f32).expect("fill_");
}
"zero_" => {
t.zero_().expect("zero_");
}
"clamp_" => {
let lo = f.min.expect("min") as f32;
let hi = f.max.expect("max") as f32;
t.clamp_(lo, hi).expect("clamp_");
}
_ => unreachable!(),
}
let after_id = storage_arc_id(&t);
assert_eq!(before_id, after_id, "{label}: storage Arc replaced");
check_f32(
&format!("{label} value"),
&read_back_f32(&t),
exp,
tolerance::F32_ELEMENTWISE,
);
}
"float64" => {
let t = make_cpu_f64(a_data, shape, false);
let before_id = storage_arc_id(&t);
match op_name {
"fill_" => {
t.fill_(f.scalar.expect("scalar")).expect("fill_");
}
"zero_" => {
t.zero_().expect("zero_");
}
"clamp_" => {
t.clamp_(f.min.expect("min"), f.max.expect("max"))
.expect("clamp_");
}
_ => unreachable!(),
}
let after_id = storage_arc_id(&t);
assert_eq!(before_id, after_id, "{label}: storage Arc replaced");
check_f64(
&format!("{label} value"),
&read_back_f64(&t),
exp,
tolerance::F64_ELEMENTWISE,
);
}
_ => unreachable!(),
}
}
}
}
#[test]
fn cpu_inplace_rejects_requires_grad_leaf() {
let t = make_cpu_f32(&[1.0, 2.0, 3.0], &[3], true);
let err = t.add_scalar_(1.0).unwrap_err();
assert!(
format!("{err}").contains("in-place") || format!("{err:?}").contains("InvalidArgument"),
"expected InvalidArgument for add_scalar_ on a requires_grad leaf, got {err:?}"
);
}
#[cfg(feature = "gpu")]
mod gpu {
use super::*;
use std::sync::Once;
static GPU_INIT: Once = Once::new();
fn ensure_cuda_backend() {
GPU_INIT.call_once(|| {
ferrotorch_gpu::init_cuda_backend()
.expect("CUDA backend must initialize for the GPU conformance suite");
});
}
fn require_cuda_fixtures(file: &FixtureFile) {
if !file.metadata.cuda_available {
panic!(
"fixtures/elementwise.json was generated without CUDA — \
regenerate on a CUDA-enabled host before running --features gpu tests"
);
}
}
fn run_binary_gpu(op_name: &str, op: BinaryOp) {
ensure_cuda_backend();
let file = load_fixtures();
require_cuda_fixtures(&file);
let cases = cases_for(&file, op_name, "cuda:0");
assert!(!cases.is_empty(), "no CUDA fixtures for op {op_name:?}");
for f in cases {
let label = format!("{op_name} cuda:0 tag={:?} dtype={}", f.tag, f.dtype);
let a_shape = f.a_shape.as_ref().unwrap();
let b_shape = f.b_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let b_data = f.b_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let expected = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
let grad_a_exp = f.grad_a.as_ref().map(F64ListSentinel::as_slice).unwrap();
let grad_b_exp = f.grad_b.as_ref().map(F64ListSentinel::as_slice).unwrap();
match f.dtype.as_str() {
"float32" => {
let a = upload_f32(make_cpu_f32(a_data, a_shape, false), Device::Cuda(0));
let b = upload_f32(make_cpu_f32(b_data, b_shape, false), Device::Cuda(0));
let c = op.apply_f32(&a, &b);
assert!(c.is_cuda(), "{label}: result not on CUDA");
check_f32(
&format!("{label} fwd"),
&read_back_f32(&c),
expected,
tolerance::F32_GPU,
);
let a_g = upload_f32(make_cpu_f32(a_data, a_shape, true), Device::Cuda(0));
let b_g = upload_f32(make_cpu_f32(b_data, b_shape, true), Device::Cuda(0));
let out = op.apply_f32(&a_g, &b_g);
grad_sum(&out).expect("sum").backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
let gb = b_g.grad().unwrap().expect("grad_b");
check_f32(
&format!("{label} grad_a"),
&read_back_f32(&ga),
grad_a_exp,
tolerance::F32_GPU,
);
check_f32(
&format!("{label} grad_b"),
&read_back_f32(&gb),
grad_b_exp,
tolerance::F32_GPU,
);
}
"float64" => {
let a = upload_f64(make_cpu_f64(a_data, a_shape, false), Device::Cuda(0));
let b = upload_f64(make_cpu_f64(b_data, b_shape, false), Device::Cuda(0));
let c = op.apply_f64(&a, &b);
assert!(c.is_cuda());
check_f64(
&format!("{label} fwd"),
&read_back_f64(&c),
expected,
tolerance::F64_GPU,
);
let a_g = upload_f64(make_cpu_f64(a_data, a_shape, true), Device::Cuda(0));
let b_g = upload_f64(make_cpu_f64(b_data, b_shape, true), Device::Cuda(0));
let out = op.apply_f64(&a_g, &b_g);
grad_sum(&out).expect("sum").backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
let gb = b_g.grad().unwrap().expect("grad_b");
check_f64(
&format!("{label} grad_a"),
&read_back_f64(&ga),
grad_a_exp,
tolerance::F64_GPU,
);
check_f64(
&format!("{label} grad_b"),
&read_back_f64(&gb),
grad_b_exp,
tolerance::F64_GPU,
);
}
_ => unreachable!(),
}
}
}
#[test]
fn gpu_add() {
run_binary_gpu("add", BinaryOp::Add);
}
#[test]
fn gpu_sub() {
run_binary_gpu("sub", BinaryOp::Sub);
}
#[test]
fn gpu_mul() {
run_binary_gpu("mul", BinaryOp::Mul);
}
#[test]
fn gpu_div() {
run_binary_gpu("div", BinaryOp::Div);
}
fn run_unary_gpu(op_name: &str, op: UnaryOp) {
ensure_cuda_backend();
let file = load_fixtures();
require_cuda_fixtures(&file);
let cases = cases_for(&file, op_name, "cuda:0");
assert!(!cases.is_empty(), "no CUDA fixtures for {op_name}");
for f in cases {
let label = format!("{op_name} cuda:0 tag={:?} dtype={}", f.tag, f.dtype);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let expected = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
let grad_a_exp = f.grad_a.as_ref().map(F64ListSentinel::as_slice).unwrap();
match f.dtype.as_str() {
"float32" => {
let a = upload_f32(make_cpu_f32(a_data, shape, false), Device::Cuda(0));
let c = op.apply_f32(&a);
assert!(c.is_cuda());
check_f32(
&format!("{label} fwd"),
&read_back_f32(&c),
expected,
tolerance::F32_GPU,
);
let a_g = upload_f32(make_cpu_f32(a_data, shape, true), Device::Cuda(0));
let out = op.apply_f32(&a_g);
grad_sum(&out).expect("sum").backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
check_f32(
&format!("{label} grad_a"),
&read_back_f32(&ga),
grad_a_exp,
tolerance::F32_GPU,
);
}
"float64" => {
let a = upload_f64(make_cpu_f64(a_data, shape, false), Device::Cuda(0));
let c = op.apply_f64(&a);
assert!(c.is_cuda());
check_f64(
&format!("{label} fwd"),
&read_back_f64(&c),
expected,
tolerance::F64_GPU,
);
let _ = op_name;
let a_g = upload_f64(make_cpu_f64(a_data, shape, true), Device::Cuda(0));
let out = op.apply_f64(&a_g);
grad_sum(&out).expect("sum").backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
check_f64(
&format!("{label} grad_a"),
&read_back_f64(&ga),
grad_a_exp,
tolerance::F64_GPU,
);
}
_ => unreachable!(),
}
}
}
#[test]
fn gpu_neg() {
run_unary_gpu("neg", UnaryOp::Neg);
}
#[test]
fn gpu_abs() {
run_unary_gpu("abs", UnaryOp::Abs);
}
#[test]
fn gpu_sqrt() {
run_unary_gpu("sqrt", UnaryOp::Sqrt);
}
#[test]
fn gpu_pow() {
ensure_cuda_backend();
let file = load_fixtures();
require_cuda_fixtures(&file);
let cases = cases_for(&file, "pow", "cuda:0");
assert!(!cases.is_empty(), "no CUDA fixtures for pow");
for f in cases {
let label = format!("pow cuda:0 dtype={} exp={:?}", f.dtype, f.exp);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f.exp.expect("exp");
let expected = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
let grad_a_exp = f.grad_a.as_ref().map(F64ListSentinel::as_slice).unwrap();
match f.dtype.as_str() {
"float32" => {
let a = upload_f32(make_cpu_f32(a_data, shape, false), Device::Cuda(0));
let c = pow(&a, exp).expect("pow");
check_f32(
&format!("{label} fwd"),
&read_back_f32(&c),
expected,
tolerance::F32_GPU,
);
let a_g = upload_f32(make_cpu_f32(a_data, shape, true), Device::Cuda(0));
let out = pow(&a_g, exp).expect("pow");
grad_sum(&out).expect("sum").backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
check_f32(
&format!("{label} grad_a"),
&read_back_f32(&ga),
grad_a_exp,
tolerance::F32_GPU,
);
}
"float64" => {
let a = upload_f64(make_cpu_f64(a_data, shape, false), Device::Cuda(0));
let c = pow(&a, exp).expect("pow");
check_f64(
&format!("{label} fwd"),
&read_back_f64(&c),
expected,
tolerance::F64_TRANSCENDENTAL,
);
let a_g = upload_f64(make_cpu_f64(a_data, shape, true), Device::Cuda(0));
let out = pow(&a_g, exp).expect("pow");
grad_sum(&out).expect("sum").backward().expect("backward");
let ga = a_g.grad().unwrap().expect("grad_a");
check_f64(
&format!("{label} grad_a"),
&read_back_f64(&ga),
grad_a_exp,
tolerance::F64_TRANSCENDENTAL,
);
}
_ => unreachable!(),
}
}
}
#[test]
fn gpu_where() {
ensure_cuda_backend();
run_where_for_device("cuda:0", Device::Cuda(0));
}
#[test]
fn gpu_sum_forward_only() {
ensure_cuda_backend();
let file = load_fixtures();
require_cuda_fixtures(&file);
for f in cases_for(&file, "sum", "cuda:0") {
let label = format!("sum cuda:0 dtype={}", f.dtype);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
match f.dtype.as_str() {
"float32" => {
let a = upload_f32(make_cpu_f32(a_data, shape, false), Device::Cuda(0));
let s = grad_sum(&a).expect("sum");
check_f32(&label, &read_back_f32(&s), exp, tolerance::F32_GPU);
}
"float64" => {
let a = upload_f64(make_cpu_f64(a_data, shape, false), Device::Cuda(0));
let s = grad_sum(&a).expect("sum");
check_f64(&label, &read_back_f64(&s), exp, tolerance::F64_GPU);
}
_ => unreachable!(),
}
}
}
#[test]
fn gpu_inplace_f32() {
ensure_cuda_backend();
let file = load_fixtures();
require_cuda_fixtures(&file);
for op_name in ["add_", "sub_", "mul_", "div_"] {
for f in cases_for(&file, op_name, "cuda:0") {
if f.dtype != "float32" {
continue; }
let label = format!("{op_name} cuda:0 dtype={}", f.dtype);
let a_shape = f.a_shape.as_ref().unwrap();
let b_shape = f.b_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let b_data = f.b_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
let t = upload_f32(make_cpu_f32(a_data, a_shape, false), Device::Cuda(0));
let other = upload_f32(make_cpu_f32(b_data, b_shape, false), Device::Cuda(0));
let before_id = storage_arc_id(&t);
match op_name {
"add_" => {
t.add_(&other).expect("add_");
}
"sub_" => {
t.sub_(&other).expect("sub_");
}
"mul_" => {
t.mul_(&other).expect("mul_");
}
"div_" => {
t.div_(&other).expect("div_");
}
_ => unreachable!(),
}
let after_id = storage_arc_id(&t);
assert_eq!(
before_id, after_id,
"{label}: GPU in-place op must mutate through the same storage Arc"
);
assert!(t.is_cuda(), "{label}: tensor left CUDA after in-place op");
check_f32(
&format!("{label} value"),
&read_back_f32(&t),
exp,
tolerance::F32_GPU,
);
}
}
}
#[test]
fn gpu_inplace_scalar_fill_clamp_f32() {
ensure_cuda_backend();
let file = load_fixtures();
require_cuda_fixtures(&file);
for op_name in ["add_scalar_", "mul_scalar_", "fill_", "zero_", "clamp_"] {
for f in cases_for(&file, op_name, "cuda:0") {
if f.dtype != "float32" {
continue;
}
let label = format!("{op_name} cuda:0 dtype={}", f.dtype);
let shape = f.a_shape.as_ref().unwrap();
let a_data = f.a_data.as_ref().map(F64ListSentinel::as_slice).unwrap();
let exp = f
.out_values
.as_ref()
.map(F64ListSentinel::as_slice)
.unwrap();
let t = upload_f32(make_cpu_f32(a_data, shape, false), Device::Cuda(0));
let before_id = storage_arc_id(&t);
match op_name {
"add_scalar_" => {
t.add_scalar_(f.scalar.expect("scalar") as f32)
.expect("add_scalar_");
}
"mul_scalar_" => {
t.mul_scalar_(f.scalar.expect("scalar") as f32)
.expect("mul_scalar_");
}
"fill_" => {
t.fill_(f.scalar.expect("scalar") as f32).expect("fill_");
}
"zero_" => {
t.zero_().expect("zero_");
}
"clamp_" => {
let lo = f.min.expect("min") as f32;
let hi = f.max.expect("max") as f32;
t.clamp_(lo, hi).expect("clamp_");
}
_ => unreachable!(),
}
let after_id = storage_arc_id(&t);
assert_eq!(before_id, after_id, "{label}: storage Arc replaced");
assert!(t.is_cuda(), "{label}: tensor left CUDA after in-place op");
check_f32(
&format!("{label} value"),
&read_back_f32(&t),
exp,
tolerance::F32_GPU,
);
}
}
}
}
#[test]
fn fixture_file_covers_every_phase21_op() {
let file = load_fixtures();
let mut by_op: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
for f in &file.fixtures {
*by_op.entry(f.op.as_str()).or_insert(0) += 1;
}
let required = [
"add",
"sub",
"mul",
"div",
"neg",
"abs",
"sqrt",
"pow",
"pow_zero_exp",
"sqrt_zero",
"div_zero",
"where",
"binary_map_maxmin",
"scalar_map_sqplus",
"unary_map_tan",
"fast_add",
"fast_sub",
"fast_mul",
"fast_div",
"fast_exp",
"fast_log",
"fast_sigmoid",
"fast_tanh",
"fast_sin",
"fast_cos",
"simd_add_f32",
"simd_mul_f32",
"simd_exp_f32",
"simd_log_f32",
"simd_sqrt_f32",
"simd_add_f64",
"simd_mul_f64",
"simd_exp_f64",
"sum",
"sum_axis",
"mean",
"nansum",
"nanmean",
"logsumexp",
"logsumexp_dim",
"add_",
"sub_",
"mul_",
"div_",
"add_scalar_",
"mul_scalar_",
"fill_",
"zero_",
"clamp_",
];
for r in required {
let n = by_op.get(r).copied().unwrap_or(0);
assert!(n > 0, "fixture file missing op {r:?}");
}
}