Skip to main content

rig_cat/vector_store/
mod.rs

1//! Vector store trait and in-memory implementation.
2
3use comp_cat_rs::effect::io::Io;
4
5use crate::embedding::{Embedding, EmbeddingModel, EmbeddingRequest};
6use crate::error::Error;
7
8/// A document stored in a vector store.
9#[derive(Debug, Clone)]
10pub struct Document {
11    id: String,
12    content: String,
13    embedding: Embedding,
14}
15
16impl Document {
17    #[must_use]
18    pub fn new(id: String, content: String, embedding: Embedding) -> Self {
19        Self { id, content, embedding }
20    }
21
22    #[must_use]
23    pub fn id(&self) -> &str { &self.id }
24
25    #[must_use]
26    pub fn content(&self) -> &str { &self.content }
27
28    #[must_use]
29    pub fn embedding(&self) -> &Embedding { &self.embedding }
30}
31
32/// A search result: a document with its similarity score.
33#[derive(Debug, Clone)]
34pub struct SearchResult {
35    document: Document,
36    score: f64,
37}
38
39impl SearchResult {
40    #[must_use]
41    pub fn new(document: Document, score: f64) -> Self {
42        Self { document, score }
43    }
44
45    #[must_use]
46    pub fn document(&self) -> &Document { &self.document }
47
48    #[must_use]
49    pub fn score(&self) -> f64 { self.score }
50}
51
52/// A vector store index: store documents, search by similarity.
53pub trait VectorStoreIndex {
54    /// Search for the top-k most similar documents to the query.
55    fn search(&self, query: &Embedding, top_k: usize) -> Io<Error, Vec<SearchResult>>;
56}
57
58/// In-memory vector store: stores documents in a Vec,
59/// searches by cosine similarity.
60pub struct InMemoryVectorStore {
61    documents: Vec<Document>,
62}
63
64impl InMemoryVectorStore {
65    /// Create an empty store.
66    #[must_use]
67    pub fn new() -> Self { Self { documents: Vec::new() } }
68
69    /// Add documents to the store.
70    #[must_use]
71    pub fn with_documents(self, docs: Vec<Document>) -> Self {
72        Self {
73            documents: self.documents.into_iter().chain(docs).collect(),
74        }
75    }
76
77    /// Ingest raw texts: embed them and store.
78    pub fn ingest<M: EmbeddingModel>(
79        texts: &[(String, String)],
80        model: &M,
81    ) -> Io<Error, Self> {
82        let contents: Vec<String> = texts.iter().map(|(_, c)| c.clone()).collect();
83        let ids: Vec<String> = texts.iter().map(|(id, _)| id.clone()).collect();
84        model.embed(EmbeddingRequest::new(contents.clone())).map(move |embeddings| {
85            let docs = ids.into_iter()
86                .zip(contents)
87                .zip(embeddings)
88                .map(|((id, content), emb)| Document::new(id, content, emb))
89                .collect();
90            Self { documents: Vec::new() }.with_documents(docs)
91        })
92    }
93}
94
95impl Default for InMemoryVectorStore {
96    fn default() -> Self { Self::new() }
97}
98
99impl VectorStoreIndex for InMemoryVectorStore {
100    fn search(&self, query: &Embedding, top_k: usize) -> Io<Error, Vec<SearchResult>> {
101        let results: Result<Vec<SearchResult>, Error> = self.documents.iter()
102            .map(|doc| {
103                doc.embedding().cosine_similarity(query)
104                    .map(|score| SearchResult::new(doc.clone(), score))
105            })
106            .collect::<Result<Vec<_>, _>>()
107            .map(|scored| {
108                // Insert into a sorted vec via fold (no mut needed)
109                scored.into_iter()
110                    .fold(Vec::<SearchResult>::new(), |acc, result| {
111                        let score = result.score();
112                        let pos = acc.iter()
113                            .position(|r| r.score() < score)
114                            .unwrap_or(acc.len());
115                        let (head, tail) = (
116                            acc.iter().take(pos).cloned().collect::<Vec<_>>(),
117                            acc.iter().skip(pos).cloned().collect::<Vec<_>>(),
118                        );
119                        head.into_iter()
120                            .chain(std::iter::once(result))
121                            .chain(tail)
122                            .take(top_k)
123                            .collect()
124                    })
125            });
126        Io::suspend(move || results)
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    fn make_doc(id: &str, emb: Vec<f64>) -> Document {
135        Document::new(id.into(), format!("content of {id}"), Embedding::new(emb))
136    }
137
138    #[test]
139    fn search_returns_most_similar_first() -> Result<(), Error> {
140        let store = InMemoryVectorStore::new().with_documents(vec![
141            make_doc("far", vec![0.0, 1.0]),
142            make_doc("close", vec![1.0, 0.1]),
143            make_doc("mid", vec![0.7, 0.7]),
144        ]);
145        let query = Embedding::new(vec![1.0, 0.0]);
146        let results = store.search(&query, 3).run()?;
147        assert_eq!(results.first().map(|r| r.document().id()), Some("close"));
148        Ok(())
149    }
150
151    #[test]
152    fn search_respects_top_k() -> Result<(), Error> {
153        let store = InMemoryVectorStore::new().with_documents(vec![
154            make_doc("a", vec![1.0, 0.0]),
155            make_doc("b", vec![0.9, 0.1]),
156            make_doc("c", vec![0.0, 1.0]),
157        ]);
158        let query = Embedding::new(vec![1.0, 0.0]);
159        let results = store.search(&query, 1).run()?;
160        assert_eq!(results.len(), 1);
161        Ok(())
162    }
163
164    #[test]
165    fn search_empty_store_returns_empty() -> Result<(), Error> {
166        let store = InMemoryVectorStore::new();
167        let query = Embedding::new(vec![1.0, 0.0]);
168        let results = store.search(&query, 5).run()?;
169        assert!(results.is_empty());
170        Ok(())
171    }
172}