use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::region::wrap_anonymous;
pub const Q4_K_SUPER_BLOCK_SIZE: u32 = 256;
pub const Q4_K_BLOCK_SIZE: u32 = 32;
pub const Q4_K_BLOCKS_PER_SUPER: u32 = 8;
pub fn q4_k_unpack(
packed: &str,
scales: &str,
mins: &str,
output: &str,
n: u32,
) -> Result<Program, String> {
if n == 0 {
return Err("Fix: q4_k_unpack n=0 is invalid".to_string());
}
let n_blocks = n.div_ceil(Q4_K_BLOCK_SIZE);
let i = Expr::var("i");
let body = vec![
Node::let_bind("i", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(i.clone(), Expr::u32(n)),
vec![
Node::let_bind(
"block_idx",
Expr::div(i.clone(), Expr::u32(Q4_K_BLOCK_SIZE)),
),
Node::let_bind(
"within_block",
Expr::rem(i.clone(), Expr::u32(Q4_K_BLOCK_SIZE)),
),
Node::let_bind(
"byte_idx",
Expr::div(Expr::var("within_block"), Expr::u32(2)),
),
Node::let_bind(
"shift",
Expr::mul(
Expr::rem(Expr::var("within_block"), Expr::u32(2)),
Expr::u32(4),
),
),
Node::let_bind(
"word_idx",
Expr::add(
Expr::mul(Expr::var("block_idx"), Expr::u32(4)),
Expr::div(Expr::var("byte_idx"), Expr::u32(4)),
),
),
Node::let_bind(
"word_shift",
Expr::mul(Expr::rem(Expr::var("byte_idx"), Expr::u32(4)), Expr::u32(8)),
),
Node::let_bind("packed_word", Expr::load(packed, Expr::var("word_idx"))),
Node::let_bind(
"byte_val",
Expr::bitand(
Expr::shr(Expr::var("packed_word"), Expr::var("word_shift")),
Expr::u32(0xFF),
),
),
Node::let_bind(
"nibble",
Expr::bitand(
Expr::shr(Expr::var("byte_val"), Expr::var("shift")),
Expr::u32(0xF),
),
),
Node::let_bind("scale", Expr::load(scales, Expr::var("block_idx"))),
Node::let_bind("min", Expr::load(mins, Expr::var("block_idx"))),
Node::Store {
buffer: output.into(),
index: i,
value: Expr::add(
Expr::mul(
Expr::cast(DataType::F32, Expr::var("nibble")),
Expr::var("scale"),
),
Expr::var("min"),
),
},
],
),
];
let packed_count = n_blocks * 4;
Ok(Program::wrapped(
vec![
BufferDecl::storage(packed, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(packed_count),
BufferDecl::storage(scales, 1, BufferAccess::ReadOnly, DataType::F32)
.with_count(n_blocks),
BufferDecl::storage(mins, 2, BufferAccess::ReadOnly, DataType::F32)
.with_count(n_blocks),
BufferDecl::output(output, 3, DataType::F32).with_count(n),
],
[256, 1, 1],
vec![wrap_anonymous("vyre-libs::quant::q4_k_unpack", body)],
))
}
pub const Q2_K_SUPER_BLOCK_SIZE: u32 = 256;
pub const Q2_K_BLOCK_SIZE: u32 = 16;
pub const Q2_K_BLOCKS_PER_SUPER: u32 = 16;
pub fn q2_k_unpack(
packed: &str,
scales: &str,
mins: &str,
output: &str,
n: u32,
) -> Result<Program, String> {
if n == 0 {
return Err("Fix: q2_k_unpack n=0 is invalid".to_string());
}
let n_blocks = n.div_ceil(Q2_K_BLOCK_SIZE);
let i = Expr::var("i");
let body = vec![
Node::let_bind("i", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(i.clone(), Expr::u32(n)),
vec![
Node::let_bind(
"block_idx",
Expr::div(i.clone(), Expr::u32(Q2_K_BLOCK_SIZE)),
),
Node::let_bind(
"within_block",
Expr::rem(i.clone(), Expr::u32(Q2_K_BLOCK_SIZE)),
),
Node::let_bind(
"byte_idx",
Expr::div(Expr::var("within_block"), Expr::u32(4)),
),
Node::let_bind(
"shift",
Expr::mul(
Expr::rem(Expr::var("within_block"), Expr::u32(4)),
Expr::u32(2),
),
),
Node::let_bind("word", Expr::load(packed, Expr::var("block_idx"))),
Node::let_bind(
"byte_val",
Expr::bitand(
Expr::shr(
Expr::var("word"),
Expr::mul(Expr::var("byte_idx"), Expr::u32(8)),
),
Expr::u32(0xFF),
),
),
Node::let_bind(
"q2",
Expr::bitand(
Expr::shr(Expr::var("byte_val"), Expr::var("shift")),
Expr::u32(0x3),
),
),
Node::let_bind("scale", Expr::load(scales, Expr::var("block_idx"))),
Node::let_bind("min", Expr::load(mins, Expr::var("block_idx"))),
Node::Store {
buffer: output.into(),
index: i,
value: Expr::add(
Expr::mul(
Expr::cast(DataType::F32, Expr::var("q2")),
Expr::var("scale"),
),
Expr::var("min"),
),
},
],
),
];
Ok(Program::wrapped(
vec![
BufferDecl::storage(packed, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_blocks),
BufferDecl::storage(scales, 1, BufferAccess::ReadOnly, DataType::F32)
.with_count(n_blocks),
BufferDecl::storage(mins, 2, BufferAccess::ReadOnly, DataType::F32)
.with_count(n_blocks),
BufferDecl::output(output, 3, DataType::F32).with_count(n),
],
[256, 1, 1],
vec![wrap_anonymous("vyre-libs::quant::q2_k_unpack", body)],
))
}
pub fn q4_k_linear(
x: &str,
w_packed: &str,
w_scales: &str,
w_mins: &str,
b: &str,
out: &str,
in_dim: u32,
out_dim: u32,
) -> Result<Program, String> {
if in_dim == 0 || out_dim == 0 {
return Err("Fix: q4_k_linear all dims must be > 0".to_string());
}
let n_blocks = in_dim.div_ceil(Q4_K_BLOCK_SIZE);
let i = Expr::var("i");
let body = vec![
Node::let_bind("i", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(i.clone(), Expr::u32(out_dim)),
vec![
Node::let_bind("acc", Expr::load(b, i.clone())),
Node::loop_for(
"k",
Expr::u32(0),
Expr::u32(in_dim),
vec![
Node::let_bind(
"linear_idx",
Expr::add(Expr::mul(Expr::var("k"), Expr::u32(out_dim)), i.clone()),
),
Node::let_bind(
"block_idx",
Expr::div(Expr::var("linear_idx"), Expr::u32(Q4_K_BLOCK_SIZE)),
),
Node::let_bind(
"within_block",
Expr::rem(Expr::var("linear_idx"), Expr::u32(Q4_K_BLOCK_SIZE)),
),
Node::let_bind(
"byte_idx",
Expr::div(Expr::var("within_block"), Expr::u32(2)),
),
Node::let_bind(
"shift",
Expr::mul(
Expr::rem(Expr::var("within_block"), Expr::u32(2)),
Expr::u32(4),
),
),
Node::let_bind(
"word_idx",
Expr::add(
Expr::mul(Expr::var("block_idx"), Expr::u32(4)),
Expr::div(Expr::var("byte_idx"), Expr::u32(4)),
),
),
Node::let_bind(
"word_shift",
Expr::mul(Expr::rem(Expr::var("byte_idx"), Expr::u32(4)), Expr::u32(8)),
),
Node::let_bind("packed_word", Expr::load(w_packed, Expr::var("word_idx"))),
Node::let_bind(
"byte_val",
Expr::bitand(
Expr::shr(Expr::var("packed_word"), Expr::var("word_shift")),
Expr::u32(0xFF),
),
),
Node::let_bind(
"nibble",
Expr::bitand(
Expr::shr(Expr::var("byte_val"), Expr::var("shift")),
Expr::u32(0xF),
),
),
Node::let_bind("scale", Expr::load(w_scales, Expr::var("block_idx"))),
Node::let_bind("min", Expr::load(w_mins, Expr::var("block_idx"))),
Node::let_bind(
"weight",
Expr::add(
Expr::mul(
Expr::cast(DataType::F32, Expr::var("nibble")),
Expr::var("scale"),
),
Expr::var("min"),
),
),
Node::assign(
"acc",
Expr::add(
Expr::var("acc"),
Expr::mul(Expr::load(x, Expr::var("k")), Expr::var("weight")),
),
),
],
),
Node::Store {
buffer: out.into(),
index: i,
value: Expr::var("acc"),
},
],
),
];
Ok(Program::wrapped(
vec![
BufferDecl::storage(x, 0, BufferAccess::ReadOnly, DataType::F32).with_count(in_dim),
BufferDecl::storage(w_packed, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_blocks * 4),
BufferDecl::storage(w_scales, 2, BufferAccess::ReadOnly, DataType::F32)
.with_count(n_blocks),
BufferDecl::storage(w_mins, 3, BufferAccess::ReadOnly, DataType::F32)
.with_count(n_blocks),
BufferDecl::storage(b, 4, BufferAccess::ReadOnly, DataType::F32).with_count(out_dim),
BufferDecl::output(out, 5, DataType::F32).with_count(out_dim),
],
[64, 1, 1],
vec![wrap_anonymous("vyre-libs::quant::q4_k_linear", body)],
))
}
pub fn q2_k_linear(
x: &str,
w_packed: &str,
w_scales: &str,
w_mins: &str,
b: &str,
out: &str,
in_dim: u32,
out_dim: u32,
) -> Result<Program, String> {
if in_dim == 0 || out_dim == 0 {
return Err("Fix: q2_k_linear all dims must be > 0".to_string());
}
let n_blocks = in_dim
.checked_mul(out_dim)
.ok_or("overflow")?
.div_ceil(Q2_K_BLOCK_SIZE);
let i = Expr::var("i");
let body = vec![
Node::let_bind("i", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(i.clone(), Expr::u32(out_dim)),
vec![
Node::let_bind("acc", Expr::load(b, i.clone())),
Node::loop_for(
"k",
Expr::u32(0),
Expr::u32(in_dim),
vec![
Node::let_bind(
"linear_idx",
Expr::add(Expr::mul(Expr::var("k"), Expr::u32(out_dim)), i.clone()),
),
Node::let_bind(
"block_idx",
Expr::div(Expr::var("linear_idx"), Expr::u32(Q2_K_BLOCK_SIZE)),
),
Node::let_bind(
"within_block",
Expr::rem(Expr::var("linear_idx"), Expr::u32(Q2_K_BLOCK_SIZE)),
),
Node::let_bind(
"byte_idx",
Expr::div(Expr::var("within_block"), Expr::u32(4)),
),
Node::let_bind(
"shift",
Expr::mul(
Expr::rem(Expr::var("within_block"), Expr::u32(4)),
Expr::u32(2),
),
),
Node::let_bind("word", Expr::load(w_packed, Expr::var("block_idx"))),
Node::let_bind(
"byte_val",
Expr::bitand(
Expr::shr(
Expr::var("word"),
Expr::mul(Expr::var("byte_idx"), Expr::u32(8)),
),
Expr::u32(0xFF),
),
),
Node::let_bind(
"q2",
Expr::bitand(
Expr::shr(Expr::var("byte_val"), Expr::var("shift")),
Expr::u32(0x3),
),
),
Node::let_bind("scale", Expr::load(w_scales, Expr::var("block_idx"))),
Node::let_bind("min", Expr::load(w_mins, Expr::var("block_idx"))),
Node::let_bind(
"weight",
Expr::add(
Expr::mul(
Expr::cast(DataType::F32, Expr::var("q2")),
Expr::var("scale"),
),
Expr::var("min"),
),
),
Node::assign(
"acc",
Expr::add(
Expr::var("acc"),
Expr::mul(Expr::load(x, Expr::var("k")), Expr::var("weight")),
),
),
],
),
Node::Store {
buffer: out.into(),
index: i,
value: Expr::var("acc"),
},
],
),
];
Ok(Program::wrapped(
vec![
BufferDecl::storage(x, 0, BufferAccess::ReadOnly, DataType::F32).with_count(in_dim),
BufferDecl::storage(w_packed, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_blocks),
BufferDecl::storage(w_scales, 2, BufferAccess::ReadOnly, DataType::F32)
.with_count(n_blocks),
BufferDecl::storage(w_mins, 3, BufferAccess::ReadOnly, DataType::F32)
.with_count(n_blocks),
BufferDecl::storage(b, 4, BufferAccess::ReadOnly, DataType::F32).with_count(out_dim),
BufferDecl::output(out, 5, DataType::F32).with_count(out_dim),
],
[64, 1, 1],
vec![wrap_anonymous("vyre-libs::quant::q2_k_linear", body)],
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::byte_pack::decode_f32;
use crate::test_support::byte_pack::f32_bytes;
use crate::test_support::byte_pack::u32_bytes;
use vyre_reference::value::Value;
#[test]
fn q4_k_unpack_simple() {
let scales = vec![1.0f32];
let mins = vec![0.0f32];
let packed = vec![0x7654_3210u32, 0xFEDC_BA98, 0x0, 0x0];
let program = q4_k_unpack("packed", "scales", "mins", "out", 16).unwrap();
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(u32_bytes(&packed)),
Value::from(f32_bytes(&scales)),
Value::from(f32_bytes(&mins)),
Value::from(vec![0u8; 64]),
],
)
.expect("Fix: q4_k_unpack must execute");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(out[0], 0.0);
assert_eq!(out[1], 1.0);
assert_eq!(out[2], 2.0);
assert_eq!(out[15], 15.0);
}
#[test]
fn q2_k_unpack_simple() {
let scales = vec![1.0f32];
let mins = vec![0.0f32];
let packed = vec![0xE4E4_E4E4u32]; let program = q2_k_unpack("packed", "scales", "mins", "out", 16).unwrap();
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(u32_bytes(&packed)),
Value::from(f32_bytes(&scales)),
Value::from(f32_bytes(&mins)),
Value::from(vec![0u8; 64]),
],
)
.expect("Fix: q2_k_unpack must execute");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(out[0], 0.0);
assert_eq!(out[1], 1.0);
assert_eq!(out[2], 2.0);
assert_eq!(out[3], 3.0);
}
#[test]
fn q4_k_linear_simple() {
let x = vec![1.0f32, 0.0];
let b = vec![0.0f32, 0.0];
let packed = vec![0x0000_3210u32, 0, 0, 0];
let scales = vec![1.0f32];
let mins = vec![0.0f32];
let program = q4_k_linear("x", "packed", "scales", "mins", "b", "out", 2, 2).unwrap();
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&x)),
Value::from(u32_bytes(&packed)),
Value::from(f32_bytes(&scales)),
Value::from(f32_bytes(&mins)),
Value::from(f32_bytes(&b)),
Value::from(vec![0u8; 8]),
],
)
.expect("Fix: q4_k_linear must execute");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(out[0], 0.0);
assert_eq!(out[1], 1.0);
}
}