use vyre_primitives::math::qsvt::{qsvt_apply_cpu_into, qsvt_block_encode_cpu_into};
#[must_use]
pub fn negative_truncator_coeffs(k_steps: u32) -> Vec<f64> {
let mut out = Vec::new();
negative_truncator_coeffs_into(k_steps, &mut out);
out
}
pub fn negative_truncator_coeffs_into(k_steps: u32, out: &mut Vec<f64>) {
let pi = std::f64::consts::PI;
let all = [
-1.0 / pi,
-0.5,
-2.0 / (3.0 * pi),
0.0,
2.0 / (15.0 * pi),
0.0,
-2.0 / (35.0 * pi),
0.0,
];
out.clear();
out.extend(all.iter().take(k_steps as usize).copied());
}
#[must_use]
pub fn transport_residual(
dispatch_cost: &[f64],
weights: &[f64],
n: u32,
chebyshev_order: u32,
) -> Vec<f64> {
let mut scaled = Vec::new();
let mut coeffs = Vec::new();
let mut out = Vec::new();
let mut t_prev = Vec::new();
let mut t_curr = Vec::new();
let mut t_next = Vec::new();
transport_residual_into(
dispatch_cost,
weights,
n,
chebyshev_order,
&mut scaled,
&mut coeffs,
&mut out,
&mut t_prev,
&mut t_curr,
&mut t_next,
);
out
}
#[allow(clippy::too_many_arguments)]
pub fn transport_residual_into(
dispatch_cost: &[f64],
weights: &[f64],
n: u32,
chebyshev_order: u32,
scaled: &mut Vec<f64>,
coeffs: &mut Vec<f64>,
out: &mut Vec<f64>,
t_prev: &mut Vec<f64>,
t_curr: &mut Vec<f64>,
t_next: &mut Vec<f64>,
) {
use crate::observability::{bump, qsvt_matrix_function_fusion_calls};
bump(&qsvt_matrix_function_fusion_calls);
assert!(
chebyshev_order > 0 && chebyshev_order <= 8,
"Fix: chebyshev_order must be in 1..=8 for 0.6, got {chebyshev_order}."
);
let n_us = n as usize;
assert_eq!(dispatch_cost.len(), n_us * n_us);
assert_eq!(weights.len(), n_us);
let _frobenius = qsvt_block_encode_cpu_into(dispatch_cost, n, scaled);
negative_truncator_coeffs_into(chebyshev_order, coeffs);
qsvt_apply_cpu_into(scaled, weights, coeffs, n, out, t_prev, t_curr, t_next);
}
#[must_use]
pub fn fusion_affinity(transport_residual: &[f64]) -> Vec<f64> {
let mut out = Vec::new();
fusion_affinity_into(transport_residual, &mut out);
out
}
pub fn fusion_affinity_into(transport_residual: &[f64], out: &mut Vec<f64>) {
out.clear();
out.reserve(transport_residual.len());
out.extend(transport_residual.iter().map(|&v| -v.abs()));
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-3 * (1.0 + a.abs() + b.abs())
}
#[test]
fn truncator_first_coeff_is_negative_inverse_pi() {
let coeffs = negative_truncator_coeffs(1);
assert_eq!(coeffs.len(), 1);
assert!(approx_eq(coeffs[0], -1.0 / std::f64::consts::PI));
}
#[test]
fn truncator_high_order_caps_at_eight() {
let coeffs = negative_truncator_coeffs(8);
assert_eq!(coeffs.len(), 8);
assert!(approx_eq(coeffs[3], 0.0));
assert!(approx_eq(coeffs[5], 0.0));
assert!(approx_eq(coeffs[7], 0.0));
}
#[test]
fn transport_residual_zero_weights_yields_zero() {
let cost = vec![1.0, 0.5, 0.5, 1.0];
let weights = vec![0.0, 0.0];
let result = transport_residual(&cost, &weights, 2, 4);
assert!(result.iter().all(|&v| approx_eq(v, 0.0)));
}
#[test]
fn fusion_affinity_inverts_residual_magnitude() {
let residual = vec![1.0, -2.5, 0.0, 0.5];
let aff = fusion_affinity(&residual);
assert!(approx_eq(aff[0], -1.0));
assert!(approx_eq(aff[1], -2.5));
assert!(approx_eq(aff[2], 0.0));
assert!(approx_eq(aff[3], -0.5));
}
#[test]
fn transport_residual_into_reuses_qsvt_buffers() {
let cost = vec![1.0, 0.5, 0.5, 1.0];
let weights = vec![1.0, 1.0];
let mut scaled = Vec::with_capacity(8);
let mut coeffs = Vec::with_capacity(8);
let mut out = Vec::with_capacity(8);
let mut t_prev = Vec::with_capacity(8);
let mut t_curr = Vec::with_capacity(8);
let mut t_next = Vec::with_capacity(8);
let pointers = [
scaled.as_ptr(),
coeffs.as_ptr(),
out.as_ptr(),
t_prev.as_ptr(),
t_curr.as_ptr(),
t_next.as_ptr(),
];
transport_residual_into(
&cost,
&weights,
2,
4,
&mut scaled,
&mut coeffs,
&mut out,
&mut t_prev,
&mut t_curr,
&mut t_next,
);
assert_eq!(out.len(), 2);
for ptr in [
scaled.as_ptr(),
coeffs.as_ptr(),
out.as_ptr(),
t_prev.as_ptr(),
t_curr.as_ptr(),
t_next.as_ptr(),
] {
assert!(pointers.contains(&ptr));
}
}
#[test]
fn transport_residual_runs_on_small_matrix() {
let cost = vec![1.0, 0.5, 0.3, 0.5, 1.0, 0.4, 0.3, 0.4, 1.0];
let weights = vec![1.0, 1.0, 1.0];
let result = transport_residual(&cost, &weights, 3, 4);
assert_eq!(result.len(), 3);
}
#[test]
#[should_panic(expected = "chebyshev_order")]
fn rejects_zero_chebyshev_order() {
let cost = vec![1.0; 4];
let weights = vec![1.0; 2];
let _ = transport_residual(&cost, &weights, 2, 0);
}
}