use crate::{CodecId, CompressionError, DecompressError, ExactFallbackAdapter};
use quant_governor::{evaluate, GovernancePolicy, GovernanceRequest};
#[cfg(feature = "fib")]
use fib_quant::{FibCodeV1, FibQuantProfileV1, FibQuantizer};
#[cfg(feature = "turbo")]
use turbo_quant::TurboQuantizer;
#[derive(Debug, Clone)]
pub enum CodecDispatch<'a> {
Governed {
policy: &'a GovernancePolicy,
request: GovernanceRequest,
},
Force(CodecId),
}
#[allow(unused_variables)]
type FallbackDecoder<T> = Box<dyn Fn(CodecId, &[u8]) -> Result<T, DecompressError> + Send + Sync>;
pub fn build_adapter<T>(_dispatch: CodecDispatch) -> ExactFallbackAdapter<T>
where
T: From<Vec<u8>> + Send + Sync + 'static,
{
let fallback_decoder: FallbackDecoder<T> = Box::new(move |codec_id, data| {
match codec_id {
CodecId::Uncompressed => Ok(T::from(data.to_vec())),
#[cfg(feature = "turbo")]
CodecId::TurboQuant => turbo_quant_decode(data).map(T::from),
#[cfg(feature = "fib")]
CodecId::FibQuant => fib_quant_decode(data).map(T::from),
#[cfg(feature = "polar")]
CodecId::Polar => Ok(T::from(data.to_vec())),
#[cfg(feature = "qjl")]
CodecId::Qjl => Ok(T::from(data.to_vec())),
#[cfg(not(any(feature = "turbo", feature = "fib", feature = "polar", feature = "qjl")))]
_ => Err(DecompressError::DecodeFailed(
"No codec features enabled".to_string(),
)),
}
});
ExactFallbackAdapter::new(fallback_decoder)
}
pub fn select_codec(
policy: &GovernancePolicy,
request: GovernanceRequest,
) -> Result<CodecId, quant_governor::error::GovernorError> {
let decision = evaluate(request, policy)?;
Ok(match decision.codec {
quant_governor::CodecProfile::Raw => CodecId::Uncompressed,
quant_governor::CodecProfile::Q8 => CodecId::Uncompressed, quant_governor::CodecProfile::Q4 => CodecId::Uncompressed, quant_governor::CodecProfile::Turbo => CodecId::TurboQuant,
quant_governor::CodecProfile::Fib => CodecId::FibQuant,
quant_governor::CodecProfile::Polar => CodecId::Polar,
quant_governor::CodecProfile::Qjl => CodecId::Qjl,
})
}
#[cfg(feature = "fib")]
pub fn fib_quant_profile(dim: usize, seed: u64) -> std::result::Result<FibQuantProfileV1, fib_quant::FibQuantError> {
let k = 4usize;
let n = 32usize;
FibQuantProfileV1::paper_default(dim, k, n, seed)
}
#[cfg(feature = "turbo")]
pub fn turbo_quant_quantizer(
dim: usize,
seed: u64,
) -> std::result::Result<TurboQuantizer, turbo_quant::TurboQuantError> {
TurboQuantizer::new(dim, 8, 32, seed)
}
pub fn encode(codec_id: CodecId, vector: &[f32], seed: u64) -> Result<Vec<u8>, CompressionError> {
match codec_id {
CodecId::Uncompressed => Ok(bytemuck::cast_slice::<f32, u8>(vector).to_vec()),
#[cfg(feature = "fib")]
CodecId::FibQuant => fib_quant_encode(vector, seed),
#[cfg(feature = "turbo")]
CodecId::TurboQuant => turbo_quant_encode(vector, seed),
#[cfg(feature = "polar")]
CodecId::Polar => polar_quant_encode(vector, seed),
#[cfg(feature = "qjl")]
CodecId::Qjl => qjl_sketch_encode(vector, seed),
#[cfg(not(any(feature = "turbo", feature = "fib", feature = "polar", feature = "qjl")))]
_ => Err(CompressionError::EncodeFailed(
"no codec features enabled".to_string(),
)),
}
}
pub fn decode(codec_id: CodecId, compressed: &[u8]) -> Result<Vec<u8>, DecompressError> {
match codec_id {
CodecId::Uncompressed => Ok(compressed.to_vec()),
#[cfg(feature = "fib")]
CodecId::FibQuant => fib_quant_decode(compressed),
#[cfg(feature = "turbo")]
CodecId::TurboQuant => turbo_quant_decode(compressed),
#[cfg(feature = "polar")]
CodecId::Polar => Ok(compressed.to_vec()),
#[cfg(feature = "qjl")]
CodecId::Qjl => Ok(compressed.to_vec()),
#[cfg(not(any(feature = "turbo", feature = "fib", feature = "polar", feature = "qjl")))]
_ => Err(DecompressError::DecodeFailed(
"no codec features enabled".to_string(),
)),
}
}
#[cfg(feature = "fib")]
fn fib_quant_encode(vector: &[f32], seed: u64) -> Result<Vec<u8>, CompressionError> {
let dim = vector.len();
let profile = fib_quant_profile(dim, seed).map_err(|e| {
CompressionError::EncodeFailed(format!("fib_quant profile build: {e}"))
})?;
let quantizer = FibQuantizer::new(profile).map_err(|e| {
CompressionError::EncodeFailed(format!("fib_quant quantizer build: {e}"))
})?;
let code = quantizer.encode(vector).map_err(|e| {
CompressionError::EncodeFailed(format!("fib_quant encode: {e}"))
})?;
serde_json::to_vec(&code).map_err(|e| {
CompressionError::EncodeFailed(format!("fib_quant serialize: {e}"))
})
}
#[cfg(feature = "fib")]
fn fib_quant_decode(compressed: &[u8]) -> Result<Vec<u8>, DecompressError> {
let code: FibCodeV1 = serde_json::from_slice(compressed).map_err(|e| {
DecompressError::DecodeFailed(format!("fib_quant deserialize: {e}"))
})?;
let seed = 42u64;
let profile = fib_quant_profile(code.ambient_dim as usize, seed).map_err(|e| {
DecompressError::DecodeFailed(format!("fib_quant profile build: {e}"))
})?;
let quantizer = FibQuantizer::new(profile).map_err(|e| {
DecompressError::DecodeFailed(format!("fib_quant quantizer build: {e}"))
})?;
let decoded = quantizer.decode(&code).map_err(|e| {
DecompressError::DecodeFailed(format!("fib_quant decode: {e}"))
})?;
Ok(bytemuck::cast_slice::<f32, u8>(&decoded).to_vec())
}
#[cfg(feature = "turbo")]
fn turbo_quant_encode(vector: &[f32], seed: u64) -> Result<Vec<u8>, CompressionError> {
let dim = vector.len();
let quantizer = turbo_quant_quantizer(dim, seed).map_err(|e| {
CompressionError::EncodeFailed(format!("turbo_quant quantizer build: {e}"))
})?;
quantizer.encode_to_bytes(vector).map_err(|e| {
CompressionError::EncodeFailed(format!("turbo_quant encode: {e}"))
})
}
#[cfg(feature = "turbo")]
#[cfg(feature = "turbo")]
fn turbo_quant_decode(compressed: &[u8]) -> Result<Vec<u8>, DecompressError> {
use turbo_quant::{TurboCodeWireV1, TurboMode, TurboQuantizer};
let header = TurboCodeWireV1::parse_header(compressed).map_err(|e| {
DecompressError::DecodeFailed(format!("turbo_quant header parse: {e}"))
})?;
let mode = if header.qjl_sign_count > 0 {
TurboMode::PolarWithQjl
} else {
TurboMode::PolarOnly
};
let quantizer = TurboQuantizer::new_with_mode(
header.dim,
match mode {
TurboMode::PolarWithQjl => header.polar_bits + 1,
TurboMode::PolarOnly => header.polar_bits,
},
header.qjl_projections,
header.seed,
mode,
)
.map_err(|e| {
DecompressError::DecodeFailed(format!("turbo_quant quantizer rebuild: {e}"))
})?;
let code = TurboCodeWireV1::decode(compressed, &quantizer)
.map_err(|e| DecompressError::DecodeFailed(format!("turbo_quant wire decode: {e}")))?;
let decoded = quantizer
.decode_approximate(&code)
.map_err(|e| DecompressError::DecodeFailed(format!("turbo_quant decode: {e}")))?;
Ok(bytemuck::cast_slice::<f32, u8>(&decoded).to_vec())
}
#[cfg(feature = "polar")]
fn polar_quant_encode(vector: &[f32], seed: u64) -> Result<Vec<u8>, CompressionError> {
use turbo_quant::PolarQuantizer;
let dim = vector.len();
let bits = 8u8;
let quantizer = PolarQuantizer::new_with_stored_rotation(dim, bits, seed).map_err(|e| {
CompressionError::EncodeFailed(format!("polar_quant build: {e}"))
})?;
let code = quantizer.encode(vector).map_err(|e| {
CompressionError::EncodeFailed(format!("polar_quant encode: {e}"))
})?;
serde_json::to_vec(&code).map_err(|e| {
CompressionError::EncodeFailed(format!("polar_quant serialize: {e}"))
})
}
#[cfg(feature = "qjl")]
fn qjl_sketch_encode(vector: &[f32], seed: u64) -> Result<Vec<u8>, CompressionError> {
use turbo_quant::QjlQuantizer;
let dim = vector.len();
let projections = 32usize;
let quantizer = QjlQuantizer::new(dim, projections, seed).map_err(|e| {
CompressionError::EncodeFailed(format!("qjl_quant build: {e}"))
})?;
let sketch = quantizer.sketch(vector).map_err(|e| {
CompressionError::EncodeFailed(format!("qjl_quant sketch: {e}"))
})?;
serde_json::to_vec(&sketch).map_err(|e| {
CompressionError::EncodeFailed(format!("qjl_quant serialize: {e}"))
})
}
#[cfg(test)]
#[allow(clippy::expect_used)] mod tests {
use super::*;
use crate::CompressionError;
fn make_vector(dim: usize, seed: u64) -> Vec<f32> {
let mut s = seed;
(0..dim)
.map(|_| {
s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
((s >> 32) as f32 / u32::MAX as f32) - 0.5
})
.collect()
}
#[test]
fn uncompressed_round_trip_is_exact() {
let v = make_vector(128, 42);
let encoded = encode(CodecId::Uncompressed, &v, 0).unwrap();
let decoded_bytes = decode(CodecId::Uncompressed, &encoded).unwrap();
let decoded: &[f32] = bytemuck::cast_slice(&decoded_bytes);
assert_eq!(v, decoded);
}
#[test]
#[cfg(feature = "fib")]
fn fib_quant_round_trip_digest_stable() {
let v = make_vector(128, 42);
let encoded_a = encode(CodecId::FibQuant, &v, 42).unwrap();
let encoded_b = encode(CodecId::FibQuant, &v, 42).unwrap();
assert_eq!(
encoded_a, encoded_b,
"fib_quant encode must be deterministic at the same seed"
);
let decoded = decode(CodecId::FibQuant, &encoded_a).unwrap();
let decoded_vec: Vec<f32> = bytemuck::cast_slice(&decoded).to_vec();
assert_eq!(decoded_vec.len(), v.len());
assert!(decoded_vec.iter().all(|x| x.is_finite()));
}
#[test]
#[cfg(feature = "turbo")]
fn turbo_quant_round_trip_reconstructs_approximate_vector() {
let v = make_vector(128, 7);
let encoded = encode(CodecId::TurboQuant, &v, 7).expect("turbo encode failed");
let decoded_bytes = decode(CodecId::TurboQuant, &encoded).expect("turbo decode failed");
let decoded_vec: Vec<f32> = bytemuck::cast_slice(&decoded_bytes).to_vec();
assert_eq!(decoded_vec.len(), v.len());
assert!(decoded_vec.iter().all(|x| x.is_finite()));
}
#[test]
#[cfg(feature = "turbo")]
fn turbo_quant_round_trip_uses_wire_embedded_profile() {
let v = make_vector(64, 1);
let encoded_seed1 = encode(CodecId::TurboQuant, &v, 1).expect("encode seed=1");
let _decoded = decode(CodecId::TurboQuant, &encoded_seed1)
.expect("decode with wire-embedded seed must succeed");
let v_seed99 = make_vector(64, 99);
let encoded_seed99 = encode(CodecId::TurboQuant, &v_seed99, 99)
.expect("encode seed=99");
let _decoded_99 = decode(CodecId::TurboQuant, &encoded_seed99)
.expect("decode with wire-embedded seed must succeed");
}
#[test]
#[cfg(feature = "fib")]
fn fib_quant_different_seeds_produce_different_codes() {
let v = make_vector(128, 42);
let a = encode(CodecId::FibQuant, &v, 1).unwrap();
let b = encode(CodecId::FibQuant, &v, 2).unwrap();
assert_ne!(a, b, "different seeds must produce different codes");
}
#[test]
#[cfg(feature = "fib")]
fn fib_quant_profile_digest_mismatch_is_an_error() {
let v = make_vector(128, 1);
let encoded = encode(CodecId::FibQuant, &v, 1).unwrap();
let result = decode(CodecId::FibQuant, &encoded);
match result {
Ok(_) => {}
Err(DecompressError::DecodeFailed(msg)) => {
assert!(
msg.contains("profile digest") || msg.contains("decode"),
"unexpected error: {msg}"
);
}
Err(e) => panic!("unexpected error variant: {e:?}"),
}
}
#[test]
fn encode_uncompressed_forces_identity() {
let v = make_vector(64, 7);
let encoded = encode(CodecId::Uncompressed, &v, 99).unwrap();
let expected: Vec<u8> = bytemuck::cast_slice(&v).to_vec();
assert_eq!(encoded, expected);
}
#[test]
fn encode_unsupported_codec_errors() {
let v = make_vector(64, 0);
let _result: Result<Vec<u8>, CompressionError> = encode(CodecId::Uncompressed, &v, 0);
}
#[test]
#[cfg(feature = "polar")]
fn polar_quant_encode_is_deterministic() {
let v = make_vector(128, 42);
let a = encode(CodecId::Polar, &v, 42).unwrap();
let b = encode(CodecId::Polar, &v, 42).unwrap();
assert_eq!(a, b, "polar encode must be deterministic at the same seed");
}
#[test]
#[cfg(feature = "polar")]
fn polar_quant_different_seeds_produce_different_codes() {
let v = make_vector(128, 42);
let a = encode(CodecId::Polar, &v, 1).unwrap();
let b = encode(CodecId::Polar, &v, 2).unwrap();
assert_ne!(a, b, "different seeds must produce different polar codes");
}
#[test]
#[cfg(feature = "polar")]
fn polar_quant_decode_is_passthrough() {
let v = make_vector(64, 7);
let encoded = encode(CodecId::Polar, &v, 7).unwrap();
let decoded = decode(CodecId::Polar, &encoded).unwrap();
assert_eq!(encoded, decoded, "polar decode must be identity");
}
#[test]
#[cfg(feature = "qjl")]
fn qjl_sketch_encode_is_deterministic() {
let v = make_vector(128, 42);
let a = encode(CodecId::Qjl, &v, 42).unwrap();
let b = encode(CodecId::Qjl, &v, 42).unwrap();
assert_eq!(a, b, "qjl sketch must be deterministic at the same seed");
assert!(
a.len() < 512,
"qjl sketch ({} bytes) should be smaller than raw (512 bytes)",
a.len()
);
}
#[test]
#[cfg(feature = "qjl")]
fn qjl_sketch_different_seeds_produce_different_codes() {
let v = make_vector(128, 42);
let a = encode(CodecId::Qjl, &v, 1).unwrap();
let b = encode(CodecId::Qjl, &v, 2).unwrap();
assert_ne!(a, b, "different seeds must produce different qjl sketches");
}
#[test]
#[cfg(feature = "qjl")]
fn qjl_sketch_decode_is_passthrough() {
let v = make_vector(64, 7);
let encoded = encode(CodecId::Qjl, &v, 7).unwrap();
let decoded = decode(CodecId::Qjl, &encoded).unwrap();
assert_eq!(encoded, decoded, "qjl decode must be identity");
}
}