1use crate::error::{HeliosError, Result};
6use crate::rag::{
7 InMemoryVectorStore, OpenAIEmbeddings, QdrantVectorStore, RAGSystem, SearchResult,
8};
9use crate::tools::{Tool, ToolParameter, ToolResult};
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14#[derive(Clone)]
16pub struct RAGTool {
17 rag_system: std::sync::Arc<RAGSystem>,
18 backend_type: String,
19}
20
21impl RAGTool {
22 pub fn new_in_memory(
24 embedding_api_url: impl Into<String>,
25 embedding_api_key: impl Into<String>,
26 ) -> Self {
27 let embeddings = OpenAIEmbeddings::new(embedding_api_url, embedding_api_key);
28 let vector_store = InMemoryVectorStore::new();
29 let rag_system = RAGSystem::new(Box::new(embeddings), Box::new(vector_store));
30
31 Self {
32 rag_system: std::sync::Arc::new(rag_system),
33 backend_type: "in-memory".to_string(),
34 }
35 }
36
37 pub fn new_qdrant(
39 qdrant_url: impl Into<String>,
40 collection_name: impl Into<String>,
41 embedding_api_url: impl Into<String>,
42 embedding_api_key: impl Into<String>,
43 ) -> Self {
44 let embeddings = OpenAIEmbeddings::new(embedding_api_url, embedding_api_key);
45 let vector_store = QdrantVectorStore::new(qdrant_url, collection_name);
46 let rag_system = RAGSystem::new(Box::new(embeddings), Box::new(vector_store));
47
48 Self {
49 rag_system: std::sync::Arc::new(rag_system),
50 backend_type: "qdrant".to_string(),
51 }
52 }
53
54 pub fn with_rag_system(rag_system: RAGSystem, backend_type: impl Into<String>) -> Self {
56 Self {
57 rag_system: std::sync::Arc::new(rag_system),
58 backend_type: backend_type.into(),
59 }
60 }
61
62 fn format_results(&self, results: &[SearchResult]) -> String {
64 if results.is_empty() {
65 return "No matching documents found".to_string();
66 }
67
68 let formatted_results: Vec<String> = results
69 .iter()
70 .enumerate()
71 .map(|(i, result)| {
72 let preview = if result.text.len() > 200 {
73 format!("{}...", &result.text[..200])
74 } else {
75 result.text.clone()
76 };
77
78 format!(
79 "{}. [Score: {:.4}] {}\n ID: {}",
80 i + 1,
81 result.score,
82 preview,
83 result.id
84 )
85 })
86 .collect();
87
88 format!(
89 "Found {} result(s):\n\n{}",
90 results.len(),
91 formatted_results.join("\n\n")
92 )
93 }
94}
95
96#[async_trait]
97impl Tool for RAGTool {
98 fn name(&self) -> &str {
99 "rag"
100 }
101
102 fn description(&self) -> &str {
103 "RAG (Retrieval-Augmented Generation) tool for document storage and semantic search. \
104 Operations: add_document, search, delete, clear, count"
105 }
106
107 fn parameters(&self) -> HashMap<String, ToolParameter> {
108 let mut params = HashMap::new();
109 params.insert(
110 "operation".to_string(),
111 ToolParameter {
112 param_type: "string".to_string(),
113 description: "Operation: 'add_document', 'search', 'delete', 'clear', 'count'"
114 .to_string(),
115 required: Some(true),
116 },
117 );
118 params.insert(
119 "text".to_string(),
120 ToolParameter {
121 param_type: "string".to_string(),
122 description: "Text content for add_document or search query".to_string(),
123 required: Some(false),
124 },
125 );
126 params.insert(
127 "doc_id".to_string(),
128 ToolParameter {
129 param_type: "string".to_string(),
130 description: "Document ID for delete operation".to_string(),
131 required: Some(false),
132 },
133 );
134 params.insert(
135 "limit".to_string(),
136 ToolParameter {
137 param_type: "number".to_string(),
138 description: "Number of results for search (default: 5)".to_string(),
139 required: Some(false),
140 },
141 );
142 params.insert(
143 "metadata".to_string(),
144 ToolParameter {
145 param_type: "object".to_string(),
146 description: "Additional metadata for the document (JSON object)".to_string(),
147 required: Some(false),
148 },
149 );
150 params
151 }
152
153 async fn execute(&self, args: Value) -> Result<ToolResult> {
154 let operation = args
155 .get("operation")
156 .and_then(|v| v.as_str())
157 .ok_or_else(|| HeliosError::ToolError("Missing 'operation' parameter".to_string()))?;
158
159 match operation {
160 "add_document" => {
161 let text = args.get("text").and_then(|v| v.as_str()).ok_or_else(|| {
162 HeliosError::ToolError("Missing 'text' for add_document".to_string())
163 })?;
164
165 let metadata: Option<HashMap<String, serde_json::Value>> = args
166 .get("metadata")
167 .and_then(|v| serde_json::from_value(v.clone()).ok());
168
169 let doc_id = self.rag_system.add_document(text, metadata).await?;
170
171 let preview = if text.len() > 100 {
172 format!("{}...", &text[..100])
173 } else {
174 text.to_string()
175 };
176
177 Ok(ToolResult::success(format!(
178 "✓ Document added successfully (backend: {})\nID: {}\nText preview: {}",
179 self.backend_type, doc_id, preview
180 )))
181 }
182 "search" => {
183 let query = args.get("text").and_then(|v| v.as_str()).ok_or_else(|| {
184 HeliosError::ToolError("Missing 'text' for search".to_string())
185 })?;
186
187 let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(5) as usize;
188
189 let results = self.rag_system.search(query, limit).await?;
190 Ok(ToolResult::success(self.format_results(&results)))
191 }
192 "delete" => {
193 let doc_id = args.get("doc_id").and_then(|v| v.as_str()).ok_or_else(|| {
194 HeliosError::ToolError("Missing 'doc_id' for delete".to_string())
195 })?;
196
197 self.rag_system.delete_document(doc_id).await?;
198 Ok(ToolResult::success(format!(
199 "✓ Document '{}' deleted",
200 doc_id
201 )))
202 }
203 "clear" => {
204 self.rag_system.clear().await?;
205 Ok(ToolResult::success(
206 "✓ All documents cleared from collection".to_string(),
207 ))
208 }
209 "count" => {
210 let count = self.rag_system.count().await?;
211 Ok(ToolResult::success(format!(
212 "Document count: {} (backend: {})",
213 count, self.backend_type
214 )))
215 }
216 _ => Err(HeliosError::ToolError(format!(
217 "Unknown operation '{}'. Valid: add_document, search, delete, clear, count",
218 operation
219 ))),
220 }
221 }
222}