rag 0.1.0

A Rust library and CLI for Retrieval-Augmented Generation
Documentation
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};

use crate::errors::{RagError, Result};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
    pub text: String,
    pub model: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
    pub embedding: Vec<f32>,
    pub model: String,
}

#[async_trait]
pub trait EmbeddingModel: Send + Sync {
    async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
    async fn embed_single(&self, text: &str) -> Result<Vec<f32>> {
        let embeddings = self.embed(vec![text.to_string()]).await?;
        Ok(embeddings.into_iter().next().ok_or(RagError::EmbeddingError(
            "No embedding returned".to_string(),
        ))?)
    }
}

#[derive(Clone)]
pub struct OpenAIEmbeddingModel {
    client: Client,
    api_key: String,
    model: String,
    base_url: String,
}

impl OpenAIEmbeddingModel {
    pub fn new(api_key: String) -> Self {
        Self::with_model(api_key, "text-embedding-ada-002".to_string())
    }

    pub fn with_model(api_key: String, model: String) -> Self {
        Self {
            client: Client::new(),
            api_key,
            model,
            base_url: "https://api.openai.com/v1".to_string(),
        }
    }

    pub fn with_base_url(mut self, base_url: String) -> Self {
        self.base_url = base_url;
        self
    }
}

#[derive(Debug, Serialize)]
struct OpenAIRequest {
    input: Vec<String>,
    model: String,
}

#[derive(Debug, Deserialize)]
struct OpenAIResponse {
    data: Vec<OpenAIEmbeddingData>,
    #[allow(dead_code)]
    model: String,
}

#[derive(Debug, Deserialize)]
struct OpenAIEmbeddingData {
    embedding: Vec<f32>,
}

#[async_trait]
impl EmbeddingModel for OpenAIEmbeddingModel {
    async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
        let request = OpenAIRequest {
            input: texts.clone(),
            model: self.model.clone(),
        };

        let response = self
            .client
            .post(&format!("{}/embeddings", self.base_url))
            .header("Authorization", format!("Bearer {}", self.api_key))
            .json(&request)
            .send()
            .await?;

        if !response.status().is_success() {
            let error_text = response.text().await?;
            return Err(RagError::EmbeddingError(error_text));
        }

        let openai_response: OpenAIResponse = response.json().await?;

        Ok(openai_response.data.into_iter().map(|d| d.embedding).collect())
    }
}

#[derive(Clone)]
pub struct OllamaEmbeddingModel {
    client: Client,
    model: String,
    base_url: String,
}

impl OllamaEmbeddingModel {
    pub fn new(model: String) -> Self {
        Self {
            client: Client::new(),
            model,
            base_url: "http://localhost:11434".to_string(),
        }
    }

    pub fn with_base_url(mut self, base_url: String) -> Self {
        self.base_url = base_url;
        self
    }
}

#[derive(Debug, Serialize)]
struct OllamaRequest {
    model: String,
    prompt: String,
}

#[derive(Debug, Deserialize)]
struct OllamaResponse {
    embedding: Vec<f32>,
}

#[async_trait]
impl EmbeddingModel for OllamaEmbeddingModel {
    async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
        let mut embeddings = Vec::new();

        for text in texts {
            let request = OllamaRequest {
                model: self.model.clone(),
                prompt: text,
            };

            let response = self
                .client
                .post(&format!("{}/api/embeddings", self.base_url))
                .json(&request)
                .send()
                .await?;

            if !response.status().is_success() {
                let error_text = response.text().await?;
                return Err(RagError::EmbeddingError(error_text));
            }

            let ollama_response: OllamaResponse = response.json().await?;
            embeddings.push(ollama_response.embedding);
        }

        Ok(embeddings)
    }
}