use rayon::prelude::*;
pub const I2S_BLOCK_SIZE: usize = 128;
pub const I2S_BYTES_PER_BLOCK: usize = 32;
pub const I2S_GROUP_SIZE: usize = 32;
#[inline(always)]
pub const fn decode_trit(code: u8) -> i8 {
match code & 0b11 {
0 => -1,
1 => 0,
2 => 1,
_ => 0,
}
}
#[inline(always)]
const fn encode_trit(v: i8) -> u8 {
if v > 0 {
2
} else if v < 0 {
0
} else {
1
}
}
#[derive(Debug, Clone)]
pub struct I2sBlock {
pub data: [u8; I2S_BYTES_PER_BLOCK],
}
impl I2sBlock {
pub fn pack(values: &[i8; I2S_BLOCK_SIZE]) -> Self {
let mut data = [0u8; I2S_BYTES_PER_BLOCK];
for group_idx in 0..4 {
let shift = 6 - 2 * group_idx;
for group_pos in 0..I2S_GROUP_SIZE {
let code = encode_trit(values[group_idx * I2S_GROUP_SIZE + group_pos]);
data[group_pos] |= code << shift;
}
}
Self { data }
}
pub fn to_bytes(&self) -> [u8; I2S_BYTES_PER_BLOCK] {
self.data
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < I2S_BYTES_PER_BLOCK {
return None;
}
let mut data = [0u8; I2S_BYTES_PER_BLOCK];
data.copy_from_slice(&bytes[..I2S_BYTES_PER_BLOCK]);
Some(Self { data })
}
pub fn unpack(&self) -> [i8; I2S_BLOCK_SIZE] {
let mut out = [0i8; I2S_BLOCK_SIZE];
for group_idx in 0..4 {
let shift = 6 - 2 * group_idx;
for group_pos in 0..I2S_GROUP_SIZE {
let code = (self.data[group_pos] >> shift) & 0b11;
out[group_idx * I2S_GROUP_SIZE + group_pos] = decode_trit(code);
}
}
out
}
}
pub fn dequantize_i2s_block(bytes: &[u8], scale: f32, out: &mut [f32]) {
debug_assert!(bytes.len() >= I2S_BYTES_PER_BLOCK);
debug_assert!(out.len() >= I2S_BLOCK_SIZE);
let b = &bytes[..I2S_BYTES_PER_BLOCK];
for (group_pos, &byte) in b.iter().enumerate() {
let c0 = decode_trit((byte >> 6) & 0b11) as f32; let c1 = decode_trit((byte >> 4) & 0b11) as f32; let c2 = decode_trit((byte >> 2) & 0b11) as f32; let c3 = decode_trit(byte & 0b11) as f32; out[group_pos] = c0 * scale;
out[group_pos + I2S_GROUP_SIZE] = c1 * scale;
out[group_pos + 2 * I2S_GROUP_SIZE] = c2 * scale;
out[group_pos + 3 * I2S_GROUP_SIZE] = c3 * scale;
}
}
pub fn dequantize_i2s(bytes: &[u8], scale: f32, out: &mut [f32]) {
let n_blocks = out.len() / I2S_BLOCK_SIZE;
out.par_chunks_mut(I2S_BLOCK_SIZE)
.take(n_blocks)
.zip(bytes.par_chunks(I2S_BYTES_PER_BLOCK).take(n_blocks))
.for_each(|(out_block, in_block)| {
if in_block.len() >= I2S_BYTES_PER_BLOCK {
dequantize_i2s_block(in_block, scale, out_block);
}
});
}
pub fn matmul_i2s(
activations: &[f32],
m: usize,
k: usize,
weight_bytes: &[u8],
n: usize,
scale: f32,
output: &mut [f32],
) {
assert!(
k % I2S_BLOCK_SIZE == 0,
"matmul_i2s: k ({k}) must be a multiple of {I2S_BLOCK_SIZE}",
);
assert_eq!(activations.len(), m * k, "activations shape mismatch");
assert_eq!(output.len(), m * n, "output shape mismatch");
let blocks_per_row = k / I2S_BLOCK_SIZE;
let bytes_per_row = blocks_per_row * I2S_BYTES_PER_BLOCK;
assert_eq!(
weight_bytes.len(),
n * bytes_per_row,
"weight_bytes shape mismatch",
);
for i in 0..m {
let act_row = &activations[i * k..(i + 1) * k];
output[i * n..(i + 1) * n]
.par_iter_mut()
.enumerate()
.for_each(|(j, out_slot)| {
let wrow = &weight_bytes[j * bytes_per_row..(j + 1) * bytes_per_row];
*out_slot = dot_row_ternary(act_row, wrow, blocks_per_row) * scale;
});
}
}
#[inline(always)]
fn dot_row_ternary(act_row: &[f32], wrow: &[u8], blocks_per_row: usize) -> f32 {
let mut acc = 0.0f32;
for block_idx in 0..blocks_per_row {
let block_off = block_idx * I2S_BYTES_PER_BLOCK;
let block = &wrow[block_off..block_off + I2S_BYTES_PER_BLOCK];
let k_base = block_idx * I2S_BLOCK_SIZE;
let mut a0 = 0.0f32;
let mut a1 = 0.0f32;
let mut a2 = 0.0f32;
let mut a3 = 0.0f32;
for (group_pos, &byte) in block.iter().enumerate() {
let t0 = decode_trit((byte >> 6) & 0b11) as f32;
let t1 = decode_trit((byte >> 4) & 0b11) as f32;
let t2 = decode_trit((byte >> 2) & 0b11) as f32;
let t3 = decode_trit(byte & 0b11) as f32;
let base = k_base + group_pos;
a0 += act_row[base] * t0;
a1 += act_row[base + I2S_GROUP_SIZE] * t1;
a2 += act_row[base + 2 * I2S_GROUP_SIZE] * t2;
a3 += act_row[base + 3 * I2S_GROUP_SIZE] * t3;
}
acc += a0 + a1 + a2 + a3;
}
acc
}
pub fn quantize_row_to_int8(input: &[f32], output: &mut [i8]) -> f32 {
debug_assert_eq!(input.len(), output.len());
let absmax = input.iter().fold(0.0f32, |m, &v| m.max(v.abs()));
if absmax == 0.0 {
for o in output.iter_mut() {
*o = 0;
}
return 0.0;
}
let scale = absmax / 127.0;
let inv_scale = 1.0 / scale;
for (o, &v) in output.iter_mut().zip(input.iter()) {
let q = (v * inv_scale).round();
*o = q.clamp(-127.0, 127.0) as i8;
}
scale
}
pub fn matmul_i2s_i8(
acts_int8: &[i8],
act_scales: &[f32],
m: usize,
k: usize,
weight_bytes: &[u8],
n: usize,
weight_scale: f32,
output: &mut [f32],
) {
assert!(
k % I2S_BLOCK_SIZE == 0,
"matmul_i2s_i8: k ({k}) must be a multiple of {I2S_BLOCK_SIZE}",
);
assert_eq!(acts_int8.len(), m * k, "acts_int8 shape mismatch");
assert_eq!(act_scales.len(), m, "act_scales length mismatch");
assert_eq!(output.len(), m * n, "output shape mismatch");
let blocks_per_row = k / I2S_BLOCK_SIZE;
let bytes_per_row = blocks_per_row * I2S_BYTES_PER_BLOCK;
assert_eq!(
weight_bytes.len(),
n * bytes_per_row,
"weight_bytes shape mismatch",
);
#[cfg(target_arch = "x86_64")]
{
if std::is_x86_feature_detected!("avxvnni") && std::is_x86_feature_detected!("avx2") {
unsafe {
matmul_i2s_i8_avxvnni(
acts_int8,
act_scales,
m,
k,
weight_bytes,
n,
weight_scale,
output,
);
}
return;
}
}
matmul_i2s_i8_scalar(
acts_int8,
act_scales,
m,
k,
weight_bytes,
n,
weight_scale,
output,
);
}
fn matmul_i2s_i8_scalar(
acts_int8: &[i8],
act_scales: &[f32],
m: usize,
k: usize,
weight_bytes: &[u8],
n: usize,
weight_scale: f32,
output: &mut [f32],
) {
let blocks_per_row = k / I2S_BLOCK_SIZE;
let bytes_per_row = blocks_per_row * I2S_BYTES_PER_BLOCK;
for i in 0..m {
let act_row = &acts_int8[i * k..(i + 1) * k];
let act_scale = act_scales[i];
let act_sum: i32 = act_row.iter().map(|&x| x as i32).sum();
let combined_scale = weight_scale * act_scale;
output[i * n..(i + 1) * n]
.par_iter_mut()
.enumerate()
.for_each(|(j, out_slot)| {
let wrow = &weight_bytes[j * bytes_per_row..(j + 1) * bytes_per_row];
let mut code_dot: i32 = 0;
for block_idx in 0..blocks_per_row {
let block_off = block_idx * I2S_BYTES_PER_BLOCK;
let block = &wrow[block_off..block_off + I2S_BYTES_PER_BLOCK];
let k_base = block_idx * I2S_BLOCK_SIZE;
for (group_pos, &byte) in block.iter().enumerate() {
let c0 = ((byte >> 6) & 0b11) as i32;
let c1 = ((byte >> 4) & 0b11) as i32;
let c2 = ((byte >> 2) & 0b11) as i32;
let c3 = (byte & 0b11) as i32;
let base = k_base + group_pos;
code_dot += c0 * act_row[base] as i32;
code_dot += c1 * act_row[base + I2S_GROUP_SIZE] as i32;
code_dot += c2 * act_row[base + 2 * I2S_GROUP_SIZE] as i32;
code_dot += c3 * act_row[base + 3 * I2S_GROUP_SIZE] as i32;
}
}
let trit_dot = code_dot - act_sum;
*out_slot = (trit_dot as f32) * combined_scale;
});
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,avxvnni")]
unsafe fn matmul_i2s_i8_avxvnni(
acts_int8: &[i8],
act_scales: &[f32],
m: usize,
k: usize,
weight_bytes: &[u8],
n: usize,
weight_scale: f32,
output: &mut [f32],
) {
matmul_i2s_i8_scalar(
acts_int8,
act_scales,
m,
k,
weight_bytes,
n,
weight_scale,
output,
);
}
pub fn bytes_for_elements(n_elements: usize) -> usize {
(n_elements / I2S_BLOCK_SIZE) * I2S_BYTES_PER_BLOCK
}
#[cfg(test)]
mod tests {
use super::*;
fn make_trits(pattern: &[i8]) -> [i8; I2S_BLOCK_SIZE] {
let mut out = [0i8; I2S_BLOCK_SIZE];
for (i, v) in pattern.iter().cycle().take(I2S_BLOCK_SIZE).enumerate() {
out[i] = *v;
}
out
}
#[test]
fn trit_encode_decode_roundtrip() {
assert_eq!(decode_trit(0b00), -1);
assert_eq!(decode_trit(0b01), 0);
assert_eq!(decode_trit(0b10), 1);
assert_eq!(decode_trit(0b11), 0);
assert_eq!(encode_trit(-1), 0);
assert_eq!(encode_trit(0), 1);
assert_eq!(encode_trit(1), 2);
assert_eq!(encode_trit(42), 2);
assert_eq!(encode_trit(-42), 0);
}
#[test]
fn block_pack_unpack_roundtrip() {
let values = make_trits(&[-1, 0, 1, 0, 1, -1, 0, 0]);
let block = I2sBlock::pack(&values);
let decoded = block.unpack();
assert_eq!(&values[..], &decoded[..]);
}
#[test]
fn block_bytes_roundtrip() {
let values = make_trits(&[1, -1, 0]);
let block = I2sBlock::pack(&values);
let bytes = block.to_bytes();
let parsed = I2sBlock::from_bytes(&bytes).unwrap();
assert_eq!(parsed.data, block.data);
assert_eq!(parsed.unpack(), values);
}
#[test]
fn dequantize_single_block() {
let values = make_trits(&[1, -1]);
let block = I2sBlock::pack(&values);
let bytes = block.to_bytes();
let mut out = [0.0f32; I2S_BLOCK_SIZE];
dequantize_i2s_block(&bytes, 2.5, &mut out);
for (i, v) in out.iter().enumerate() {
let expected = if i % 2 == 0 { 2.5 } else { -2.5 };
assert!(
(v - expected).abs() < 1e-6,
"idx {i}: got {v}, expected {expected}",
);
}
}
#[test]
fn group_strided_layout_is_correct() {
let mut values = [0i8; I2S_BLOCK_SIZE];
values[0] = 1; let block = I2sBlock::pack(&values);
assert_eq!(
block.data[0] & 0b1100_0000,
0b1000_0000,
"expected code=2 (+1) in byte 0 bits 6-7",
);
assert_eq!(
(block.data[0] >> 4) & 0b11,
1,
"expected code=1 (0) in byte 0 bits 4-5",
);
}
#[test]
fn dequantize_multi_block_tensor() {
let n_blocks = 3;
let n_elem = n_blocks * I2S_BLOCK_SIZE;
let mut bytes = Vec::with_capacity(n_blocks * I2S_BYTES_PER_BLOCK);
let patterns: &[&[i8]] = &[&[1, 0, -1], &[-1, 1, 0], &[0, 0, 1, -1]];
for b in 0..n_blocks {
let block = I2sBlock::pack(&make_trits(patterns[b]));
bytes.extend_from_slice(&block.to_bytes());
}
let mut out = vec![0.0f32; n_elem];
dequantize_i2s(&bytes, 1.0, &mut out);
assert_eq!(out[0], 1.0);
assert_eq!(out[I2S_BLOCK_SIZE], -1.0);
assert_eq!(out[2 * I2S_BLOCK_SIZE], 0.0);
}
fn reference_matmul(
activations: &[f32],
m: usize,
k: usize,
weight_bytes: &[u8],
n: usize,
scale: f32,
output: &mut [f32],
) {
let mut w = vec![0.0f32; n * k];
dequantize_i2s(weight_bytes, scale, &mut w);
for i in 0..m {
for j in 0..n {
let mut s = 0.0f32;
for kk in 0..k {
s += activations[i * k + kk] * w[j * k + kk];
}
output[i * n + j] = s;
}
}
}
#[test]
fn matmul_matches_reference_small() {
let m = 2;
let k = I2S_BLOCK_SIZE;
let n = 4;
let scale = 0.125f32;
let mut weight_bytes = Vec::new();
let patterns: &[&[i8]] = &[&[1, 0, -1], &[-1, 1, 0], &[0, -1, 1], &[1, 1, -1, -1]];
for j in 0..n {
let vals = make_trits(patterns[j]);
let block = I2sBlock::pack(&vals);
weight_bytes.extend_from_slice(&block.to_bytes());
}
let mut activations = vec![0.0f32; m * k];
for i in 0..m {
for kk in 0..k {
activations[i * k + kk] = (i as f32 + 1.0) * (kk as f32 / k as f32);
}
}
let mut fused_out = vec![0.0f32; m * n];
let mut ref_out = vec![0.0f32; m * n];
matmul_i2s(&activations, m, k, &weight_bytes, n, scale, &mut fused_out);
reference_matmul(&activations, m, k, &weight_bytes, n, scale, &mut ref_out);
for (i, (f, r)) in fused_out.iter().zip(ref_out.iter()).enumerate() {
assert!((f - r).abs() < 1e-5, "mismatch at {i}: fused={f}, ref={r}",);
}
}
#[test]
fn matmul_matches_reference_multi_block() {
let m = 3;
let k = 3 * I2S_BLOCK_SIZE;
let n = 5;
let scale = 0.25f32;
let mut weight_bytes = Vec::new();
for j in 0..n {
for b in 0..(k / I2S_BLOCK_SIZE) {
let pattern = if (j + b) % 2 == 0 {
&[1, 0, -1, 1, -1][..]
} else {
&[-1, -1, 1, 0, 1][..]
};
let block = I2sBlock::pack(&make_trits(pattern));
weight_bytes.extend_from_slice(&block.to_bytes());
}
}
let mut activations = vec![0.0f32; m * k];
for i in 0..m {
for kk in 0..k {
activations[i * k + kk] = ((i + 1) as f32) * ((kk as f32).sin());
}
}
let mut fused_out = vec![0.0f32; m * n];
let mut ref_out = vec![0.0f32; m * n];
matmul_i2s(&activations, m, k, &weight_bytes, n, scale, &mut fused_out);
reference_matmul(&activations, m, k, &weight_bytes, n, scale, &mut ref_out);
for (i, (f, r)) in fused_out.iter().zip(ref_out.iter()).enumerate() {
assert!((f - r).abs() < 1e-4, "mismatch at {i}: fused={f}, ref={r}",);
}
}
#[test]
fn bytes_for_elements_calculation() {
assert_eq!(bytes_for_elements(128), 32);
assert_eq!(bytes_for_elements(256), 64);
assert_eq!(bytes_for_elements(1024), 256);
assert_eq!(bytes_for_elements(0), 0);
}
#[test]
fn int8_matmul_matches_f32_within_quant_error() {
let m = 2;
let k = 2 * I2S_BLOCK_SIZE;
let n = 6;
let weight_scale = 0.1f32;
let mut weight_bytes = Vec::new();
for j in 0..n {
for b in 0..(k / I2S_BLOCK_SIZE) {
let pattern: &[i8] = if (j + b) % 2 == 0 {
&[1, 0, -1, 1]
} else {
&[-1, 1, 0, -1]
};
let block = I2sBlock::pack(&make_trits(pattern));
weight_bytes.extend_from_slice(&block.to_bytes());
}
}
let mut activations = vec![0.0f32; m * k];
for i in 0..m {
for kk in 0..k {
activations[i * k + kk] = ((kk as f32) * 0.13 - 2.0).sin() * (1.0 + i as f32 * 0.1);
}
}
let mut ref_out = vec![0.0f32; m * n];
matmul_i2s(
&activations,
m,
k,
&weight_bytes,
n,
weight_scale,
&mut ref_out,
);
let mut acts_i8 = vec![0i8; m * k];
let mut act_scales = vec![0.0f32; m];
for i in 0..m {
act_scales[i] = quantize_row_to_int8(
&activations[i * k..(i + 1) * k],
&mut acts_i8[i * k..(i + 1) * k],
);
}
let mut i8_out = vec![0.0f32; m * n];
matmul_i2s_i8(
&acts_i8,
&act_scales,
m,
k,
&weight_bytes,
n,
weight_scale,
&mut i8_out,
);
for (i, (&r, &q)) in ref_out.iter().zip(i8_out.iter()).enumerate() {
let abs_err = (r - q).abs();
let rel_err = abs_err / r.abs().max(1e-6);
assert!(
rel_err < 0.05 || abs_err < 1e-3,
"idx {i}: f32 ref = {r}, int8 quantized = {q}, rel_err = {rel_err}",
);
}
}
#[test]
fn quantize_row_to_int8_roundtrip() {
let input = [1.0f32, -2.0, 0.5, -0.5, 0.0, 2.0, -1.5];
let mut output = [0i8; 7];
let scale = quantize_row_to_int8(&input, &mut output);
assert!(scale > 0.0);
let max_idx = input
.iter()
.enumerate()
.max_by(|a, b| a.1.abs().partial_cmp(&b.1.abs()).unwrap())
.unwrap()
.0;
assert_eq!(output[max_idx].unsigned_abs(), 127);
for (i, &v) in input.iter().enumerate() {
let recovered = output[i] as f32 * scale;
assert!(
(recovered - v).abs() < scale,
"idx {i}: {v} → {} (scale={scale})",
recovered
);
}
}
#[test]
fn quantize_row_to_int8_zero_input() {
let input = [0.0f32; 8];
let mut output = [0i8; 8];
let scale = quantize_row_to_int8(&input, &mut output);
assert_eq!(scale, 0.0);
assert!(output.iter().all(|&x| x == 0));
}
#[test]
#[should_panic(expected = "k")]
fn matmul_rejects_misaligned_k() {
let m = 1;
let k = 100;
let n = 1;
let acts = vec![0.0; m * k];
let weight_bytes = vec![0u8; I2S_BYTES_PER_BLOCK];
let mut out = vec![0.0; m * n];
matmul_i2s(&acts, m, k, &weight_bytes, n, 1.0, &mut out);
}
}