use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::geom::clifford2_geometric_product";
pub const MV_COMPONENTS: u32 = 4;
#[must_use]
pub fn clifford2_product(lhs: &str, rhs: &str, out: &str, n_pairs: u32) -> Program {
if n_pairs == 0 {
return crate::invalid_output_program(
OP_ID,
out,
DataType::U32,
format!("Fix: clifford2_product requires n_pairs > 0, got {n_pairs}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let base = Expr::mul(t.clone(), Expr::u32(MV_COMPONENTS));
let load_l = |off: u32| Expr::load(lhs, Expr::add(base.clone(), Expr::u32(off)));
let load_r = |off: u32| Expr::load(rhs, Expr::add(base.clone(), Expr::u32(off)));
let mul_shr = |a: Expr, b: Expr| Expr::shr(Expr::mul(a, b), Expr::u32(16));
let out_s = Expr::sub(
Expr::add(
Expr::add(mul_shr(load_l(0), load_r(0)), mul_shr(load_l(1), load_r(1))),
mul_shr(load_l(2), load_r(2)),
),
mul_shr(load_l(3), load_r(3)),
);
let out_1 = Expr::add(
Expr::sub(
Expr::add(mul_shr(load_l(0), load_r(1)), mul_shr(load_l(1), load_r(0))),
mul_shr(load_l(2), load_r(3)),
),
mul_shr(load_l(3), load_r(2)),
);
let out_2 = Expr::sub(
Expr::add(
Expr::add(mul_shr(load_l(0), load_r(2)), mul_shr(load_l(2), load_r(0))),
mul_shr(load_l(1), load_r(3)),
),
mul_shr(load_l(3), load_r(1)),
);
let out_12 = Expr::sub(
Expr::add(
Expr::add(mul_shr(load_l(0), load_r(3)), mul_shr(load_l(3), load_r(0))),
mul_shr(load_l(1), load_r(2)),
),
mul_shr(load_l(2), load_r(1)),
);
let store_to =
|off: u32, val: Expr| Node::store(out, Expr::add(base.clone(), Expr::u32(off)), val);
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(n_pairs)),
vec![
store_to(0, out_s),
store_to(1, out_1),
store_to(2, out_2),
store_to(3, out_12),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(lhs, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_pairs * MV_COMPONENTS),
BufferDecl::storage(rhs, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_pairs * MV_COMPONENTS),
BufferDecl::storage(out, 2, BufferAccess::ReadWrite, DataType::U32)
.with_count(n_pairs * MV_COMPONENTS),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Cl2Mv {
pub s: f64,
pub e1: f64,
pub e2: f64,
pub e12: f64,
}
impl Cl2Mv {
pub const IDENTITY: Self = Self {
s: 1.0,
e1: 0.0,
e2: 0.0,
e12: 0.0,
};
#[must_use]
pub const fn scalar(s: f64) -> Self {
Self {
s,
e1: 0.0,
e2: 0.0,
e12: 0.0,
}
}
#[must_use]
pub const fn vector(x: f64, y: f64) -> Self {
Self {
s: 0.0,
e1: x,
e2: y,
e12: 0.0,
}
}
}
#[must_use]
pub fn clifford2_product_cpu(a: Cl2Mv, b: Cl2Mv) -> Cl2Mv {
Cl2Mv {
s: a.s * b.s + a.e1 * b.e1 + a.e2 * b.e2 - a.e12 * b.e12,
e1: a.s * b.e1 + a.e1 * b.s - a.e2 * b.e12 + a.e12 * b.e2,
e2: a.s * b.e2 + a.e2 * b.s + a.e1 * b.e12 - a.e12 * b.e1,
e12: a.s * b.e12 + a.e12 * b.s + a.e1 * b.e2 - a.e2 * b.e1,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-10 * (1.0 + a.abs() + b.abs())
}
fn mv_eq(a: Cl2Mv, b: Cl2Mv) -> bool {
approx_eq(a.s, b.s)
&& approx_eq(a.e1, b.e1)
&& approx_eq(a.e2, b.e2)
&& approx_eq(a.e12, b.e12)
}
#[test]
fn cpu_identity_left_unit() {
let v = Cl2Mv::vector(2.0, 3.0);
let out = clifford2_product_cpu(Cl2Mv::IDENTITY, v);
assert!(mv_eq(out, v));
}
#[test]
fn cpu_identity_right_unit() {
let v = Cl2Mv::vector(2.0, 3.0);
let out = clifford2_product_cpu(v, Cl2Mv::IDENTITY);
assert!(mv_eq(out, v));
}
#[test]
fn cpu_basis_vector_squares_to_one() {
let e1 = Cl2Mv::vector(1.0, 0.0);
let out = clifford2_product_cpu(e1, e1);
assert!(approx_eq(out.s, 1.0));
assert!(approx_eq(out.e1, 0.0));
assert!(approx_eq(out.e2, 0.0));
assert!(approx_eq(out.e12, 0.0));
}
#[test]
fn cpu_e12_squared_is_minus_one() {
let e12 = Cl2Mv {
s: 0.0,
e1: 0.0,
e2: 0.0,
e12: 1.0,
};
let out = clifford2_product_cpu(e12, e12);
assert!(approx_eq(out.s, -1.0));
}
#[test]
fn cpu_basis_anticommutes() {
let e1 = Cl2Mv::vector(1.0, 0.0);
let e2 = Cl2Mv::vector(0.0, 1.0);
let p1 = clifford2_product_cpu(e1, e2);
let p2 = clifford2_product_cpu(e2, e1);
assert!(approx_eq(p1.e12, 1.0));
assert!(approx_eq(p2.e12, -1.0));
}
#[test]
fn cpu_pseudoscalar_anticommutes_with_vector() {
let e1 = Cl2Mv::vector(1.0, 0.0);
let e12 = Cl2Mv {
s: 0.0,
e1: 0.0,
e2: 0.0,
e12: 1.0,
};
let left = clifford2_product_cpu(e12, e1);
let right = clifford2_product_cpu(e1, e12);
assert!(approx_eq(left.e2, -1.0));
assert!(approx_eq(right.e2, 1.0));
}
#[test]
fn cpu_geometric_product_distributes() {
let a = Cl2Mv {
s: 1.0,
e1: 2.0,
e2: 3.0,
e12: 4.0,
};
let b = Cl2Mv::vector(1.0, 0.0);
let c = Cl2Mv {
s: 0.5,
e1: 0.0,
e2: 0.5,
e12: 0.0,
};
let bc = Cl2Mv {
s: b.s + c.s,
e1: b.e1 + c.e1,
e2: b.e2 + c.e2,
e12: b.e12 + c.e12,
};
let lhs = clifford2_product_cpu(a, bc);
let ab = clifford2_product_cpu(a, b);
let ac = clifford2_product_cpu(a, c);
let rhs = Cl2Mv {
s: ab.s + ac.s,
e1: ab.e1 + ac.e1,
e2: ab.e2 + ac.e2,
e12: ab.e12 + ac.e12,
};
assert!(mv_eq(lhs, rhs));
}
#[test]
fn ir_program_buffer_layout() {
let p = clifford2_product("a", "b", "out", 8);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["a", "b", "out"]);
for buf in p.buffers.iter() {
assert_eq!(buf.count(), 32); }
}
#[test]
fn zero_n_pairs_traps() {
let p = clifford2_product("a", "b", "out", 0);
assert!(p.stats().trap());
}
}