use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const BIND_OP_ID: &str = "vyre-primitives::hash::hypervector_xor_bind";
pub const BUNDLE_OP_ID: &str = "vyre-primitives::hash::hypervector_majority_bundle";
pub const STANDARD_DIM_BITS: u32 = 10240;
pub const STANDARD_DIM_WORDS: u32 = STANDARD_DIM_BITS / 32;
#[must_use]
pub fn hypervector_xor_bind(a: &str, b: &str, out: &str, dim_words: u32) -> Program {
if dim_words == 0 {
return crate::invalid_output_program(
BIND_OP_ID,
out,
DataType::U32,
"Fix: hypervector_xor_bind requires dim_words > 0, got 0.".to_string(),
);
}
let t = Expr::InvocationId { axis: 0 };
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(dim_words)),
vec![Node::store(
out,
t.clone(),
Expr::bitxor(Expr::load(a, t.clone()), Expr::load(b, t)),
)],
)];
Program::wrapped(
vec![
BufferDecl::storage(a, 0, BufferAccess::ReadOnly, DataType::U32).with_count(dim_words),
BufferDecl::storage(b, 1, BufferAccess::ReadOnly, DataType::U32).with_count(dim_words),
BufferDecl::storage(out, 2, BufferAccess::ReadWrite, DataType::U32)
.with_count(dim_words),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(BIND_OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn hypervector_majority_bundle(stacked: &str, out: &str, dim_words: u32, k: u32) -> Program {
if dim_words == 0 {
return crate::invalid_output_program(
BUNDLE_OP_ID,
out,
DataType::U32,
"Fix: hypervector_majority_bundle requires dim_words > 0, got 0.".to_string(),
);
}
if k == 0 {
return crate::invalid_output_program(
BUNDLE_OP_ID,
out,
DataType::U32,
"Fix: hypervector_majority_bundle requires k > 0, got 0.".to_string(),
);
}
let t = Expr::InvocationId { axis: 0 };
let threshold = k / 2;
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(dim_words)),
vec![
Node::let_bind("acc", Expr::u32(0)),
Node::loop_for(
"bit",
Expr::u32(0),
Expr::u32(32),
vec![
Node::let_bind("count", Expr::u32(0)),
Node::loop_for(
"ii",
Expr::u32(0),
Expr::u32(k),
vec![
Node::let_bind("_unused_assign", Expr::u32(0)),
Node::assign(
"count",
Expr::add(
Expr::var("count"),
Expr::bitand(
Expr::shr(
Expr::load(
stacked,
Expr::add(
Expr::mul(
Expr::var("ii"),
Expr::u32(dim_words),
),
t.clone(),
),
),
Expr::var("bit"),
),
Expr::u32(1),
),
),
),
],
),
Node::if_then(
Expr::gt(Expr::var("count"), Expr::u32(threshold)),
vec![Node::assign(
"acc",
Expr::bitor(
Expr::var("acc"),
Expr::shl(Expr::u32(1), Expr::var("bit")),
),
)],
),
],
),
Node::store(out, t, Expr::var("acc")),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(stacked, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(k * dim_words),
BufferDecl::storage(out, 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(dim_words),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(BUNDLE_OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn xor_bind_cpu(a: &[u32], b: &[u32]) -> Vec<u32> {
let mut out = Vec::new();
xor_bind_cpu_into(a, b, &mut out);
out
}
pub fn xor_bind_cpu_into(a: &[u32], b: &[u32], out: &mut Vec<u32>) {
let dim_words = a.len().min(b.len());
out.clear();
out.reserve(dim_words);
out.extend(a.iter().zip(b.iter()).take(dim_words).map(|(&x, &y)| x ^ y));
}
#[must_use]
pub fn majority_bundle_cpu(hvs: &[Vec<u32>]) -> Vec<u32> {
let mut out = Vec::new();
majority_bundle_cpu_into(hvs, &mut out);
out
}
pub fn majority_bundle_cpu_into(hvs: &[Vec<u32>], out: &mut Vec<u32>) {
out.clear();
let Some(dim_words) = hvs.iter().map(Vec::len).min() else {
return;
};
if dim_words == 0 {
return;
}
let k = hvs.len();
let threshold = k / 2;
out.resize(dim_words, 0);
for w in 0..dim_words {
for bit in 0..32 {
let mut count = 0;
for hv in hvs {
count += (hv[w] >> bit) & 1;
}
if count as usize > threshold {
out[w] |= 1 << bit;
}
}
}
}
#[must_use]
pub fn hamming_similarity(a: &[u32], b: &[u32]) -> f32 {
let dim_words = a.len().min(b.len());
if dim_words == 0 {
return 1.0;
}
let dim_bits = (dim_words * 32) as f32;
let hamming: u32 = a
.iter()
.zip(b.iter())
.take(dim_words)
.map(|(&x, &y)| (x ^ y).count_ones())
.sum();
1.0 - 2.0 * (hamming as f32) / dim_bits
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn xor_bind_self_cancels() {
let a = vec![0xDEAD_BEEFu32, 0x0BAD_F00D];
let b = vec![0x1234_5678, 0x90AB_CDEF];
let bound = xor_bind_cpu(&a, &b);
let unbound = xor_bind_cpu(&bound, &b);
assert_eq!(unbound, a);
}
#[test]
fn xor_bind_zero_is_identity() {
let a = vec![0x1234, 0x5678, 0xABCD];
let zero = vec![0u32; a.len()];
assert_eq!(xor_bind_cpu(&a, &zero), a);
}
#[test]
fn xor_bind_cpu_into_reuses_output() {
let a = vec![0x1234, 0x5678, 0xABCD];
let b = vec![0xFFFF, 0x0000, 0x1111];
let mut out = Vec::with_capacity(8);
let ptr = out.as_ptr();
xor_bind_cpu_into(&a, &b, &mut out);
assert_eq!(out, vec![0xEDCB, 0x5678, 0xBADC]);
assert_eq!(out.as_ptr(), ptr);
}
#[test]
fn xor_bind_cpu_truncates_mismatched_inputs() {
let a = vec![0x1234, 0x5678, 0xABCD];
let b = vec![0xFFFF];
assert_eq!(xor_bind_cpu(&a, &b), vec![0xEDCB]);
}
#[test]
fn majority_bundle_three_vectors() {
let hvs = vec![vec![0b001], vec![0b001], vec![0b010]];
let out = majority_bundle_cpu(&hvs);
assert_eq!(out, vec![0b001]);
}
#[test]
fn majority_bundle_unanimous() {
let hvs = vec![vec![0xFF], vec![0xFF], vec![0xFF]];
let out = majority_bundle_cpu(&hvs);
assert_eq!(out, vec![0xFF]);
}
#[test]
fn majority_bundle_cpu_into_reuses_output() {
let hvs = vec![vec![0b001], vec![0b001], vec![0b010]];
let mut out = Vec::with_capacity(8);
let ptr = out.as_ptr();
majority_bundle_cpu_into(&hvs, &mut out);
assert_eq!(out, vec![0b001]);
assert_eq!(out.as_ptr(), ptr);
}
#[test]
fn majority_bundle_tie_rounds_to_zero() {
let hvs = vec![vec![0b1], vec![0b0]];
let out = majority_bundle_cpu(&hvs);
assert_eq!(out, vec![0b0]);
}
#[test]
fn majority_bundle_cpu_handles_empty_and_mismatched_inputs() {
let empty: Vec<Vec<u32>> = Vec::new();
assert!(majority_bundle_cpu(&empty).is_empty());
let hvs = vec![vec![0b001, 0b111], vec![0b001]];
assert_eq!(majority_bundle_cpu(&hvs), vec![0b001]);
}
#[test]
fn hamming_similarity_self_is_one() {
let a = vec![0xDEAD_BEEFu32; 8];
assert!((hamming_similarity(&a, &a) - 1.0).abs() < 1e-6);
}
#[test]
fn hamming_similarity_complement_is_minus_one() {
let a = vec![0xFFFF_FFFFu32; 4];
let b = vec![0x0000_0000u32; 4];
assert!((hamming_similarity(&a, &b) - (-1.0)).abs() < 1e-6);
}
#[test]
fn hamming_similarity_handles_empty_and_mismatched_inputs() {
assert_eq!(hamming_similarity(&[], &[]), 1.0);
let a = vec![0xFFFF_FFFFu32, 0];
let b = vec![0];
assert!((hamming_similarity(&a, &b) - (-1.0)).abs() < 1e-6);
}
#[test]
fn ir_program_xor_bind_buffer_layout() {
let p = hypervector_xor_bind("a", "b", "out", 64);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["a", "b", "out"]);
for buf in p.buffers.iter() {
assert_eq!(buf.count(), 64);
}
}
#[test]
fn ir_program_xor_bind_zero_dim_is_trap() {
let p = hypervector_xor_bind("a", "b", "out", 0);
assert_eq!(p.buffers.len(), 1);
assert_eq!(p.buffers[0].name(), "out");
}
#[test]
fn ir_program_bundle_buffer_layout() {
let p = hypervector_majority_bundle("stack", "out", 8, 5);
assert_eq!(p.buffers[0].count(), 5 * 8);
assert_eq!(p.buffers[1].count(), 8);
}
#[test]
fn standard_dim_constants() {
assert_eq!(STANDARD_DIM_BITS, STANDARD_DIM_WORDS * 32);
const _: () = assert!(STANDARD_DIM_BITS >= 8192);
}
}