use super::encoding::decode_sortable;
use super::error::ElidError;
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum VectorPrecision {
#[default]
Full32,
Half16,
Quant8,
Bits {
bits: u8,
},
}
impl VectorPrecision {
#[must_use]
pub fn bits_per_dim(&self) -> u8 {
match self {
VectorPrecision::Full32 => 32,
VectorPrecision::Half16 => 16,
VectorPrecision::Quant8 => 8,
VectorPrecision::Bits { bits } => *bits,
}
}
pub fn validate(&self) -> Result<(), ElidError> {
match self {
VectorPrecision::Bits { bits } if *bits == 0 || *bits > 32 => Err(
ElidError::InvalidPrecision(format!("Bits must be 1-32, got {}", bits)),
),
_ => Ok(()),
}
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "mode", rename_all = "snake_case")]
pub enum DimensionMode {
#[default]
Preserve,
Reduce {
target_dims: u16,
},
Common {
dims: u16,
},
}
impl DimensionMode {
#[must_use]
pub fn output_dims(&self, input_dims: u16) -> u16 {
match self {
DimensionMode::Preserve => input_dims,
DimensionMode::Reduce { target_dims } => *target_dims,
DimensionMode::Common { dims } => *dims,
}
}
pub fn validate(&self, input_dims: u16) -> Result<(), ElidError> {
match self {
DimensionMode::Preserve => Ok(()),
DimensionMode::Reduce { target_dims } => {
if *target_dims == 0 {
Err(ElidError::InvalidDimension {
got: 0,
expected_range: (1, input_dims as usize),
})
} else if *target_dims >= input_dims {
Err(ElidError::ProjectionError(format!(
"Target dims {} must be less than input dims {}",
target_dims, input_dims
)))
} else {
Ok(())
}
}
DimensionMode::Common { dims } => {
if *dims == 0 {
Err(ElidError::InvalidDimension {
got: 0,
expected_range: (1, 2048),
})
} else {
Ok(())
}
}
}
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct Elid(String);
impl Elid {
pub fn from_string(s: String) -> Result<Self, ElidError> {
if s.chars().all(|c| matches!(c, '0'..='9' | 'a'..='v')) {
Ok(Elid(s))
} else {
Err(ElidError::InvalidEncoding)
}
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
pub fn to_bytes(&self) -> Result<Vec<u8>, ElidError> {
decode_sortable(self.as_str())
}
pub fn profile(&self) -> Result<ProfileInfo, ElidError> {
let bytes = self.to_bytes()?;
if bytes.len() < 2 {
return Err(ElidError::InvalidHeader);
}
ProfileInfo::from_header(&bytes[0..2])
}
}
impl fmt::Debug for Elid {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Elid(\"{}\")", self.0)
}
}
impl fmt::Display for Elid {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Profile {
Mini128 {
seed: u64,
},
Morton10x10 {
dims: u8,
bits_per_dim: u8,
#[serde(skip_serializing_if = "Option::is_none")]
transform_id: Option<u16>,
},
Hilbert10x10 {
dims: u8,
bits_per_dim: u8,
#[serde(skip_serializing_if = "Option::is_none")]
transform_id: Option<u16>,
},
FullVector {
precision: VectorPrecision,
dimensions: DimensionMode,
seed: u64,
},
}
impl Profile {
#[must_use]
pub fn bit_length(&self) -> usize {
match self {
Profile::Mini128 { .. } => 128,
Profile::Morton10x10 {
dims, bits_per_dim, ..
}
| Profile::Hilbert10x10 {
dims, bits_per_dim, ..
} => (*dims as usize) * (*bits_per_dim as usize),
Profile::FullVector {
precision,
dimensions,
..
} => {
let output_dims = dimensions.output_dims(768);
(output_dims as usize) * (precision.bits_per_dim() as usize)
}
}
}
#[must_use]
pub fn bit_length_for_dims(&self, input_dims: u16) -> usize {
match self {
Profile::Mini128 { .. } => 128,
Profile::Morton10x10 {
dims, bits_per_dim, ..
}
| Profile::Hilbert10x10 {
dims, bits_per_dim, ..
} => (*dims as usize) * (*bits_per_dim as usize),
Profile::FullVector {
precision,
dimensions,
..
} => {
let output_dims = dimensions.output_dims(input_dims);
(output_dims as usize) * (precision.bits_per_dim() as usize)
}
}
}
#[must_use]
pub fn string_length(&self) -> usize {
let header_bits = match self {
Profile::FullVector { .. } => 96, _ => 16,
};
(self.bit_length() + header_bits).div_ceil(5)
}
#[must_use]
pub fn string_length_for_dims(&self, input_dims: u16) -> usize {
let header_bits = match self {
Profile::FullVector { .. } => 96, _ => 16,
};
(self.bit_length_for_dims(input_dims) + header_bits).div_ceil(5)
}
#[must_use]
pub fn type_id(&self) -> u8 {
match self {
Profile::Mini128 { .. } => 0x01,
Profile::Morton10x10 { .. } => 0x02,
Profile::Hilbert10x10 { .. } => 0x03,
Profile::FullVector { .. } => 0x04,
}
}
#[must_use]
pub fn is_reversible(&self) -> bool {
matches!(self, Profile::FullVector { .. })
}
#[must_use]
pub fn lossless() -> Self {
Profile::FullVector {
precision: VectorPrecision::Full32,
dimensions: DimensionMode::Preserve,
seed: 0x454c4944_46554c4c, }
}
#[must_use]
pub fn compressed(retention_pct: f32, original_dims: u16) -> Self {
let retention = retention_pct.clamp(0.01, 1.0);
let full_bits = 32.0 * original_dims as f32;
let target_bits = full_bits * retention;
let full32_target_dims = (target_bits / 32.0).round() as u16;
if full32_target_dims >= original_dims {
return Profile::lossless();
}
if full32_target_dims >= original_dims / 4 {
return Profile::FullVector {
precision: VectorPrecision::Full32,
dimensions: DimensionMode::Reduce {
target_dims: full32_target_dims.max(1),
},
seed: 0x454c4944_434f4d50, };
}
let half16_target_dims = (target_bits / 16.0).round() as u16;
if half16_target_dims >= original_dims / 4 {
return Profile::FullVector {
precision: VectorPrecision::Half16,
dimensions: if half16_target_dims >= original_dims {
DimensionMode::Preserve
} else {
DimensionMode::Reduce {
target_dims: half16_target_dims.max(1),
}
},
seed: 0x454c4944_434f4d50,
};
}
let quant8_target_dims = (target_bits / 8.0).round() as u16;
Profile::FullVector {
precision: VectorPrecision::Quant8,
dimensions: if quant8_target_dims >= original_dims {
DimensionMode::Preserve
} else {
DimensionMode::Reduce {
target_dims: quant8_target_dims.max(1),
}
},
seed: 0x454c4944_434f4d50,
}
}
#[must_use]
pub fn max_length(max_chars: usize, original_dims: u16) -> Self {
let header_chars = 20;
let payload_chars = max_chars.saturating_sub(header_chars);
let payload_bits = payload_chars * 5;
if payload_bits == 0 {
return Profile::FullVector {
precision: VectorPrecision::Bits { bits: 1 },
dimensions: DimensionMode::Reduce { target_dims: 1 },
seed: 0x454c4944_4d41584c, };
}
let bits_per_dim_full32 = payload_bits / original_dims as usize;
if bits_per_dim_full32 >= 32 {
return Profile::lossless();
}
let precisions = [
(VectorPrecision::Full32, 32),
(VectorPrecision::Half16, 16),
(VectorPrecision::Quant8, 8),
(VectorPrecision::Bits { bits: 4 }, 4),
(VectorPrecision::Bits { bits: 2 }, 2),
(VectorPrecision::Bits { bits: 1 }, 1),
];
for (precision, bits) in precisions {
let dims_that_fit = payload_bits / bits;
if dims_that_fit >= original_dims as usize {
return Profile::FullVector {
precision,
dimensions: DimensionMode::Preserve,
seed: 0x454c4944_4d41584c,
};
} else if dims_that_fit >= 16 {
return Profile::FullVector {
precision,
dimensions: DimensionMode::Reduce {
target_dims: dims_that_fit as u16,
},
seed: 0x454c4944_4d41584c,
};
}
}
Profile::FullVector {
precision: VectorPrecision::Bits { bits: 1 },
dimensions: DimensionMode::Reduce {
target_dims: (payload_bits as u16).max(1),
},
seed: 0x454c4944_4d41584c,
}
}
#[must_use]
pub fn cross_dimensional(common_dims: u16) -> Self {
Profile::FullVector {
precision: VectorPrecision::Half16,
dimensions: DimensionMode::Common { dims: common_dims },
seed: 0x454c4944_58444949, }
}
#[must_use]
pub fn cross_dimensional_with_precision(common_dims: u16, precision: VectorPrecision) -> Self {
Profile::FullVector {
precision,
dimensions: DimensionMode::Common { dims: common_dims },
seed: 0x454c4944_58444949,
}
}
}
impl Default for Profile {
fn default() -> Self {
Profile::Mini128 {
seed: 0x454c4944_53494d48, }
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ProfileInfo {
pub version: u8,
pub profile_type: u8,
pub transform_id: Option<u16>,
pub model_id: Option<u16>,
pub original_dims: Option<u16>,
pub precision: Option<VectorPrecision>,
pub dimension_mode: Option<DimensionMode>,
pub seed: Option<u64>,
}
impl ProfileInfo {
pub fn from_header(header: &[u8]) -> Result<Self, ElidError> {
if header.len() < 2 {
return Err(ElidError::InvalidHeader);
}
let version = (header[0] & 0xF0) >> 4; let profile_type = header[0] & 0x0F;
if profile_type == 0x04 {
return Self::from_full_vector_header(version, header);
}
let transform_id = if header.len() >= 4 {
Some(u16::from_be_bytes([header[2], header[3]]))
} else {
None
};
Ok(ProfileInfo {
version,
profile_type,
transform_id,
model_id: None,
original_dims: None,
precision: None,
dimension_mode: None,
seed: None,
})
}
fn from_full_vector_header(version: u8, header: &[u8]) -> Result<Self, ElidError> {
if header.len() < 12 {
return Err(ElidError::InvalidHeader);
}
let original_dims = u16::from_be_bytes([header[2], header[3]]);
let precision = match header[4] {
0 => VectorPrecision::Full32,
1 => VectorPrecision::Half16,
2 => VectorPrecision::Quant8,
n if (3..=35).contains(&n) => VectorPrecision::Bits { bits: n - 3 + 1 },
_ => {
return Err(ElidError::InvalidMetadata(
"Invalid precision type".to_string(),
))
}
};
let dim_mode_type = header[5];
let target_dims = u16::from_be_bytes([header[6], header[7]]);
let dimension_mode = match dim_mode_type {
0 => DimensionMode::Preserve,
1 => DimensionMode::Reduce { target_dims },
2 => DimensionMode::Common { dims: target_dims },
_ => {
return Err(ElidError::InvalidMetadata(
"Invalid dimension mode".to_string(),
))
}
};
let seed_low = u32::from_be_bytes([header[8], header[9], header[10], header[11]]);
Ok(ProfileInfo {
version,
profile_type: 0x04,
transform_id: None,
model_id: None,
original_dims: Some(original_dims),
precision: Some(precision),
dimension_mode: Some(dimension_mode),
seed: Some(seed_low as u64),
})
}
#[must_use]
pub fn to_header(&self) -> Vec<u8> {
let mut bytes = vec![
(self.version << 4) | (self.profile_type & 0x0F),
0x00, ];
if self.profile_type == 0x04 {
let orig_dims = self.original_dims.unwrap_or(0);
bytes.extend_from_slice(&orig_dims.to_be_bytes());
let precision_byte = match self.precision {
Some(VectorPrecision::Full32) => 0,
Some(VectorPrecision::Half16) => 1,
Some(VectorPrecision::Quant8) => 2,
Some(VectorPrecision::Bits { bits }) => 3 + bits - 1,
None => 0,
};
bytes.push(precision_byte);
let (mode_byte, target_dims) = match self.dimension_mode {
Some(DimensionMode::Preserve) => (0u8, 0u16),
Some(DimensionMode::Reduce { target_dims }) => (1u8, target_dims),
Some(DimensionMode::Common { dims }) => (2u8, dims),
None => (0u8, 0u16),
};
bytes.push(mode_byte);
bytes.extend_from_slice(&target_dims.to_be_bytes());
let seed_low = (self.seed.unwrap_or(0) & 0xFFFF_FFFF) as u32;
bytes.extend_from_slice(&seed_low.to_be_bytes());
return bytes;
}
if let Some(tid) = self.transform_id {
bytes.extend_from_slice(&tid.to_be_bytes());
}
bytes
}
#[must_use]
pub fn from_full_vector(
original_dims: u16,
precision: VectorPrecision,
dimensions: DimensionMode,
seed: u64,
) -> Self {
ProfileInfo {
version: 0,
profile_type: 0x04,
transform_id: None,
model_id: None,
original_dims: Some(original_dims),
precision: Some(precision),
dimension_mode: Some(dimensions),
seed: Some(seed),
}
}
}
#[derive(Clone, Debug)]
pub struct Embedding {
values: Vec<f32>,
}
impl Embedding {
pub fn new(values: Vec<f32>) -> Result<Self, ElidError> {
Self::validate(&values)?;
Ok(Embedding { values })
}
pub fn from_f64(values: Vec<f64>) -> Result<Self, ElidError> {
let values_f32: Vec<f32> = values.iter().map(|&v| v as f32).collect();
Self::new(values_f32)
}
fn validate(values: &[f32]) -> Result<(), ElidError> {
if values.len() < 64 || values.len() > 2048 {
return Err(ElidError::InvalidDimension {
got: values.len(),
expected_range: (64, 2048),
});
}
if values.iter().any(|v| !v.is_finite()) {
return Err(ElidError::InvalidValue);
}
if values.iter().all(|v| *v == 0.0) {
}
Ok(())
}
#[must_use]
pub fn as_slice(&self) -> &[f32] {
&self.values
}
pub fn normalize(&mut self) {
let norm: f32 = self.values.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 0.0 {
self.values.iter_mut().for_each(|v| *v /= norm);
}
}
#[must_use]
pub fn dim(&self) -> usize {
self.values.len()
}
}
#[derive(Clone, Debug)]
pub struct QuantizedCoords {
coords: Vec<u16>,
bits: u8,
}
impl QuantizedCoords {
pub fn from_embedding(
embedding: &Embedding,
dims: u8,
bits_per_dim: u8,
) -> Result<Self, ElidError> {
if dims == 0 || dims > 32 {
return Err(ElidError::InvalidDimension {
got: dims as usize,
expected_range: (1, 32),
});
}
if bits_per_dim == 0 || bits_per_dim > 16 {
return Err(ElidError::InvalidDimension {
got: bits_per_dim as usize,
expected_range: (1, 16),
});
}
if (dims as usize) > embedding.dim() {
return Err(ElidError::InvalidDimension {
got: dims as usize,
expected_range: (1, embedding.dim()),
});
}
let max_val = ((1u32 << bits_per_dim) - 1) as f32;
let coords: Vec<u16> = embedding
.as_slice()
.iter()
.take(dims as usize)
.map(|&value| {
let normalized = (value + 1.0) / 2.0;
let scaled = normalized * max_val;
let clamped = scaled.clamp(0.0, max_val);
clamped.round() as u16
})
.collect();
Ok(QuantizedCoords {
coords,
bits: bits_per_dim,
})
}
#[must_use]
pub fn as_slice(&self) -> &[u16] {
&self.coords
}
#[must_use]
pub fn bits_per_dim(&self) -> u8 {
self.bits
}
#[must_use]
pub fn len(&self) -> usize {
self.coords.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.coords.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_elid_valid_string() {
let id = Elid::from_string("0123456789abcdefghijklmnopqrstuv".to_string());
assert!(id.is_ok());
assert_eq!(id.unwrap().as_str(), "0123456789abcdefghijklmnopqrstuv");
}
#[test]
fn test_elid_invalid_chars() {
let id = Elid::from_string("0123456789abcdefghijklmnopqrstuvw".to_string());
assert!(matches!(id, Err(ElidError::InvalidEncoding)));
}
#[test]
fn test_elid_ordering() {
let id1 = Elid::from_string("00000000".to_string()).unwrap();
let id2 = Elid::from_string("00000001".to_string()).unwrap();
assert!(id1 < id2);
}
#[test]
fn test_profile_default() {
let profile = Profile::default();
assert_eq!(
profile,
Profile::Mini128 {
seed: 0x454c4944_53494d48
}
);
}
#[test]
fn test_profile_bit_length() {
let mini = Profile::Mini128 { seed: 0 };
assert_eq!(mini.bit_length(), 128);
let morton = Profile::Morton10x10 {
dims: 10,
bits_per_dim: 10,
transform_id: None,
};
assert_eq!(morton.bit_length(), 100);
let hilbert = Profile::Hilbert10x10 {
dims: 8,
bits_per_dim: 12,
transform_id: None,
};
assert_eq!(hilbert.bit_length(), 96);
}
#[test]
fn test_profile_string_length() {
let mini = Profile::Mini128 { seed: 0 };
assert_eq!(mini.string_length(), 29);
let morton = Profile::Morton10x10 {
dims: 10,
bits_per_dim: 10,
transform_id: None,
};
assert_eq!(morton.string_length(), 24);
}
#[test]
fn test_profile_type_id() {
assert_eq!(Profile::Mini128 { seed: 0 }.type_id(), 0x01);
assert_eq!(
Profile::Morton10x10 {
dims: 10,
bits_per_dim: 10,
transform_id: None
}
.type_id(),
0x02
);
assert_eq!(
Profile::Hilbert10x10 {
dims: 10,
bits_per_dim: 10,
transform_id: None
}
.type_id(),
0x03
);
}
#[test]
fn test_profile_info_basic_header() {
let header = vec![0x01, 0x00]; let info = ProfileInfo::from_header(&header).unwrap();
assert_eq!(info.version, 0);
assert_eq!(info.profile_type, 1);
assert_eq!(info.transform_id, None);
}
#[test]
fn test_profile_info_extended_header() {
let header = vec![0x12, 0x00, 0x00, 0xFF]; let info = ProfileInfo::from_header(&header).unwrap();
assert_eq!(info.version, 1);
assert_eq!(info.profile_type, 2);
assert_eq!(info.transform_id, Some(255));
}
#[test]
fn test_profile_info_to_header() {
let info = ProfileInfo {
version: 1,
profile_type: 3,
transform_id: Some(0x1234),
model_id: None,
original_dims: None,
precision: None,
dimension_mode: None,
seed: None,
};
let header = info.to_header();
assert_eq!(header[0], 0x13); assert_eq!(header[1], 0x00);
assert_eq!(header[2..4], [0x12, 0x34]);
}
#[test]
fn test_profile_info_roundtrip() {
let info = ProfileInfo {
version: 2,
profile_type: 1,
transform_id: Some(42),
model_id: None,
original_dims: None,
precision: None,
dimension_mode: None,
seed: None,
};
let header = info.to_header();
let decoded = ProfileInfo::from_header(&header).unwrap();
assert_eq!(decoded.version, info.version);
assert_eq!(decoded.profile_type, info.profile_type);
assert_eq!(decoded.transform_id, info.transform_id);
}
#[test]
fn test_profile_info_full_vector_roundtrip() {
let info = ProfileInfo::from_full_vector(
768,
VectorPrecision::Half16,
DimensionMode::Reduce { target_dims: 256 },
0x12345678,
);
let header = info.to_header();
assert_eq!(header.len(), 12);
let decoded = ProfileInfo::from_header(&header).unwrap();
assert_eq!(decoded.version, 0);
assert_eq!(decoded.profile_type, 0x04);
assert_eq!(decoded.original_dims, Some(768));
assert_eq!(decoded.precision, Some(VectorPrecision::Half16));
assert_eq!(
decoded.dimension_mode,
Some(DimensionMode::Reduce { target_dims: 256 })
);
assert_eq!(decoded.seed, Some(0x12345678));
}
#[test]
fn test_profile_info_invalid_header() {
let header = vec![0x01]; assert!(matches!(
ProfileInfo::from_header(&header),
Err(ElidError::InvalidHeader)
));
}
#[test]
fn test_profile_info_full_vector_short_header() {
let header = vec![0x04, 0x00, 0x00, 0x00]; assert!(matches!(
ProfileInfo::from_header(&header),
Err(ElidError::InvalidHeader)
));
}
#[test]
fn test_embedding_valid() {
let values = vec![0.1; 128];
let embedding = Embedding::new(values);
assert!(embedding.is_ok());
assert_eq!(embedding.unwrap().dim(), 128);
}
#[test]
fn test_embedding_too_small() {
let values = vec![0.1; 32]; let embedding = Embedding::new(values);
assert!(matches!(embedding, Err(ElidError::InvalidDimension { .. })));
}
#[test]
fn test_embedding_too_large() {
let values = vec![0.1; 4096]; let embedding = Embedding::new(values);
assert!(matches!(embedding, Err(ElidError::InvalidDimension { .. })));
}
#[test]
fn test_embedding_nan() {
let mut values = vec![0.1; 128];
values[64] = f32::NAN;
let embedding = Embedding::new(values);
assert!(matches!(embedding, Err(ElidError::InvalidValue)));
}
#[test]
fn test_embedding_inf() {
let mut values = vec![0.1; 128];
values[64] = f32::INFINITY;
let embedding = Embedding::new(values);
assert!(matches!(embedding, Err(ElidError::InvalidValue)));
}
#[test]
fn test_embedding_from_f64() {
let values = vec![0.1_f64; 128];
let embedding = Embedding::from_f64(values);
assert!(embedding.is_ok());
}
#[test]
fn test_embedding_normalize() {
let values = vec![3.0, 4.0].into_iter().cycle().take(128).collect();
let mut embedding = Embedding::new(values).unwrap();
embedding.normalize();
let norm: f32 = embedding
.as_slice()
.iter()
.map(|v| v * v)
.sum::<f32>()
.sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_embedding_normalize_zero() {
let values = vec![0.0; 128];
let mut embedding = Embedding::new(values).unwrap();
embedding.normalize();
assert!(embedding.as_slice().iter().all(|&v| v == 0.0));
}
}
#[cfg(test)]
mod quantized_coords_tests {
use super::*;
#[test]
fn test_quantized_coords_basic() {
let values = vec![0.0; 128]; let embedding = Embedding::new(values).unwrap();
let coords = QuantizedCoords::from_embedding(&embedding, 10, 10).unwrap();
assert_eq!(coords.len(), 10);
assert_eq!(coords.bits_per_dim(), 10);
assert!(!coords.is_empty());
let expected = 512u16;
assert!(
coords.as_slice().iter().all(|&c| c == expected),
"All coords should be at midpoint, got {:?}",
coords.as_slice()
);
}
#[test]
fn test_quantized_coords_min_value() {
let values = vec![-1.0; 128]; let embedding = Embedding::new(values).unwrap();
let coords = QuantizedCoords::from_embedding(&embedding, 8, 10).unwrap();
assert!(
coords.as_slice().iter().all(|&c| c == 0),
"Min values should quantize to 0, got {:?}",
coords.as_slice()
);
}
#[test]
fn test_quantized_coords_max_value() {
let values = vec![1.0; 128]; let embedding = Embedding::new(values).unwrap();
let coords = QuantizedCoords::from_embedding(&embedding, 8, 10).unwrap();
let expected = (1u16 << 10) - 1; assert!(
coords.as_slice().iter().all(|&c| c == expected),
"Max values should quantize to {}, got {:?}",
expected,
coords.as_slice()
);
}
#[test]
fn test_quantized_coords_zero_dims() {
let values = vec![0.0; 128];
let embedding = Embedding::new(values).unwrap();
let result = QuantizedCoords::from_embedding(&embedding, 0, 10);
assert!(
matches!(result, Err(ElidError::InvalidDimension { .. })),
"Should reject 0 dimensions"
);
}
#[test]
fn test_quantized_coords_too_many_dims() {
let values = vec![0.0; 128];
let embedding = Embedding::new(values).unwrap();
let result = QuantizedCoords::from_embedding(&embedding, 33, 10);
assert!(
matches!(result, Err(ElidError::InvalidDimension { .. })),
"Should reject > 32 dimensions"
);
}
#[test]
fn test_quantized_coords_zero_bits() {
let values = vec![0.0; 128];
let embedding = Embedding::new(values).unwrap();
let result = QuantizedCoords::from_embedding(&embedding, 10, 0);
assert!(
matches!(result, Err(ElidError::InvalidDimension { .. })),
"Should reject 0 bits per dimension"
);
}
#[test]
fn test_quantized_coords_too_many_bits() {
let values = vec![0.0; 128];
let embedding = Embedding::new(values).unwrap();
let result = QuantizedCoords::from_embedding(&embedding, 10, 17);
assert!(
matches!(result, Err(ElidError::InvalidDimension { .. })),
"Should reject > 16 bits per dimension"
);
}
}