use candle_core::{Result, Tensor};
pub fn bitnet_threshold(weights: &Tensor) -> Result<f64> {
let abs_mean = weights.abs()?.mean_all()?.to_scalar::<f32>()?;
Ok(0.5 * abs_mean as f64)
}
pub fn ternary_project(weights: &Tensor) -> Result<Tensor> {
let threshold = bitnet_threshold(weights)?;
let t = Tensor::new(threshold as f32, weights.device())?
.broadcast_as(weights.shape())?;
let neg_t = Tensor::new(-(threshold as f32), weights.device())?
.broadcast_as(weights.shape())?;
let pos_mask = weights.gt(&t)?.to_dtype(candle_core::DType::F32)?;
let neg_mask = weights.lt(&neg_t)?.to_dtype(candle_core::DType::F32)?;
(pos_mask - neg_mask)
}
pub fn ste_forward(weights: &Tensor) -> Result<Tensor> {
let quantized = ternary_project(weights)?;
let quantized_detached = quantized.detach();
let weights_detached = weights.detach();
weights + (quantized_detached - weights_detached)?
}
pub fn compute_sparsity(weights: &Tensor) -> Result<f64> {
let threshold = bitnet_threshold(weights)?;
let t = Tensor::new(threshold as f32, weights.device())?
.broadcast_as(weights.shape())?;
let in_zero_zone = weights.abs()?.le(&t)?.to_dtype(candle_core::DType::F32)?;
let zero_count = in_zero_zone.sum_all()?.to_scalar::<f32>()?;
let total = weights.elem_count() as f64;
Ok(zero_count as f64 / total)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{Device, Tensor};
#[test]
fn ternary_project_correct() {
let device = Device::Cpu;
let w = Tensor::new(&[-2.0f32, -0.01, 0.0, 0.01, 2.0], &device).unwrap();
let q = ternary_project(&w).unwrap();
let v = q.to_vec1::<f32>().unwrap();
assert_eq!(v[0], -1.0, "large negative → -1");
assert_eq!(v[4], 1.0, "large positive → +1");
assert_eq!(v[1], 0.0, "near-zero → 0 (hold)");
assert_eq!(v[2], 0.0, "zero → 0 (hold)");
assert_eq!(v[3], 0.0, "near-zero → 0 (hold)");
}
#[test]
fn ste_forward_matches_quantized() {
let device = Device::Cpu;
let w = Tensor::new(&[-1.5f32, 0.0, 0.05, 1.5], &device).unwrap();
let ste = ste_forward(&w).unwrap();
let q = ternary_project(&w).unwrap();
let ste_v = ste.to_vec1::<f32>().unwrap();
let q_v = q.to_vec1::<f32>().unwrap();
for (s, p) in ste_v.iter().zip(q_v.iter()) {
assert!((s - p).abs() < 1e-6, "STE forward ≠ quantized: {} ≠ {}", s, p);
}
}
}