oxide-agent 0.1.0

Type-safe, high-performance Rust crate for building agentic systems on Ollama
Documentation
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;

use crate::client::OllamaClient;
use crate::error::OxideError;
use crate::types::{EmbedInput, EmbedRequest};

// ── Math helpers ──────────────────────────────────────────────────────────────

fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
    if norm_a == 0.0 || norm_b == 0.0 {
        return 0.0;
    }
    dot / (norm_a * norm_b)
}

// ── Types ─────────────────────────────────────────────────────────────────────

#[derive(Debug, Clone)]
pub struct Document {
    pub content: String,
    pub embedding: Vec<f32>,
    pub metadata: HashMap<String, String>,
}

#[derive(Debug, Clone)]
pub struct SearchResult {
    pub content: String,
    pub score: f32,
    pub metadata: HashMap<String, String>,
}

// ── VectorStore ───────────────────────────────────────────────────────────────

/// In-memory vector store for retrieval-augmented generation.
///
/// Uses the Ollama `/api/embed` endpoint to compute embeddings and cosine
/// similarity for nearest-neighbour search.
///
/// ```no_run
/// use std::sync::Arc;
/// use oxide_agent::rag::VectorStore;
/// use oxide_agent::client::HttpOllamaClient;
///
/// # async fn example() -> anyhow::Result<()> {
/// let client = Arc::new(HttpOllamaClient::new("http://localhost:11434"));
/// let mut store = VectorStore::new(client, "nomic-embed-text");
///
/// store.add_text("Rust ownership means one owner at a time.", Default::default()).await?;
/// let results = store.query("Who owns memory in Rust?", 3).await?;
/// println!("{}", results[0].content);
/// # Ok(())
/// # }
/// ```
pub struct VectorStore {
    client: Arc<dyn OllamaClient>,
    embed_model: String,
    documents: Vec<Document>,
}

impl VectorStore {
    pub fn new<C: OllamaClient + 'static>(client: Arc<C>, embed_model: impl Into<String>) -> Self {
        let client: Arc<dyn OllamaClient> = client;
        Self {
            client,
            embed_model: embed_model.into(),
            documents: Vec::new(),
        }
    }

    /// Embed a single string and add it to the store.
    pub async fn add_text(
        &mut self,
        text: impl Into<String>,
        metadata: HashMap<String, String>,
    ) -> Result<(), OxideError> {
        let content = text.into();
        let embedding = self.embed_one(&content).await?;
        self.documents.push(Document { content, embedding, metadata });
        Ok(())
    }

    /// Read a UTF-8 text file and add each non-empty line as a separate document.
    pub async fn add_file(&mut self, path: &Path) -> Result<usize, OxideError> {
        let raw = tokio::fs::read_to_string(path)
            .await
            .map_err(|e| OxideError::Other(format!("read file: {e}")))?;

        let file_name = path
            .file_name()
            .and_then(|s| s.to_str())
            .unwrap_or("")
            .to_string();

        // Chunk by paragraph (double newline) for better context.
        let chunks: Vec<&str> = raw.split("\n\n").map(str::trim).filter(|s| !s.is_empty()).collect();
        let count = chunks.len();

        for (i, chunk) in chunks.into_iter().enumerate() {
            let mut meta = HashMap::new();
            meta.insert("source".into(), file_name.clone());
            meta.insert("chunk".into(), i.to_string());
            self.add_text(chunk, meta).await?;
        }

        Ok(count)
    }

    /// Return the `top_k` most similar documents to `query`, ranked by
    /// cosine similarity (highest first).
    pub async fn query(
        &self,
        query: impl Into<String>,
        top_k: usize,
    ) -> Result<Vec<SearchResult>, OxideError> {
        let q_text = query.into();
        let q_emb = self.embed_one(&q_text).await?;

        let mut scored: Vec<(f32, &Document)> = self
            .documents
            .iter()
            .map(|doc| (cosine_similarity(&q_emb, &doc.embedding), doc))
            .collect();

        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));

        Ok(scored
            .into_iter()
            .take(top_k)
            .map(|(score, doc)| SearchResult {
                content: doc.content.clone(),
                score,
                metadata: doc.metadata.clone(),
            })
            .collect())
    }

    /// Number of documents in the store.
    pub fn len(&self) -> usize {
        self.documents.len()
    }

    pub fn is_empty(&self) -> bool {
        self.documents.is_empty()
    }

    // ── Internals ─────────────────────────────────────────────────────────────

    async fn embed_one(&self, text: &str) -> Result<Vec<f32>, OxideError> {
        let resp = self
            .client
            .embed(EmbedRequest {
                model: self.embed_model.clone(),
                input: EmbedInput::Single(text.to_string()),
            })
            .await?;

        resp.embeddings
            .into_iter()
            .next()
            .ok_or_else(|| OxideError::Other("embed returned no vectors".into()))
    }
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;
    use crate::client::{BoxStream, OllamaClient};
    use crate::types::{
        ChatRequest, ChatResponse, EmbedResponse, GenerateRequest, GenerateResponse,
        ListModelsResponse,
    };
    use async_trait::async_trait;

    /// Returns a fixed-length embedding based on the first character's ASCII
    /// value so similar strings score higher than dissimilar ones.
    struct FakeEmbedClient;

    #[async_trait]
    impl OllamaClient for FakeEmbedClient {
        async fn generate(&self, _: GenerateRequest) -> Result<GenerateResponse, OxideError> {
            unimplemented!()
        }
        async fn chat(&self, _: ChatRequest) -> Result<ChatResponse, OxideError> {
            unimplemented!()
        }
        async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, OxideError> {
            let text = match &req.input {
                EmbedInput::Single(s) => s.clone(),
                EmbedInput::Batch(v) => v[0].clone(),
            };
            // Produce a 4-D embedding where dim[0] is proportional to the
            // first char so "rust" and "rustacean" end up close together.
            let v = text.chars().next().map(|c| c as u8).unwrap_or(0) as f32;
            Ok(EmbedResponse {
                model: req.model,
                embeddings: vec![vec![v, 1.0, 0.0, 0.0]],
            })
        }
        async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
            unimplemented!()
        }
        fn stream_generate(&self, _: GenerateRequest) -> BoxStream<GenerateResponse> {
            unimplemented!()
        }
        fn stream_chat(&self, _: ChatRequest) -> BoxStream<ChatResponse> {
            unimplemented!()
        }
    }

    #[tokio::test]
    async fn add_and_query_returns_ranked_results() {
        let client = Arc::new(FakeEmbedClient);
        let mut store = VectorStore::new(client, "test-model");

        store.add_text("rust ownership model", Default::default()).await.unwrap();
        store.add_text("python garbage collector", Default::default()).await.unwrap();
        store.add_text("rustaceans love borrowing", Default::default()).await.unwrap();

        assert_eq!(store.len(), 3);

        // Query starts with 'r' — the two "rust*" docs should rank highest.
        let results = store.query("rust lifetimes", 2).await.unwrap();
        assert_eq!(results.len(), 2);
        // Both top results should have 'r' as first char.
        assert!(results[0].content.starts_with('r'));
    }

    #[test]
    fn cosine_similarity_identical_vectors() {
        let v = vec![1.0_f32, 2.0, 3.0];
        let sim = cosine_similarity(&v, &v);
        assert!((sim - 1.0).abs() < 1e-6);
    }

    #[test]
    fn cosine_similarity_orthogonal_vectors() {
        let a = vec![1.0_f32, 0.0];
        let b = vec![0.0_f32, 1.0];
        let sim = cosine_similarity(&a, &b);
        assert!(sim.abs() < 1e-6);
    }
}