rig_cat/vector_store/
mod.rs1use comp_cat_rs::effect::io::Io;
4
5use crate::embedding::{Embedding, EmbeddingModel, EmbeddingRequest};
6use crate::error::Error;
7
8#[derive(Debug, Clone)]
10pub struct Document {
11 id: String,
12 content: String,
13 embedding: Embedding,
14}
15
16impl Document {
17 #[must_use]
18 pub fn new(id: String, content: String, embedding: Embedding) -> Self {
19 Self { id, content, embedding }
20 }
21
22 #[must_use]
23 pub fn id(&self) -> &str { &self.id }
24
25 #[must_use]
26 pub fn content(&self) -> &str { &self.content }
27
28 #[must_use]
29 pub fn embedding(&self) -> &Embedding { &self.embedding }
30}
31
32#[derive(Debug, Clone)]
34pub struct SearchResult {
35 document: Document,
36 score: f64,
37}
38
39impl SearchResult {
40 #[must_use]
41 pub fn new(document: Document, score: f64) -> Self {
42 Self { document, score }
43 }
44
45 #[must_use]
46 pub fn document(&self) -> &Document { &self.document }
47
48 #[must_use]
49 pub fn score(&self) -> f64 { self.score }
50}
51
52pub trait VectorStoreIndex {
54 fn search(&self, query: &Embedding, top_k: usize) -> Io<Error, Vec<SearchResult>>;
56}
57
58pub struct InMemoryVectorStore {
61 documents: Vec<Document>,
62}
63
64impl InMemoryVectorStore {
65 #[must_use]
67 pub fn new() -> Self { Self { documents: Vec::new() } }
68
69 #[must_use]
71 pub fn with_documents(self, docs: Vec<Document>) -> Self {
72 Self {
73 documents: self.documents.into_iter().chain(docs).collect(),
74 }
75 }
76
77 pub fn ingest<M: EmbeddingModel>(
79 texts: &[(String, String)],
80 model: &M,
81 ) -> Io<Error, Self> {
82 let contents: Vec<String> = texts.iter().map(|(_, c)| c.clone()).collect();
83 let ids: Vec<String> = texts.iter().map(|(id, _)| id.clone()).collect();
84 model.embed(EmbeddingRequest::new(contents.clone())).map(move |embeddings| {
85 let docs = ids.into_iter()
86 .zip(contents)
87 .zip(embeddings)
88 .map(|((id, content), emb)| Document::new(id, content, emb))
89 .collect();
90 Self { documents: Vec::new() }.with_documents(docs)
91 })
92 }
93}
94
95impl Default for InMemoryVectorStore {
96 fn default() -> Self { Self::new() }
97}
98
99impl VectorStoreIndex for InMemoryVectorStore {
100 fn search(&self, query: &Embedding, top_k: usize) -> Io<Error, Vec<SearchResult>> {
101 let results: Result<Vec<SearchResult>, Error> = self.documents.iter()
102 .map(|doc| {
103 doc.embedding().cosine_similarity(query)
104 .map(|score| SearchResult::new(doc.clone(), score))
105 })
106 .collect::<Result<Vec<_>, _>>()
107 .map(|scored| {
108 scored.into_iter()
110 .fold(Vec::<SearchResult>::new(), |acc, result| {
111 let score = result.score();
112 let pos = acc.iter()
113 .position(|r| r.score() < score)
114 .unwrap_or(acc.len());
115 let (head, tail) = (
116 acc.iter().take(pos).cloned().collect::<Vec<_>>(),
117 acc.iter().skip(pos).cloned().collect::<Vec<_>>(),
118 );
119 head.into_iter()
120 .chain(std::iter::once(result))
121 .chain(tail)
122 .take(top_k)
123 .collect()
124 })
125 });
126 Io::suspend(move || results)
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 fn make_doc(id: &str, emb: Vec<f64>) -> Document {
135 Document::new(id.into(), format!("content of {id}"), Embedding::new(emb))
136 }
137
138 #[test]
139 fn search_returns_most_similar_first() -> Result<(), Error> {
140 let store = InMemoryVectorStore::new().with_documents(vec![
141 make_doc("far", vec![0.0, 1.0]),
142 make_doc("close", vec![1.0, 0.1]),
143 make_doc("mid", vec![0.7, 0.7]),
144 ]);
145 let query = Embedding::new(vec![1.0, 0.0]);
146 let results = store.search(&query, 3).run()?;
147 assert_eq!(results.first().map(|r| r.document().id()), Some("close"));
148 Ok(())
149 }
150
151 #[test]
152 fn search_respects_top_k() -> Result<(), Error> {
153 let store = InMemoryVectorStore::new().with_documents(vec![
154 make_doc("a", vec![1.0, 0.0]),
155 make_doc("b", vec![0.9, 0.1]),
156 make_doc("c", vec![0.0, 1.0]),
157 ]);
158 let query = Embedding::new(vec![1.0, 0.0]);
159 let results = store.search(&query, 1).run()?;
160 assert_eq!(results.len(), 1);
161 Ok(())
162 }
163
164 #[test]
165 fn search_empty_store_returns_empty() -> Result<(), Error> {
166 let store = InMemoryVectorStore::new();
167 let query = Embedding::new(vec![1.0, 0.0]);
168 let results = store.search(&query, 5).run()?;
169 assert!(results.is_empty());
170 Ok(())
171 }
172}