1use std::cmp::Ordering;
42
43#[derive(Clone, Debug)]
45pub struct Int8Vector {
46 data: Vec<i8>,
48 scale: f32,
50 norm: f32,
52}
53
54impl Int8Vector {
55 pub fn from_f32(values: &[f32]) -> Self {
59 if values.is_empty() {
60 return Self {
61 data: Vec::new(),
62 scale: 1.0,
63 norm: 0.0,
64 };
65 }
66
67 let max_abs = values
69 .iter()
70 .map(|v| v.abs())
71 .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
72 .unwrap_or(1.0);
73
74 let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
76
77 let data: Vec<i8> = values
79 .iter()
80 .map(|&v| {
81 let quantized = (v / scale).round();
82 quantized.clamp(-127.0, 127.0) as i8
83 })
84 .collect();
85
86 let norm = values.iter().map(|v| v * v).sum::<f32>().sqrt();
88
89 Self { data, scale, norm }
90 }
91
92 pub fn from_f32_with_scale(values: &[f32], scale: f32) -> Self {
94 let data: Vec<i8> = values
95 .iter()
96 .map(|&v| {
97 let quantized = (v / scale).round();
98 quantized.clamp(-127.0, 127.0) as i8
99 })
100 .collect();
101
102 let norm = values.iter().map(|v| v * v).sum::<f32>().sqrt();
103
104 Self { data, scale, norm }
105 }
106
107 pub fn from_raw(data: Vec<i8>, scale: f32, norm: f32) -> Self {
109 Self { data, scale, norm }
110 }
111
112 #[inline]
114 pub fn dim(&self) -> usize {
115 self.data.len()
116 }
117
118 #[inline]
120 pub fn data(&self) -> &[i8] {
121 &self.data
122 }
123
124 #[inline]
126 pub fn scale(&self) -> f32 {
127 self.scale
128 }
129
130 #[inline]
132 pub fn size_bytes(&self) -> usize {
133 self.data.len() + 8 }
135
136 pub fn to_f32(&self) -> Vec<f32> {
138 self.data.iter().map(|&v| v as f32 * self.scale).collect()
139 }
140
141 #[inline]
145 pub fn dot_product(&self, other: &Self) -> f32 {
146 debug_assert_eq!(self.data.len(), other.data.len(), "Dimensions must match");
147
148 let raw_dot = dot_product_i8_simd(&self.data, &other.data);
149 raw_dot as f32 * self.scale * other.scale
150 }
151
152 #[inline]
157 pub fn dot_product_f32(&self, query: &[f32]) -> f32 {
158 debug_assert_eq!(self.data.len(), query.len(), "Dimensions must match");
159
160 dot_product_i8_f32_simd(&self.data, query) * self.scale
161 }
162
163 #[inline]
165 pub fn l2_squared(&self, other: &Self) -> f32 {
166 debug_assert_eq!(self.data.len(), other.data.len(), "Dimensions must match");
167
168 let raw_dist = l2_squared_i8_simd(&self.data, &other.data);
169 raw_dist as f32 * self.scale * other.scale
170 }
171
172 #[inline]
176 pub fn cosine_distance(&self, other: &Self) -> f32 {
177 let dot = self.dot_product(other);
178 let denom = self.norm * other.norm;
179 if denom > 0.0 {
180 1.0 - (dot / denom)
181 } else {
182 1.0
183 }
184 }
185}
186
187#[inline]
193pub fn dot_product_i8_simd(a: &[i8], b: &[i8]) -> i32 {
194 debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
195
196 #[cfg(target_arch = "x86_64")]
197 {
198 if is_x86_feature_detected!("avx2") {
199 return unsafe { dot_product_i8_avx2(a, b) };
200 }
201 if is_x86_feature_detected!("sse4.1") {
202 return unsafe { dot_product_i8_sse4(a, b) };
203 }
204 }
205
206 dot_product_i8_scalar(a, b)
207}
208
209#[inline]
211fn dot_product_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
212 let mut sum = 0i32;
213 for (x, y) in a.iter().zip(b.iter()) {
214 sum += (*x as i32) * (*y as i32);
215 }
216 sum
217}
218
219#[cfg(target_arch = "x86_64")]
221#[target_feature(enable = "avx2")]
222#[inline]
223unsafe fn dot_product_i8_avx2(a: &[i8], b: &[i8]) -> i32 {
224 use std::arch::x86_64::*;
225
226 let len = a.len();
227 let mut sum = _mm256_setzero_si256();
228
229 let chunks = len / 32;
231 for i in 0..chunks {
232 let idx = i * 32;
233 let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
234 let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
235
236 let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
238 let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
239 let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
240 let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
241
242 let prod_lo = _mm256_madd_epi16(va_lo, vb_lo);
244 let prod_hi = _mm256_madd_epi16(va_hi, vb_hi);
245
246 sum = _mm256_add_epi32(sum, prod_lo);
247 sum = _mm256_add_epi32(sum, prod_hi);
248 }
249
250 let sum128 = _mm_add_epi32(
252 _mm256_castsi256_si128(sum),
253 _mm256_extracti128_si256(sum, 1),
254 );
255 let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
256 let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
257 let mut result = _mm_cvtsi128_si32(sum32);
258
259 for i in (chunks * 32)..len {
261 result += (a[i] as i32) * (b[i] as i32);
262 }
263
264 result
265}
266
267#[cfg(target_arch = "x86_64")]
269#[target_feature(enable = "sse4.1")]
270#[inline]
271unsafe fn dot_product_i8_sse4(a: &[i8], b: &[i8]) -> i32 {
272 use std::arch::x86_64::*;
273
274 let len = a.len();
275 let mut sum = _mm_setzero_si128();
276
277 let chunks = len / 16;
279 for i in 0..chunks {
280 let idx = i * 16;
281 let va = _mm_loadu_si128(a.as_ptr().add(idx) as *const __m128i);
282 let vb = _mm_loadu_si128(b.as_ptr().add(idx) as *const __m128i);
283
284 let va_lo = _mm_cvtepi8_epi16(va);
286 let va_hi = _mm_cvtepi8_epi16(_mm_srli_si128(va, 8));
287 let vb_lo = _mm_cvtepi8_epi16(vb);
288 let vb_hi = _mm_cvtepi8_epi16(_mm_srli_si128(vb, 8));
289
290 let prod_lo = _mm_madd_epi16(va_lo, vb_lo);
292 let prod_hi = _mm_madd_epi16(va_hi, vb_hi);
293
294 sum = _mm_add_epi32(sum, prod_lo);
295 sum = _mm_add_epi32(sum, prod_hi);
296 }
297
298 let sum64 = _mm_add_epi32(sum, _mm_srli_si128(sum, 8));
300 let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
301 let mut result = _mm_cvtsi128_si32(sum32);
302
303 for i in (chunks * 16)..len {
305 result += (a[i] as i32) * (b[i] as i32);
306 }
307
308 result
309}
310
311#[inline]
313pub fn dot_product_i8_f32_simd(a: &[i8], b: &[f32]) -> f32 {
314 debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
315
316 #[cfg(target_arch = "x86_64")]
317 {
318 if is_x86_feature_detected!("avx2") {
319 return unsafe { dot_product_i8_f32_avx2(a, b) };
320 }
321 }
322
323 dot_product_i8_f32_scalar(a, b)
324}
325
326#[inline]
328fn dot_product_i8_f32_scalar(a: &[i8], b: &[f32]) -> f32 {
329 let mut sum = 0.0f32;
330 for (x, y) in a.iter().zip(b.iter()) {
331 sum += (*x as f32) * y;
332 }
333 sum
334}
335
336#[cfg(target_arch = "x86_64")]
338#[target_feature(enable = "avx2")]
339#[inline]
340unsafe fn dot_product_i8_f32_avx2(a: &[i8], b: &[f32]) -> f32 {
341 use std::arch::x86_64::*;
342
343 let len = a.len();
344 let mut sum = _mm256_setzero_ps();
345
346 let chunks = len / 8;
348 for i in 0..chunks {
349 let idx = i * 8;
350
351 let va_i8 = _mm_loadl_epi64(a.as_ptr().add(idx) as *const __m128i);
353 let va_i16 = _mm_cvtepi8_epi16(va_i8);
354 let va_i32 = _mm256_cvtepi16_epi32(va_i16);
355 let va_f32 = _mm256_cvtepi32_ps(va_i32);
356
357 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
359
360 sum = _mm256_fmadd_ps(va_f32, vb, sum);
362 }
363
364 let sum128 = _mm_add_ps(_mm256_castps256_ps128(sum), _mm256_extractf128_ps(sum, 1));
366 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
367 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
368 let mut result = _mm_cvtss_f32(sum32);
369
370 for i in (chunks * 8)..len {
372 result += (a[i] as f32) * b[i];
373 }
374
375 result
376}
377
378#[inline]
380pub fn l2_squared_i8_simd(a: &[i8], b: &[i8]) -> i32 {
381 debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
382
383 let mut sum = 0i32;
385 for (x, y) in a.iter().zip(b.iter()) {
386 let d = (*x as i32) - (*y as i32);
387 sum += d * d;
388 }
389 sum
390}
391
392#[derive(Clone)]
398pub struct Int8Index {
399 vectors: Vec<i8>,
401 scales: Vec<f32>,
403 norms: Vec<f32>,
405 dim: usize,
407 n_vectors: usize,
409}
410
411impl Int8Index {
412 pub fn new(dim: usize) -> Self {
414 Self {
415 vectors: Vec::new(),
416 scales: Vec::new(),
417 norms: Vec::new(),
418 dim,
419 n_vectors: 0,
420 }
421 }
422
423 pub fn with_capacity(dim: usize, capacity: usize) -> Self {
425 Self {
426 vectors: Vec::with_capacity(capacity * dim),
427 scales: Vec::with_capacity(capacity),
428 norms: Vec::with_capacity(capacity),
429 dim,
430 n_vectors: 0,
431 }
432 }
433
434 pub fn add(&mut self, vector: &Int8Vector) {
436 debug_assert_eq!(vector.dim(), self.dim, "Dimension mismatch");
437 self.vectors.extend_from_slice(&vector.data);
438 self.scales.push(vector.scale);
439 self.norms.push(vector.norm);
440 self.n_vectors += 1;
441 }
442
443 pub fn add_f32(&mut self, vector: &[f32]) {
445 let int8 = Int8Vector::from_f32(vector);
446 self.add(&int8);
447 }
448
449 #[inline]
451 pub fn len(&self) -> usize {
452 self.n_vectors
453 }
454
455 #[inline]
457 pub fn is_empty(&self) -> bool {
458 self.n_vectors == 0
459 }
460
461 pub fn memory_bytes(&self) -> usize {
463 self.vectors.len() + self.scales.len() * 4 + self.norms.len() * 4
464 }
465
466 pub fn get(&self, idx: usize) -> Option<Int8Vector> {
468 if idx >= self.n_vectors {
469 return None;
470 }
471 let start = idx * self.dim;
472 let end = start + self.dim;
473 Some(Int8Vector::from_raw(
474 self.vectors[start..end].to_vec(),
475 self.scales[idx],
476 self.norms[idx],
477 ))
478 }
479
480 #[inline]
482 pub fn get_data(&self, idx: usize) -> &[i8] {
483 let start = idx * self.dim;
484 let end = start + self.dim;
485 &self.vectors[start..end]
486 }
487
488 #[inline]
490 pub fn dot_product_f32(&self, idx: usize, query: &[f32]) -> f32 {
491 let data = self.get_data(idx);
492 let scale = self.scales[idx];
493 dot_product_i8_f32_simd(data, query) * scale
494 }
495
496 pub fn rescore_candidates(
500 &self,
501 candidates: &[(usize, u32)],
502 query: &[f32],
503 ) -> Vec<(usize, f32)> {
504 let mut results: Vec<(usize, f32)> = candidates
505 .iter()
506 .filter_map(|&(idx, _)| {
507 if idx < self.n_vectors {
508 let dot = self.dot_product_f32(idx, query);
510 Some((idx, -dot))
511 } else {
512 None
513 }
514 })
515 .collect();
516
517 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
518 results
519 }
520}
521
522#[cfg(test)]
527mod tests {
528 use super::*;
529
530 #[test]
531 fn test_int8_quantization() {
532 let values = vec![1.0, -1.0, 0.5, -0.5, 0.0];
533 let int8 = Int8Vector::from_f32(&values);
534
535 assert_eq!(int8.data[0], 127); assert_eq!(int8.data[1], -127); assert_eq!(int8.data[4], 0); }
540
541 #[test]
542 fn test_dot_product_identical() {
543 let v1 = Int8Vector::from_f32(&[1.0, 2.0, 3.0, 4.0]);
544 let v2 = Int8Vector::from_f32(&[1.0, 2.0, 3.0, 4.0]);
545
546 let dot = v1.dot_product(&v2);
547 let expected = 1.0 + 4.0 + 9.0 + 16.0; assert!((dot - expected).abs() < 1.0); }
550
551 #[test]
552 fn test_dot_product_f32() {
553 let int8 = Int8Vector::from_f32(&[1.0, 0.0, -1.0, 0.5]);
554 let query = vec![1.0, 1.0, 1.0, 1.0];
555
556 let dot = int8.dot_product_f32(&query);
557 assert!((dot - 0.5).abs() < 0.1);
559 }
560
561 #[test]
562 fn test_compression_ratio() {
563 let fp32_size = 1024 * 4;
568 let int8 = Int8Vector::from_f32(&vec![1.0; 1024]);
569 let int8_size = int8.size_bytes();
570
571 assert_eq!(int8_size, 1032);
572 assert!(fp32_size / int8_size >= 3); }
574
575 #[test]
576 fn test_index_rescore() {
577 let mut index = Int8Index::new(4);
578
579 index.add_f32(&[1.0, 0.0, 0.0, 0.0]);
580 index.add_f32(&[0.0, 1.0, 0.0, 0.0]);
581 index.add_f32(&[0.0, 0.0, 1.0, 0.0]);
582
583 let query = vec![1.0, 0.0, 0.0, 0.0];
584
585 let binary_candidates = vec![(0, 10), (1, 20), (2, 30)];
587
588 let rescored = index.rescore_candidates(&binary_candidates, &query);
589
590 assert_eq!(rescored[0].0, 0);
592 }
593
594 #[test]
595 fn test_simd_vs_scalar() {
596 let a: Vec<i8> = (0..128).map(|i| (i % 127) as i8).collect();
597 let b: Vec<i8> = (0..128).map(|i| ((127 - i) % 127) as i8).collect();
598
599 let scalar = dot_product_i8_scalar(&a, &b);
600
601 #[cfg(target_arch = "x86_64")]
602 {
603 let simd = dot_product_i8_simd(&a, &b);
604 assert_eq!(scalar, simd);
605 }
606 }
607}