Skip to main content

oxide_agent/rag/
mod.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::Arc;
4
5use crate::client::OllamaClient;
6use crate::error::OxideError;
7use crate::types::{EmbedInput, EmbedRequest};
8
9// ── Math helpers ──────────────────────────────────────────────────────────────
10
11fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
12    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
13    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
14    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
15    if norm_a == 0.0 || norm_b == 0.0 {
16        return 0.0;
17    }
18    dot / (norm_a * norm_b)
19}
20
21// ── Types ─────────────────────────────────────────────────────────────────────
22
23#[derive(Debug, Clone)]
24pub struct Document {
25    pub content: String,
26    pub embedding: Vec<f32>,
27    pub metadata: HashMap<String, String>,
28}
29
30#[derive(Debug, Clone)]
31pub struct SearchResult {
32    pub content: String,
33    pub score: f32,
34    pub metadata: HashMap<String, String>,
35}
36
37// ── VectorStore ───────────────────────────────────────────────────────────────
38
39/// In-memory vector store for retrieval-augmented generation.
40///
41/// Uses the Ollama `/api/embed` endpoint to compute embeddings and cosine
42/// similarity for nearest-neighbour search.
43///
44/// ```no_run
45/// use std::sync::Arc;
46/// use oxide_agent::rag::VectorStore;
47/// use oxide_agent::client::HttpOllamaClient;
48///
49/// # async fn example() -> anyhow::Result<()> {
50/// let client = Arc::new(HttpOllamaClient::new("http://localhost:11434"));
51/// let mut store = VectorStore::new(client, "nomic-embed-text");
52///
53/// store.add_text("Rust ownership means one owner at a time.", Default::default()).await?;
54/// let results = store.query("Who owns memory in Rust?", 3).await?;
55/// println!("{}", results[0].content);
56/// # Ok(())
57/// # }
58/// ```
59pub struct VectorStore {
60    client: Arc<dyn OllamaClient>,
61    embed_model: String,
62    documents: Vec<Document>,
63}
64
65impl VectorStore {
66    pub fn new<C: OllamaClient + 'static>(client: Arc<C>, embed_model: impl Into<String>) -> Self {
67        let client: Arc<dyn OllamaClient> = client;
68        Self {
69            client,
70            embed_model: embed_model.into(),
71            documents: Vec::new(),
72        }
73    }
74
75    /// Embed a single string and add it to the store.
76    pub async fn add_text(
77        &mut self,
78        text: impl Into<String>,
79        metadata: HashMap<String, String>,
80    ) -> Result<(), OxideError> {
81        let content = text.into();
82        let embedding = self.embed_one(&content).await?;
83        self.documents.push(Document { content, embedding, metadata });
84        Ok(())
85    }
86
87    /// Read a UTF-8 text file and add each non-empty line as a separate document.
88    pub async fn add_file(&mut self, path: &Path) -> Result<usize, OxideError> {
89        let raw = tokio::fs::read_to_string(path)
90            .await
91            .map_err(|e| OxideError::Other(format!("read file: {e}")))?;
92
93        let file_name = path
94            .file_name()
95            .and_then(|s| s.to_str())
96            .unwrap_or("")
97            .to_string();
98
99        // Chunk by paragraph (double newline) for better context.
100        let chunks: Vec<&str> = raw.split("\n\n").map(str::trim).filter(|s| !s.is_empty()).collect();
101        let count = chunks.len();
102
103        for (i, chunk) in chunks.into_iter().enumerate() {
104            let mut meta = HashMap::new();
105            meta.insert("source".into(), file_name.clone());
106            meta.insert("chunk".into(), i.to_string());
107            self.add_text(chunk, meta).await?;
108        }
109
110        Ok(count)
111    }
112
113    /// Return the `top_k` most similar documents to `query`, ranked by
114    /// cosine similarity (highest first).
115    pub async fn query(
116        &self,
117        query: impl Into<String>,
118        top_k: usize,
119    ) -> Result<Vec<SearchResult>, OxideError> {
120        let q_text = query.into();
121        let q_emb = self.embed_one(&q_text).await?;
122
123        let mut scored: Vec<(f32, &Document)> = self
124            .documents
125            .iter()
126            .map(|doc| (cosine_similarity(&q_emb, &doc.embedding), doc))
127            .collect();
128
129        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
130
131        Ok(scored
132            .into_iter()
133            .take(top_k)
134            .map(|(score, doc)| SearchResult {
135                content: doc.content.clone(),
136                score,
137                metadata: doc.metadata.clone(),
138            })
139            .collect())
140    }
141
142    /// Number of documents in the store.
143    pub fn len(&self) -> usize {
144        self.documents.len()
145    }
146
147    pub fn is_empty(&self) -> bool {
148        self.documents.is_empty()
149    }
150
151    // ── Internals ─────────────────────────────────────────────────────────────
152
153    async fn embed_one(&self, text: &str) -> Result<Vec<f32>, OxideError> {
154        let resp = self
155            .client
156            .embed(EmbedRequest {
157                model: self.embed_model.clone(),
158                input: EmbedInput::Single(text.to_string()),
159            })
160            .await?;
161
162        resp.embeddings
163            .into_iter()
164            .next()
165            .ok_or_else(|| OxideError::Other("embed returned no vectors".into()))
166    }
167}
168
169// ── Tests ─────────────────────────────────────────────────────────────────────
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use crate::client::{BoxStream, OllamaClient};
175    use crate::types::{
176        ChatRequest, ChatResponse, EmbedResponse, GenerateRequest, GenerateResponse,
177        ListModelsResponse,
178    };
179    use async_trait::async_trait;
180
181    /// Returns a fixed-length embedding based on the first character's ASCII
182    /// value so similar strings score higher than dissimilar ones.
183    struct FakeEmbedClient;
184
185    #[async_trait]
186    impl OllamaClient for FakeEmbedClient {
187        async fn generate(&self, _: GenerateRequest) -> Result<GenerateResponse, OxideError> {
188            unimplemented!()
189        }
190        async fn chat(&self, _: ChatRequest) -> Result<ChatResponse, OxideError> {
191            unimplemented!()
192        }
193        async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, OxideError> {
194            let text = match &req.input {
195                EmbedInput::Single(s) => s.clone(),
196                EmbedInput::Batch(v) => v[0].clone(),
197            };
198            // Produce a 4-D embedding where dim[0] is proportional to the
199            // first char so "rust" and "rustacean" end up close together.
200            let v = text.chars().next().map(|c| c as u8).unwrap_or(0) as f32;
201            Ok(EmbedResponse {
202                model: req.model,
203                embeddings: vec![vec![v, 1.0, 0.0, 0.0]],
204            })
205        }
206        async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
207            unimplemented!()
208        }
209        fn stream_generate(&self, _: GenerateRequest) -> BoxStream<GenerateResponse> {
210            unimplemented!()
211        }
212        fn stream_chat(&self, _: ChatRequest) -> BoxStream<ChatResponse> {
213            unimplemented!()
214        }
215    }
216
217    #[tokio::test]
218    async fn add_and_query_returns_ranked_results() {
219        let client = Arc::new(FakeEmbedClient);
220        let mut store = VectorStore::new(client, "test-model");
221
222        store.add_text("rust ownership model", Default::default()).await.unwrap();
223        store.add_text("python garbage collector", Default::default()).await.unwrap();
224        store.add_text("rustaceans love borrowing", Default::default()).await.unwrap();
225
226        assert_eq!(store.len(), 3);
227
228        // Query starts with 'r' — the two "rust*" docs should rank highest.
229        let results = store.query("rust lifetimes", 2).await.unwrap();
230        assert_eq!(results.len(), 2);
231        // Both top results should have 'r' as first char.
232        assert!(results[0].content.starts_with('r'));
233    }
234
235    #[test]
236    fn cosine_similarity_identical_vectors() {
237        let v = vec![1.0_f32, 2.0, 3.0];
238        let sim = cosine_similarity(&v, &v);
239        assert!((sim - 1.0).abs() < 1e-6);
240    }
241
242    #[test]
243    fn cosine_similarity_orthogonal_vectors() {
244        let a = vec![1.0_f32, 0.0];
245        let b = vec![0.0_f32, 1.0];
246        let sim = cosine_similarity(&a, &b);
247        assert!(sim.abs() < 1e-6);
248    }
249}