use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::region::wrap_anonymous;
use vyre_primitives::nn::f32_stability::{finite_or, positive_finite_or_min as positive_scale};
const ROUND_OP_ID: &str = "vyre-libs::quant::gptq_round";
const SDCLIP_OP_ID: &str = "vyre-libs::quant::gptq_sdclip";
fn clamp_f32(value: Expr, lo: f32, hi: f32) -> Expr {
let finite = finite_or(value, Expr::f32(lo));
let lower = Expr::select(
Expr::lt(finite.clone(), Expr::f32(lo)),
Expr::f32(lo),
finite,
);
Expr::select(Expr::gt(lower.clone(), Expr::f32(hi)), Expr::f32(hi), lower)
}
#[must_use]
pub fn gptq_round(input: &str, scale: &str, output: &str, n: u32, max_val: f32) -> Program {
let i = Expr::var("i");
let x = finite_or(Expr::load(input, i.clone()), Expr::f32(0.0));
let s = positive_scale(Expr::load(scale, i.clone()));
let divided = Expr::select(
Expr::eq(x.clone(), s.clone()),
Expr::f32(1.0),
Expr::div(x, s),
);
let clamped = clamp_f32(divided, 0.0, max_val);
let body = vec![
Node::let_bind("i", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(i.clone(), Expr::u32(n)),
vec![Node::Store {
buffer: output.into(),
index: i,
value: clamped,
}],
),
];
Program::wrapped(
vec![
BufferDecl::storage(input, 0, BufferAccess::ReadOnly, DataType::F32).with_count(n),
BufferDecl::storage(scale, 1, BufferAccess::ReadOnly, DataType::F32).with_count(n),
BufferDecl::output(output, 2, DataType::F32).with_count(n),
],
[64, 1, 1],
vec![wrap_anonymous(ROUND_OP_ID, body)],
)
}
#[must_use]
pub fn gptq_sdclip(input: &str, output: &str, n: u32, k: f32) -> Program {
let i = Expr::var("i");
let x = finite_or(Expr::load(input, i.clone()), Expr::f32(0.0));
let clamped = clamp_f32(x, -k, k);
let body = vec![
Node::let_bind("i", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(i.clone(), Expr::u32(n)),
vec![Node::Store {
buffer: output.into(),
index: i,
value: clamped,
}],
),
];
Program::wrapped(
vec![
BufferDecl::storage(input, 0, BufferAccess::ReadOnly, DataType::F32).with_count(n),
BufferDecl::output(output, 1, DataType::F32).with_count(n),
],
[64, 1, 1],
vec![wrap_anonymous(SDCLIP_OP_ID, body)],
)
}
inventory::submit! {
crate::harness::OpEntry {
id: ROUND_OP_ID,
build: || gptq_round("input", "scale", "output", 4, 63.0),
test_inputs: Some(|| {
let to_f32 = |w: &[f32]| vyre_primitives::wire::pack_f32_slice(w);
vec![vec![
to_f32(&[100.0, 200.0, 50.0, 10.0]),
to_f32(&[2.0, 3.0, 1.0, 5.0]),
]]
}),
expected_output: Some(|| {
let out = [50.0_f32, 63.0, 50.0, 2.0];
let bytes = vyre_primitives::wire::pack_f32_slice(&out);
vec![vec![bytes]]
}),
category: Some("nn"),
}
}
inventory::submit! {
crate::harness::OpEntry {
id: SDCLIP_OP_ID,
build: || gptq_sdclip("input", "output", 4, 30.0),
test_inputs: Some(|| {
let to_f32 = |w: &[f32]| vyre_primitives::wire::pack_f32_slice(w);
vec![vec![
to_f32(&[10.0, 50.0, -40.0, 25.0]),
]]
}),
expected_output: Some(|| {
let out = [10.0_f32, 30.0, -30.0, 25.0];
let bytes = vyre_primitives::wire::pack_f32_slice(&out);
vec![vec![bytes]]
}),
category: Some("nn"),
}
}