use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{Result, anyhow};
use ronn_core::{CompiledKernel, KernelStats, MemoryUsage, SubGraph, Tensor};
use super::quantization::{BinaryTensor, BitNetQuantizer, QuantizationMethod, TernaryTensor};
#[derive(Debug, Clone)]
pub enum BitNetOperation {
BinaryMatMul { m: usize, n: usize, k: usize },
TernaryMatMul { m: usize, n: usize, k: usize },
BinaryElementWise {
op_type: String,
element_count: usize,
},
QuantizedBatchNorm {
channels: usize,
spatial_size: usize,
},
}
#[derive(Debug)]
pub struct BitNetKernel {
operation: BitNetOperation,
quantizer: Arc<BitNetQuantizer>,
stats: KernelStats,
memory_usage: MemoryUsage,
cache: HashMap<String, Vec<u8>>,
}
impl BitNetKernel {
pub fn new(operation: BitNetOperation, quantization_method: QuantizationMethod) -> Self {
let quantizer = Arc::new(BitNetQuantizer::new(quantization_method));
Self {
operation,
quantizer,
stats: KernelStats {
execution_count: 0,
average_time_us: 0.0,
min_time_us: 0.0,
max_time_us: 0.0,
},
memory_usage: MemoryUsage {
peak_bytes: 0,
current_bytes: 0,
allocation_count: 0,
},
cache: HashMap::new(),
}
}
fn execute_binary_matmul(
a: &BinaryTensor,
b: &BinaryTensor,
m: usize,
n: usize,
k: usize,
) -> Result<Tensor> {
if a.shape != [m, k] || b.shape != [k, n] {
return Err(anyhow!(
"Matrix dimension mismatch: A={:?}, B={:?}, expected A=[{},{}], B=[{},{}]",
a.shape,
b.shape,
m,
k,
k,
n
));
}
let mut result_data = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut dot_product = 0i32;
let chunks_per_row = (k + 63) / 64;
for chunk in 0..chunks_per_row {
let start_bit = chunk * 64;
let end_bit = (start_bit + 64).min(k);
let bits_in_chunk = end_bit - start_bit;
let a_chunk = Self::extract_bits_as_u64(a, i * k + start_bit, bits_in_chunk);
let b_chunk = Self::extract_bits_as_u64(b, start_bit * n + j, bits_in_chunk);
let xnor_result = !(a_chunk ^ b_chunk);
dot_product += xnor_result.count_ones() as i32;
dot_product -= (bits_in_chunk as i32 - xnor_result.count_ones() as i32);
}
result_data[i * n + j] = dot_product as f32 * a.scale * b.scale;
}
}
Tensor::from_data(
result_data,
vec![m, n],
ronn_core::DataType::F32,
ronn_core::TensorLayout::RowMajor,
)
}
fn execute_ternary_matmul(
a: &TernaryTensor,
b: &TernaryTensor,
m: usize,
n: usize,
k: usize,
) -> Result<Tensor> {
if a.shape != [m, k] || b.shape != [k, n] {
return Err(anyhow!(
"Matrix dimension mismatch: A={:?}, B={:?}, expected A=[{},{}], B=[{},{}]",
a.shape,
b.shape,
m,
k,
k,
n
));
}
let mut result_data = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut dot_product = 0i32;
for l in 0..k {
let a_val = a.get_value(i * k + l) as i32;
let b_val = b.get_value(l * n + j) as i32;
dot_product += a_val * b_val;
}
result_data[i * n + j] = dot_product as f32 * a.scale * b.scale;
}
}
Tensor::from_data(
result_data,
vec![m, n],
ronn_core::DataType::F32,
ronn_core::TensorLayout::RowMajor,
)
}
fn extract_bits_as_u64(tensor: &BinaryTensor, start_bit: usize, bit_count: usize) -> u64 {
let mut result = 0u64;
for i in 0..bit_count.min(64) {
if tensor.get_bit(start_bit + i) {
result |= 1u64 << i;
}
}
result
}
fn execute_binary_elementwise(inputs: &[BinaryTensor], op_type: &str) -> Result<BinaryTensor> {
if inputs.len() != 2 {
return Err(anyhow!(
"Binary element-wise operations require exactly 2 inputs"
));
}
let a = &inputs[0];
let b = &inputs[1];
if a.shape != b.shape {
return Err(anyhow!(
"Shape mismatch for element-wise operation: {:?} vs {:?}",
a.shape,
b.shape
));
}
let byte_count = a.packed_data.len();
let mut result_data = vec![0u8; byte_count];
for i in 0..byte_count {
result_data[i] = match op_type {
"And" => a.packed_data[i] & b.packed_data[i],
"Or" => a.packed_data[i] | b.packed_data[i],
"Xor" => a.packed_data[i] ^ b.packed_data[i],
"Nand" => !(a.packed_data[i] & b.packed_data[i]),
_ => return Err(anyhow!("Unsupported binary operation: {}", op_type)),
};
}
Ok(BinaryTensor {
packed_data: result_data,
shape: a.shape.clone(),
scale: (a.scale + b.scale) / 2.0, element_count: a.element_count,
})
}
fn update_stats(&mut self, execution_time_us: f64) {
self.stats.execution_count += 1;
if self.stats.execution_count == 1 {
self.stats.min_time_us = execution_time_us;
self.stats.max_time_us = execution_time_us;
self.stats.average_time_us = execution_time_us;
} else {
self.stats.min_time_us = self.stats.min_time_us.min(execution_time_us);
self.stats.max_time_us = self.stats.max_time_us.max(execution_time_us);
let n = self.stats.execution_count as f64;
self.stats.average_time_us =
((n - 1.0) * self.stats.average_time_us + execution_time_us) / n;
}
}
}
impl CompiledKernel for BitNetKernel {
fn execute(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
let start_time = std::time::Instant::now();
let results = match &self.operation {
BitNetOperation::BinaryMatMul { m, n, k } => {
if inputs.len() != 2 {
return Err(anyhow!("Binary MatMul requires exactly 2 inputs"));
}
let a_binary = self.quantizer.quantize_binary(&inputs[0])?;
let b_binary = self.quantizer.quantize_binary(&inputs[1])?;
let result = BitNetKernel::execute_binary_matmul(&a_binary, &b_binary, *m, *n, *k)?;
vec![result]
}
BitNetOperation::TernaryMatMul { m, n, k } => {
if inputs.len() != 2 {
return Err(anyhow!("Ternary MatMul requires exactly 2 inputs"));
}
let a_ternary = self.quantizer.quantize_ternary(&inputs[0])?;
let b_ternary = self.quantizer.quantize_ternary(&inputs[1])?;
let result =
BitNetKernel::execute_ternary_matmul(&a_ternary, &b_ternary, *m, *n, *k)?;
vec![result]
}
BitNetOperation::BinaryElementWise { op_type, .. } => {
if inputs.len() != 2 {
return Err(anyhow!(
"Binary element-wise operations require exactly 2 inputs"
));
}
let a_binary = self.quantizer.quantize_binary(&inputs[0])?;
let b_binary = self.quantizer.quantize_binary(&inputs[1])?;
let binary_result =
BitNetKernel::execute_binary_elementwise(&[a_binary, b_binary], op_type)?;
let result = self.quantizer.dequantize_binary(&binary_result)?;
vec![result]
}
BitNetOperation::QuantizedBatchNorm {
channels,
spatial_size,
} => {
if inputs.len() < 3 {
return Err(anyhow!("Quantized BatchNorm requires at least 3 inputs"));
}
vec![inputs[0].clone()]
}
};
let execution_time = start_time.elapsed().as_micros() as f64;
Ok(results)
}
fn get_memory_usage(&self) -> MemoryUsage {
self.memory_usage.clone()
}
fn get_performance_stats(&self) -> KernelStats {
self.stats.clone()
}
}
pub fn create_binary_matmul_kernel(m: usize, n: usize, k: usize) -> BitNetKernel {
BitNetKernel::new(
BitNetOperation::BinaryMatMul { m, n, k },
QuantizationMethod::Binary,
)
}
pub fn create_ternary_matmul_kernel(m: usize, n: usize, k: usize) -> BitNetKernel {
BitNetKernel::new(
BitNetOperation::TernaryMatMul { m, n, k },
QuantizationMethod::Ternary,
)
}
pub fn create_binary_elementwise_kernel(op_type: &str, element_count: usize) -> BitNetKernel {
BitNetKernel::new(
BitNetOperation::BinaryElementWise {
op_type: op_type.to_string(),
element_count,
},
QuantizationMethod::Binary,
)
}
#[cfg(test)]
mod tests {
use super::*;
use ronn_core::{DataType, TensorLayout};
#[test]
fn test_binary_matmul_kernel() -> Result<()> {
let mut kernel = create_binary_matmul_kernel(2, 3, 4);
let a_data = vec![1.0, -1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0]; let b_data = vec![
1.0, -1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0,
];
let a = Tensor::from_data(a_data, vec![2, 4], DataType::F32, TensorLayout::RowMajor)?;
let b = Tensor::from_data(b_data, vec![4, 3], DataType::F32, TensorLayout::RowMajor)?;
let results = kernel.execute(&[a, b])?;
assert_eq!(results.len(), 1);
assert_eq!(results[0].shape(), &[2, 3]);
Ok(())
}
#[test]
fn test_ternary_matmul_kernel() -> Result<()> {
let mut kernel = create_ternary_matmul_kernel(2, 2, 3);
let a_data = vec![2.0, -1.5, 0.1, -0.05, 1.8, -2.1]; let b_data = vec![1.0, -2.0, 0.05, -1.5, 2.5, 0.1];
let a = Tensor::from_data(a_data, vec![2, 3], DataType::F32, TensorLayout::RowMajor)?;
let b = Tensor::from_data(b_data, vec![3, 2], DataType::F32, TensorLayout::RowMajor)?;
let results = kernel.execute(&[a, b])?;
assert_eq!(results.len(), 1);
assert_eq!(results[0].shape(), &[2, 2]);
Ok(())
}
#[test]
fn test_binary_elementwise_kernel() -> Result<()> {
let mut kernel = create_binary_elementwise_kernel("Xor", 4);
let a_data = vec![1.0, -1.0, 1.0, -1.0];
let b_data = vec![1.0, 1.0, -1.0, -1.0];
let a = Tensor::from_data(a_data, vec![4], DataType::F32, TensorLayout::RowMajor)?;
let b = Tensor::from_data(b_data, vec![4], DataType::F32, TensorLayout::RowMajor)?;
let results = kernel.execute(&[a, b])?;
assert_eq!(results.len(), 1);
assert_eq!(results[0].shape(), &[4]);
Ok(())
}
#[test]
fn test_kernel_stats_tracking() -> Result<()> {
let kernel = create_binary_matmul_kernel(2, 2, 2);
let stats = kernel.get_performance_stats();
assert_eq!(stats.execution_count, 0);
assert_eq!(stats.average_time_us, 0.0);
let memory_usage = kernel.get_memory_usage();
assert_eq!(memory_usage.current_bytes, 0);
Ok(())
}
}