1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::Arc;
4
5use crate::client::OllamaClient;
6use crate::error::OxideError;
7use crate::types::{EmbedInput, EmbedRequest};
8
9fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
12 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
13 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
14 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
15 if norm_a == 0.0 || norm_b == 0.0 {
16 return 0.0;
17 }
18 dot / (norm_a * norm_b)
19}
20
21#[derive(Debug, Clone)]
24pub struct Document {
25 pub content: String,
26 pub embedding: Vec<f32>,
27 pub metadata: HashMap<String, String>,
28}
29
30#[derive(Debug, Clone)]
31pub struct SearchResult {
32 pub content: String,
33 pub score: f32,
34 pub metadata: HashMap<String, String>,
35}
36
37pub struct VectorStore {
60 client: Arc<dyn OllamaClient>,
61 embed_model: String,
62 documents: Vec<Document>,
63}
64
65impl VectorStore {
66 pub fn new<C: OllamaClient + 'static>(client: Arc<C>, embed_model: impl Into<String>) -> Self {
67 let client: Arc<dyn OllamaClient> = client;
68 Self {
69 client,
70 embed_model: embed_model.into(),
71 documents: Vec::new(),
72 }
73 }
74
75 pub async fn add_text(
77 &mut self,
78 text: impl Into<String>,
79 metadata: HashMap<String, String>,
80 ) -> Result<(), OxideError> {
81 let content = text.into();
82 let embedding = self.embed_one(&content).await?;
83 self.documents.push(Document { content, embedding, metadata });
84 Ok(())
85 }
86
87 pub async fn add_file(&mut self, path: &Path) -> Result<usize, OxideError> {
89 let raw = tokio::fs::read_to_string(path)
90 .await
91 .map_err(|e| OxideError::Other(format!("read file: {e}")))?;
92
93 let file_name = path
94 .file_name()
95 .and_then(|s| s.to_str())
96 .unwrap_or("")
97 .to_string();
98
99 let chunks: Vec<&str> = raw.split("\n\n").map(str::trim).filter(|s| !s.is_empty()).collect();
101 let count = chunks.len();
102
103 for (i, chunk) in chunks.into_iter().enumerate() {
104 let mut meta = HashMap::new();
105 meta.insert("source".into(), file_name.clone());
106 meta.insert("chunk".into(), i.to_string());
107 self.add_text(chunk, meta).await?;
108 }
109
110 Ok(count)
111 }
112
113 pub async fn query(
116 &self,
117 query: impl Into<String>,
118 top_k: usize,
119 ) -> Result<Vec<SearchResult>, OxideError> {
120 let q_text = query.into();
121 let q_emb = self.embed_one(&q_text).await?;
122
123 let mut scored: Vec<(f32, &Document)> = self
124 .documents
125 .iter()
126 .map(|doc| (cosine_similarity(&q_emb, &doc.embedding), doc))
127 .collect();
128
129 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
130
131 Ok(scored
132 .into_iter()
133 .take(top_k)
134 .map(|(score, doc)| SearchResult {
135 content: doc.content.clone(),
136 score,
137 metadata: doc.metadata.clone(),
138 })
139 .collect())
140 }
141
142 pub fn len(&self) -> usize {
144 self.documents.len()
145 }
146
147 pub fn is_empty(&self) -> bool {
148 self.documents.is_empty()
149 }
150
151 async fn embed_one(&self, text: &str) -> Result<Vec<f32>, OxideError> {
154 let resp = self
155 .client
156 .embed(EmbedRequest {
157 model: self.embed_model.clone(),
158 input: EmbedInput::Single(text.to_string()),
159 })
160 .await?;
161
162 resp.embeddings
163 .into_iter()
164 .next()
165 .ok_or_else(|| OxideError::Other("embed returned no vectors".into()))
166 }
167}
168
169#[cfg(test)]
172mod tests {
173 use super::*;
174 use crate::client::{BoxStream, OllamaClient};
175 use crate::types::{
176 ChatRequest, ChatResponse, EmbedResponse, GenerateRequest, GenerateResponse,
177 ListModelsResponse,
178 };
179 use async_trait::async_trait;
180
181 struct FakeEmbedClient;
184
185 #[async_trait]
186 impl OllamaClient for FakeEmbedClient {
187 async fn generate(&self, _: GenerateRequest) -> Result<GenerateResponse, OxideError> {
188 unimplemented!()
189 }
190 async fn chat(&self, _: ChatRequest) -> Result<ChatResponse, OxideError> {
191 unimplemented!()
192 }
193 async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, OxideError> {
194 let text = match &req.input {
195 EmbedInput::Single(s) => s.clone(),
196 EmbedInput::Batch(v) => v[0].clone(),
197 };
198 let v = text.chars().next().map(|c| c as u8).unwrap_or(0) as f32;
201 Ok(EmbedResponse {
202 model: req.model,
203 embeddings: vec![vec![v, 1.0, 0.0, 0.0]],
204 })
205 }
206 async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
207 unimplemented!()
208 }
209 fn stream_generate(&self, _: GenerateRequest) -> BoxStream<GenerateResponse> {
210 unimplemented!()
211 }
212 fn stream_chat(&self, _: ChatRequest) -> BoxStream<ChatResponse> {
213 unimplemented!()
214 }
215 }
216
217 #[tokio::test]
218 async fn add_and_query_returns_ranked_results() {
219 let client = Arc::new(FakeEmbedClient);
220 let mut store = VectorStore::new(client, "test-model");
221
222 store.add_text("rust ownership model", Default::default()).await.unwrap();
223 store.add_text("python garbage collector", Default::default()).await.unwrap();
224 store.add_text("rustaceans love borrowing", Default::default()).await.unwrap();
225
226 assert_eq!(store.len(), 3);
227
228 let results = store.query("rust lifetimes", 2).await.unwrap();
230 assert_eq!(results.len(), 2);
231 assert!(results[0].content.starts_with('r'));
233 }
234
235 #[test]
236 fn cosine_similarity_identical_vectors() {
237 let v = vec![1.0_f32, 2.0, 3.0];
238 let sim = cosine_similarity(&v, &v);
239 assert!((sim - 1.0).abs() < 1e-6);
240 }
241
242 #[test]
243 fn cosine_similarity_orthogonal_vectors() {
244 let a = vec![1.0_f32, 0.0];
245 let b = vec![0.0_f32, 1.0];
246 let sim = cosine_similarity(&a, &b);
247 assert!(sim.abs() < 1e-6);
248 }
249}