use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::conv1d";
pub const MAX_RADIUS: u32 = 64;
fn expr_min(a: Expr, b: Expr) -> Expr {
Expr::select(Expr::lt(a.clone(), b.clone()), a, b)
}
#[must_use]
pub fn conv1d_node(input: &str, output: &str, weights: &str, params: &str) -> Node {
Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(vec![
Node::let_bind("count", Expr::load(params, Expr::u32(0))),
Node::let_bind("stride", Expr::load(params, Expr::u32(1))),
Node::let_bind("radius", Expr::load(params, Expr::u32(2))),
Node::let_bind("idx", Expr::gid_x()),
Node::if_then(
Expr::lt(Expr::var("idx"), Expr::var("count")),
vec![
Node::let_bind(
"diameter",
Expr::add(Expr::mul(Expr::var("radius"), Expr::u32(2)), Expr::u32(1)),
),
Node::let_bind("acc", Expr::u32(0)),
Node::loop_for(
"k",
Expr::u32(0),
Expr::var("diameter"),
vec![
Node::let_bind(
"src_idx",
Expr::select(
Expr::ge(Expr::var("k"), Expr::var("radius")),
expr_min(
Expr::add(
Expr::var("idx"),
Expr::mul(
Expr::sub(Expr::var("k"), Expr::var("radius")),
Expr::var("stride"),
),
),
Expr::sub(Expr::var("count"), Expr::u32(1)),
),
Expr::select(
Expr::ge(
Expr::var("idx"),
Expr::mul(
Expr::sub(Expr::var("radius"), Expr::var("k")),
Expr::var("stride"),
),
),
Expr::sub(
Expr::var("idx"),
Expr::mul(
Expr::sub(Expr::var("radius"), Expr::var("k")),
Expr::var("stride"),
),
),
Expr::u32(0),
),
),
),
Node::let_bind("val", Expr::load(input, Expr::var("src_idx"))),
Node::let_bind("w", Expr::load(weights, Expr::var("k"))),
Node::assign(
"acc",
Expr::add(
Expr::var("acc"),
Expr::mul(Expr::var("val"), Expr::var("w")),
),
),
],
),
Node::store(output, Expr::var("idx"), Expr::var("acc")),
],
),
]),
}
}
#[must_use]
pub fn conv1d_program(count: u32, radius: u32) -> Program {
let clamped_radius = radius.min(MAX_RADIUS);
let diameter = 2 * clamped_radius + 1;
Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(count),
BufferDecl::storage("output", 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(count),
BufferDecl::storage("weights", 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(diameter),
BufferDecl::storage("params", 3, BufferAccess::ReadOnly, DataType::U32).with_count(4),
],
[256, 1, 1],
vec![conv1d_node("input", "output", "weights", "params")],
)
}
#[must_use]
pub fn gaussian_weights(radius: u32, sigma: f32) -> Vec<u32> {
let clamped = radius.min(MAX_RADIUS);
let diameter = (2 * clamped + 1) as usize;
let mut weights = vec![0.0f64; diameter];
let s2 = 2.0 * (sigma as f64) * (sigma as f64);
let mut sum = 0.0;
for (i, w) in weights.iter_mut().enumerate() {
let x = i as f64 - clamped as f64;
*w = (-x * x / s2).exp();
sum += *w;
}
weights
.iter()
.map(|w| ((w / sum) * 65536.0).round() as u32)
.collect()
}
#[must_use]
pub fn pack_params(count: u32, stride: u32, radius: u32) -> Vec<u32> {
vec![count, stride, radius.min(MAX_RADIUS), 0]
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| conv1d_program(8, 1),
Some(|| {
let input: Vec<u32> = vec![100, 200, 300, 400, 500, 600, 700, 800];
let params = pack_params(8, 1, 1);
let weights: Vec<u32> = vec![16384, 32768, 16384];
let to_bytes = |v: &[u32]| crate::wire::pack_u32_slice(v);
vec![vec![
to_bytes(&input),
vec![0u8; 32], to_bytes(&weights),
to_bytes(¶ms),
]]
}),
Some(|| {
let to_bytes = |v: &[u32]| crate::wire::pack_u32_slice(v);
vec![vec![to_bytes(&[
8_192_000, 13_107_200, 19_660_800, 26_214_400, 32_768_000, 39_321_600,
45_875_200, 50_790_400,
])]]
}),
)
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn cpu_conv1d(input: &[u32], weights: &[u32], stride: u32) -> Vec<u32> {
let mut output = Vec::new();
cpu_conv1d_into(input, weights, stride, &mut output);
output
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn cpu_conv1d_into(input: &[u32], weights: &[u32], stride: u32, output: &mut Vec<u32>) {
output.clear();
let count = input.len();
if count == 0 {
return;
}
let diameter = weights.len();
let radius = diameter / 2;
output.reserve(count);
for idx in 0..count {
let mut acc: u32 = 0;
for k in 0..diameter {
let src_idx = if k >= radius {
let offset = (k - radius) as u32 * stride;
let raw = idx as u32 + offset;
raw.min(count as u32 - 1) as usize
} else {
let offset = (radius - k) as u32 * stride;
if idx as u32 >= offset {
(idx as u32 - offset) as usize
} else {
0
}
};
acc = acc.wrapping_add(input[src_idx].wrapping_mul(weights[k]));
}
output.push(acc);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_conv1d_identity_kernel() {
let input = vec![10, 20, 30, 40, 50];
let weights = vec![0, 65536, 0];
let result = cpu_conv1d(&input, &weights, 1);
let expected: Vec<u32> = input.iter().map(|&v| v * 65536).collect();
assert_eq!(result, expected);
}
#[test]
fn cpu_conv1d_averaging_kernel_matches_inventory() {
let input = vec![100u32, 200, 300, 400, 500, 600, 700, 800];
let weights = vec![16384u32, 32768, 16384]; let result = cpu_conv1d(&input, &weights, 1);
let expected = vec![
8_192_000, 13_107_200, 19_660_800, 26_214_400, 32_768_000, 39_321_600, 45_875_200,
50_790_400,
];
assert_eq!(result, expected);
}
#[test]
fn cpu_conv1d_empty() {
let result = cpu_conv1d(&[], &[65536], 1);
assert!(result.is_empty());
}
#[test]
fn cpu_conv1d_single_element() {
let result = cpu_conv1d(&[42], &[16384, 32768, 16384], 1);
assert_eq!(result, vec![42 * 65536]);
}
#[test]
fn cpu_conv1d_into_reuses_output_and_removes_stale_tail() {
let mut out = Vec::with_capacity(8);
out.extend_from_slice(&[1, 2, 3, 4, 5, 6]);
let capacity = out.capacity();
cpu_conv1d_into(&[10, 20, 30], &[0, 65536, 0], 1, &mut out);
assert_eq!(out, vec![655_360, 1_310_720, 1_966_080]);
assert_eq!(out.capacity(), capacity);
cpu_conv1d_into(&[7], &[65536], 1, &mut out);
assert_eq!(out, vec![458_752]);
assert_eq!(out.capacity(), capacity);
}
}