use std::path::PathBuf;
use ferrotorch_core::grad_fns::arithmetic::{add, mul};
use ferrotorch_core::grad_fns::reduction::sum;
use ferrotorch_core::storage::TensorStorage;
use ferrotorch_core::{FerrotorchResult, Tensor};
use ferrotorch_jit::TracedModule;
use ferrotorch_jit_script::script;
use serde::Deserialize;
#[derive(Debug, Deserialize)]
struct FixtureFile {
#[allow(dead_code, reason = "metadata kept for forward-compat and diagnostics")]
metadata: FixtureMetadata,
fixtures: Vec<Fixture>,
}
#[derive(Debug, Deserialize)]
struct FixtureMetadata {
#[allow(dead_code, reason = "used for version drift diagnostics")]
torch_version: String,
#[allow(dead_code, reason = "metadata kept for forward-compat")]
python_executable: String,
#[allow(dead_code, reason = "metadata kept for forward-compat")]
python_platform: String,
#[allow(dead_code, reason = "metadata kept for forward-compat")]
generated_at: String,
#[allow(dead_code, reason = "metadata kept for forward-compat")]
description: String,
}
#[derive(Debug, Deserialize)]
struct Fixture {
case: String,
op: String,
description: String,
#[serde(default)]
input_a: Option<Vec<f64>>,
#[serde(default)]
input_b: Option<Vec<f64>>,
#[serde(default)]
input_c: Option<Vec<f64>>,
#[serde(default)]
input_w: Option<Vec<f64>>,
#[serde(default)]
input_a_first: Option<Vec<f64>>,
#[serde(default)]
input_w_first: Option<Vec<f64>>,
#[serde(default)]
input_a_second: Option<Vec<f64>>,
#[serde(default)]
input_w_second: Option<Vec<f64>>,
#[serde(default)]
#[allow(dead_code, reason = "kept for fixture-schema parity and diagnostics")]
dtype: Option<String>,
#[serde(default)]
expected_output: Option<Vec<f64>>,
#[serde(default)]
expected_output_first: Option<Vec<f64>>,
#[serde(default)]
expected_output_second: Option<Vec<f64>>,
#[serde(default)]
cascade_skip: Option<String>,
}
fn load_fixtures() -> FixtureFile {
let p = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("conformance")
.join("fixtures.json");
let bytes = std::fs::read(&p).unwrap_or_else(|e| {
panic!(
"read {} failed: {e}. Regenerate via \
scripts/regenerate_jit_script_fixtures.py",
p.display()
)
});
serde_json::from_slice(&bytes).unwrap_or_else(|e| panic!("parse {}: {e}", p.display()))
}
fn fixtures_for<'a>(file: &'a FixtureFile, case: &str) -> Option<&'a Fixture> {
file.fixtures.iter().find(|f| f.case == case)
}
fn t1d_f32(data: &[f64]) -> Tensor<f32> {
let f32s: Vec<f32> = data.iter().map(|&x| x as f32).collect();
let n = f32s.len();
Tensor::from_storage(TensorStorage::cpu(f32s), vec![n], false).unwrap()
}
fn t1d_f64(data: &[f64]) -> Tensor<f64> {
let n = data.len();
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![n], false).unwrap()
}
#[script]
fn cs_weighted_sum(a: Tensor<f32>, w: Tensor<f32>) -> FerrotorchResult<Tensor<f32>> {
let prod = mul(&a, &w)?;
sum(&prod)
}
#[script]
fn cs_three_arg_add(
a: Tensor<f32>,
b: Tensor<f32>,
c: Tensor<f32>,
) -> FerrotorchResult<Tensor<f32>> {
let ab = add(&a, &b)?;
add(&ab, &c)
}
#[script]
fn cs_weighted_sum_f64(a: Tensor<f64>, w: Tensor<f64>) -> FerrotorchResult<Tensor<f64>> {
let prod = mul(&a, &w)?;
sum(&prod)
}
fn assert_close_f32(actual: &[f32], expected: &[f64], label: &str) {
assert_eq!(
actual.len(),
expected.len(),
"{label}: length mismatch (actual={}, expected={})",
actual.len(),
expected.len()
);
const TOL: f32 = 1e-5;
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
let e_f32 = e as f32;
let diff = (a - e_f32).abs();
let scale = e_f32.abs().max(1.0);
assert!(
diff <= TOL * scale,
"{label}: index {i}: actual={a} expected={e_f32} diff={diff:.3e} tol={:.3e}",
TOL * scale
);
}
}
fn assert_close_f64(actual: &[f64], expected: &[f64], label: &str) {
assert_eq!(
actual.len(),
expected.len(),
"{label}: length mismatch (actual={}, expected={})",
actual.len(),
expected.len()
);
const TOL: f64 = 1e-10;
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
let diff = (a - e).abs();
let scale = e.abs().max(1.0);
assert!(
diff <= TOL * scale,
"{label}: index {i}: actual={a} expected={e} diff={diff:.3e}",
);
}
}
#[test]
fn fixture_file_covers_every_case() {
let file = load_fixtures();
let required = [
"two_arg_weighted_sum_f32",
"three_arg_add_f32",
"module_save_load_roundtrip_f32",
"scalar_type_preservation_f64",
"module_reuse_f32",
"unrecognized_return_type_emits_compile_error",
"missing_return_type_emits_compile_error",
];
for case in required {
assert!(
fixtures_for(&file, case).is_some(),
"fixture file missing case {case:?} — regenerate via \
scripts/regenerate_jit_script_fixtures.py"
);
}
}
#[test]
fn script_two_arg_weighted_sum_f32() {
let file = load_fixtures();
let fx =
fixtures_for(&file, "two_arg_weighted_sum_f32").expect("fixture two_arg_weighted_sum_f32");
assert_eq!(fx.op, "script");
assert!(fx.cascade_skip.is_none(), "unexpected cascade_skip");
let input_a = fx.input_a.as_deref().expect("input_a");
let input_w = fx.input_w.as_deref().expect("input_w");
let expected = fx.expected_output.as_deref().expect("expected_output");
let module: TracedModule<f32> = cs_weighted_sum(t1d_f32(input_a), t1d_f32(input_w)).unwrap();
let result = module
.forward_multi(&[t1d_f32(input_a), t1d_f32(input_w)])
.unwrap();
let actual = result.data().expect("read data");
assert_close_f32(actual, expected, "script two_arg_weighted_sum_f32");
}
#[test]
fn script_three_arg_add_f32() {
let file = load_fixtures();
let fx = fixtures_for(&file, "three_arg_add_f32").expect("fixture three_arg_add_f32");
assert_eq!(fx.op, "script");
assert!(fx.cascade_skip.is_none(), "unexpected cascade_skip");
let input_a = fx.input_a.as_deref().expect("input_a");
let input_b = fx.input_b.as_deref().expect("input_b");
let input_c = fx.input_c.as_deref().expect("input_c");
let expected = fx.expected_output.as_deref().expect("expected_output");
let module: TracedModule<f32> =
cs_three_arg_add(t1d_f32(input_a), t1d_f32(input_b), t1d_f32(input_c)).unwrap();
let result = module
.forward_multi(&[t1d_f32(input_a), t1d_f32(input_b), t1d_f32(input_c)])
.unwrap();
let actual = result.data().expect("read data");
assert_close_f32(actual, expected, "script three_arg_add_f32");
}
#[test]
fn script_module_save_load_roundtrip_f32() {
let file = load_fixtures();
let fx = fixtures_for(&file, "module_save_load_roundtrip_f32")
.expect("fixture module_save_load_roundtrip_f32");
assert_eq!(fx.op, "script");
assert!(fx.cascade_skip.is_none(), "unexpected cascade_skip");
let input_a = fx.input_a.as_deref().expect("input_a");
let input_w = fx.input_w.as_deref().expect("input_w");
let expected = fx.expected_output.as_deref().expect("expected_output");
let module = cs_weighted_sum(t1d_f32(input_a), t1d_f32(input_w)).unwrap();
let bytes = module.to_bytes();
let loaded: TracedModule<f32> = TracedModule::<f32>::from_bytes(&bytes).unwrap();
let result = loaded
.forward_multi(&[t1d_f32(input_a), t1d_f32(input_w)])
.unwrap();
let actual = result.data().expect("read data");
assert_close_f32(actual, expected, "script module_save_load_roundtrip_f32");
}
#[test]
fn script_scalar_type_preservation_f64() {
let file = load_fixtures();
let fx = fixtures_for(&file, "scalar_type_preservation_f64")
.expect("fixture scalar_type_preservation_f64");
assert_eq!(fx.op, "script");
assert!(fx.cascade_skip.is_none(), "unexpected cascade_skip");
let input_a = fx.input_a.as_deref().expect("input_a");
let input_w = fx.input_w.as_deref().expect("input_w");
let expected = fx.expected_output.as_deref().expect("expected_output");
let module: TracedModule<f64> =
cs_weighted_sum_f64(t1d_f64(input_a), t1d_f64(input_w)).unwrap();
let result = module
.forward_multi(&[t1d_f64(input_a), t1d_f64(input_w)])
.unwrap();
let actual = result.data().expect("read data");
assert_close_f64(actual, expected, "script scalar_type_preservation_f64");
}
#[test]
fn script_module_reuse_f32() {
let file = load_fixtures();
let fx = fixtures_for(&file, "module_reuse_f32").expect("fixture module_reuse_f32");
assert_eq!(fx.op, "script");
assert!(fx.cascade_skip.is_none(), "unexpected cascade_skip");
let a_first = fx.input_a_first.as_deref().expect("input_a_first");
let w_first = fx.input_w_first.as_deref().expect("input_w_first");
let a_second = fx.input_a_second.as_deref().expect("input_a_second");
let w_second = fx.input_w_second.as_deref().expect("input_w_second");
let exp_first = fx
.expected_output_first
.as_deref()
.expect("expected_output_first");
let exp_second = fx
.expected_output_second
.as_deref()
.expect("expected_output_second");
let module: TracedModule<f32> = cs_weighted_sum(t1d_f32(a_first), t1d_f32(w_first)).unwrap();
let r1 = module
.forward_multi(&[t1d_f32(a_first), t1d_f32(w_first)])
.unwrap();
assert_close_f32(
r1.data().expect("r1 data"),
exp_first,
"module_reuse first call",
);
let r2 = module
.forward_multi(&[t1d_f32(a_second), t1d_f32(w_second)])
.unwrap();
assert_close_f32(
r2.data().expect("r2 data"),
exp_second,
"module_reuse second call",
);
}
#[test]
fn script_error_cases_are_cascade_skipped_in_fixtures() {
let file = load_fixtures();
let spec_only_cases = [
"unrecognized_return_type_emits_compile_error",
"missing_return_type_emits_compile_error",
];
assert_eq!(
spec_only_cases.len(),
2,
"spec_only_cases list must enumerate every script_error fixture; \
update both the list and this count together"
);
let fixture_script_error_count = file
.fixtures
.iter()
.filter(|f| f.op == "script_error")
.count();
assert_eq!(
fixture_script_error_count,
spec_only_cases.len(),
"fixture file has {fixture_script_error_count} entries with op=script_error \
but the test enumerates {}; add the new case_name to spec_only_cases",
spec_only_cases.len()
);
let mut matched_count = 0_usize;
for case_name in spec_only_cases {
let fx = fixtures_for(&file, case_name)
.unwrap_or_else(|| panic!("missing fixture for spec-only case {case_name:?}"));
assert_eq!(
fx.op, "script_error",
"{case_name}: expected op=script_error"
);
let skip = fx.cascade_skip.as_deref().unwrap_or_else(|| {
panic!(
"{case_name}: spec-only fixture must have cascade_skip set; \
set cascade_skip = \"spec-only marker, no PyTorch reference\""
)
});
assert!(
skip.contains("spec-only"),
"{case_name}: cascade_skip must contain 'spec-only', got: {skip:?}"
);
assert!(
fx.description.contains("compile_error"),
"{case_name}: description must mention 'compile_error', got: {:?}",
fx.description
);
matched_count += 1;
}
assert_eq!(
matched_count,
spec_only_cases.len(),
"matched only {matched_count} of {} spec-only cases",
spec_only_cases.len()
);
}