1pub const DEFAULT_EMBEDDING_DIM: usize = 1536;
16pub const EMBEDDING_F16_BYTES: usize = DEFAULT_EMBEDDING_DIM * 2;
18pub const EMBEDDING_F32_BYTES: usize = DEFAULT_EMBEDDING_DIM * 4;
20pub const EMBEDDING_BQ_BYTES: usize = DEFAULT_EMBEDDING_DIM / 8;
22
23pub const DEFAULT_VECTOR_SIZE_U64: u64 = DEFAULT_EMBEDDING_DIM as u64;
25
26pub const DEFAULT_VERIFICATION_THRESHOLD: f32 = 0.70;
28
29pub const DEFAULT_MAX_SEQ_LEN: usize = 8192;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub struct DimConfig {
38 pub embedding_dim: usize,
40}
41
42impl Default for DimConfig {
43 fn default() -> Self {
44 Self {
45 embedding_dim: DEFAULT_EMBEDDING_DIM,
46 }
47 }
48}
49
50impl DimConfig {
51 pub fn new(embedding_dim: usize) -> Self {
53 Self { embedding_dim }
54 }
55
56 pub fn validate(&self) -> Result<(), DimValidationError> {
62 if self.embedding_dim == 0 {
63 return Err(DimValidationError::ZeroDimension);
64 }
65 if !self.embedding_dim.is_multiple_of(8) {
66 return Err(DimValidationError::NotDivisibleBy8 {
67 dim: self.embedding_dim,
68 });
69 }
70 Ok(())
71 }
72
73 pub fn f16_bytes(&self) -> usize {
75 self.embedding_dim * 2
76 }
77
78 pub fn f32_bytes(&self) -> usize {
80 self.embedding_dim * 4
81 }
82
83 pub fn bq_bytes(&self) -> usize {
85 self.embedding_dim / 8
86 }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq)]
91pub enum DimValidationError {
92 ZeroDimension,
94 NotDivisibleBy8 {
96 dim: usize,
98 },
99 DimensionMismatch {
101 expected: usize,
103 actual: usize,
105 },
106}
107
108impl std::fmt::Display for DimValidationError {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 match self {
111 Self::ZeroDimension => write!(f, "embedding dimension cannot be zero"),
112 Self::NotDivisibleBy8 { dim } => {
113 write!(
114 f,
115 "embedding dimension {} is not divisible by 8 (required for BQ)",
116 dim
117 )
118 }
119 Self::DimensionMismatch { expected, actual } => {
120 write!(
121 f,
122 "dimension mismatch: expected {}, got {}",
123 expected, actual
124 )
125 }
126 }
127 }
128}
129
130impl std::error::Error for DimValidationError {}
131
132pub fn validate_embedding_dim(actual: usize, expected: usize) -> Result<(), DimValidationError> {
147 if actual != expected {
148 return Err(DimValidationError::DimensionMismatch { expected, actual });
149 }
150 Ok(())
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_dim_config_default() {
159 let config = DimConfig::default();
160 assert_eq!(config.embedding_dim, DEFAULT_EMBEDDING_DIM);
161 }
162
163 #[test]
164 fn test_dim_config_validate_success() {
165 let config = DimConfig::new(1536);
166 assert!(config.validate().is_ok());
167 }
168
169 #[test]
170 fn test_dim_config_validate_zero() {
171 let config = DimConfig::new(0);
172 assert_eq!(config.validate(), Err(DimValidationError::ZeroDimension));
173 }
174
175 #[test]
176 fn test_dim_config_validate_not_divisible_by_8() {
177 let config = DimConfig::new(1537);
178 assert_eq!(
179 config.validate(),
180 Err(DimValidationError::NotDivisibleBy8 { dim: 1537 })
181 );
182 }
183
184 #[test]
185 fn test_dim_config_byte_calculations() {
186 let config = DimConfig::new(1536);
187 assert_eq!(config.f16_bytes(), EMBEDDING_F16_BYTES);
188 assert_eq!(config.f32_bytes(), EMBEDDING_F32_BYTES);
189 assert_eq!(config.bq_bytes(), EMBEDDING_BQ_BYTES);
190 }
191
192 #[test]
193 fn test_validate_embedding_dim_match() {
194 assert!(validate_embedding_dim(1536, 1536).is_ok());
195 }
196
197 #[test]
198 fn test_validate_embedding_dim_mismatch() {
199 assert_eq!(
200 validate_embedding_dim(768, 1536),
201 Err(DimValidationError::DimensionMismatch {
202 expected: 1536,
203 actual: 768
204 })
205 );
206 }
207
208 #[test]
209 fn test_error_display() {
210 let err = DimValidationError::ZeroDimension;
211 assert_eq!(err.to_string(), "embedding dimension cannot be zero");
212
213 let err = DimValidationError::NotDivisibleBy8 { dim: 1537 };
214 assert!(err.to_string().contains("1537"));
215 assert!(err.to_string().contains("divisible by 8"));
216
217 let err = DimValidationError::DimensionMismatch {
218 expected: 1536,
219 actual: 768,
220 };
221 assert!(err.to_string().contains("1536"));
222 assert!(err.to_string().contains("768"));
223 }
224}