pub const DEFAULT_EMBEDDING_DIM: usize = 1536;
pub const EMBEDDING_F16_BYTES: usize = DEFAULT_EMBEDDING_DIM * 2;
pub const EMBEDDING_F32_BYTES: usize = DEFAULT_EMBEDDING_DIM * 4;
pub const EMBEDDING_BQ_BYTES: usize = DEFAULT_EMBEDDING_DIM / 8;
pub const DEFAULT_VECTOR_SIZE_U64: u64 = DEFAULT_EMBEDDING_DIM as u64;
pub const DEFAULT_VERIFICATION_THRESHOLD: f32 = 0.70;
pub const DEFAULT_MAX_SEQ_LEN: usize = 8192;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DimConfig {
pub embedding_dim: usize,
}
impl Default for DimConfig {
fn default() -> Self {
Self {
embedding_dim: DEFAULT_EMBEDDING_DIM,
}
}
}
impl DimConfig {
pub fn new(embedding_dim: usize) -> Self {
Self { embedding_dim }
}
pub fn validate(&self) -> Result<(), DimValidationError> {
if self.embedding_dim == 0 {
return Err(DimValidationError::ZeroDimension);
}
if !self.embedding_dim.is_multiple_of(8) {
return Err(DimValidationError::NotDivisibleBy8 {
dim: self.embedding_dim,
});
}
Ok(())
}
pub fn f16_bytes(&self) -> usize {
self.embedding_dim * 2
}
pub fn f32_bytes(&self) -> usize {
self.embedding_dim * 4
}
pub fn bq_bytes(&self) -> usize {
self.embedding_dim / 8
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DimValidationError {
ZeroDimension,
NotDivisibleBy8 {
dim: usize,
},
DimensionMismatch {
expected: usize,
actual: usize,
},
}
impl std::fmt::Display for DimValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ZeroDimension => write!(f, "embedding dimension cannot be zero"),
Self::NotDivisibleBy8 { dim } => {
write!(
f,
"embedding dimension {} is not divisible by 8 (required for BQ)",
dim
)
}
Self::DimensionMismatch { expected, actual } => {
write!(
f,
"dimension mismatch: expected {}, got {}",
expected, actual
)
}
}
}
}
impl std::error::Error for DimValidationError {}
pub fn validate_embedding_dim(actual: usize, expected: usize) -> Result<(), DimValidationError> {
if actual != expected {
return Err(DimValidationError::DimensionMismatch { expected, actual });
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dim_config_default() {
let config = DimConfig::default();
assert_eq!(config.embedding_dim, DEFAULT_EMBEDDING_DIM);
}
#[test]
fn test_dim_config_validate_success() {
let config = DimConfig::new(1536);
assert!(config.validate().is_ok());
}
#[test]
fn test_dim_config_validate_zero() {
let config = DimConfig::new(0);
assert_eq!(config.validate(), Err(DimValidationError::ZeroDimension));
}
#[test]
fn test_dim_config_validate_not_divisible_by_8() {
let config = DimConfig::new(1537);
assert_eq!(
config.validate(),
Err(DimValidationError::NotDivisibleBy8 { dim: 1537 })
);
}
#[test]
fn test_dim_config_byte_calculations() {
let config = DimConfig::new(1536);
assert_eq!(config.f16_bytes(), EMBEDDING_F16_BYTES);
assert_eq!(config.f32_bytes(), EMBEDDING_F32_BYTES);
assert_eq!(config.bq_bytes(), EMBEDDING_BQ_BYTES);
}
#[test]
fn test_validate_embedding_dim_match() {
assert!(validate_embedding_dim(1536, 1536).is_ok());
}
#[test]
fn test_validate_embedding_dim_mismatch() {
assert_eq!(
validate_embedding_dim(768, 1536),
Err(DimValidationError::DimensionMismatch {
expected: 1536,
actual: 768
})
);
}
#[test]
fn test_error_display() {
let err = DimValidationError::ZeroDimension;
assert_eq!(err.to_string(), "embedding dimension cannot be zero");
let err = DimValidationError::NotDivisibleBy8 { dim: 1537 };
assert!(err.to_string().contains("1537"));
assert!(err.to_string().contains("divisible by 8"));
let err = DimValidationError::DimensionMismatch {
expected: 1536,
actual: 768,
};
assert!(err.to_string().contains("1536"));
assert!(err.to_string().contains("768"));
}
}