use std::f32::consts::PI;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::{
bitpack,
error::{Result, TurboQuantError},
rotation::{Rotation, RotationBackend, RotationKind},
};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct PolarCode {
pub dim: usize,
pub bits: u8,
pub radii: Vec<f32>,
pub angle_indices: Vec<u16>,
}
impl PolarCode {
pub fn pair_count(&self) -> usize {
self.dim / 2
}
pub fn from_parts(
dim: usize,
bits: u8,
radii: Vec<f32>,
angle_indices: &[u16],
) -> Result<Self> {
let code = Self {
dim,
bits,
radii,
angle_indices: angle_indices.to_vec(),
};
code.validate_for(dim, bits)?;
Ok(code)
}
pub fn angle_index(&self, i: usize) -> Result<u16> {
if i >= self.pair_count() {
return Err(TurboQuantError::MalformedCode {
reason: format!(
"angle index {i} is outside pair count {}",
self.pair_count()
),
});
}
Ok(self.angle_indices[i])
}
pub fn angle_indices(&self) -> Result<Vec<u16>> {
self.validate_for(self.dim, self.bits)?;
Ok(self.angle_indices.clone())
}
pub fn dequantize_angle(&self, i: usize) -> Result<f32> {
let levels = 1u32 << self.bits;
let idx = self.angle_index(i)? as f32;
Ok((idx / levels as f32) * (2.0 * PI) - PI)
}
pub fn encoded_bytes(&self) -> usize {
self.radii.len() * std::mem::size_of::<f32>()
+ bitpack::packed_len(self.angle_indices.len(), self.bits).unwrap_or(usize::MAX)
}
pub fn validate_for(&self, dim: usize, bits: u8) -> Result<()> {
if self.dim != dim {
return Err(TurboQuantError::DimensionMismatch {
expected: dim,
got: self.dim,
});
}
if self.bits != bits {
return Err(TurboQuantError::MalformedCode {
reason: format!("code has bits={}, expected {bits}", self.bits),
});
}
if dim == 0 || dim % 2 != 0 {
return Err(TurboQuantError::MalformedCode {
reason: format!("code dimension must be positive and even, got {dim}"),
});
}
let pairs = dim / 2;
if self.radii.len() != pairs {
return Err(TurboQuantError::MalformedCode {
reason: format!("code has {} radii, expected {pairs}", self.radii.len()),
});
}
for (index, radius) in self.radii.iter().enumerate() {
if !radius.is_finite() || *radius < 0.0 {
return Err(TurboQuantError::MalformedCode {
reason: format!("radius {index} is not finite and non-negative"),
});
}
}
if self.angle_indices.len() != pairs {
return Err(TurboQuantError::MalformedCode {
reason: format!(
"code has {} angle indices, expected {pairs}",
self.angle_indices.len()
),
});
}
let levels = 1u32 << bits;
for (index, angle_index) in self.angle_indices.iter().enumerate() {
if u32::from(*angle_index) >= levels {
return Err(TurboQuantError::MalformedCode {
reason: format!(
"angle index {index} value {angle_index} is outside [0, {levels})"
),
});
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolarQuantizer {
dim: usize,
bits: u8,
rotation: RotationBackend,
}
#[derive(Debug, Clone, PartialEq)]
pub struct PolarProjectedQuery {
rotated_query: Vec<f32>,
}
impl PolarQuantizer {
pub fn new(dim: usize, bits: u8, seed: u64) -> Result<Self> {
if dim == 0 {
return Err(TurboQuantError::ZeroDimension);
}
if dim % 2 != 0 {
return Err(TurboQuantError::OddDimension { got: dim });
}
if bits == 0 || bits > 16 {
return Err(TurboQuantError::InvalidBitWidth { got: bits });
}
Self::new_with_rotation(dim, bits, seed, RotationKind::Auto)
}
pub fn new_with_rotation(
dim: usize,
bits: u8,
seed: u64,
rotation_kind: RotationKind,
) -> Result<Self> {
if dim == 0 {
return Err(TurboQuantError::ZeroDimension);
}
if dim % 2 != 0 {
return Err(TurboQuantError::OddDimension { got: dim });
}
if bits == 0 || bits > 16 {
return Err(TurboQuantError::InvalidBitWidth { got: bits });
}
let rotation = RotationBackend::new(dim, seed, rotation_kind)?;
Ok(Self {
dim,
bits,
rotation,
})
}
pub fn new_with_stored_rotation(dim: usize, bits: u8, seed: u64) -> Result<Self> {
Self::new_with_rotation(dim, bits, seed, RotationKind::StoredQr)
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn bits(&self) -> u8 {
self.bits
}
pub fn rotation_kind(&self) -> RotationKind {
self.rotation.kind()
}
pub fn rotation_kind_label(&self) -> &'static str {
self.rotation.kind_label()
}
pub fn encode(&self, vector: &[f32]) -> Result<PolarCode> {
self.check_input_dim(vector.len())?;
check_finite_vector(vector)?;
let mut rotated = vec![0.0f32; self.dim];
self.rotation.apply(vector, &mut rotated)?;
let pairs = self.dim / 2;
let mut radii = Vec::with_capacity(pairs);
let mut angle_indices = Vec::with_capacity(pairs);
for i in 0..pairs {
let a = rotated[2 * i];
let b = rotated[2 * i + 1];
let (r, idx) = encode_pair(a, b, self.bits);
radii.push(r);
angle_indices.push(idx);
}
PolarCode::from_parts(self.dim, self.bits, radii, &angle_indices)
}
pub fn decode(&self, code: &PolarCode) -> Result<Vec<f32>> {
self.validate_code(code)?;
let mut rotated = vec![0.0f32; self.dim];
let pairs = self.dim / 2;
for i in 0..pairs {
let theta = code.dequantize_angle(i)?;
let r = code.radii[i];
rotated[2 * i] = r * theta.cos();
rotated[2 * i + 1] = r * theta.sin();
}
let mut output = vec![0.0f32; self.dim];
self.rotation.apply_inverse(&rotated, &mut output)?;
Ok(output)
}
pub fn inner_product_estimate(&self, code: &PolarCode, query: &[f32]) -> Result<f32> {
let projected = self.project_query(query)?;
self.inner_product_estimate_with_projected_query(code, &projected)
}
pub fn project_query(&self, query: &[f32]) -> Result<PolarProjectedQuery> {
self.check_input_dim(query.len())?;
check_finite_vector(query)?;
let mut rotated_query = vec![0.0f32; self.dim];
self.rotation.apply(query, &mut rotated_query)?;
check_finite_vector(&rotated_query)?;
Ok(PolarProjectedQuery { rotated_query })
}
pub fn inner_product_estimate_with_projected_query(
&self,
code: &PolarCode,
query: &PolarProjectedQuery,
) -> Result<f32> {
self.validate_code(code)?;
let pairs = self.dim / 2;
let mut estimate = 0.0f32;
for i in 0..pairs {
let theta = code.dequantize_angle(i)?;
let r = code.radii[i];
let q_a = query.rotated_query[2 * i];
let q_b = query.rotated_query[2 * i + 1];
estimate += r * (q_a * theta.cos() + q_b * theta.sin());
}
if !estimate.is_finite() {
return Err(TurboQuantError::MalformedCode {
reason: "polar score is not finite".into(),
});
}
Ok(estimate)
}
pub fn l2_distance_estimate(&self, code: &PolarCode, query: &[f32]) -> Result<f32> {
let ip = self.inner_product_estimate(code, query)?;
let query_norm_sq: f32 = query.iter().map(|x| x * x).sum();
let code_norm_sq: f32 = code.radii.iter().map(|r| r * r).sum();
Ok((query_norm_sq + code_norm_sq - 2.0 * ip).max(0.0))
}
fn check_input_dim(&self, got: usize) -> Result<()> {
if got != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got,
});
}
Ok(())
}
fn validate_code(&self, code: &PolarCode) -> Result<()> {
code.validate_for(self.dim, self.bits)
}
}
fn check_finite_vector(vector: &[f32]) -> Result<()> {
if let Some((index, _)) = vector
.iter()
.enumerate()
.find(|(_, value)| !value.is_finite())
{
return Err(TurboQuantError::NonFiniteInput { index });
}
Ok(())
}
fn encode_pair(a: f32, b: f32, bits: u8) -> (f32, u16) {
let r = (a * a + b * b).sqrt();
let theta = b.atan2(a); let levels = 1u32 << bits;
let normalized = (theta + PI) / (2.0 * PI);
let idx = (normalized * levels as f32).floor() as u32 % levels;
(r, idx as u16)
}
#[cfg(test)]
mod tests {
use super::*;
fn unit_vector(dim: usize, i: usize) -> Vec<f32> {
let mut v = vec![0.0f32; dim];
v[i] = 1.0;
v
}
fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
(0..dim).map(|_| StandardNormal.sample(&mut rng)).collect()
}
#[test]
fn encode_decode_roundtrip_high_bits() {
let q = PolarQuantizer::new(8, 16, 42).unwrap();
let x = vec![1.0f32, 2.0, -1.5, 0.5, 3.0, -2.0, 0.1, -0.8];
let code = q.encode(&x).unwrap();
let decoded = q.decode(&code).unwrap();
for (orig, dec) in x.iter().zip(decoded.iter()) {
assert!(
(orig - dec).abs() < 0.01,
"orig={orig:.4}, decoded={dec:.4}"
);
}
}
#[test]
fn inner_product_estimate_is_close_at_high_bits() {
let q = PolarQuantizer::new(16, 16, 7).unwrap();
let x = random_vector(16, 1);
let y = random_vector(16, 2);
let code = q.encode(&x).unwrap();
let estimated = q.inner_product_estimate(&code, &y).unwrap();
let exact: f32 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
let relative_error = (estimated - exact).abs() / (exact.abs() + 1e-6);
assert!(
relative_error < 0.02,
"relative error {relative_error:.4} too large: estimated={estimated:.4}, exact={exact:.4}"
);
}
#[test]
fn encoding_is_deterministic() {
let q = PolarQuantizer::new(8, 8, 0).unwrap();
let x = vec![1.0f32; 8];
let c1 = q.encode(&x).unwrap();
let c2 = q.encode(&x).unwrap();
assert_eq!(c1.angle_indices, c2.angle_indices);
assert_eq!(c1.radii, c2.radii);
}
#[test]
fn zero_vector_has_zero_radius() {
let q = PolarQuantizer::new(8, 8, 1).unwrap();
let x = vec![0.0f32; 8];
let code = q.encode(&x).unwrap();
for r in &code.radii {
assert!(*r < 1e-7, "expected zero radius, got {r}");
}
}
#[test]
fn unit_vectors_preserve_norm() {
let q = PolarQuantizer::new(8, 16, 3).unwrap();
for i in 0..8 {
let x = unit_vector(8, i);
let code = q.encode(&x).unwrap();
let norm_sq: f32 = code.radii.iter().map(|r| r * r).sum();
assert!((norm_sq - 1.0).abs() < 1e-5, "norm_sq={norm_sq}");
}
}
#[test]
fn nearest_neighbor_ordering_preserved_at_8bits() {
let q = PolarQuantizer::new(16, 8, 42).unwrap();
let query = random_vector(16, 99);
let close = {
let mut v = query.clone();
v.iter_mut().for_each(|x| *x += 0.01);
v
};
let far1 = random_vector(16, 200);
let far2 = random_vector(16, 201);
let code_close = q.encode(&close).unwrap();
let code_far1 = q.encode(&far1).unwrap();
let code_far2 = q.encode(&far2).unwrap();
let ip_close = q.inner_product_estimate(&code_close, &query).unwrap();
let ip_far1 = q.inner_product_estimate(&code_far1, &query).unwrap();
let ip_far2 = q.inner_product_estimate(&code_far2, &query).unwrap();
assert!(
ip_close > ip_far1 && ip_close > ip_far2,
"close={ip_close:.3}, far1={ip_far1:.3}, far2={ip_far2:.3}"
);
}
#[test]
fn dimension_mismatch_is_rejected() {
let q = PolarQuantizer::new(8, 8, 0).unwrap();
let result = q.encode(&[1.0f32; 10]);
assert!(result.is_err());
}
#[test]
fn odd_dimension_is_rejected() {
assert!(PolarQuantizer::new(7, 8, 0).is_err());
}
#[test]
fn zero_bits_rejected() {
assert!(PolarQuantizer::new(8, 0, 0).is_err());
}
#[test]
fn code_serialization_roundtrip() {
let q = PolarQuantizer::new(8, 8, 42).unwrap();
let x = vec![1.0f32, -2.0, 0.5, 1.5, -0.3, 0.8, -1.0, 2.0];
let code = q.encode(&x).unwrap();
let json = serde_json::to_string(&code).unwrap();
let restored: PolarCode = serde_json::from_str(&json).unwrap();
assert_eq!(code, restored);
}
}