use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const UPDATE_OP_ID: &str = "vyre-primitives::hash::count_sketch_update";
#[must_use]
pub fn count_sketch_update(table: &str, hashes: &str, signs: &str, d: u32, w: u32) -> Program {
if d == 0 {
return crate::invalid_output_program(
UPDATE_OP_ID,
table,
DataType::U32,
format!("Fix: count_sketch_update requires d > 0, got {d}."),
);
}
if w == 0 {
return crate::invalid_output_program(
UPDATE_OP_ID,
table,
DataType::U32,
format!("Fix: count_sketch_update requires w > 0, got {w}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(d)),
vec![
Node::let_bind("col", Expr::load(hashes, t.clone())),
Node::let_bind("sgn", Expr::load(signs, t.clone())),
Node::let_bind("row_base", Expr::mul(t.clone(), Expr::u32(w))),
Node::let_bind("addr", Expr::add(Expr::var("row_base"), Expr::var("col"))),
Node::store(
table,
Expr::var("addr"),
Expr::add(Expr::load(table, Expr::var("addr")), Expr::var("sgn")),
),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(table, 0, BufferAccess::ReadWrite, DataType::U32).with_count(d * w),
BufferDecl::storage(hashes, 1, BufferAccess::ReadOnly, DataType::U32).with_count(d),
BufferDecl::storage(signs, 2, BufferAccess::ReadOnly, DataType::U32).with_count(d),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(UPDATE_OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
pub fn count_sketch_update_cpu(table: &mut [u32], hashes: &[u32], signs: &[i32], d: u32, w: u32) {
if w == 0
|| table.len() != d.saturating_mul(w) as usize
|| hashes.len() < d as usize
|| signs.len() < d as usize
{
return;
}
for r in 0..d as usize {
let col = hashes[r] as usize;
if col >= w as usize {
continue;
}
let addr = r * w as usize + col;
let cell = table[addr] as i32;
table[addr] = (cell + signs[r]) as u32;
}
}
#[must_use]
pub fn count_sketch_query_cpu(table: &[u32], hashes: &[u32], signs: &[i32], d: u32, w: u32) -> i32 {
let mut estimates = Vec::new();
count_sketch_query_cpu_into(table, hashes, signs, d, w, &mut estimates)
}
pub fn count_sketch_query_cpu_into(
table: &[u32],
hashes: &[u32],
signs: &[i32],
d: u32,
w: u32,
estimates: &mut Vec<i32>,
) -> i32 {
estimates.clear();
if d == 0
|| w == 0
|| table.len() != d.saturating_mul(w) as usize
|| hashes.len() < d as usize
|| signs.len() < d as usize
{
return 0;
}
estimates.reserve(d as usize);
for r in 0..d as usize {
let col = hashes[r] as usize;
if col >= w as usize {
return 0;
}
let cell = table[r * w as usize + col] as i32;
estimates.push(cell * signs[r]);
}
estimates.sort_unstable();
estimates[estimates.len() / 2]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_single_item_round_trip() {
let d = 5u32;
let w = 16u32;
let mut table = vec![0u32; (d * w) as usize];
let hashes = vec![3u32, 11, 2, 7, 14];
let signs = vec![1i32, -1, 1, -1, 1];
count_sketch_update_cpu(&mut table, &hashes, &signs, d, w);
let est = count_sketch_query_cpu(&table, &hashes, &signs, d, w);
assert_eq!(est, 1);
}
#[test]
fn cpu_repeated_inserts_count() {
let d = 5u32;
let w = 16u32;
let mut table = vec![0u32; (d * w) as usize];
let hashes = vec![3u32, 11, 2, 7, 14];
let signs = vec![1i32, -1, 1, -1, 1];
for _ in 0..7 {
count_sketch_update_cpu(&mut table, &hashes, &signs, d, w);
}
let est = count_sketch_query_cpu(&table, &hashes, &signs, d, w);
assert_eq!(est, 7);
}
#[test]
fn cpu_unrelated_query_returns_zero_or_small() {
let d = 5u32;
let w = 16u32;
let mut table = vec![0u32; (d * w) as usize];
let h_a = vec![3u32, 11, 2, 7, 14];
let s_a = vec![1i32, -1, 1, -1, 1];
count_sketch_update_cpu(&mut table, &h_a, &s_a, d, w);
let h_b = vec![5u32, 9, 0, 4, 12];
let s_b = vec![-1i32, 1, -1, 1, -1];
let est = count_sketch_query_cpu(&table, &h_b, &s_b, d, w);
assert_eq!(est, 0);
}
#[test]
fn cpu_two_items_independent_estimates() {
let d = 7u32;
let w = 32u32;
let mut table = vec![0u32; (d * w) as usize];
let h_a = vec![1u32, 2, 3, 4, 5, 6, 7];
let s_a = vec![1i32, 1, -1, 1, -1, 1, 1];
let h_b = vec![10u32, 20, 30, 11, 21, 0, 25];
let s_b = vec![-1i32, 1, 1, -1, 1, 1, -1];
for _ in 0..3 {
count_sketch_update_cpu(&mut table, &h_a, &s_a, d, w);
}
for _ in 0..5 {
count_sketch_update_cpu(&mut table, &h_b, &s_b, d, w);
}
assert_eq!(count_sketch_query_cpu(&table, &h_a, &s_a, d, w), 3);
assert_eq!(count_sketch_query_cpu(&table, &h_b, &s_b, d, w), 5);
}
#[test]
fn cpu_helpers_reject_malformed_inputs_without_panicking() {
let mut table = vec![9u32; 4];
count_sketch_update_cpu(&mut table, &[9], &[1], 2, 2);
assert_eq!(table, vec![9u32; 4]);
let mut estimates = Vec::with_capacity(8);
let ptr = estimates.as_ptr();
let got = count_sketch_query_cpu_into(&table, &[99, 1], &[1, 1], 2, 2, &mut estimates);
assert_eq!(got, 0);
assert_eq!(estimates.as_ptr(), ptr);
}
#[test]
fn ir_program_buffer_layout() {
let p = count_sketch_update("t", "h", "s", 5, 16);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["t", "h", "s"]);
assert_eq!(p.buffers[0].count(), 5 * 16);
assert_eq!(p.buffers[1].count(), 5);
assert_eq!(p.buffers[2].count(), 5);
}
#[test]
fn zero_d_traps() {
let p = count_sketch_update("t", "h", "s", 0, 16);
assert!(p.stats().trap());
}
#[test]
fn zero_w_traps() {
let p = count_sketch_update("t", "h", "s", 5, 0);
assert!(p.stats().trap());
}
}