1#![allow(dead_code)]
2use std::collections::HashMap;
6use std::f32;
7
8pub const EMBEDDING_DIM: usize = 1536; pub trait EmbeddingProvider: Send + Sync {
12 fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
13 fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
14 fn dimension(&self) -> usize;
15}
16
17#[derive(Debug)]
18pub enum EmbeddingError {
19 ApiError(String),
20 InvalidDimension,
21 NetworkError,
22}
23
24pub struct EmbeddingCache {
26 embeddings: HashMap<String, Vec<f32>>,
27 dimension: usize,
28}
29
30impl EmbeddingCache {
31 pub fn new(dimension: usize) -> Self {
32 Self {
33 embeddings: HashMap::new(),
34 dimension,
35 }
36 }
37
38 pub fn insert(&mut self, text: String, embedding: Vec<f32>) {
39 if embedding.len() == self.dimension {
40 self.embeddings.insert(text, embedding);
41 }
42 }
43
44 pub fn get(&self, text: &str) -> Option<&Vec<f32>> {
45 self.embeddings.get(text)
46 }
47
48 pub fn contains(&self, text: &str) -> bool {
49 self.embeddings.contains_key(text)
50 }
51}
52
53#[cfg(target_arch = "x86_64")]
55pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
56 use std::arch::x86_64::*;
57
58 if a.len() != b.len() {
59 return 0.0;
60 }
61
62 unsafe {
63 let mut dot_sum = _mm256_setzero_ps();
64 let mut norm_a = _mm256_setzero_ps();
65 let mut norm_b = _mm256_setzero_ps();
66
67 let chunks = a.len() / 8;
68
69 for i in 0..chunks {
70 let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
71 let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
72
73 dot_sum = _mm256_fmadd_ps(va, vb, dot_sum);
74 norm_a = _mm256_fmadd_ps(va, va, norm_a);
75 norm_b = _mm256_fmadd_ps(vb, vb, norm_b);
76 }
77
78 let dot = horizontal_sum_ps256(dot_sum);
80 let na = horizontal_sum_ps256(norm_a).sqrt();
81 let nb = horizontal_sum_ps256(norm_b).sqrt();
82
83 let mut dot_rem = 0.0;
85 let mut na_rem = 0.0;
86 let mut nb_rem = 0.0;
87
88 for i in (chunks * 8)..a.len() {
89 dot_rem += a[i] * b[i];
90 na_rem += a[i] * a[i];
91 nb_rem += b[i] * b[i];
92 }
93
94 (dot + dot_rem) / ((na + na_rem.sqrt()) * (nb + nb_rem.sqrt()))
95 }
96}
97
98#[cfg(not(target_arch = "x86_64"))]
99pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
100 cosine_similarity_scalar(a, b)
101}
102
103pub fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
105 if a.len() != b.len() {
106 return 0.0;
107 }
108
109 let mut dot = 0.0;
110 let mut norm_a = 0.0;
111 let mut norm_b = 0.0;
112
113 for i in 0..a.len() {
114 dot += a[i] * b[i];
115 norm_a += a[i] * a[i];
116 norm_b += b[i] * b[i];
117 }
118
119 dot / (norm_a.sqrt() * norm_b.sqrt())
120}
121
122#[cfg(target_arch = "x86_64")]
123unsafe fn horizontal_sum_ps256(v: std::arch::x86_64::__m256) -> f32 {
124 use std::arch::x86_64::*;
125
126 let high = _mm256_extractf128_ps(v, 1);
127 let low = _mm256_castps256_ps128(v);
128 let sum = _mm_add_ps(high, low);
129 let shuf = _mm_shuffle_ps(sum, sum, 0x0E);
130 let sums = _mm_add_ps(sum, shuf);
131 let shuf2 = _mm_movehl_ps(sums, sums);
132 let result = _mm_add_ss(sums, shuf2);
133 _mm_cvtss_f32(result)
134}
135
136pub struct MockEmbeddingProvider {
138 dimension: usize,
139}
140
141impl MockEmbeddingProvider {
142 pub fn new(dimension: usize) -> Self {
143 Self { dimension }
144 }
145}
146
147impl EmbeddingProvider for MockEmbeddingProvider {
148 fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
149 let mut embedding = vec![0.0; self.dimension];
151 let hash = text
152 .bytes()
153 .fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64));
154
155 for (i, slot) in embedding.iter_mut().enumerate() {
156 let val = ((hash.wrapping_mul(i as u64 + 1)) % 1000) as f32 / 1000.0;
157 *slot = val * 2.0 - 1.0; }
159
160 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
162 if norm > 0.0 {
163 for val in &mut embedding {
164 *val /= norm;
165 }
166 }
167
168 Ok(embedding)
169 }
170
171 fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
172 texts.iter().map(|t| self.embed(t)).collect()
173 }
174
175 fn dimension(&self) -> usize {
176 self.dimension
177 }
178}
179
180#[repr(C, packed)]
182pub struct EmbeddedBlockHeader {
183 pub x: f32,
185 pub y: f32,
186 pub z: f32,
187 pub zoom: f32,
188 pub depth: u8,
189 pub layer_id: u8,
190 pub data_offset: u32,
191 pub data_len: u16,
192 pub parent_idx: u32,
193 pub child_count: u16,
194 pub crc16: [u8; 2], pub embedding_offset: u32, pub has_embedding: bool, }
200
201#[cfg(feature = "embeddings")]
203pub struct CandleEmbeddingProvider {
204 model: candle_transformers::models::bert::BertModel,
205 tokenizer: tokenizers::Tokenizer,
206 dim: usize,
207 device: candle_core::Device,
208}
209
210#[cfg(feature = "embeddings")]
211impl CandleEmbeddingProvider {
212 pub fn new(model_id: &str) -> Result<Self, EmbeddingError> {
213 use candle_core::Device;
214 use hf_hub::api::sync::Api;
215
216 let device = Device::Cpu;
217 let api = Api::new().map_err(|e| EmbeddingError::ApiError(e.to_string()))?;
218 let repo = api.model(model_id.to_string());
219
220 let tokenizer_path = repo
222 .get("tokenizer.json")
223 .map_err(|e| EmbeddingError::ApiError(format!("tokenizer download: {}", e)))?;
224 let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
225 .map_err(|e| EmbeddingError::ApiError(format!("tokenizer load: {}", e)))?;
226
227 let weights_path = repo
229 .get("model.safetensors")
230 .map_err(|e| EmbeddingError::ApiError(format!("weights download: {}", e)))?;
231 let vb = unsafe {
233 candle_nn::VarBuilder::from_mmaped_safetensors(
234 &[weights_path],
235 candle_core::DType::F32,
236 &device,
237 )
238 }
239 .map_err(|e| EmbeddingError::ApiError(format!("varbuilder: {}", e)))?;
240
241 let config = candle_transformers::models::bert::Config {
243 vocab_size: 30522,
244 hidden_size: 768,
245 num_hidden_layers: 12,
246 num_attention_heads: 12,
247 intermediate_size: 3072,
248 hidden_act: candle_transformers::models::bert::Activation::Gelu,
249 hidden_dropout_prob: 0.1,
250 attention_probs_dropout_prob: 0.1,
251 max_position_embeddings: 512,
252 type_vocab_size: 2,
253 initializer_range: 0.02,
254 layer_norm_eps: 1e-12,
255 pad_token_id: 0,
256 model_type: Some("bert".to_string()),
257 };
258 let dim = config.hidden_size;
259
260 let model = candle_transformers::models::bert::BertModel::load(vb, &config)
261 .map_err(|e| EmbeddingError::ApiError(format!("model load: {}", e)))?;
262
263 Ok(Self {
264 model,
265 tokenizer,
266 dim,
267 device,
268 })
269 }
270
271 fn embed_inner(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
272 use candle_core::Tensor;
273
274 let encoding = self
275 .tokenizer
276 .encode(text, true)
277 .map_err(|e| EmbeddingError::ApiError(format!("tokenize: {}", e)))?;
278
279 let ids = encoding.get_ids();
280 let type_ids = encoding.get_type_ids();
281 let len = ids.len();
282
283 let input_ids = Tensor::new(ids, &self.device)
284 .map_err(|e| EmbeddingError::ApiError(e.to_string()))?
285 .reshape((1, len))
286 .map_err(|e| EmbeddingError::ApiError(e.to_string()))?;
287 let token_type_ids = Tensor::new(type_ids, &self.device)
288 .map_err(|e| EmbeddingError::ApiError(e.to_string()))?
289 .reshape((1, len))
290 .map_err(|e| EmbeddingError::ApiError(e.to_string()))?;
291
292 let output = self
293 .model
294 .forward(&input_ids, &token_type_ids, None)
295 .map_err(|e| EmbeddingError::ApiError(format!("forward: {}", e)))?;
296
297 let pooled = output
299 .mean(1)
300 .map_err(|e| EmbeddingError::ApiError(format!("mean pool: {}", e)))?
301 .squeeze(0)
302 .map_err(|e| EmbeddingError::ApiError(format!("squeeze: {}", e)))?;
303
304 let mut embedding: Vec<f32> = pooled
305 .to_vec1()
306 .map_err(|e| EmbeddingError::ApiError(format!("to_vec: {}", e)))?;
307
308 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
310 if norm > 0.0 {
311 for val in &mut embedding {
312 *val /= norm;
313 }
314 }
315
316 Ok(embedding)
317 }
318}
319
320#[cfg(feature = "embeddings")]
321impl EmbeddingProvider for CandleEmbeddingProvider {
322 fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
323 self.embed_inner(text)
324 }
325
326 fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
327 texts.iter().map(|t| self.embed_inner(t)).collect()
328 }
329
330 fn dimension(&self) -> usize {
331 self.dim
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_cosine_similarity() {
341 let a = vec![1.0, 0.0, 0.0];
342 let b = vec![1.0, 0.0, 0.0];
343 assert!((cosine_similarity_scalar(&a, &b) - 1.0).abs() < 0.001);
344
345 let c = vec![0.0, 1.0, 0.0];
346 assert!((cosine_similarity_scalar(&a, &c) - 0.0).abs() < 0.001);
347
348 let d = vec![-1.0, 0.0, 0.0];
349 assert!((cosine_similarity_scalar(&a, &d) - -1.0).abs() < 0.001);
350 }
351
352 #[test]
353 fn test_mock_embeddings() {
354 let provider = MockEmbeddingProvider::new(128);
355 let embedding = provider.embed("test text").unwrap();
356
357 assert_eq!(embedding.len(), 128);
358
359 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
361 assert!((norm - 1.0).abs() < 0.001);
362 }
363
364 #[test]
365 fn test_embedding_cache() {
366 let mut cache = EmbeddingCache::new(3);
367 let embedding = vec![1.0, 0.0, 0.0];
368
369 cache.insert("test".to_string(), embedding.clone());
370 assert!(cache.contains("test"));
371 assert_eq!(cache.get("test"), Some(&embedding));
372 }
373}