use candle_core::{DType, Device, Result, Tensor};
pub fn butterfly_wht_inverse_cpu(
dequant: &Tensor,
rotation_fwd: &Tensor,
block_size: usize,
) -> Result<Tensor> {
let (m, bs) = dequant.dims2()?;
if bs != block_size {
candle_core::bail!(
"butterfly_wht_inverse_cpu: block_size mismatch (tensor={bs}, expected={block_size})"
);
}
let sqrt_n = (block_size as f32).sqrt();
let signs: Vec<f32> = rotation_fwd
.narrow(0, 0, 1)?
.squeeze(0)?
.to_dtype(DType::F32)?
.to_vec1()?;
let mut data: Vec<f32> = dequant.to_dtype(DType::F32)?.flatten_all()?.to_vec1()?;
let inv_sqrt_n = 1.0 / sqrt_n;
for block in data.chunks_mut(block_size) {
for j in 0..block_size {
block[j] *= signs[j] * sqrt_n;
}
let mut h = 1;
while h < block_size {
let full = h << 1;
let mut i = 0;
while i < block_size {
for j in 0..h {
let a = block[i + j];
let b = block[i + j + h];
block[i + j] = a + b;
block[i + j + h] = a - b;
}
i += full;
}
h <<= 1;
}
for val in block.iter_mut() {
*val *= inv_sqrt_n;
}
}
Tensor::from_vec(data, (m, block_size), &Device::Cpu)
}