helios_engine/
rag_tool.rs

1//! # RAG Tool Implementation
2//!
3//! Provides a Tool implementation that wraps the RAG system for agent use.
4
5use crate::error::{HeliosError, Result};
6use crate::rag::{
7    InMemoryVectorStore, OpenAIEmbeddings, QdrantVectorStore, RAGSystem, SearchResult,
8};
9use crate::tools::{Tool, ToolParameter, ToolResult};
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14/// RAG Tool with flexible backend support
15#[derive(Clone)]
16pub struct RAGTool {
17    rag_system: std::sync::Arc<RAGSystem>,
18    backend_type: String,
19}
20
21impl RAGTool {
22    /// Create a new RAG tool with in-memory vector store
23    pub fn new_in_memory(
24        embedding_api_url: impl Into<String>,
25        embedding_api_key: impl Into<String>,
26    ) -> Self {
27        let embeddings = OpenAIEmbeddings::new(embedding_api_url, embedding_api_key);
28        let vector_store = InMemoryVectorStore::new();
29        let rag_system = RAGSystem::new(Box::new(embeddings), Box::new(vector_store));
30
31        Self {
32            rag_system: std::sync::Arc::new(rag_system),
33            backend_type: "in-memory".to_string(),
34        }
35    }
36
37    /// Create a new RAG tool with Qdrant vector store
38    pub fn new_qdrant(
39        qdrant_url: impl Into<String>,
40        collection_name: impl Into<String>,
41        embedding_api_url: impl Into<String>,
42        embedding_api_key: impl Into<String>,
43    ) -> Self {
44        let embeddings = OpenAIEmbeddings::new(embedding_api_url, embedding_api_key);
45        let vector_store = QdrantVectorStore::new(qdrant_url, collection_name);
46        let rag_system = RAGSystem::new(Box::new(embeddings), Box::new(vector_store));
47
48        Self {
49            rag_system: std::sync::Arc::new(rag_system),
50            backend_type: "qdrant".to_string(),
51        }
52    }
53
54    /// Create with a custom RAG system
55    pub fn with_rag_system(rag_system: RAGSystem, backend_type: impl Into<String>) -> Self {
56        Self {
57            rag_system: std::sync::Arc::new(rag_system),
58            backend_type: backend_type.into(),
59        }
60    }
61
62    /// Format search results for display
63    fn format_results(&self, results: &[SearchResult]) -> String {
64        if results.is_empty() {
65            return "No matching documents found".to_string();
66        }
67
68        let formatted_results: Vec<String> = results
69            .iter()
70            .enumerate()
71            .map(|(i, result)| {
72                let preview = if result.text.len() > 200 {
73                    format!("{}...", &result.text[..200])
74                } else {
75                    result.text.clone()
76                };
77
78                format!(
79                    "{}. [Score: {:.4}] {}\n   ID: {}",
80                    i + 1,
81                    result.score,
82                    preview,
83                    result.id
84                )
85            })
86            .collect();
87
88        format!(
89            "Found {} result(s):\n\n{}",
90            results.len(),
91            formatted_results.join("\n\n")
92        )
93    }
94}
95
96#[async_trait]
97impl Tool for RAGTool {
98    fn name(&self) -> &str {
99        "rag"
100    }
101
102    fn description(&self) -> &str {
103        "RAG (Retrieval-Augmented Generation) tool for document storage and semantic search. \
104         Operations: add_document, search, delete, clear, count"
105    }
106
107    fn parameters(&self) -> HashMap<String, ToolParameter> {
108        let mut params = HashMap::new();
109        params.insert(
110            "operation".to_string(),
111            ToolParameter {
112                param_type: "string".to_string(),
113                description: "Operation: 'add_document', 'search', 'delete', 'clear', 'count'"
114                    .to_string(),
115                required: Some(true),
116            },
117        );
118        params.insert(
119            "text".to_string(),
120            ToolParameter {
121                param_type: "string".to_string(),
122                description: "Text content for add_document or search query".to_string(),
123                required: Some(false),
124            },
125        );
126        params.insert(
127            "doc_id".to_string(),
128            ToolParameter {
129                param_type: "string".to_string(),
130                description: "Document ID for delete operation".to_string(),
131                required: Some(false),
132            },
133        );
134        params.insert(
135            "limit".to_string(),
136            ToolParameter {
137                param_type: "number".to_string(),
138                description: "Number of results for search (default: 5)".to_string(),
139                required: Some(false),
140            },
141        );
142        params.insert(
143            "metadata".to_string(),
144            ToolParameter {
145                param_type: "object".to_string(),
146                description: "Additional metadata for the document (JSON object)".to_string(),
147                required: Some(false),
148            },
149        );
150        params
151    }
152
153    async fn execute(&self, args: Value) -> Result<ToolResult> {
154        let operation = args
155            .get("operation")
156            .and_then(|v| v.as_str())
157            .ok_or_else(|| HeliosError::ToolError("Missing 'operation' parameter".to_string()))?;
158
159        match operation {
160            "add_document" => {
161                let text = args.get("text").and_then(|v| v.as_str()).ok_or_else(|| {
162                    HeliosError::ToolError("Missing 'text' for add_document".to_string())
163                })?;
164
165                let metadata: Option<HashMap<String, serde_json::Value>> = args
166                    .get("metadata")
167                    .and_then(|v| serde_json::from_value(v.clone()).ok());
168
169                let doc_id = self.rag_system.add_document(text, metadata).await?;
170
171                let preview = if text.len() > 100 {
172                    format!("{}...", &text[..100])
173                } else {
174                    text.to_string()
175                };
176
177                Ok(ToolResult::success(format!(
178                    "✓ Document added successfully (backend: {})\nID: {}\nText preview: {}",
179                    self.backend_type, doc_id, preview
180                )))
181            }
182            "search" => {
183                let query = args.get("text").and_then(|v| v.as_str()).ok_or_else(|| {
184                    HeliosError::ToolError("Missing 'text' for search".to_string())
185                })?;
186
187                let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(5) as usize;
188
189                let results = self.rag_system.search(query, limit).await?;
190                Ok(ToolResult::success(self.format_results(&results)))
191            }
192            "delete" => {
193                let doc_id = args.get("doc_id").and_then(|v| v.as_str()).ok_or_else(|| {
194                    HeliosError::ToolError("Missing 'doc_id' for delete".to_string())
195                })?;
196
197                self.rag_system.delete_document(doc_id).await?;
198                Ok(ToolResult::success(format!(
199                    "✓ Document '{}' deleted",
200                    doc_id
201                )))
202            }
203            "clear" => {
204                self.rag_system.clear().await?;
205                Ok(ToolResult::success(
206                    "✓ All documents cleared from collection".to_string(),
207                ))
208            }
209            "count" => {
210                let count = self.rag_system.count().await?;
211                Ok(ToolResult::success(format!(
212                    "Document count: {} (backend: {})",
213                    count, self.backend_type
214                )))
215            }
216            _ => Err(HeliosError::ToolError(format!(
217                "Unknown operation '{}'. Valid: add_document, search, delete, clear, count",
218                operation
219            ))),
220        }
221    }
222}