use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::conv1d";
pub const MAX_RADIUS: u32 = 64;
fn expr_min(a: Expr, b: Expr) -> Expr {
Expr::select(Expr::lt(a.clone(), b.clone()), a, b)
}
#[must_use]
pub fn conv1d_node(input: &str, output: &str, weights: &str, params: &str) -> Node {
Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(vec![
Node::let_bind("count", Expr::load(params, Expr::u32(0))),
Node::let_bind("stride", Expr::load(params, Expr::u32(1))),
Node::let_bind("radius", Expr::load(params, Expr::u32(2))),
Node::let_bind("idx", Expr::gid_x()),
Node::if_then(
Expr::lt(Expr::var("idx"), Expr::var("count")),
vec![
Node::let_bind(
"diameter",
Expr::add(Expr::mul(Expr::var("radius"), Expr::u32(2)), Expr::u32(1)),
),
Node::let_bind("acc", Expr::u32(0)),
Node::loop_for(
"k",
Expr::u32(0),
Expr::var("diameter"),
vec![
Node::let_bind(
"src_idx",
Expr::select(
Expr::ge(Expr::var("k"), Expr::var("radius")),
expr_min(
Expr::add(
Expr::var("idx"),
Expr::mul(
Expr::sub(Expr::var("k"), Expr::var("radius")),
Expr::var("stride"),
),
),
Expr::sub(Expr::var("count"), Expr::u32(1)),
),
Expr::select(
Expr::ge(
Expr::var("idx"),
Expr::mul(
Expr::sub(Expr::var("radius"), Expr::var("k")),
Expr::var("stride"),
),
),
Expr::sub(
Expr::var("idx"),
Expr::mul(
Expr::sub(Expr::var("radius"), Expr::var("k")),
Expr::var("stride"),
),
),
Expr::u32(0),
),
),
),
Node::let_bind("val", Expr::load(input, Expr::var("src_idx"))),
Node::let_bind("w", Expr::load(weights, Expr::var("k"))),
Node::assign(
"acc",
Expr::add(
Expr::var("acc"),
Expr::mul(Expr::var("val"), Expr::var("w")),
),
),
],
),
Node::store(output, Expr::var("idx"), Expr::var("acc")),
],
),
]),
}
}
#[must_use]
pub fn conv1d_program(count: u32, radius: u32) -> Program {
let clamped_radius = radius.min(MAX_RADIUS);
let diameter = 2 * clamped_radius + 1;
Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(count),
BufferDecl::storage("output", 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(count),
BufferDecl::storage("weights", 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(diameter),
BufferDecl::storage("params", 3, BufferAccess::ReadOnly, DataType::U32).with_count(4),
],
[256, 1, 1],
vec![conv1d_node("input", "output", "weights", "params")],
)
}
#[must_use]
pub fn gaussian_weights(radius: u32, sigma: f32) -> Vec<u32> {
let clamped = radius.min(MAX_RADIUS);
let diameter = (2 * clamped + 1) as usize;
let mut weights = vec![0.0f64; diameter];
let s2 = 2.0 * (sigma as f64) * (sigma as f64);
let mut sum = 0.0;
for (i, w) in weights.iter_mut().enumerate() {
let x = i as f64 - clamped as f64;
*w = (-x * x / s2).exp();
sum += *w;
}
weights
.iter()
.map(|w| ((w / sum) * 65536.0).round() as u32)
.collect()
}
#[must_use]
pub fn pack_params(count: u32, stride: u32, radius: u32) -> Vec<u32> {
vec![count, stride, radius.min(MAX_RADIUS), 0]
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| conv1d_program(8, 1),
Some(|| {
let input: Vec<u32> = vec![100, 200, 300, 400, 500, 600, 700, 800];
let params = pack_params(8, 1, 1);
let weights: Vec<u32> = vec![16384, 32768, 16384];
let to_bytes = |v: &[u32]| v.iter().flat_map(|w| w.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![
to_bytes(&input),
vec![0u8; 32], to_bytes(&weights),
to_bytes(¶ms),
]]
}),
Some(|| {
let to_bytes = |v: &[u32]| v.iter().flat_map(|w| w.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![to_bytes(&[
8_192_000, 13_107_200, 19_660_800, 26_214_400, 32_768_000, 39_321_600,
45_875_200, 50_790_400,
])]]
}),
)
}