use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use crate::client::OllamaClient;
use crate::error::OxideError;
use crate::types::{EmbedInput, EmbedRequest};
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[derive(Debug, Clone)]
pub struct Document {
pub content: String,
pub embedding: Vec<f32>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub content: String,
pub score: f32,
pub metadata: HashMap<String, String>,
}
pub struct VectorStore {
client: Arc<dyn OllamaClient>,
embed_model: String,
documents: Vec<Document>,
}
impl VectorStore {
pub fn new<C: OllamaClient + 'static>(client: Arc<C>, embed_model: impl Into<String>) -> Self {
let client: Arc<dyn OllamaClient> = client;
Self {
client,
embed_model: embed_model.into(),
documents: Vec::new(),
}
}
pub async fn add_text(
&mut self,
text: impl Into<String>,
metadata: HashMap<String, String>,
) -> Result<(), OxideError> {
let content = text.into();
let embedding = self.embed_one(&content).await?;
self.documents.push(Document { content, embedding, metadata });
Ok(())
}
pub async fn add_file(&mut self, path: &Path) -> Result<usize, OxideError> {
let raw = tokio::fs::read_to_string(path)
.await
.map_err(|e| OxideError::Other(format!("read file: {e}")))?;
let file_name = path
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("")
.to_string();
let chunks: Vec<&str> = raw.split("\n\n").map(str::trim).filter(|s| !s.is_empty()).collect();
let count = chunks.len();
for (i, chunk) in chunks.into_iter().enumerate() {
let mut meta = HashMap::new();
meta.insert("source".into(), file_name.clone());
meta.insert("chunk".into(), i.to_string());
self.add_text(chunk, meta).await?;
}
Ok(count)
}
pub async fn query(
&self,
query: impl Into<String>,
top_k: usize,
) -> Result<Vec<SearchResult>, OxideError> {
let q_text = query.into();
let q_emb = self.embed_one(&q_text).await?;
let mut scored: Vec<(f32, &Document)> = self
.documents
.iter()
.map(|doc| (cosine_similarity(&q_emb, &doc.embedding), doc))
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
Ok(scored
.into_iter()
.take(top_k)
.map(|(score, doc)| SearchResult {
content: doc.content.clone(),
score,
metadata: doc.metadata.clone(),
})
.collect())
}
pub fn len(&self) -> usize {
self.documents.len()
}
pub fn is_empty(&self) -> bool {
self.documents.is_empty()
}
async fn embed_one(&self, text: &str) -> Result<Vec<f32>, OxideError> {
let resp = self
.client
.embed(EmbedRequest {
model: self.embed_model.clone(),
input: EmbedInput::Single(text.to_string()),
})
.await?;
resp.embeddings
.into_iter()
.next()
.ok_or_else(|| OxideError::Other("embed returned no vectors".into()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::{BoxStream, OllamaClient};
use crate::types::{
ChatRequest, ChatResponse, EmbedResponse, GenerateRequest, GenerateResponse,
ListModelsResponse,
};
use async_trait::async_trait;
struct FakeEmbedClient;
#[async_trait]
impl OllamaClient for FakeEmbedClient {
async fn generate(&self, _: GenerateRequest) -> Result<GenerateResponse, OxideError> {
unimplemented!()
}
async fn chat(&self, _: ChatRequest) -> Result<ChatResponse, OxideError> {
unimplemented!()
}
async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, OxideError> {
let text = match &req.input {
EmbedInput::Single(s) => s.clone(),
EmbedInput::Batch(v) => v[0].clone(),
};
let v = text.chars().next().map(|c| c as u8).unwrap_or(0) as f32;
Ok(EmbedResponse {
model: req.model,
embeddings: vec![vec![v, 1.0, 0.0, 0.0]],
})
}
async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
unimplemented!()
}
fn stream_generate(&self, _: GenerateRequest) -> BoxStream<GenerateResponse> {
unimplemented!()
}
fn stream_chat(&self, _: ChatRequest) -> BoxStream<ChatResponse> {
unimplemented!()
}
}
#[tokio::test]
async fn add_and_query_returns_ranked_results() {
let client = Arc::new(FakeEmbedClient);
let mut store = VectorStore::new(client, "test-model");
store.add_text("rust ownership model", Default::default()).await.unwrap();
store.add_text("python garbage collector", Default::default()).await.unwrap();
store.add_text("rustaceans love borrowing", Default::default()).await.unwrap();
assert_eq!(store.len(), 3);
let results = store.query("rust lifetimes", 2).await.unwrap();
assert_eq!(results.len(), 2);
assert!(results[0].content.starts_with('r'));
}
#[test]
fn cosine_similarity_identical_vectors() {
let v = vec![1.0_f32, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_orthogonal_vectors() {
let a = vec![1.0_f32, 0.0];
let b = vec![0.0_f32, 1.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
}