use core::cmp::Ordering;
use crate::errors::CodecError;
pub fn scalar_quantize(
entries: &[f32],
values: &[f32],
indices: &mut [u8],
) -> Result<(), CodecError> {
if values.len() != indices.len() {
return Err(CodecError::LengthMismatch {
left: values.len(),
right: indices.len(),
});
}
let n = entries.len();
if n == 0 {
return Ok(());
}
let last = n - 1;
for (value, slot) in values.iter().zip(indices.iter_mut()) {
let insertion = entries
.iter()
.position(|e| !matches!(f32_cmp(*e, *value), Ordering::Less))
.unwrap_or(n);
let right = insertion.min(last);
let left = right.saturating_sub(1);
let right_entry = entries.get(right).copied().unwrap_or(f32::NAN);
let left_entry = entries.get(left).copied().unwrap_or(f32::NAN);
let ld = libm::fabsf(*value - left_entry);
let rd = libm::fabsf(*value - right_entry);
let chosen = if ld < rd { left } else { right };
*slot = u8::try_from(chosen).unwrap_or(u8::MAX);
}
Ok(())
}
pub fn scalar_dequantize(
entries: &[f32],
indices: &[u8],
values: &mut [f32],
) -> Result<(), CodecError> {
if indices.len() != values.len() {
return Err(CodecError::LengthMismatch {
left: indices.len(),
right: values.len(),
});
}
let bound = u32::try_from(entries.len()).unwrap_or(u32::MAX);
for (&idx, slot) in indices.iter().zip(values.iter_mut()) {
let idx_u32 = u32::from(idx);
if idx_u32 >= bound {
return Err(CodecError::IndexOutOfRange { index: idx, bound });
}
*slot = entries.get(idx as usize).copied().unwrap_or(f32::NAN);
}
Ok(())
}
fn f32_cmp(a: f32, b: f32) -> Ordering {
a.partial_cmp(&b).unwrap_or(Ordering::Greater)
}