rag 0.1.0

A Rust library and CLI for Retrieval-Augmented Generation
Documentation
use serde::{Deserialize, Serialize};
use serde_json::json;

use crate::chunker::ParagraphChunker;
use crate::embeddings::{EmbeddingModel, OllamaEmbeddingModel, OpenAIEmbeddingModel};
use crate::retriever::Retriever;
use crate::vector_store::{InMemoryVectorStore, VectorStore};

#[derive(Debug, Serialize, Deserialize)]
pub struct McpRequest {
    pub jsonrpc: String,
    pub id: Option<serde_json::Value>,
    pub method: String,
    pub params: Option<serde_json::Value>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct McpResponse {
    pub jsonrpc: String,
    pub id: Option<serde_json::Value>,
    pub result: Option<serde_json::Value>,
    pub error: Option<McpError>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct McpError {
    pub code: i32,
    pub message: String,
}

pub enum McpServer {
    OpenAI(Retriever<OpenAIEmbeddingModel, InMemoryVectorStore>),
    Ollama(Retriever<OllamaEmbeddingModel, InMemoryVectorStore>),
}

impl McpServer {
    pub fn new_openai(api_key: String) -> Self {
        let embedding_model = OpenAIEmbeddingModel::new(api_key);
        let vector_store = InMemoryVectorStore::new();
        let retriever = Retriever::new(embedding_model, vector_store)
            .with_chunker(Box::new(ParagraphChunker))
            .with_top_k(5);
        Self::OpenAI(retriever)
    }

    pub fn new_ollama() -> Self {
        let embedding_model = OllamaEmbeddingModel::new("nomic-embed-text".to_string());
        let vector_store = InMemoryVectorStore::new();
        let retriever = Retriever::new(embedding_model, vector_store)
            .with_chunker(Box::new(ParagraphChunker))
            .with_top_k(5);
        Self::Ollama(retriever)
    }

    pub async fn handle_request(&self, request: McpRequest) -> McpResponse {
        let result: std::result::Result<serde_json::Value, McpError> = match request.method.as_str() {
            "initialize" => self.initialize(request.params).await,
            "tools/list" => self.list_tools().await,
            "tools/call" => self.call_tool(request.params).await,
            "ping" => Ok(json!({"status": "ok"})),
            _ => Err(McpError {
                code: -32601,
                message: format!("Method not found: {}", request.method),
            }),
        };

        match result {
            Ok(data) => McpResponse {
                jsonrpc: "2.0".to_string(),
                id: request.id,
                result: Some(data),
                error: None,
            },
            Err(error) => McpResponse {
                jsonrpc: "2.0".to_string(),
                id: request.id,
                result: None,
                error: Some(error),
            },
        }
    }

    async fn initialize(&self, _params: Option<serde_json::Value>) -> std::result::Result<serde_json::Value, McpError> {
        Ok(json!({
            "protocolVersion": "2024-11-05",
            "serverInfo": {
                "name": "rag-mcp-server",
                "version": "0.1.0"
            },
            "capabilities": {
                "tools": {
                    "listChanged": false
                }
            }
        }))
    }

    async fn list_tools(&self) -> std::result::Result<serde_json::Value, McpError> {
        Ok(json!({
            "tools": [
                {
                    "name": "rag_add_document",
                    "description": "Add a document to the RAG vector store",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "content": {
                                "type": "string",
                                "description": "The document content to add"
                            },
                            "source": {
                                "type": "string",
                                "description": "Optional source identifier for the document"
                            }
                        },
                        "required": ["content"]
                    }
                },
                {
                    "name": "rag_query",
                    "description": "Query the RAG vector store for relevant documents",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "query": {
                                "type": "string",
                                "description": "The search query"
                            },
                            "top_k": {
                                "type": "number",
                                "description": "Number of results to return (default: 5)"
                            }
                        },
                        "required": ["query"]
                    }
                },
                {
                    "name": "rag_list_documents",
                    "description": "List documents in the vector store",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "limit": {
                                "type": "number",
                                "description": "Maximum number of documents to return"
                            },
                            "offset": {
                                "type": "number",
                                "description": "Number of documents to skip"
                            }
                        }
                    }
                },
                {
                    "name": "rag_count",
                    "description": "Count total documents in the vector store",
                    "inputSchema": {
                        "type": "object",
                        "properties": {}
                    }
                }
            ]
        }))
    }

    async fn call_tool(&self, params: Option<serde_json::Value>) -> std::result::Result<serde_json::Value, McpError> {
        let params = params.ok_or_else(|| McpError {
            code: -32602,
            message: "Missing params".to_string(),
        })?;

        let tool_name = params
            .get("name")
            .and_then(|v| v.as_str())
            .ok_or_else(|| McpError {
                code: -32602,
                message: "Missing tool name".to_string(),
            })?;

        let arguments = params.get("arguments");

        match tool_name {
            "rag_add_document" => self.tool_add_document(arguments).await,
            "rag_query" => self.tool_query(arguments).await,
            "rag_list_documents" => self.tool_list_documents(arguments).await,
            "rag_count" => self.tool_count().await,
            _ => Err(McpError {
                code: -32601,
                message: format!("Unknown tool: {}", tool_name),
            }),
        }
    }

    async fn tool_add_document(&self, args: Option<&serde_json::Value>) -> std::result::Result<serde_json::Value, McpError> {
        let args = args.ok_or_else(|| McpError {
            code: -32602,
            message: "Missing arguments".to_string(),
        })?;

        let content = args
            .get("content")
            .and_then(|v| v.as_str())
            .ok_or_else(|| McpError {
                code: -32602,
                message: "Missing content".to_string(),
            })?;

        let source = args.get("source").and_then(|v| v.as_str()).unwrap_or("unknown");

        let metadata = vec![("source".to_string(), source.to_string())];

        let doc_ids = match self {
            Self::OpenAI(retriever) => {
                retriever
                    .add_document_with_metadata(content.to_string(), metadata)
                    .await
            }
            Self::Ollama(retriever) => {
                retriever
                    .add_document_with_metadata(content.to_string(), metadata)
                    .await
            }
        }
        .map_err(|e| McpError {
            code: -32603,
            message: format!("Failed to add document: {}", e),
        })?;

        Ok(json!({
            "success": true,
            "message": "Document added successfully",
            "chunk_ids": doc_ids
        }))
    }

    async fn tool_query(&self, args: Option<&serde_json::Value>) -> std::result::Result<serde_json::Value, McpError> {
        let args = args.ok_or_else(|| McpError {
            code: -32602,
            message: "Missing arguments".to_string(),
        })?;

        let query = args
            .get("query")
            .and_then(|v| v.as_str())
            .ok_or_else(|| McpError {
                code: -32602,
                message: "Missing query".to_string(),
            })?;

        let top_k = args
            .get("top_k")
            .and_then(|v| v.as_u64())
            .unwrap_or(5) as usize;

        let embedding = match self {
            Self::OpenAI(retriever) => {
                retriever.embedding_model().embed_single(query).await
            }
            Self::Ollama(retriever) => {
                retriever.embedding_model().embed_single(query).await
            }
        }
        .map_err(|e| McpError {
            code: -32603,
            message: format!("Embedding generation failed: {}", e),
        })?;

        let results_vec: Vec<crate::vector_store::Similarity> = match self {
            Self::OpenAI(retriever) => {
                retriever.vector_store().search(&embedding, top_k).await
            }
            Self::Ollama(retriever) => {
                retriever.vector_store().search(&embedding, top_k).await
            }
        }
        .map_err(|e| McpError {
            code: -32603,
            message: format!("Search failed: {}", e),
        })?;

        let results_json: Vec<_> = results_vec
            .into_iter()
            .enumerate()
            .map(|(i, similarity)| {
                json!({
                    "rank": i + 1,
                    "content": similarity.document.content,
                    "score": similarity.score
                })
            })
            .collect();

        Ok(json!({
            "query": query,
            "results": results_json
        }))
    }

    async fn tool_list_documents(&self, args: Option<&serde_json::Value>) -> std::result::Result<serde_json::Value, McpError> {
        let limit = args
            .and_then(|a| a.get("limit"))
            .and_then(|v| v.as_u64())
            .unwrap_or(10) as usize;

        let offset = args
            .and_then(|a| a.get("offset"))
            .and_then(|v| v.as_u64())
            .unwrap_or(0) as usize;

        let documents = match self {
            Self::OpenAI(retriever) => {
                retriever.vector_store().list(limit, offset).await
            }
            Self::Ollama(retriever) => {
                retriever.vector_store().list(limit, offset).await
            }
        }
        .map_err(|e| McpError {
            code: -32603,
            message: format!("Failed to list documents: {}", e),
        })?;

        let docs_json: Vec<_> = documents
            .into_iter()
            .map(|doc| {
                json!({
                    "id": doc.id,
                    "content": doc.content.chars().take(200).collect::<String>() + "...",
                    "metadata": doc.metadata
                })
            })
            .collect();

        Ok(json!({
            "documents": docs_json
        }))
    }

    async fn tool_count(&self) -> std::result::Result<serde_json::Value, McpError> {
        let count = match self {
            Self::OpenAI(retriever) => {
                retriever.vector_store().count().await
            }
            Self::Ollama(retriever) => {
                retriever.vector_store().count().await
            }
        }
        .map_err(|e| McpError {
            code: -32603,
            message: format!("Failed to count documents: {}", e),
        })?;

        Ok(json!({
            "total_documents": count
        }))
    }
}

impl From<crate::errors::RagError> for McpError {
    fn from(err: crate::errors::RagError) -> Self {
        McpError {
            code: -32603,
            message: err.to_string(),
        }
    }
}