use scirs2_core::ndarray::ArrayD;
use std::collections::HashMap;
use tensorlogic_infer::ExecutorError;
use crate::executor_f32::Scirs2Exec32;
use tensorlogic_infer::TlExecutor;
pub fn cast_f64_to_f32(t: &ArrayD<f64>) -> ArrayD<f32> {
t.mapv(|v| v as f32)
}
pub fn cast_f32_to_f64(t: &ArrayD<f32>) -> ArrayD<f64> {
t.mapv(|v| v as f64)
}
pub struct DualPrecisionBridge {
exec32: Scirs2Exec32,
f64_accumulator: HashMap<String, ArrayD<f64>>,
}
impl Default for DualPrecisionBridge {
fn default() -> Self {
Self::new()
}
}
impl DualPrecisionBridge {
pub fn new() -> Self {
DualPrecisionBridge {
exec32: Scirs2Exec32::new(),
f64_accumulator: HashMap::new(),
}
}
pub fn add_f64_tensor(&mut self, name: impl Into<String>, t: ArrayD<f64>) {
let key: String = name.into();
let t32 = cast_f64_to_f32(&t);
self.exec32.add_tensor(key, t32);
}
pub fn einsum_f32_result_f64(
&mut self,
spec: &str,
input_names: &[&str],
) -> Result<ArrayD<f64>, ExecutorError> {
let inputs: Result<Vec<_>, ExecutorError> = input_names
.iter()
.map(|name| {
self.exec32.tensors.get(*name).cloned().ok_or_else(|| {
ExecutorError::InvalidEinsumSpec(format!(
"Tensor '{}' not found in bridge",
name
))
})
})
.collect();
let inputs = inputs?;
let result32 = self.exec32.einsum(spec, &inputs)?;
Ok(cast_f32_to_f64(&result32))
}
pub fn accumulate_grad(&mut self, name: &str, grad: ArrayD<f64>) {
let entry = self
.f64_accumulator
.entry(name.to_string())
.or_insert_with(|| ArrayD::zeros(grad.raw_dim()));
*entry = &*entry + &grad;
}
pub fn get_grad(&self, name: &str) -> Option<&ArrayD<f64>> {
self.f64_accumulator.get(name)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::ArrayD;
fn f64_tensor(shape: &[usize], data: Vec<f64>) -> ArrayD<f64> {
ArrayD::from_shape_vec(shape, data).expect("valid shape/data for test tensor")
}
fn f32_tensor(shape: &[usize], data: Vec<f32>) -> ArrayD<f32> {
ArrayD::from_shape_vec(shape, data).expect("valid shape/data for test tensor")
}
#[test]
fn test_cast_f64_to_f32_values() {
let t = f64_tensor(&[3], vec![1.0, 2.0, std::f64::consts::PI]);
let t32 = cast_f64_to_f32(&t);
let pi32 = t32[[2]] as f64;
assert!(
(pi32 - std::f64::consts::PI).abs() < 1e-6,
"pi approx failed: {}",
pi32
);
assert_eq!(t32[[0]], 1.0_f32);
assert_eq!(t32[[1]], 2.0_f32);
}
#[test]
fn test_cast_f32_to_f64_lossless() {
let t = f32_tensor(&[4], vec![1.0, 2.0, 4.0, 8.0]);
let t64 = cast_f32_to_f64(&t);
let expected = [1.0_f64, 2.0, 4.0, 8.0];
for (got, exp) in t64.iter().zip(expected.iter()) {
assert_eq!(*got, *exp, "lossless upcast failed");
}
}
#[test]
fn test_cast_shape_preserved() {
let t = f64_tensor(&[3, 4, 5], (0..60).map(|i| i as f64).collect());
let t32 = cast_f64_to_f32(&t);
let t64 = cast_f32_to_f64(&t32);
assert_eq!(t.shape(), t32.shape());
assert_eq!(t32.shape(), t64.shape());
}
#[test]
fn test_dual_bridge_einsum() {
let identity = f64_tensor(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
let matrix = f64_tensor(&[2, 2], vec![3.0, 4.0, 5.0, 6.0]);
let mut bridge = DualPrecisionBridge::new();
bridge.add_f64_tensor("I", identity);
bridge.add_f64_tensor("M", matrix);
let result = bridge
.einsum_f32_result_f64("ij,jk->ik", &["I", "M"])
.expect("dual bridge einsum");
assert_eq!(result.shape(), &[2, 2]);
let data: Vec<f64> = result.iter().copied().collect();
assert!(
(data[0] - 3.0).abs() < 1e-4,
"expected 3.0, got {}",
data[0]
);
assert!(
(data[1] - 4.0).abs() < 1e-4,
"expected 4.0, got {}",
data[1]
);
}
#[test]
fn test_dual_bridge_accumulate_grad() {
let grad1 = f64_tensor(&[2], vec![1.0, 2.0]);
let grad2 = f64_tensor(&[2], vec![3.0, 4.0]);
let mut bridge = DualPrecisionBridge::new();
bridge.accumulate_grad("w", grad1);
bridge.accumulate_grad("w", grad2);
let acc = bridge
.get_grad("w")
.expect("grad should exist after accumulation");
let data: Vec<f64> = acc.iter().copied().collect();
assert!(
(data[0] - 4.0).abs() < 1e-10,
"accumulated grad[0]={}",
data[0]
);
assert!(
(data[1] - 6.0).abs() < 1e-10,
"accumulated grad[1]={}",
data[1]
);
}
#[test]
fn test_dual_bridge_get_nonexistent_grad() {
let bridge = DualPrecisionBridge::new();
assert!(bridge.get_grad("nonexistent").is_none());
}
#[test]
fn test_cast_zeros() {
let zeros = f64_tensor(&[3, 3], vec![0.0; 9]);
let z32 = cast_f64_to_f32(&zeros);
assert!(z32.iter().all(|&v| v == 0.0_f32));
let z64 = cast_f32_to_f64(&z32);
assert!(z64.iter().all(|&v| v == 0.0_f64));
}
#[test]
fn test_cast_large_values() {
let large = f32::MAX;
let t32 = f32_tensor(&[1], vec![large]);
let t64 = cast_f32_to_f64(&t32);
assert!(
(t64[[0]] - (large as f64)).abs() < 1.0,
"large value cast failed"
);
}
#[test]
fn test_dual_bridge_multiple_einsums() {
let a = f64_tensor(&[2, 2], vec![1.0, 0.0, 0.0, 2.0]); let b = f64_tensor(&[2, 2], vec![3.0, 0.0, 0.0, 4.0]); let c = f64_tensor(&[2, 2], vec![1.0, 1.0, 1.0, 1.0]);
let mut bridge = DualPrecisionBridge::new();
bridge.add_f64_tensor("A", a);
bridge.add_f64_tensor("B", b);
bridge.add_f64_tensor("C", c);
let ab = bridge
.einsum_f32_result_f64("ij,jk->ik", &["A", "B"])
.expect("first einsum");
let ab_data: Vec<f64> = ab.iter().copied().collect();
assert!((ab_data[0] - 3.0).abs() < 1e-4);
assert!((ab_data[3] - 8.0).abs() < 1e-4);
let ac = bridge
.einsum_f32_result_f64("ij,jk->ik", &["A", "C"])
.expect("second einsum");
let ac_data: Vec<f64> = ac.iter().copied().collect();
assert!((ac_data[0] - 1.0).abs() < 1e-4);
assert!((ac_data[1] - 1.0).abs() < 1e-4);
assert!((ac_data[2] - 2.0).abs() < 1e-4);
assert!((ac_data[3] - 2.0).abs() < 1e-4);
}
#[test]
fn test_dual_bridge_default() {
let bridge = DualPrecisionBridge::default();
assert!(bridge.get_grad("any_name").is_none());
}
}