1use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10use chrono::{DateTime, Utc};
11use rusqlite::Connection;
12use serde::{Deserialize, Serialize};
13use tracing::debug;
14
15use punch_types::{PunchError, PunchResult};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct Embedding {
24 pub id: String,
25 pub text: String,
26 pub vector: Vec<f32>,
27 pub metadata: HashMap<String, String>,
28 pub created_at: DateTime<Utc>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct EmbeddingConfig {
34 pub provider: EmbeddingProvider,
35 pub dimensions: usize,
36 pub batch_size: usize,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub enum EmbeddingProvider {
42 OpenAi { api_key: String, model: String },
44 Local { endpoint: String },
46 BuiltIn,
48}
49
50pub trait Embedder: Send + Sync {
56 fn embed(&self, text: &str) -> PunchResult<Vec<f32>>;
58
59 fn embed_batch(&self, texts: &[&str]) -> PunchResult<Vec<Vec<f32>>> {
61 texts.iter().map(|t| self.embed(t)).collect()
62 }
63
64 fn dimensions(&self) -> usize;
66}
67
68pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
78 assert_eq!(a.len(), b.len(), "vectors must have equal length");
79
80 let mut dot = 0.0_f32;
81 let mut norm_a = 0.0_f32;
82 let mut norm_b = 0.0_f32;
83
84 for (ai, bi) in a.iter().zip(b.iter()) {
85 dot += ai * bi;
86 norm_a += ai * ai;
87 norm_b += bi * bi;
88 }
89
90 let denom = norm_a.sqrt() * norm_b.sqrt();
91 if denom == 0.0 {
92 return 0.0;
93 }
94 dot / denom
95}
96
97pub fn top_k_similar<'a>(
100 query_vec: &[f32],
101 embeddings: &'a [Embedding],
102 k: usize,
103) -> Vec<(f32, &'a Embedding)> {
104 let mut scored: Vec<(f32, &Embedding)> = embeddings
105 .iter()
106 .map(|e| (cosine_similarity(query_vec, &e.vector), e))
107 .collect();
108 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
109 scored.truncate(k);
110 scored
111}
112
113pub struct BuiltInEmbedder {
121 vocab: HashMap<String, usize>,
123 idf: Vec<f32>,
125 dims: usize,
127}
128
129impl std::fmt::Debug for BuiltInEmbedder {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 f.debug_struct("BuiltInEmbedder")
132 .field("dims", &self.dims)
133 .field("vocab_size", &self.vocab.len())
134 .finish()
135 }
136}
137
138impl BuiltInEmbedder {
139 pub fn new() -> Self {
141 Self {
142 vocab: HashMap::new(),
143 idf: Vec::new(),
144 dims: 0,
145 }
146 }
147
148 pub fn fit(&mut self, documents: &[&str]) {
153 let total_docs = documents.len() as f32;
154 if total_docs == 0.0 {
155 self.vocab.clear();
156 self.idf.clear();
157 self.dims = 0;
158 return;
159 }
160
161 let mut doc_freq: HashMap<String, usize> = HashMap::new();
163 for doc in documents {
164 let unique_words: std::collections::HashSet<String> =
165 tokenize(doc).into_iter().collect();
166 for word in unique_words {
167 *doc_freq.entry(word).or_insert(0) += 1;
168 }
169 }
170
171 let mut terms: Vec<(String, usize)> = doc_freq.into_iter().collect();
174 terms.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
175 terms.truncate(1024);
176
177 self.vocab.clear();
178 self.idf = Vec::with_capacity(terms.len());
179 for (i, (term, df)) in terms.iter().enumerate() {
180 self.vocab.insert(term.clone(), i);
181 self.idf.push((total_docs / *df as f32).ln());
183 }
184 self.dims = self.vocab.len();
185 }
186
187 fn compute_tfidf(&self, text: &str) -> Vec<f32> {
189 if self.dims == 0 {
190 return Vec::new();
191 }
192
193 let tokens = tokenize(text);
194 let total_tokens = tokens.len() as f32;
195 if total_tokens == 0.0 {
196 return vec![0.0; self.dims];
197 }
198
199 let mut tf_counts: HashMap<&str, usize> = HashMap::new();
201 for t in &tokens {
202 *tf_counts.entry(t.as_str()).or_insert(0) += 1;
203 }
204
205 let mut vec = vec![0.0_f32; self.dims];
206 for (term, count) in &tf_counts {
207 if let Some(&idx) = self.vocab.get(*term) {
208 let tf = *count as f32 / total_tokens;
209 vec[idx] = tf * self.idf[idx];
210 }
211 }
212
213 l2_normalize(&mut vec);
214 vec
215 }
216}
217
218impl Default for BuiltInEmbedder {
219 fn default() -> Self {
220 Self::new()
221 }
222}
223
224impl Embedder for BuiltInEmbedder {
225 fn embed(&self, text: &str) -> PunchResult<Vec<f32>> {
226 Ok(self.compute_tfidf(text))
227 }
228
229 fn embed_batch(&self, texts: &[&str]) -> PunchResult<Vec<Vec<f32>>> {
230 Ok(texts.iter().map(|t| self.compute_tfidf(t)).collect())
231 }
232
233 fn dimensions(&self) -> usize {
234 self.dims
235 }
236}
237
238pub struct OpenAiEmbedder {
246 api_key: String,
247 model: String,
248 dimensions: usize,
249}
250
251impl std::fmt::Debug for OpenAiEmbedder {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 f.debug_struct("OpenAiEmbedder")
254 .field("model", &self.model)
255 .field("dimensions", &self.dimensions)
256 .finish()
257 }
258}
259
260impl OpenAiEmbedder {
261 pub fn new(api_key: String, model: String, dimensions: usize) -> Self {
265 Self {
266 api_key,
267 model,
268 dimensions,
269 }
270 }
271
272 pub fn build_request_body(&self, input: &[&str]) -> serde_json::Value {
274 if input.len() == 1 {
275 serde_json::json!({
276 "input": input[0],
277 "model": self.model,
278 })
279 } else {
280 serde_json::json!({
281 "input": input,
282 "model": self.model,
283 })
284 }
285 }
286
287 pub fn parse_response(body: &serde_json::Value) -> PunchResult<Vec<Vec<f32>>> {
289 let data = body
290 .get("data")
291 .and_then(|d| d.as_array())
292 .ok_or_else(|| PunchError::Memory("missing 'data' array in response".into()))?;
293
294 let mut results = Vec::with_capacity(data.len());
295 for item in data {
296 let embedding = item
297 .get("embedding")
298 .and_then(|e| e.as_array())
299 .ok_or_else(|| PunchError::Memory("missing 'embedding' in data item".into()))?;
300 let vec: Vec<f32> = embedding
301 .iter()
302 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
303 .collect();
304 results.push(vec);
305 }
306 Ok(results)
307 }
308}
309
310impl Embedder for OpenAiEmbedder {
311 fn embed(&self, text: &str) -> PunchResult<Vec<f32>> {
312 Err(PunchError::Memory(format!(
315 "OpenAI embedding requires async runtime; use embed_batch or call the API directly. \
316 model={}, key_len={}, text_len={}",
317 self.model,
318 self.api_key.len(),
319 text.len()
320 )))
321 }
322
323 fn embed_batch(&self, texts: &[&str]) -> PunchResult<Vec<Vec<f32>>> {
324 Err(PunchError::Memory(format!(
325 "OpenAI embedding requires async runtime; call the API directly. \
326 model={}, batch_size={}",
327 self.model,
328 texts.len()
329 )))
330 }
331
332 fn dimensions(&self) -> usize {
333 self.dimensions
334 }
335}
336
337pub struct EmbeddingStore {
343 conn: Arc<Mutex<Connection>>,
344 embedder: Box<dyn Embedder>,
345}
346
347impl EmbeddingStore {
348 pub fn new(conn: Arc<Mutex<Connection>>, embedder: Box<dyn Embedder>) -> PunchResult<Self> {
351 {
352 let c = conn
353 .lock()
354 .map_err(|e| PunchError::Memory(format!("lock failed: {e}")))?;
355 c.execute_batch(
356 "CREATE TABLE IF NOT EXISTS embeddings (
357 id TEXT PRIMARY KEY,
358 text TEXT NOT NULL,
359 vector BLOB NOT NULL,
360 metadata TEXT NOT NULL DEFAULT '{}',
361 created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
362 );",
363 )
364 .map_err(|e| PunchError::Memory(format!("failed to create embeddings table: {e}")))?;
365 }
366 Ok(Self { conn, embedder })
367 }
368
369 pub fn store(&self, text: &str, metadata: HashMap<String, String>) -> PunchResult<String> {
371 let vector = self.embedder.embed(text)?;
372 let id = uuid::Uuid::new_v4().to_string();
373 let blob = vec_to_bytes(&vector);
374 let meta_json = serde_json::to_string(&metadata)
375 .map_err(|e| PunchError::Memory(format!("metadata serialization failed: {e}")))?;
376 let now = Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
377
378 let conn = self
379 .conn
380 .lock()
381 .map_err(|e| PunchError::Memory(format!("lock failed: {e}")))?;
382 conn.execute(
383 "INSERT INTO embeddings (id, text, vector, metadata, created_at)
384 VALUES (?1, ?2, ?3, ?4, ?5)",
385 rusqlite::params![id, text, blob, meta_json, now],
386 )
387 .map_err(|e| PunchError::Memory(format!("failed to store embedding: {e}")))?;
388
389 debug!(id = %id, text_len = text.len(), "embedding stored");
390 Ok(id)
391 }
392
393 pub fn search(&self, query: &str, k: usize) -> PunchResult<Vec<(f32, Embedding)>> {
395 let query_vec = self.embedder.embed(query)?;
396 let all = self.load_all()?;
397 let results = top_k_similar(&query_vec, &all, k);
398 Ok(results.into_iter().map(|(s, e)| (s, e.clone())).collect())
399 }
400
401 pub fn delete(&self, id: &str) -> PunchResult<()> {
403 let conn = self
404 .conn
405 .lock()
406 .map_err(|e| PunchError::Memory(format!("lock failed: {e}")))?;
407 conn.execute("DELETE FROM embeddings WHERE id = ?1", [id])
408 .map_err(|e| PunchError::Memory(format!("failed to delete embedding: {e}")))?;
409 debug!(id = %id, "embedding deleted");
410 Ok(())
411 }
412
413 pub fn count(&self) -> PunchResult<usize> {
415 let conn = self
416 .conn
417 .lock()
418 .map_err(|e| PunchError::Memory(format!("lock failed: {e}")))?;
419 let count: i64 = conn
420 .query_row("SELECT COUNT(*) FROM embeddings", [], |row| row.get(0))
421 .map_err(|e| PunchError::Memory(format!("failed to count embeddings: {e}")))?;
422 Ok(count as usize)
423 }
424
425 pub fn rebuild_index(&self) -> PunchResult<usize> {
428 let all = self.load_all()?;
429 let conn = self
430 .conn
431 .lock()
432 .map_err(|e| PunchError::Memory(format!("lock failed: {e}")))?;
433
434 let mut count = 0usize;
435 for emb in &all {
436 let new_vec = self.embedder.embed(&emb.text)?;
437 let blob = vec_to_bytes(&new_vec);
438 conn.execute(
439 "UPDATE embeddings SET vector = ?1 WHERE id = ?2",
440 rusqlite::params![blob, emb.id],
441 )
442 .map_err(|e| PunchError::Memory(format!("failed to update embedding: {e}")))?;
443 count += 1;
444 }
445 debug!(count, "embedding index rebuilt");
446 Ok(count)
447 }
448
449 pub fn embedder(&self) -> &dyn Embedder {
451 self.embedder.as_ref()
452 }
453
454 fn load_all(&self) -> PunchResult<Vec<Embedding>> {
459 let conn = self
460 .conn
461 .lock()
462 .map_err(|e| PunchError::Memory(format!("lock failed: {e}")))?;
463
464 let mut stmt = conn
465 .prepare("SELECT id, text, vector, metadata, created_at FROM embeddings")
466 .map_err(|e| PunchError::Memory(format!("failed to query embeddings: {e}")))?;
467
468 let rows = stmt
469 .query_map([], |row| {
470 let id: String = row.get(0)?;
471 let text: String = row.get(1)?;
472 let blob: Vec<u8> = row.get(2)?;
473 let meta_json: String = row.get(3)?;
474 let created_at: String = row.get(4)?;
475 Ok((id, text, blob, meta_json, created_at))
476 })
477 .map_err(|e| PunchError::Memory(format!("failed to query embeddings: {e}")))?;
478
479 let mut embeddings = Vec::new();
480 for row in rows {
481 let (id, text, blob, meta_json, created_at_str) =
482 row.map_err(|e| PunchError::Memory(format!("failed to read row: {e}")))?;
483
484 let vector = bytes_to_vec(&blob);
485 let metadata: HashMap<String, String> =
486 serde_json::from_str(&meta_json).unwrap_or_default();
487 let created_at = parse_ts(&created_at_str)?;
488
489 embeddings.push(Embedding {
490 id,
491 text,
492 vector,
493 metadata,
494 created_at,
495 });
496 }
497 Ok(embeddings)
498 }
499}
500
501pub fn vec_to_bytes(vec: &[f32]) -> Vec<u8> {
507 let mut bytes = Vec::with_capacity(vec.len() * 4);
508 for &v in vec {
509 bytes.extend_from_slice(&v.to_le_bytes());
510 }
511 bytes
512}
513
514pub fn bytes_to_vec(bytes: &[u8]) -> Vec<f32> {
516 bytes
517 .chunks_exact(4)
518 .map(|chunk| {
519 let arr: [u8; 4] = chunk.try_into().expect("chunk is 4 bytes");
520 f32::from_le_bytes(arr)
521 })
522 .collect()
523}
524
525fn tokenize(text: &str) -> Vec<String> {
532 text.to_lowercase()
533 .split_whitespace()
534 .map(|w| {
535 w.chars()
536 .filter(|c| c.is_alphanumeric())
537 .collect::<String>()
538 })
539 .filter(|w| !w.is_empty())
540 .collect()
541}
542
543fn l2_normalize(vec: &mut [f32]) {
546 let norm: f32 = vec.iter().map(|v| v * v).sum::<f32>().sqrt();
547 if norm > 0.0 {
548 for v in vec.iter_mut() {
549 *v /= norm;
550 }
551 }
552}
553
554fn parse_ts(s: &str) -> PunchResult<DateTime<Utc>> {
555 DateTime::parse_from_rfc3339(s)
556 .map(|dt| dt.with_timezone(&Utc))
557 .or_else(|_| {
558 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ").map(|ndt| ndt.and_utc())
559 })
560 .map_err(|e| PunchError::Memory(format!("invalid timestamp '{s}': {e}")))
561}
562
563#[cfg(test)]
568mod tests {
569 use super::*;
570
571 #[test]
574 fn test_cosine_identical_vectors() {
575 let v = vec![1.0, 2.0, 3.0];
576 let sim = cosine_similarity(&v, &v);
577 assert!(
578 (sim - 1.0).abs() < 1e-6,
579 "identical vectors should have similarity 1.0"
580 );
581 }
582
583 #[test]
584 fn test_cosine_orthogonal_vectors() {
585 let a = vec![1.0, 0.0];
586 let b = vec![0.0, 1.0];
587 let sim = cosine_similarity(&a, &b);
588 assert!(
589 sim.abs() < 1e-6,
590 "orthogonal vectors should have similarity ~0.0"
591 );
592 }
593
594 #[test]
595 fn test_cosine_opposite_vectors() {
596 let a = vec![1.0, 2.0, 3.0];
597 let b = vec![-1.0, -2.0, -3.0];
598 let sim = cosine_similarity(&a, &b);
599 assert!(
600 (sim + 1.0).abs() < 1e-6,
601 "opposite vectors should have similarity -1.0"
602 );
603 }
604
605 #[test]
606 fn test_cosine_zero_vector() {
607 let a = vec![1.0, 2.0];
608 let b = vec![0.0, 0.0];
609 let sim = cosine_similarity(&a, &b);
610 assert!(sim.abs() < 1e-6, "zero vector should yield 0.0");
611 }
612
613 #[test]
616 fn test_builtin_fit_and_embed_nonzero() {
617 let mut embedder = BuiltInEmbedder::new();
618 embedder.fit(&["the cat sat on the mat", "the dog chased the ball"]);
619
620 let vec = embedder.embed("cat sat on mat").unwrap();
621 assert!(!vec.is_empty());
622 assert!(vec.iter().any(|&v| v != 0.0), "vector should be non-zero");
623 }
624
625 #[test]
626 fn test_builtin_similar_texts_higher_similarity() {
627 let mut embedder = BuiltInEmbedder::new();
628 embedder.fit(&[
629 "rust programming language",
630 "python programming language",
631 "cooking recipes for dinner",
632 "baking bread at home",
633 ]);
634
635 let v_rust = embedder.embed("rust programming").unwrap();
636 let v_python = embedder.embed("python programming").unwrap();
637 let v_cooking = embedder.embed("cooking dinner recipes").unwrap();
638
639 let sim_related = cosine_similarity(&v_rust, &v_python);
640 let sim_unrelated = cosine_similarity(&v_rust, &v_cooking);
641
642 assert!(
643 sim_related > sim_unrelated,
644 "related texts should have higher similarity ({sim_related} > {sim_unrelated})"
645 );
646 }
647
648 #[test]
649 fn test_builtin_l2_normalization() {
650 let mut embedder = BuiltInEmbedder::new();
651 embedder.fit(&["hello world", "foo bar baz"]);
652
653 let vec = embedder.embed("hello world foo").unwrap();
654 if !vec.is_empty() && vec.iter().any(|&v| v != 0.0) {
655 let norm: f32 = vec.iter().map(|v| v * v).sum::<f32>().sqrt();
656 assert!(
657 (norm - 1.0).abs() < 1e-5,
658 "vector should be L2-normalized, got norm={norm}"
659 );
660 }
661 }
662
663 #[test]
664 fn test_builtin_empty_corpus() {
665 let mut embedder = BuiltInEmbedder::new();
666 embedder.fit(&[]);
667 let vec = embedder.embed("anything").unwrap();
668 assert!(vec.is_empty(), "empty corpus should produce empty vector");
669 assert_eq!(embedder.dimensions(), 0);
670 }
671
672 #[test]
673 fn test_builtin_single_document_corpus() {
674 let mut embedder = BuiltInEmbedder::new();
675 embedder.fit(&["the only document in the corpus"]);
676
677 let vec = embedder.embed("the only document").unwrap();
678 assert!(!vec.is_empty());
679 assert!(
683 vec.iter().all(|&v| v == 0.0),
684 "single-doc corpus yields zero IDF, so vector is zero"
685 );
686 }
687
688 #[test]
689 fn test_builtin_batch_embedding() {
690 let mut embedder = BuiltInEmbedder::new();
691 embedder.fit(&["hello world", "foo bar"]);
692
693 let batch = embedder.embed_batch(&["hello", "foo"]).unwrap();
694 assert_eq!(batch.len(), 2);
695 assert_eq!(batch[0].len(), embedder.dimensions());
696 assert_eq!(batch[1].len(), embedder.dimensions());
697 }
698
699 #[test]
702 fn test_vec_bytes_roundtrip() {
703 let original = vec![1.0_f32, -2.5, 3.14, 0.0, f32::MAX, f32::MIN];
704 let bytes = vec_to_bytes(&original);
705 let restored = bytes_to_vec(&bytes);
706 assert_eq!(original, restored);
707 }
708
709 #[test]
710 fn test_vec_bytes_empty() {
711 let empty: Vec<f32> = Vec::new();
712 let bytes = vec_to_bytes(&empty);
713 assert!(bytes.is_empty());
714 let restored = bytes_to_vec(&bytes);
715 assert!(restored.is_empty());
716 }
717
718 #[test]
721 fn test_embedding_config_serde() {
722 let config = EmbeddingConfig {
723 provider: EmbeddingProvider::BuiltIn,
724 dimensions: 1024,
725 batch_size: 32,
726 };
727 let json = serde_json::to_string(&config).unwrap();
728 let restored: EmbeddingConfig = serde_json::from_str(&json).unwrap();
729 assert_eq!(restored.dimensions, 1024);
730 assert_eq!(restored.batch_size, 32);
731 }
732
733 #[test]
734 fn test_embedding_config_openai_serde() {
735 let config = EmbeddingConfig {
736 provider: EmbeddingProvider::OpenAi {
737 api_key: "sk-test".into(),
738 model: "text-embedding-3-small".into(),
739 },
740 dimensions: 1536,
741 batch_size: 100,
742 };
743 let json = serde_json::to_string(&config).unwrap();
744 let restored: EmbeddingConfig = serde_json::from_str(&json).unwrap();
745 assert_eq!(restored.dimensions, 1536);
746 }
747
748 #[test]
751 fn test_openai_request_single() {
752 let embedder =
753 OpenAiEmbedder::new("sk-test-key".into(), "text-embedding-3-small".into(), 1536);
754 let body = embedder.build_request_body(&["hello world"]);
755 assert_eq!(body["input"], "hello world");
756 assert_eq!(body["model"], "text-embedding-3-small");
757 }
758
759 #[test]
760 fn test_openai_request_batch() {
761 let embedder =
762 OpenAiEmbedder::new("sk-test-key".into(), "text-embedding-3-small".into(), 1536);
763 let body = embedder.build_request_body(&["hello", "world"]);
764 let input = body["input"].as_array().unwrap();
765 assert_eq!(input.len(), 2);
766 assert_eq!(input[0], "hello");
767 assert_eq!(input[1], "world");
768 }
769
770 #[test]
771 fn test_openai_parse_response() {
772 let response = serde_json::json!({
773 "data": [
774 {"embedding": [0.1, 0.2, 0.3], "index": 0},
775 {"embedding": [0.4, 0.5, 0.6], "index": 1}
776 ]
777 });
778 let vecs = OpenAiEmbedder::parse_response(&response).unwrap();
779 assert_eq!(vecs.len(), 2);
780 assert_eq!(vecs[0], vec![0.1_f32, 0.2, 0.3]);
781 assert_eq!(vecs[1], vec![0.4_f32, 0.5, 0.6]);
782 }
783
784 fn test_store() -> EmbeddingStore {
787 let conn = Connection::open_in_memory().unwrap();
788 conn.execute_batch("PRAGMA foreign_keys = ON;").unwrap();
789 let arc = Arc::new(Mutex::new(conn));
790 let mut embedder = BuiltInEmbedder::new();
791 embedder.fit(&[
792 "rust programming language systems",
793 "python scripting language web",
794 "cooking recipes kitchen food",
795 "machine learning neural networks",
796 ]);
797 EmbeddingStore::new(arc, Box::new(embedder)).unwrap()
798 }
799
800 #[test]
801 fn test_store_and_search() {
802 let store = test_store();
803 store
804 .store("rust systems programming", HashMap::new())
805 .unwrap();
806 store
807 .store("python web development", HashMap::new())
808 .unwrap();
809 store
810 .store("cooking recipes for pasta", HashMap::new())
811 .unwrap();
812
813 let results = store.search("rust programming", 2).unwrap();
814 assert!(!results.is_empty());
815 assert!(
817 results[0].1.text.contains("rust"),
818 "top result should match 'rust', got: {}",
819 results[0].1.text
820 );
821 }
822
823 #[test]
824 fn test_store_top_k_count() {
825 let store = test_store();
826 store.store("alpha", HashMap::new()).unwrap();
827 store.store("beta", HashMap::new()).unwrap();
828 store.store("gamma", HashMap::new()).unwrap();
829 store.store("delta", HashMap::new()).unwrap();
830
831 let results = store.search("alpha", 2).unwrap();
832 assert_eq!(results.len(), 2, "should return exactly k results");
833 }
834
835 #[test]
836 fn test_store_delete() {
837 let store = test_store();
838 let id = store.store("to be deleted", HashMap::new()).unwrap();
839 assert_eq!(store.count().unwrap(), 1);
840
841 store.delete(&id).unwrap();
842 assert_eq!(store.count().unwrap(), 0);
843 }
844
845 #[test]
846 fn test_store_count() {
847 let store = test_store();
848 assert_eq!(store.count().unwrap(), 0);
849
850 store.store("one", HashMap::new()).unwrap();
851 assert_eq!(store.count().unwrap(), 1);
852
853 store.store("two", HashMap::new()).unwrap();
854 assert_eq!(store.count().unwrap(), 2);
855 }
856
857 #[test]
858 fn test_store_rebuild_index() {
859 let conn = Connection::open_in_memory().unwrap();
860 conn.execute_batch("PRAGMA foreign_keys = ON;").unwrap();
861 let arc = Arc::new(Mutex::new(conn));
862
863 let mut embedder = BuiltInEmbedder::new();
864 embedder.fit(&["hello world", "foo bar"]);
865 let store = EmbeddingStore::new(Arc::clone(&arc), Box::new(embedder)).unwrap();
866
867 store.store("hello world test", HashMap::new()).unwrap();
868 store.store("foo bar baz", HashMap::new()).unwrap();
869 assert_eq!(store.count().unwrap(), 2);
870
871 let rebuilt = store.rebuild_index().unwrap();
872 assert_eq!(rebuilt, 2);
873 assert_eq!(store.count().unwrap(), 2);
874 }
875
876 #[test]
879 fn test_cosine_similarity_single_dimension() {
880 let a = vec![3.0];
881 let b = vec![5.0];
882 let sim = cosine_similarity(&a, &b);
883 assert!(
884 (sim - 1.0).abs() < 1e-6,
885 "same direction in 1D should be 1.0"
886 );
887 }
888
889 #[test]
890 fn test_cosine_similarity_negative_values() {
891 let a = vec![-1.0, -2.0];
892 let b = vec![-3.0, -6.0];
893 let sim = cosine_similarity(&a, &b);
894 assert!(
895 (sim - 1.0).abs() < 1e-6,
896 "parallel negative vectors are similar"
897 );
898 }
899
900 #[test]
901 fn test_builtin_embedder_default() {
902 let embedder = BuiltInEmbedder::default();
903 assert_eq!(embedder.dimensions(), 0);
904 }
905
906 #[test]
907 fn test_builtin_embed_empty_text() {
908 let mut embedder = BuiltInEmbedder::new();
909 embedder.fit(&["hello world", "foo bar"]);
910 let vec = embedder.embed("").unwrap();
911 assert_eq!(vec.len(), embedder.dimensions());
912 assert!(
913 vec.iter().all(|&v| v == 0.0),
914 "empty text yields zero vector"
915 );
916 }
917
918 #[test]
919 fn test_builtin_dimensions_matches_vocab() {
920 let mut embedder = BuiltInEmbedder::new();
921 embedder.fit(&["alpha beta gamma", "delta epsilon"]);
922 assert!(embedder.dimensions() > 0);
923 let vec = embedder.embed("alpha").unwrap();
924 assert_eq!(vec.len(), embedder.dimensions());
925 }
926
927 #[test]
928 fn test_openai_embedder_dimensions() {
929 let embedder = OpenAiEmbedder::new("key".into(), "model".into(), 768);
930 assert_eq!(embedder.dimensions(), 768);
931 }
932
933 #[test]
934 fn test_openai_embed_returns_error() {
935 let embedder = OpenAiEmbedder::new("key".into(), "model".into(), 768);
936 assert!(embedder.embed("test").is_err());
937 }
938
939 #[test]
940 fn test_openai_embed_batch_returns_error() {
941 let embedder = OpenAiEmbedder::new("key".into(), "model".into(), 768);
942 assert!(embedder.embed_batch(&["a", "b"]).is_err());
943 }
944
945 #[test]
946 fn test_openai_parse_response_missing_data() {
947 let resp = serde_json::json!({"no_data": true});
948 assert!(OpenAiEmbedder::parse_response(&resp).is_err());
949 }
950
951 #[test]
952 fn test_vec_bytes_single_value() {
953 let original = vec![42.0_f32];
954 let bytes = vec_to_bytes(&original);
955 assert_eq!(bytes.len(), 4);
956 let restored = bytes_to_vec(&bytes);
957 assert_eq!(original, restored);
958 }
959
960 #[test]
961 fn test_store_with_metadata() {
962 let store = test_store();
963 let mut meta = HashMap::new();
964 meta.insert("source".to_string(), "test".to_string());
965 let id = store.store("text with metadata", meta).unwrap();
966 assert!(!id.is_empty());
967 assert_eq!(store.count().unwrap(), 1);
968 }
969
970 #[test]
971 fn test_store_delete_nonexistent() {
972 let store = test_store();
973 store.delete("nonexistent-id").unwrap();
975 assert_eq!(store.count().unwrap(), 0);
976 }
977
978 #[test]
979 fn test_top_k_similar_empty_list() {
980 let query = vec![1.0, 0.0];
981 let results = top_k_similar(&query, &[], 5);
982 assert!(results.is_empty());
983 }
984
985 #[test]
986 fn test_top_k_similar_k_larger_than_list() {
987 let embeddings = vec![Embedding {
988 id: "only".into(),
989 text: "one".into(),
990 vector: vec![1.0, 0.0],
991 metadata: HashMap::new(),
992 created_at: Utc::now(),
993 }];
994 let query = vec![1.0, 0.0];
995 let results = top_k_similar(&query, &embeddings, 10);
996 assert_eq!(results.len(), 1);
997 }
998
999 #[test]
1000 fn test_top_k_similar_ordering() {
1001 let embeddings = vec![
1002 Embedding {
1003 id: "a".into(),
1004 text: "close".into(),
1005 vector: vec![0.9, 0.1],
1006 metadata: HashMap::new(),
1007 created_at: Utc::now(),
1008 },
1009 Embedding {
1010 id: "b".into(),
1011 text: "far".into(),
1012 vector: vec![0.0, 1.0],
1013 metadata: HashMap::new(),
1014 created_at: Utc::now(),
1015 },
1016 ];
1017 let query = vec![1.0, 0.0];
1018 let results = top_k_similar(&query, &embeddings, 2);
1019 assert_eq!(results.len(), 2);
1020 assert_eq!(results[0].1.id, "a", "closer vector should come first");
1021 assert!(results[0].0 > results[1].0, "scores should be descending");
1022 }
1023}