use vyre_primitives::graph::string_diagram::monoidal_compose_cpu_into;
#[derive(Debug, Default)]
pub struct StringDiagramRewriteScratch {
gf: Vec<f64>,
h_after_gf: Vec<f64>,
hg: Vec<f64>,
hg_after_f: Vec<f64>,
}
impl StringDiagramRewriteScratch {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
#[must_use]
pub fn compose_ir_arrows(f: &[f64], g: &[f64], a: u32, b: u32, c: u32) -> Vec<f64> {
let mut out = Vec::new();
compose_ir_arrows_into(f, g, a, b, c, &mut out);
out
}
pub fn compose_ir_arrows_into(f: &[f64], g: &[f64], a: u32, b: u32, c: u32, out: &mut Vec<f64>) {
use crate::observability::{bump, string_diagram_ir_rewrite_calls};
bump(&string_diagram_ir_rewrite_calls);
monoidal_compose_cpu_into(f, g, a, b, c, out);
}
#[must_use]
pub fn identity_arrow(n: u32) -> Vec<f64> {
let mut out = Vec::new();
identity_arrow_into(n, &mut out);
out
}
pub fn identity_arrow_into(n: u32, out: &mut Vec<f64>) {
let n_us = n as usize;
out.clear();
out.resize(n_us * n_us, 0.0);
for i in 0..n_us {
out[i * n_us + i] = 1.0;
}
}
#[must_use]
pub fn composition_associates(
f: &[f64],
g: &[f64],
h: &[f64],
a: u32,
b: u32,
c: u32,
d: u32,
) -> bool {
let mut scratch = StringDiagramRewriteScratch::new();
composition_associates_with_scratch(f, g, h, a, b, c, d, &mut scratch)
}
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn composition_associates_with_scratch(
f: &[f64],
g: &[f64],
h: &[f64],
a: u32,
b: u32,
c: u32,
d: u32,
scratch: &mut StringDiagramRewriteScratch,
) -> bool {
compose_ir_arrows_into(f, g, a, b, c, &mut scratch.gf);
compose_ir_arrows_into(&scratch.gf, h, a, c, d, &mut scratch.h_after_gf);
compose_ir_arrows_into(g, h, b, c, d, &mut scratch.hg);
compose_ir_arrows_into(f, &scratch.hg, a, b, d, &mut scratch.hg_after_f);
let tol = 1e-9_f64;
scratch
.h_after_gf
.iter()
.zip(scratch.hg_after_f.iter())
.all(|(a, b)| (a - b).abs() < tol * (1.0 + a.abs() + b.abs()))
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq_vec(a: &[f64], b: &[f64]) -> bool {
if a.len() != b.len() {
return false;
}
a.iter()
.zip(b.iter())
.all(|(x, y)| (x - y).abs() < 1e-9 * (1.0 + x.abs() + y.abs()))
}
#[test]
fn identity_left_unit() {
let f = vec![1.0, 2.0, 3.0, 4.0]; let id = identity_arrow(2);
let composed = compose_ir_arrows(&f, &id, 2, 2, 2);
assert!(approx_eq_vec(&composed, &f));
}
#[test]
fn identity_right_unit() {
let f = vec![1.0, 2.0, 3.0, 4.0];
let id = identity_arrow(2);
let composed = compose_ir_arrows(&id, &f, 2, 2, 2);
assert!(approx_eq_vec(&composed, &f));
}
#[test]
fn composition_associativity_holds() {
let f = vec![1.0, 0.5, -0.25, 0.5];
let g = vec![0.5, 0.5, 0.5, -0.5];
let h = vec![1.0, 0.0, 0.0, 1.0];
assert!(composition_associates(&f, &g, &h, 2, 2, 2, 2));
}
#[test]
fn rectangular_composition_dimensions() {
let f = vec![1.0; 6];
let g = vec![1.0; 12];
let composed = compose_ir_arrows(&f, &g, 2, 3, 4);
assert_eq!(composed.len(), 8);
}
#[test]
fn identity_arrow_size_matches() {
let id = identity_arrow(3);
assert_eq!(id.len(), 9);
assert_eq!(id[0], 1.0);
assert_eq!(id[4], 1.0);
assert_eq!(id[8], 1.0);
assert_eq!(id[1], 0.0);
assert_eq!(id[3], 0.0);
}
#[test]
fn reusable_outputs_preserve_associativity() {
let f = vec![1.0, 0.5, -0.25, 0.5];
let g = vec![0.5, 0.5, 0.5, -0.5];
let h = vec![1.0, 0.0, 0.0, 1.0];
let mut scratch = StringDiagramRewriteScratch::new();
assert!(composition_associates_with_scratch(
&f,
&g,
&h,
2,
2,
2,
2,
&mut scratch
));
}
}