use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::mp_edge_clip";
#[must_use]
pub fn mp_edge_clip(eigenvalues: &str, mp_edge: &str, out: &str, n: u32) -> Program {
if n == 0 {
return crate::invalid_output_program(
OP_ID,
out,
DataType::U32,
format!("Fix: mp_edge_clip requires n > 0, got {n}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let bound = Expr::load(mp_edge, Expr::u32(0));
let value = Expr::min(Expr::load(eigenvalues, t.clone()), bound);
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(n)),
vec![Node::store(out, t, value)],
)];
Program::wrapped(
vec![
BufferDecl::storage(eigenvalues, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(n),
BufferDecl::storage(mp_edge, 1, BufferAccess::ReadOnly, DataType::U32).with_count(1),
BufferDecl::storage(out, 2, BufferAccess::ReadWrite, DataType::U32).with_count(n),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn mp_upper_edge(m: u32, n: u32, sigma_sq: f64) -> f64 {
if m == 0 || n == 0 {
return f64::NAN;
}
let q = (m.min(n) as f64) / (m.max(n) as f64);
let factor = (1.0 + q.sqrt()).powi(2);
sigma_sq * factor
}
#[must_use]
pub fn mp_edge_clip_cpu(eigenvalues: &[f64], edge: f64) -> Vec<f64> {
eigenvalues.iter().map(|&v| v.min(edge)).collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-10 * (1.0 + a.abs() + b.abs())
}
#[test]
fn cpu_mp_edge_square_matrix() {
let edge = mp_upper_edge(100, 100, 1.0);
assert!(approx_eq(edge, 4.0));
}
#[test]
fn cpu_mp_edge_tall_matrix() {
let edge = mp_upper_edge(100, 25, 1.0);
assert!(approx_eq(edge, 2.25));
}
#[test]
fn cpu_clip_below_edge_unchanged() {
let eig = vec![1.0, 2.0, 3.0];
let out = mp_edge_clip_cpu(&eig, 4.0);
assert_eq!(out, eig);
}
#[test]
fn cpu_clip_above_edge_clamped() {
let eig = vec![1.0, 5.0, 10.0];
let out = mp_edge_clip_cpu(&eig, 4.0);
assert!(approx_eq(out[0], 1.0));
assert!(approx_eq(out[1], 4.0));
assert!(approx_eq(out[2], 4.0));
}
#[test]
fn ir_program_buffer_layout() {
let p = mp_edge_clip("e", "edge", "out", 16);
assert_eq!(p.workgroup_size, [256, 1, 1]);
assert_eq!(p.buffers[0].count(), 16);
assert_eq!(p.buffers[1].count(), 1);
assert_eq!(p.buffers[2].count(), 16);
}
#[test]
fn zero_n_traps() {
let p = mp_edge_clip("e", "edge", "out", 0);
assert!(p.stats().trap());
}
}