lean_ctx/core/
embedding_quant.rs1use serde::{Deserialize, Serialize};
16
17const I8_ABS_MAX: f32 = 127.0;
20
21const LANES: usize = 8;
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
29pub struct QuantizedVector {
30 pub code: Vec<i8>,
31 pub scale: f32,
32}
33
34impl QuantizedVector {
35 #[must_use]
36 pub fn dim(&self) -> usize {
37 self.code.len()
38 }
39
40 #[must_use]
43 pub fn dequantize(&self) -> Vec<f32> {
44 self.code
45 .iter()
46 .map(|&c| f32::from(c) * self.scale)
47 .collect()
48 }
49}
50
51#[must_use]
56pub fn quantize(v: &[f32]) -> QuantizedVector {
57 let max_abs = v.iter().fold(0.0f32, |m, &x| m.max(x.abs()));
58 if max_abs == 0.0 {
59 return QuantizedVector {
60 code: vec![0; v.len()],
61 scale: 0.0,
62 };
63 }
64 let scale = max_abs / I8_ABS_MAX;
65 let inv = 1.0 / scale;
66 let code = v
67 .iter()
68 .map(|&x| {
69 let q = (x * inv).round().clamp(-I8_ABS_MAX, I8_ABS_MAX);
71 q as i8
72 })
73 .collect();
74 QuantizedVector { code, scale }
75}
76
77#[must_use]
83pub fn dot_quant(query: &[f32], doc: &QuantizedVector) -> f32 {
84 debug_assert_eq!(query.len(), doc.code.len(), "dim mismatch");
85 if doc.scale == 0.0 {
86 return 0.0;
87 }
88
89 let mut lanes = [0.0f32; LANES];
90 let mut q_chunks = query.chunks_exact(LANES);
91 let mut c_chunks = doc.code.chunks_exact(LANES);
92
93 for (q, c) in q_chunks.by_ref().zip(c_chunks.by_ref()) {
94 for i in 0..LANES {
95 lanes[i] += q[i] * f32::from(c[i]);
96 }
97 }
98
99 let mut tail = 0.0f32;
100 for (q, c) in q_chunks.remainder().iter().zip(c_chunks.remainder()) {
101 tail += q * f32::from(*c);
102 }
103
104 (lanes.iter().sum::<f32>() + tail) * doc.scale
105}
106
107#[must_use]
113pub fn dot_f32(a: &[f32], b: &[f32]) -> f32 {
114 debug_assert_eq!(a.len(), b.len(), "dim mismatch");
115
116 let mut lanes = [0.0f32; LANES];
117 let mut a_chunks = a.chunks_exact(LANES);
118 let mut b_chunks = b.chunks_exact(LANES);
119
120 for (x, y) in a_chunks.by_ref().zip(b_chunks.by_ref()) {
121 for i in 0..LANES {
122 lanes[i] += x[i] * y[i];
123 }
124 }
125
126 let mut tail = 0.0f32;
127 for (x, y) in a_chunks.remainder().iter().zip(b_chunks.remainder()) {
128 tail += x * y;
129 }
130
131 lanes.iter().sum::<f32>() + tail
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 fn naive_dot(a: &[f32], b: &[f32]) -> f32 {
139 a.iter().zip(b).map(|(x, y)| x * y).sum()
140 }
141
142 #[test]
143 fn dot_f32_matches_naive_within_tolerance() {
144 let a: Vec<f32> = (0..384).map(|i| (i as f32 * 0.013).sin()).collect();
145 let b: Vec<f32> = (0..384).map(|i| (i as f32 * 0.017).cos()).collect();
146 let chunked = dot_f32(&a, &b);
147 let naive = naive_dot(&a, &b);
148 assert!(
149 (chunked - naive).abs() < 1e-3,
150 "chunked={chunked} naive={naive}"
151 );
152 }
153
154 #[test]
155 fn dot_f32_handles_non_multiple_of_lane_width() {
156 let a: Vec<f32> = (0..13).map(|i| i as f32).collect();
158 let b: Vec<f32> = (0..13).map(|i| (i * 2) as f32).collect();
159 assert!((dot_f32(&a, &b) - naive_dot(&a, &b)).abs() < 1e-4);
160 }
161
162 #[test]
163 fn quantize_zero_vector_is_zero_scale() {
164 let q = quantize(&[0.0, 0.0, 0.0]);
165 assert_eq!(q.scale, 0.0);
166 assert_eq!(q.code, vec![0, 0, 0]);
167 assert_eq!(dot_quant(&[1.0, 1.0, 1.0], &q), 0.0);
168 }
169
170 #[test]
171 fn quantize_preserves_max_magnitude_at_full_scale() {
172 let q = quantize(&[1.0, 0.0, 0.0]);
173 assert_eq!(q.code[0], 127);
174 assert_eq!(q.code[1], 0);
175 let recon = q.dequantize();
177 assert!((recon[0] - 1.0).abs() < 1e-6);
178 }
179
180 #[test]
181 fn dot_quant_approximates_cosine_for_normalized_vectors() {
182 let mut a: Vec<f32> = (0..384).map(|i| (i as f32 * 0.011).sin() + 0.3).collect();
184 let mut b: Vec<f32> = a.iter().map(|x| x + 0.02).collect();
185 l2_normalize(&mut a);
186 l2_normalize(&mut b);
187
188 let exact = naive_dot(&a, &b);
189 let approx = dot_quant(&a, &quantize(&b));
190 assert!(
191 (exact - approx).abs() < 5e-3,
192 "exact={exact} approx={approx}"
193 );
194 let self_sim = dot_quant(&a, &quantize(&a));
196 assert!(self_sim > 0.99, "self_sim={self_sim}");
197 }
198
199 #[test]
200 fn dot_quant_preserves_ranking() {
201 let query = normalized(vec![1.0, 0.2, 0.0, 0.1]);
203 let near = quantize(&normalized(vec![0.9, 0.3, 0.0, 0.1]));
204 let far = quantize(&normalized(vec![-0.5, 0.8, 0.2, 0.0]));
205 assert!(dot_quant(&query, &near) > dot_quant(&query, &far));
206 }
207
208 fn l2_normalize(v: &mut [f32]) {
209 let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
210 if norm > 0.0 {
211 for x in v.iter_mut() {
212 *x /= norm;
213 }
214 }
215 }
216
217 fn normalized(mut v: Vec<f32>) -> Vec<f32> {
218 l2_normalize(&mut v);
219 v
220 }
221}