Skip to main content

punch_memory/
embeddings.rs

1//! Vector embedding support for semantic recall in the Punch memory system.
2//!
3//! Provides a [`BuiltInEmbedder`] (TF-IDF based, no external deps), an
4//! [`OpenAiEmbedder`] (calls the OpenAI embeddings API), and an
5//! [`EmbeddingStore`] backed by SQLite for persistence and similarity search.
6
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10use chrono::{DateTime, Utc};
11use rusqlite::Connection;
12use serde::{Deserialize, Serialize};
13use tracing::debug;
14
15use punch_types::{PunchError, PunchResult};
16
17// ---------------------------------------------------------------------------
18// Types
19// ---------------------------------------------------------------------------
20
21/// A vector embedding with its source text and metadata.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct Embedding {
24    pub id: String,
25    pub text: String,
26    pub vector: Vec<f32>,
27    pub metadata: HashMap<String, String>,
28    pub created_at: DateTime<Utc>,
29}
30
31/// Configuration for the embedding engine.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct EmbeddingConfig {
34    pub provider: EmbeddingProvider,
35    pub dimensions: usize,
36    pub batch_size: usize,
37}
38
39/// Which backend to use for computing embeddings.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub enum EmbeddingProvider {
42    /// OpenAI text-embedding-3-small/large.
43    OpenAi { api_key: String, model: String },
44    /// Local sentence-transformers via HTTP (e.g., running on localhost).
45    Local { endpoint: String },
46    /// Simple TF-IDF bag-of-words (no external dependency, works offline).
47    BuiltIn,
48}
49
50// ---------------------------------------------------------------------------
51// Embedder trait
52// ---------------------------------------------------------------------------
53
54/// Trait for computing vector embeddings from text.
55pub trait Embedder: Send + Sync {
56    /// Compute an embedding vector for a single piece of text.
57    fn embed(&self, text: &str) -> PunchResult<Vec<f32>>;
58
59    /// Compute embedding vectors for a batch of texts.
60    fn embed_batch(&self, texts: &[&str]) -> PunchResult<Vec<Vec<f32>>> {
61        texts.iter().map(|t| self.embed(t)).collect()
62    }
63
64    /// The dimensionality of vectors produced by this embedder.
65    fn dimensions(&self) -> usize;
66}
67
68// ---------------------------------------------------------------------------
69// Cosine similarity
70// ---------------------------------------------------------------------------
71
72/// Compute the cosine similarity between two vectors.
73///
74/// Returns a value in \[-1.0, 1.0\]. Identical directions yield 1.0,
75/// orthogonal vectors yield 0.0, and opposite directions yield -1.0.
76/// Returns 0.0 if either vector has zero magnitude.
77pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
78    assert_eq!(a.len(), b.len(), "vectors must have equal length");
79
80    let mut dot = 0.0_f32;
81    let mut norm_a = 0.0_f32;
82    let mut norm_b = 0.0_f32;
83
84    for (ai, bi) in a.iter().zip(b.iter()) {
85        dot += ai * bi;
86        norm_a += ai * ai;
87        norm_b += bi * bi;
88    }
89
90    let denom = norm_a.sqrt() * norm_b.sqrt();
91    if denom == 0.0 {
92        return 0.0;
93    }
94    dot / denom
95}
96
97/// Return the top-k most similar embeddings to `query_vec`, sorted by
98/// descending similarity.
99pub fn top_k_similar<'a>(
100    query_vec: &[f32],
101    embeddings: &'a [Embedding],
102    k: usize,
103) -> Vec<(f32, &'a Embedding)> {
104    let mut scored: Vec<(f32, &Embedding)> = embeddings
105        .iter()
106        .map(|e| (cosine_similarity(query_vec, &e.vector), e))
107        .collect();
108    scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
109    scored.truncate(k);
110    scored
111}
112
113// ---------------------------------------------------------------------------
114// Built-in TF-IDF embedder
115// ---------------------------------------------------------------------------
116
117/// A TF-IDF vectorizer that works entirely offline with no external
118/// dependencies. Call [`BuiltInEmbedder::fit`] with a corpus to build the
119/// vocabulary, then use [`Embedder::embed`] to compute vectors.
120pub struct BuiltInEmbedder {
121    /// Ordered vocabulary (word → index).
122    vocab: HashMap<String, usize>,
123    /// IDF weight for each vocabulary term (same indexing as `vocab`).
124    idf: Vec<f32>,
125    /// Number of dimensions (= vocabulary size, capped at 1024).
126    dims: usize,
127}
128
129impl std::fmt::Debug for BuiltInEmbedder {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        f.debug_struct("BuiltInEmbedder")
132            .field("dims", &self.dims)
133            .field("vocab_size", &self.vocab.len())
134            .finish()
135    }
136}
137
138impl BuiltInEmbedder {
139    /// Create an empty embedder. You must call [`Self::fit`] before embedding.
140    pub fn new() -> Self {
141        Self {
142            vocab: HashMap::new(),
143            idf: Vec::new(),
144            dims: 0,
145        }
146    }
147
148    /// Build the vocabulary and IDF weights from a corpus of documents.
149    ///
150    /// The vocabulary is capped at 1024 terms. Terms are selected by document
151    /// frequency (the most widely occurring terms across documents come first).
152    pub fn fit(&mut self, documents: &[&str]) {
153        let total_docs = documents.len() as f32;
154        if total_docs == 0.0 {
155            self.vocab.clear();
156            self.idf.clear();
157            self.dims = 0;
158            return;
159        }
160
161        // Collect document frequency for each term.
162        let mut doc_freq: HashMap<String, usize> = HashMap::new();
163        for doc in documents {
164            let unique_words: std::collections::HashSet<String> =
165                tokenize(doc).into_iter().collect();
166            for word in unique_words {
167                *doc_freq.entry(word).or_insert(0) += 1;
168            }
169        }
170
171        // Sort terms by document frequency descending, then alphabetically for
172        // determinism, and cap at 1024.
173        let mut terms: Vec<(String, usize)> = doc_freq.into_iter().collect();
174        terms.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
175        terms.truncate(1024);
176
177        self.vocab.clear();
178        self.idf = Vec::with_capacity(terms.len());
179        for (i, (term, df)) in terms.iter().enumerate() {
180            self.vocab.insert(term.clone(), i);
181            // Standard IDF: log(N / df). Use ln for smoothness.
182            self.idf.push((total_docs / *df as f32).ln());
183        }
184        self.dims = self.vocab.len();
185    }
186
187    /// Tokenize text and compute the TF-IDF vector, then L2-normalize it.
188    fn compute_tfidf(&self, text: &str) -> Vec<f32> {
189        if self.dims == 0 {
190            return Vec::new();
191        }
192
193        let tokens = tokenize(text);
194        let total_tokens = tokens.len() as f32;
195        if total_tokens == 0.0 {
196            return vec![0.0; self.dims];
197        }
198
199        // Term frequency counts.
200        let mut tf_counts: HashMap<&str, usize> = HashMap::new();
201        for t in &tokens {
202            *tf_counts.entry(t.as_str()).or_insert(0) += 1;
203        }
204
205        let mut vec = vec![0.0_f32; self.dims];
206        for (term, count) in &tf_counts {
207            if let Some(&idx) = self.vocab.get(*term) {
208                let tf = *count as f32 / total_tokens;
209                vec[idx] = tf * self.idf[idx];
210            }
211        }
212
213        l2_normalize(&mut vec);
214        vec
215    }
216}
217
218impl Default for BuiltInEmbedder {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224impl Embedder for BuiltInEmbedder {
225    fn embed(&self, text: &str) -> PunchResult<Vec<f32>> {
226        Ok(self.compute_tfidf(text))
227    }
228
229    fn embed_batch(&self, texts: &[&str]) -> PunchResult<Vec<Vec<f32>>> {
230        Ok(texts.iter().map(|t| self.compute_tfidf(t)).collect())
231    }
232
233    fn dimensions(&self) -> usize {
234        self.dims
235    }
236}
237
238// ---------------------------------------------------------------------------
239// OpenAI embedder
240// ---------------------------------------------------------------------------
241
242/// An embedder that calls the OpenAI embeddings API.
243///
244/// Requires the `reqwest` crate (already a workspace dependency).
245pub struct OpenAiEmbedder {
246    api_key: String,
247    model: String,
248    dimensions: usize,
249}
250
251impl std::fmt::Debug for OpenAiEmbedder {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        f.debug_struct("OpenAiEmbedder")
254            .field("model", &self.model)
255            .field("dimensions", &self.dimensions)
256            .finish()
257    }
258}
259
260impl OpenAiEmbedder {
261    /// Create a new OpenAI embedder.
262    ///
263    /// `model` should be something like `"text-embedding-3-small"`.
264    pub fn new(api_key: String, model: String, dimensions: usize) -> Self {
265        Self {
266            api_key,
267            model,
268            dimensions,
269        }
270    }
271
272    /// Build the JSON request body for the OpenAI embeddings endpoint.
273    pub fn build_request_body(&self, input: &[&str]) -> serde_json::Value {
274        if input.len() == 1 {
275            serde_json::json!({
276                "input": input[0],
277                "model": self.model,
278            })
279        } else {
280            serde_json::json!({
281                "input": input,
282                "model": self.model,
283            })
284        }
285    }
286
287    /// Parse the embedding vectors from an OpenAI API response.
288    pub fn parse_response(body: &serde_json::Value) -> PunchResult<Vec<Vec<f32>>> {
289        let data = body
290            .get("data")
291            .and_then(|d| d.as_array())
292            .ok_or_else(|| PunchError::Memory("missing 'data' array in response".into()))?;
293
294        let mut results = Vec::with_capacity(data.len());
295        for item in data {
296            let embedding = item
297                .get("embedding")
298                .and_then(|e| e.as_array())
299                .ok_or_else(|| PunchError::Memory("missing 'embedding' in data item".into()))?;
300            let vec: Vec<f32> = embedding
301                .iter()
302                .map(|v| v.as_f64().unwrap_or(0.0) as f32)
303                .collect();
304            results.push(vec);
305        }
306        Ok(results)
307    }
308}
309
310impl Embedder for OpenAiEmbedder {
311    fn embed(&self, text: &str) -> PunchResult<Vec<f32>> {
312        // In a real implementation this would use reqwest to call the API.
313        // For now, we return an error indicating that the runtime needs async.
314        Err(PunchError::Memory(format!(
315            "OpenAI embedding requires async runtime; use embed_batch or call the API directly. \
316             model={}, key_len={}, text_len={}",
317            self.model,
318            self.api_key.len(),
319            text.len()
320        )))
321    }
322
323    fn embed_batch(&self, texts: &[&str]) -> PunchResult<Vec<Vec<f32>>> {
324        Err(PunchError::Memory(format!(
325            "OpenAI embedding requires async runtime; call the API directly. \
326             model={}, batch_size={}",
327            self.model,
328            texts.len()
329        )))
330    }
331
332    fn dimensions(&self) -> usize {
333        self.dimensions
334    }
335}
336
337// ---------------------------------------------------------------------------
338// EmbeddingStore (SQLite-backed)
339// ---------------------------------------------------------------------------
340
341/// Persistent store for embeddings, backed by SQLite.
342pub struct EmbeddingStore {
343    conn: Arc<Mutex<Connection>>,
344    embedder: Box<dyn Embedder>,
345}
346
347impl EmbeddingStore {
348    /// Create a new embedding store, creating the `embeddings` table if it
349    /// does not already exist.
350    pub fn new(conn: Arc<Mutex<Connection>>, embedder: Box<dyn Embedder>) -> PunchResult<Self> {
351        {
352            let c = conn
353                .lock()
354                .map_err(|e| PunchError::Memory(format!("lock failed: {e}")))?;
355            c.execute_batch(
356                "CREATE TABLE IF NOT EXISTS embeddings (
357                    id         TEXT PRIMARY KEY,
358                    text       TEXT NOT NULL,
359                    vector     BLOB NOT NULL,
360                    metadata   TEXT NOT NULL DEFAULT '{}',
361                    created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
362                );",
363            )
364            .map_err(|e| PunchError::Memory(format!("failed to create embeddings table: {e}")))?;
365        }
366        Ok(Self { conn, embedder })
367    }
368
369    /// Store text with its embedding vector.
370    pub fn store(&self, text: &str, metadata: HashMap<String, String>) -> PunchResult<String> {
371        let vector = self.embedder.embed(text)?;
372        let id = uuid::Uuid::new_v4().to_string();
373        let blob = vec_to_bytes(&vector);
374        let meta_json = serde_json::to_string(&metadata)
375            .map_err(|e| PunchError::Memory(format!("metadata serialization failed: {e}")))?;
376        let now = Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
377
378        let conn = self
379            .conn
380            .lock()
381            .map_err(|e| PunchError::Memory(format!("lock failed: {e}")))?;
382        conn.execute(
383            "INSERT INTO embeddings (id, text, vector, metadata, created_at)
384             VALUES (?1, ?2, ?3, ?4, ?5)",
385            rusqlite::params![id, text, blob, meta_json, now],
386        )
387        .map_err(|e| PunchError::Memory(format!("failed to store embedding: {e}")))?;
388
389        debug!(id = %id, text_len = text.len(), "embedding stored");
390        Ok(id)
391    }
392
393    /// Search for the top-k most similar embeddings to `query`.
394    pub fn search(&self, query: &str, k: usize) -> PunchResult<Vec<(f32, Embedding)>> {
395        let query_vec = self.embedder.embed(query)?;
396        let all = self.load_all()?;
397        let results = top_k_similar(&query_vec, &all, k);
398        Ok(results.into_iter().map(|(s, e)| (s, e.clone())).collect())
399    }
400
401    /// Delete an embedding by ID.
402    pub fn delete(&self, id: &str) -> PunchResult<()> {
403        let conn = self
404            .conn
405            .lock()
406            .map_err(|e| PunchError::Memory(format!("lock failed: {e}")))?;
407        conn.execute("DELETE FROM embeddings WHERE id = ?1", [id])
408            .map_err(|e| PunchError::Memory(format!("failed to delete embedding: {e}")))?;
409        debug!(id = %id, "embedding deleted");
410        Ok(())
411    }
412
413    /// Return the total number of stored embeddings.
414    pub fn count(&self) -> PunchResult<usize> {
415        let conn = self
416            .conn
417            .lock()
418            .map_err(|e| PunchError::Memory(format!("lock failed: {e}")))?;
419        let count: i64 = conn
420            .query_row("SELECT COUNT(*) FROM embeddings", [], |row| row.get(0))
421            .map_err(|e| PunchError::Memory(format!("failed to count embeddings: {e}")))?;
422        Ok(count as usize)
423    }
424
425    /// Re-embed all stored texts. Useful when the embedder changes (e.g.,
426    /// after re-fitting the built-in TF-IDF vocabulary).
427    pub fn rebuild_index(&self) -> PunchResult<usize> {
428        let all = self.load_all()?;
429        let conn = self
430            .conn
431            .lock()
432            .map_err(|e| PunchError::Memory(format!("lock failed: {e}")))?;
433
434        let mut count = 0usize;
435        for emb in &all {
436            let new_vec = self.embedder.embed(&emb.text)?;
437            let blob = vec_to_bytes(&new_vec);
438            conn.execute(
439                "UPDATE embeddings SET vector = ?1 WHERE id = ?2",
440                rusqlite::params![blob, emb.id],
441            )
442            .map_err(|e| PunchError::Memory(format!("failed to update embedding: {e}")))?;
443            count += 1;
444        }
445        debug!(count, "embedding index rebuilt");
446        Ok(count)
447    }
448
449    /// Provide a reference to the current embedder.
450    pub fn embedder(&self) -> &dyn Embedder {
451        self.embedder.as_ref()
452    }
453
454    // -----------------------------------------------------------------------
455    // Internal helpers
456    // -----------------------------------------------------------------------
457
458    fn load_all(&self) -> PunchResult<Vec<Embedding>> {
459        let conn = self
460            .conn
461            .lock()
462            .map_err(|e| PunchError::Memory(format!("lock failed: {e}")))?;
463
464        let mut stmt = conn
465            .prepare("SELECT id, text, vector, metadata, created_at FROM embeddings")
466            .map_err(|e| PunchError::Memory(format!("failed to query embeddings: {e}")))?;
467
468        let rows = stmt
469            .query_map([], |row| {
470                let id: String = row.get(0)?;
471                let text: String = row.get(1)?;
472                let blob: Vec<u8> = row.get(2)?;
473                let meta_json: String = row.get(3)?;
474                let created_at: String = row.get(4)?;
475                Ok((id, text, blob, meta_json, created_at))
476            })
477            .map_err(|e| PunchError::Memory(format!("failed to query embeddings: {e}")))?;
478
479        let mut embeddings = Vec::new();
480        for row in rows {
481            let (id, text, blob, meta_json, created_at_str) =
482                row.map_err(|e| PunchError::Memory(format!("failed to read row: {e}")))?;
483
484            let vector = bytes_to_vec(&blob);
485            let metadata: HashMap<String, String> =
486                serde_json::from_str(&meta_json).unwrap_or_default();
487            let created_at = parse_ts(&created_at_str)?;
488
489            embeddings.push(Embedding {
490                id,
491                text,
492                vector,
493                metadata,
494                created_at,
495            });
496        }
497        Ok(embeddings)
498    }
499}
500
501// ---------------------------------------------------------------------------
502// Serialization helpers for f32 vectors ↔ byte blobs
503// ---------------------------------------------------------------------------
504
505/// Serialize a `Vec<f32>` to little-endian bytes.
506pub fn vec_to_bytes(vec: &[f32]) -> Vec<u8> {
507    let mut bytes = Vec::with_capacity(vec.len() * 4);
508    for &v in vec {
509        bytes.extend_from_slice(&v.to_le_bytes());
510    }
511    bytes
512}
513
514/// Deserialize little-endian bytes back to a `Vec<f32>`.
515pub fn bytes_to_vec(bytes: &[u8]) -> Vec<f32> {
516    bytes
517        .chunks_exact(4)
518        .map(|chunk| {
519            let arr: [u8; 4] = chunk.try_into().expect("chunk is 4 bytes");
520            f32::from_le_bytes(arr)
521        })
522        .collect()
523}
524
525// ---------------------------------------------------------------------------
526// Utility functions
527// ---------------------------------------------------------------------------
528
529/// Simple whitespace-and-punctuation tokenizer. Lowercases, strips
530/// non-alphanumeric chars, and splits on whitespace.
531fn tokenize(text: &str) -> Vec<String> {
532    text.to_lowercase()
533        .split_whitespace()
534        .map(|w| {
535            w.chars()
536                .filter(|c| c.is_alphanumeric())
537                .collect::<String>()
538        })
539        .filter(|w| !w.is_empty())
540        .collect()
541}
542
543/// L2-normalize a vector in place. If the norm is zero, the vector is
544/// left unchanged.
545fn l2_normalize(vec: &mut [f32]) {
546    let norm: f32 = vec.iter().map(|v| v * v).sum::<f32>().sqrt();
547    if norm > 0.0 {
548        for v in vec.iter_mut() {
549            *v /= norm;
550        }
551    }
552}
553
554fn parse_ts(s: &str) -> PunchResult<DateTime<Utc>> {
555    DateTime::parse_from_rfc3339(s)
556        .map(|dt| dt.with_timezone(&Utc))
557        .or_else(|_| {
558            chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ").map(|ndt| ndt.and_utc())
559        })
560        .map_err(|e| PunchError::Memory(format!("invalid timestamp '{s}': {e}")))
561}
562
563// ---------------------------------------------------------------------------
564// Tests
565// ---------------------------------------------------------------------------
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570
571    // -- Cosine similarity ---------------------------------------------------
572
573    #[test]
574    fn test_cosine_identical_vectors() {
575        let v = vec![1.0, 2.0, 3.0];
576        let sim = cosine_similarity(&v, &v);
577        assert!(
578            (sim - 1.0).abs() < 1e-6,
579            "identical vectors should have similarity 1.0"
580        );
581    }
582
583    #[test]
584    fn test_cosine_orthogonal_vectors() {
585        let a = vec![1.0, 0.0];
586        let b = vec![0.0, 1.0];
587        let sim = cosine_similarity(&a, &b);
588        assert!(
589            sim.abs() < 1e-6,
590            "orthogonal vectors should have similarity ~0.0"
591        );
592    }
593
594    #[test]
595    fn test_cosine_opposite_vectors() {
596        let a = vec![1.0, 2.0, 3.0];
597        let b = vec![-1.0, -2.0, -3.0];
598        let sim = cosine_similarity(&a, &b);
599        assert!(
600            (sim + 1.0).abs() < 1e-6,
601            "opposite vectors should have similarity -1.0"
602        );
603    }
604
605    #[test]
606    fn test_cosine_zero_vector() {
607        let a = vec![1.0, 2.0];
608        let b = vec![0.0, 0.0];
609        let sim = cosine_similarity(&a, &b);
610        assert!(sim.abs() < 1e-6, "zero vector should yield 0.0");
611    }
612
613    // -- BuiltInEmbedder -----------------------------------------------------
614
615    #[test]
616    fn test_builtin_fit_and_embed_nonzero() {
617        let mut embedder = BuiltInEmbedder::new();
618        embedder.fit(&["the cat sat on the mat", "the dog chased the ball"]);
619
620        let vec = embedder.embed("cat sat on mat").unwrap();
621        assert!(!vec.is_empty());
622        assert!(vec.iter().any(|&v| v != 0.0), "vector should be non-zero");
623    }
624
625    #[test]
626    fn test_builtin_similar_texts_higher_similarity() {
627        let mut embedder = BuiltInEmbedder::new();
628        embedder.fit(&[
629            "rust programming language",
630            "python programming language",
631            "cooking recipes for dinner",
632            "baking bread at home",
633        ]);
634
635        let v_rust = embedder.embed("rust programming").unwrap();
636        let v_python = embedder.embed("python programming").unwrap();
637        let v_cooking = embedder.embed("cooking dinner recipes").unwrap();
638
639        let sim_related = cosine_similarity(&v_rust, &v_python);
640        let sim_unrelated = cosine_similarity(&v_rust, &v_cooking);
641
642        assert!(
643            sim_related > sim_unrelated,
644            "related texts should have higher similarity ({sim_related} > {sim_unrelated})"
645        );
646    }
647
648    #[test]
649    fn test_builtin_l2_normalization() {
650        let mut embedder = BuiltInEmbedder::new();
651        embedder.fit(&["hello world", "foo bar baz"]);
652
653        let vec = embedder.embed("hello world foo").unwrap();
654        if !vec.is_empty() && vec.iter().any(|&v| v != 0.0) {
655            let norm: f32 = vec.iter().map(|v| v * v).sum::<f32>().sqrt();
656            assert!(
657                (norm - 1.0).abs() < 1e-5,
658                "vector should be L2-normalized, got norm={norm}"
659            );
660        }
661    }
662
663    #[test]
664    fn test_builtin_empty_corpus() {
665        let mut embedder = BuiltInEmbedder::new();
666        embedder.fit(&[]);
667        let vec = embedder.embed("anything").unwrap();
668        assert!(vec.is_empty(), "empty corpus should produce empty vector");
669        assert_eq!(embedder.dimensions(), 0);
670    }
671
672    #[test]
673    fn test_builtin_single_document_corpus() {
674        let mut embedder = BuiltInEmbedder::new();
675        embedder.fit(&["the only document in the corpus"]);
676
677        let vec = embedder.embed("the only document").unwrap();
678        assert!(!vec.is_empty());
679        // With a single document, IDF for all terms is log(1/1) = 0.
680        // All terms appear in the only document, so IDF = ln(1) = 0.
681        // The vector will be all zeros.
682        assert!(
683            vec.iter().all(|&v| v == 0.0),
684            "single-doc corpus yields zero IDF, so vector is zero"
685        );
686    }
687
688    #[test]
689    fn test_builtin_batch_embedding() {
690        let mut embedder = BuiltInEmbedder::new();
691        embedder.fit(&["hello world", "foo bar"]);
692
693        let batch = embedder.embed_batch(&["hello", "foo"]).unwrap();
694        assert_eq!(batch.len(), 2);
695        assert_eq!(batch[0].len(), embedder.dimensions());
696        assert_eq!(batch[1].len(), embedder.dimensions());
697    }
698
699    // -- Vector serialization ------------------------------------------------
700
701    #[test]
702    fn test_vec_bytes_roundtrip() {
703        let original = vec![1.0_f32, -2.5, 3.14, 0.0, f32::MAX, f32::MIN];
704        let bytes = vec_to_bytes(&original);
705        let restored = bytes_to_vec(&bytes);
706        assert_eq!(original, restored);
707    }
708
709    #[test]
710    fn test_vec_bytes_empty() {
711        let empty: Vec<f32> = Vec::new();
712        let bytes = vec_to_bytes(&empty);
713        assert!(bytes.is_empty());
714        let restored = bytes_to_vec(&bytes);
715        assert!(restored.is_empty());
716    }
717
718    // -- EmbeddingConfig serialization ---------------------------------------
719
720    #[test]
721    fn test_embedding_config_serde() {
722        let config = EmbeddingConfig {
723            provider: EmbeddingProvider::BuiltIn,
724            dimensions: 1024,
725            batch_size: 32,
726        };
727        let json = serde_json::to_string(&config).unwrap();
728        let restored: EmbeddingConfig = serde_json::from_str(&json).unwrap();
729        assert_eq!(restored.dimensions, 1024);
730        assert_eq!(restored.batch_size, 32);
731    }
732
733    #[test]
734    fn test_embedding_config_openai_serde() {
735        let config = EmbeddingConfig {
736            provider: EmbeddingProvider::OpenAi {
737                api_key: "sk-test".into(),
738                model: "text-embedding-3-small".into(),
739            },
740            dimensions: 1536,
741            batch_size: 100,
742        };
743        let json = serde_json::to_string(&config).unwrap();
744        let restored: EmbeddingConfig = serde_json::from_str(&json).unwrap();
745        assert_eq!(restored.dimensions, 1536);
746    }
747
748    // -- OpenAiEmbedder request formatting -----------------------------------
749
750    #[test]
751    fn test_openai_request_single() {
752        let embedder =
753            OpenAiEmbedder::new("sk-test-key".into(), "text-embedding-3-small".into(), 1536);
754        let body = embedder.build_request_body(&["hello world"]);
755        assert_eq!(body["input"], "hello world");
756        assert_eq!(body["model"], "text-embedding-3-small");
757    }
758
759    #[test]
760    fn test_openai_request_batch() {
761        let embedder =
762            OpenAiEmbedder::new("sk-test-key".into(), "text-embedding-3-small".into(), 1536);
763        let body = embedder.build_request_body(&["hello", "world"]);
764        let input = body["input"].as_array().unwrap();
765        assert_eq!(input.len(), 2);
766        assert_eq!(input[0], "hello");
767        assert_eq!(input[1], "world");
768    }
769
770    #[test]
771    fn test_openai_parse_response() {
772        let response = serde_json::json!({
773            "data": [
774                {"embedding": [0.1, 0.2, 0.3], "index": 0},
775                {"embedding": [0.4, 0.5, 0.6], "index": 1}
776            ]
777        });
778        let vecs = OpenAiEmbedder::parse_response(&response).unwrap();
779        assert_eq!(vecs.len(), 2);
780        assert_eq!(vecs[0], vec![0.1_f32, 0.2, 0.3]);
781        assert_eq!(vecs[1], vec![0.4_f32, 0.5, 0.6]);
782    }
783
784    // -- EmbeddingStore ------------------------------------------------------
785
786    fn test_store() -> EmbeddingStore {
787        let conn = Connection::open_in_memory().unwrap();
788        conn.execute_batch("PRAGMA foreign_keys = ON;").unwrap();
789        let arc = Arc::new(Mutex::new(conn));
790        let mut embedder = BuiltInEmbedder::new();
791        embedder.fit(&[
792            "rust programming language systems",
793            "python scripting language web",
794            "cooking recipes kitchen food",
795            "machine learning neural networks",
796        ]);
797        EmbeddingStore::new(arc, Box::new(embedder)).unwrap()
798    }
799
800    #[test]
801    fn test_store_and_search() {
802        let store = test_store();
803        store
804            .store("rust systems programming", HashMap::new())
805            .unwrap();
806        store
807            .store("python web development", HashMap::new())
808            .unwrap();
809        store
810            .store("cooking recipes for pasta", HashMap::new())
811            .unwrap();
812
813        let results = store.search("rust programming", 2).unwrap();
814        assert!(!results.is_empty());
815        // The top result should be about rust.
816        assert!(
817            results[0].1.text.contains("rust"),
818            "top result should match 'rust', got: {}",
819            results[0].1.text
820        );
821    }
822
823    #[test]
824    fn test_store_top_k_count() {
825        let store = test_store();
826        store.store("alpha", HashMap::new()).unwrap();
827        store.store("beta", HashMap::new()).unwrap();
828        store.store("gamma", HashMap::new()).unwrap();
829        store.store("delta", HashMap::new()).unwrap();
830
831        let results = store.search("alpha", 2).unwrap();
832        assert_eq!(results.len(), 2, "should return exactly k results");
833    }
834
835    #[test]
836    fn test_store_delete() {
837        let store = test_store();
838        let id = store.store("to be deleted", HashMap::new()).unwrap();
839        assert_eq!(store.count().unwrap(), 1);
840
841        store.delete(&id).unwrap();
842        assert_eq!(store.count().unwrap(), 0);
843    }
844
845    #[test]
846    fn test_store_count() {
847        let store = test_store();
848        assert_eq!(store.count().unwrap(), 0);
849
850        store.store("one", HashMap::new()).unwrap();
851        assert_eq!(store.count().unwrap(), 1);
852
853        store.store("two", HashMap::new()).unwrap();
854        assert_eq!(store.count().unwrap(), 2);
855    }
856
857    #[test]
858    fn test_store_rebuild_index() {
859        let conn = Connection::open_in_memory().unwrap();
860        conn.execute_batch("PRAGMA foreign_keys = ON;").unwrap();
861        let arc = Arc::new(Mutex::new(conn));
862
863        let mut embedder = BuiltInEmbedder::new();
864        embedder.fit(&["hello world", "foo bar"]);
865        let store = EmbeddingStore::new(Arc::clone(&arc), Box::new(embedder)).unwrap();
866
867        store.store("hello world test", HashMap::new()).unwrap();
868        store.store("foo bar baz", HashMap::new()).unwrap();
869        assert_eq!(store.count().unwrap(), 2);
870
871        let rebuilt = store.rebuild_index().unwrap();
872        assert_eq!(rebuilt, 2);
873        assert_eq!(store.count().unwrap(), 2);
874    }
875
876    // -- top_k_similar -------------------------------------------------------
877
878    #[test]
879    fn test_cosine_similarity_single_dimension() {
880        let a = vec![3.0];
881        let b = vec![5.0];
882        let sim = cosine_similarity(&a, &b);
883        assert!(
884            (sim - 1.0).abs() < 1e-6,
885            "same direction in 1D should be 1.0"
886        );
887    }
888
889    #[test]
890    fn test_cosine_similarity_negative_values() {
891        let a = vec![-1.0, -2.0];
892        let b = vec![-3.0, -6.0];
893        let sim = cosine_similarity(&a, &b);
894        assert!(
895            (sim - 1.0).abs() < 1e-6,
896            "parallel negative vectors are similar"
897        );
898    }
899
900    #[test]
901    fn test_builtin_embedder_default() {
902        let embedder = BuiltInEmbedder::default();
903        assert_eq!(embedder.dimensions(), 0);
904    }
905
906    #[test]
907    fn test_builtin_embed_empty_text() {
908        let mut embedder = BuiltInEmbedder::new();
909        embedder.fit(&["hello world", "foo bar"]);
910        let vec = embedder.embed("").unwrap();
911        assert_eq!(vec.len(), embedder.dimensions());
912        assert!(
913            vec.iter().all(|&v| v == 0.0),
914            "empty text yields zero vector"
915        );
916    }
917
918    #[test]
919    fn test_builtin_dimensions_matches_vocab() {
920        let mut embedder = BuiltInEmbedder::new();
921        embedder.fit(&["alpha beta gamma", "delta epsilon"]);
922        assert!(embedder.dimensions() > 0);
923        let vec = embedder.embed("alpha").unwrap();
924        assert_eq!(vec.len(), embedder.dimensions());
925    }
926
927    #[test]
928    fn test_openai_embedder_dimensions() {
929        let embedder = OpenAiEmbedder::new("key".into(), "model".into(), 768);
930        assert_eq!(embedder.dimensions(), 768);
931    }
932
933    #[test]
934    fn test_openai_embed_returns_error() {
935        let embedder = OpenAiEmbedder::new("key".into(), "model".into(), 768);
936        assert!(embedder.embed("test").is_err());
937    }
938
939    #[test]
940    fn test_openai_embed_batch_returns_error() {
941        let embedder = OpenAiEmbedder::new("key".into(), "model".into(), 768);
942        assert!(embedder.embed_batch(&["a", "b"]).is_err());
943    }
944
945    #[test]
946    fn test_openai_parse_response_missing_data() {
947        let resp = serde_json::json!({"no_data": true});
948        assert!(OpenAiEmbedder::parse_response(&resp).is_err());
949    }
950
951    #[test]
952    fn test_vec_bytes_single_value() {
953        let original = vec![42.0_f32];
954        let bytes = vec_to_bytes(&original);
955        assert_eq!(bytes.len(), 4);
956        let restored = bytes_to_vec(&bytes);
957        assert_eq!(original, restored);
958    }
959
960    #[test]
961    fn test_store_with_metadata() {
962        let store = test_store();
963        let mut meta = HashMap::new();
964        meta.insert("source".to_string(), "test".to_string());
965        let id = store.store("text with metadata", meta).unwrap();
966        assert!(!id.is_empty());
967        assert_eq!(store.count().unwrap(), 1);
968    }
969
970    #[test]
971    fn test_store_delete_nonexistent() {
972        let store = test_store();
973        // Deleting a non-existent ID should not error
974        store.delete("nonexistent-id").unwrap();
975        assert_eq!(store.count().unwrap(), 0);
976    }
977
978    #[test]
979    fn test_top_k_similar_empty_list() {
980        let query = vec![1.0, 0.0];
981        let results = top_k_similar(&query, &[], 5);
982        assert!(results.is_empty());
983    }
984
985    #[test]
986    fn test_top_k_similar_k_larger_than_list() {
987        let embeddings = vec![Embedding {
988            id: "only".into(),
989            text: "one".into(),
990            vector: vec![1.0, 0.0],
991            metadata: HashMap::new(),
992            created_at: Utc::now(),
993        }];
994        let query = vec![1.0, 0.0];
995        let results = top_k_similar(&query, &embeddings, 10);
996        assert_eq!(results.len(), 1);
997    }
998
999    #[test]
1000    fn test_top_k_similar_ordering() {
1001        let embeddings = vec![
1002            Embedding {
1003                id: "a".into(),
1004                text: "close".into(),
1005                vector: vec![0.9, 0.1],
1006                metadata: HashMap::new(),
1007                created_at: Utc::now(),
1008            },
1009            Embedding {
1010                id: "b".into(),
1011                text: "far".into(),
1012                vector: vec![0.0, 1.0],
1013                metadata: HashMap::new(),
1014                created_at: Utc::now(),
1015            },
1016        ];
1017        let query = vec![1.0, 0.0];
1018        let results = top_k_similar(&query, &embeddings, 2);
1019        assert_eq!(results.len(), 2);
1020        assert_eq!(results[0].1.id, "a", "closer vector should come first");
1021        assert!(results[0].0 > results[1].0, "scores should be descending");
1022    }
1023}