helios_engine/
rag.rs

1//! # RAG (Retrieval-Augmented Generation) Module
2//!
3//! This module provides a flexible RAG system with:
4//! - Multiple vector store backends (in-memory, Qdrant)
5//! - Embedding generation (OpenAI API, local models)
6//! - Document chunking and preprocessing
7//! - Semantic search and retrieval
8//! - Reranking capabilities
9
10use crate::error::{HeliosError, Result};
11use async_trait::async_trait;
12use reqwest::Client;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use uuid::Uuid;
16
17// ============================================================================
18// Core Types and Traits
19// ============================================================================
20
21/// Represents a document in the RAG system
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct Document {
24    /// Unique document identifier
25    pub id: String,
26    /// The text content of the document
27    pub text: String,
28    /// Optional metadata associated with the document
29    pub metadata: HashMap<String, serde_json::Value>,
30    /// Timestamp when the document was added
31    pub timestamp: String,
32}
33
34/// Represents a search result from the RAG system
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct SearchResult {
37    /// Document ID
38    pub id: String,
39    /// Similarity score (0.0 to 1.0, higher is better)
40    pub score: f64,
41    /// The document text
42    pub text: String,
43    /// Optional metadata
44    pub metadata: Option<HashMap<String, serde_json::Value>>,
45}
46
47// ============================================================================
48// Embedding Provider Trait
49// ============================================================================
50
51/// Trait for embedding generation
52#[async_trait]
53pub trait EmbeddingProvider: Send + Sync {
54    /// Generate embeddings for the given text
55    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
56
57    /// Get the dimension of embeddings produced by this provider
58    fn dimension(&self) -> usize;
59}
60
61// ============================================================================
62// Vector Store Trait
63// ============================================================================
64
65/// Trait for vector storage backends
66#[async_trait]
67pub trait VectorStore: Send + Sync {
68    /// Initialize the vector store (create collections, etc.)
69    async fn initialize(&self, dimension: usize) -> Result<()>;
70
71    /// Add a document with its embedding
72    async fn add(
73        &self,
74        id: &str,
75        embedding: Vec<f32>,
76        text: &str,
77        metadata: HashMap<String, serde_json::Value>,
78    ) -> Result<()>;
79
80    /// Search for similar documents
81    async fn search(&self, query_embedding: Vec<f32>, limit: usize) -> Result<Vec<SearchResult>>;
82
83    /// Delete a document by ID
84    async fn delete(&self, id: &str) -> Result<()>;
85
86    /// Clear all documents
87    async fn clear(&self) -> Result<()>;
88
89    /// Get document count
90    async fn count(&self) -> Result<usize>;
91}
92
93// ============================================================================
94// OpenAI Embedding Provider
95// ============================================================================
96
97/// OpenAI embedding provider using text-embedding-ada-002 or text-embedding-3-small
98pub struct OpenAIEmbeddings {
99    api_url: String,
100    api_key: String,
101    model: String,
102    client: Client,
103}
104
105#[derive(Debug, Serialize)]
106struct OpenAIEmbeddingRequest {
107    input: String,
108    model: String,
109}
110
111#[derive(Debug, Deserialize)]
112struct OpenAIEmbeddingResponse {
113    data: Vec<OpenAIEmbeddingData>,
114}
115
116#[derive(Debug, Deserialize)]
117struct OpenAIEmbeddingData {
118    embedding: Vec<f32>,
119}
120
121impl OpenAIEmbeddings {
122    /// Create a new OpenAI embeddings provider
123    pub fn new(api_url: impl Into<String>, api_key: impl Into<String>) -> Self {
124        Self {
125            api_url: api_url.into(),
126            api_key: api_key.into(),
127            model: "text-embedding-ada-002".to_string(),
128            client: Client::new(),
129        }
130    }
131
132    /// Create with a specific model
133    pub fn with_model(
134        api_url: impl Into<String>,
135        api_key: impl Into<String>,
136        model: impl Into<String>,
137    ) -> Self {
138        Self {
139            api_url: api_url.into(),
140            api_key: api_key.into(),
141            model: model.into(),
142            client: Client::new(),
143        }
144    }
145}
146
147#[async_trait]
148impl EmbeddingProvider for OpenAIEmbeddings {
149    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
150        let request = OpenAIEmbeddingRequest {
151            input: text.to_string(),
152            model: self.model.clone(),
153        };
154
155        let response = self
156            .client
157            .post(&self.api_url)
158            .header("Authorization", format!("Bearer {}", self.api_key))
159            .json(&request)
160            .send()
161            .await
162            .map_err(|e| HeliosError::ToolError(format!("Embedding API error: {}", e)))?;
163
164        if !response.status().is_success() {
165            let error_text = response
166                .text()
167                .await
168                .unwrap_or_else(|_| "Unknown error".to_string());
169            return Err(HeliosError::ToolError(format!(
170                "Embedding API failed: {}",
171                error_text
172            )));
173        }
174
175        let embedding_response: OpenAIEmbeddingResponse = response.json().await.map_err(|e| {
176            HeliosError::ToolError(format!("Failed to parse embedding response: {}", e))
177        })?;
178
179        embedding_response
180            .data
181            .into_iter()
182            .next()
183            .map(|d| d.embedding)
184            .ok_or_else(|| HeliosError::ToolError("No embedding returned".to_string()))
185    }
186
187    fn dimension(&self) -> usize {
188        // text-embedding-ada-002 produces 1536-dimensional embeddings
189        // text-embedding-3-small produces 1536 by default
190        // text-embedding-3-large produces 3072 by default
191        if self.model == "text-embedding-3-large" {
192            3072
193        } else {
194            1536
195        }
196    }
197}
198
199// ============================================================================
200// In-Memory Vector Store
201// ============================================================================
202
203/// In-memory vector store using cosine similarity
204pub struct InMemoryVectorStore {
205    documents:
206        std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<String, StoredDocument>>>,
207}
208
209#[derive(Debug, Clone)]
210struct StoredDocument {
211    id: String,
212    embedding: Vec<f32>,
213    text: String,
214    metadata: HashMap<String, serde_json::Value>,
215}
216
217impl InMemoryVectorStore {
218    /// Create a new in-memory vector store
219    pub fn new() -> Self {
220        Self {
221            documents: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
222        }
223    }
224}
225
226impl Default for InMemoryVectorStore {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232/// Calculate cosine similarity between two vectors
233fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
234    if a.len() != b.len() {
235        return 0.0;
236    }
237
238    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
239    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
240    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
241
242    if norm_a == 0.0 || norm_b == 0.0 {
243        return 0.0;
244    }
245
246    (dot_product / (norm_a * norm_b)) as f64
247}
248
249#[async_trait]
250impl VectorStore for InMemoryVectorStore {
251    async fn initialize(&self, _dimension: usize) -> Result<()> {
252        // No initialization needed for in-memory store
253        Ok(())
254    }
255
256    async fn add(
257        &self,
258        id: &str,
259        embedding: Vec<f32>,
260        text: &str,
261        metadata: HashMap<String, serde_json::Value>,
262    ) -> Result<()> {
263        let mut docs = self.documents.write().await;
264
265        // Insert or update document with same ID
266        docs.insert(
267            id.to_string(),
268            StoredDocument {
269                id: id.to_string(),
270                embedding,
271                text: text.to_string(),
272                metadata,
273            },
274        );
275
276        Ok(())
277    }
278
279    async fn search(&self, query_embedding: Vec<f32>, limit: usize) -> Result<Vec<SearchResult>> {
280        let docs = self.documents.read().await;
281
282        if docs.is_empty() {
283            return Ok(Vec::new());
284        }
285
286        // Calculate similarities for all documents
287        let mut results: Vec<(String, f64)> = docs
288            .iter()
289            .map(|(id, doc)| {
290                let similarity = cosine_similarity(&query_embedding, &doc.embedding);
291                (id.clone(), similarity)
292            })
293            .collect();
294
295        // Sort by similarity (descending)
296        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
297
298        // Take top results
299        let top_results: Vec<SearchResult> = results
300            .into_iter()
301            .take(limit)
302            .filter_map(|(id, score)| {
303                docs.get(&id).map(|doc| SearchResult {
304                    id: doc.id.clone(),
305                    score,
306                    text: doc.text.clone(),
307                    metadata: Some(doc.metadata.clone()),
308                })
309            })
310            .collect();
311
312        Ok(top_results)
313    }
314
315    async fn delete(&self, id: &str) -> Result<()> {
316        let mut docs = self.documents.write().await;
317        docs.remove(id);
318        Ok(())
319    }
320
321    async fn clear(&self) -> Result<()> {
322        let mut docs = self.documents.write().await;
323        docs.clear();
324        Ok(())
325    }
326
327    async fn count(&self) -> Result<usize> {
328        let docs = self.documents.read().await;
329        Ok(docs.len())
330    }
331}
332
333// ============================================================================
334// Qdrant Vector Store
335// ============================================================================
336
337/// Qdrant vector store implementation
338pub struct QdrantVectorStore {
339    qdrant_url: String,
340    collection_name: String,
341    client: Client,
342}
343
344#[derive(Debug, Serialize, Deserialize)]
345struct QdrantPoint {
346    id: String,
347    vector: Vec<f32>,
348    payload: HashMap<String, serde_json::Value>,
349}
350
351#[derive(Debug, Serialize, Deserialize)]
352struct QdrantSearchRequest {
353    vector: Vec<f32>,
354    limit: usize,
355    with_payload: bool,
356    with_vector: bool,
357}
358
359#[derive(Debug, Serialize, Deserialize)]
360struct QdrantSearchResponse {
361    result: Vec<QdrantSearchResult>,
362}
363
364#[derive(Debug, Serialize, Deserialize)]
365struct QdrantSearchResult {
366    id: String,
367    score: f64,
368    payload: Option<HashMap<String, serde_json::Value>>,
369}
370
371impl QdrantVectorStore {
372    /// Create a new Qdrant vector store
373    pub fn new(qdrant_url: impl Into<String>, collection_name: impl Into<String>) -> Self {
374        Self {
375            qdrant_url: qdrant_url.into(),
376            collection_name: collection_name.into(),
377            client: Client::new(),
378        }
379    }
380}
381
382#[async_trait]
383impl VectorStore for QdrantVectorStore {
384    async fn initialize(&self, dimension: usize) -> Result<()> {
385        let collection_url = format!("{}/collections/{}", self.qdrant_url, self.collection_name);
386
387        // Check if collection exists
388        let response = self.client.get(&collection_url).send().await;
389
390        if response.is_ok() && response.unwrap().status().is_success() {
391            return Ok(()); // Collection exists
392        }
393
394        // Create collection
395        let create_payload = serde_json::json!({
396            "vectors": {
397                "size": dimension,
398                "distance": "Cosine"
399            }
400        });
401
402        let response = self
403            .client
404            .put(&collection_url)
405            .json(&create_payload)
406            .send()
407            .await
408            .map_err(|e| HeliosError::ToolError(format!("Failed to create collection: {}", e)))?;
409
410        if !response.status().is_success() {
411            let error_text = response
412                .text()
413                .await
414                .unwrap_or_else(|_| "Unknown error".to_string());
415            return Err(HeliosError::ToolError(format!(
416                "Collection creation failed: {}",
417                error_text
418            )));
419        }
420
421        Ok(())
422    }
423
424    async fn add(
425        &self,
426        id: &str,
427        embedding: Vec<f32>,
428        text: &str,
429        metadata: HashMap<String, serde_json::Value>,
430    ) -> Result<()> {
431        let mut payload = metadata;
432        payload.insert("text".to_string(), serde_json::json!(text));
433        payload.insert(
434            "timestamp".to_string(),
435            serde_json::json!(chrono::Utc::now().to_rfc3339()),
436        );
437
438        let point = QdrantPoint {
439            id: id.to_string(),
440            vector: embedding,
441            payload,
442        };
443
444        let upsert_url = format!(
445            "{}/collections/{}/points",
446            self.qdrant_url, self.collection_name
447        );
448        let upsert_payload = serde_json::json!({
449            "points": [point]
450        });
451
452        let response = self
453            .client
454            .put(&upsert_url)
455            .json(&upsert_payload)
456            .send()
457            .await
458            .map_err(|e| HeliosError::ToolError(format!("Failed to upload document: {}", e)))?;
459
460        if !response.status().is_success() {
461            let error_text = response
462                .text()
463                .await
464                .unwrap_or_else(|_| "Unknown error".to_string());
465            return Err(HeliosError::ToolError(format!(
466                "Document upload failed: {}",
467                error_text
468            )));
469        }
470
471        Ok(())
472    }
473
474    async fn search(&self, query_embedding: Vec<f32>, limit: usize) -> Result<Vec<SearchResult>> {
475        let search_url = format!(
476            "{}/collections/{}/points/search",
477            self.qdrant_url, self.collection_name
478        );
479        let search_request = QdrantSearchRequest {
480            vector: query_embedding,
481            limit,
482            with_payload: true,
483            with_vector: false,
484        };
485
486        let response = self
487            .client
488            .post(&search_url)
489            .json(&search_request)
490            .send()
491            .await
492            .map_err(|e| HeliosError::ToolError(format!("Search failed: {}", e)))?;
493
494        if !response.status().is_success() {
495            let error_text = response
496                .text()
497                .await
498                .unwrap_or_else(|_| "Unknown error".to_string());
499            return Err(HeliosError::ToolError(format!(
500                "Search request failed: {}",
501                error_text
502            )));
503        }
504
505        let search_response: QdrantSearchResponse = response.json().await.map_err(|e| {
506            HeliosError::ToolError(format!("Failed to parse search response: {}", e))
507        })?;
508
509        let results: Vec<SearchResult> = search_response
510            .result
511            .into_iter()
512            .filter_map(|r| {
513                r.payload.and_then(|p| {
514                    p.get("text").and_then(|t| t.as_str()).map(|text| {
515                        let mut metadata = p.clone();
516                        metadata.remove("text");
517                        SearchResult {
518                            id: r.id,
519                            score: r.score,
520                            text: text.to_string(),
521                            metadata: Some(metadata),
522                        }
523                    })
524                })
525            })
526            .collect();
527
528        Ok(results)
529    }
530
531    async fn delete(&self, id: &str) -> Result<()> {
532        let delete_url = format!(
533            "{}/collections/{}/points/delete",
534            self.qdrant_url, self.collection_name
535        );
536        let delete_payload = serde_json::json!({
537            "points": [id]
538        });
539
540        let response = self
541            .client
542            .post(&delete_url)
543            .json(&delete_payload)
544            .send()
545            .await
546            .map_err(|e| HeliosError::ToolError(format!("Delete failed: {}", e)))?;
547
548        if !response.status().is_success() {
549            let error_text = response
550                .text()
551                .await
552                .unwrap_or_else(|_| "Unknown error".to_string());
553            return Err(HeliosError::ToolError(format!(
554                "Delete request failed: {}",
555                error_text
556            )));
557        }
558
559        Ok(())
560    }
561
562    async fn clear(&self) -> Result<()> {
563        let delete_url = format!("{}/collections/{}", self.qdrant_url, self.collection_name);
564
565        let response = self
566            .client
567            .delete(&delete_url)
568            .send()
569            .await
570            .map_err(|e| HeliosError::ToolError(format!("Clear failed: {}", e)))?;
571
572        if !response.status().is_success() {
573            let error_text = response
574                .text()
575                .await
576                .unwrap_or_else(|_| "Unknown error".to_string());
577            return Err(HeliosError::ToolError(format!(
578                "Clear collection failed: {}",
579                error_text
580            )));
581        }
582
583        Ok(())
584    }
585
586    async fn count(&self) -> Result<usize> {
587        let count_url = format!("{}/collections/{}", self.qdrant_url, self.collection_name);
588
589        let response = self
590            .client
591            .get(&count_url)
592            .send()
593            .await
594            .map_err(|e| HeliosError::ToolError(format!("Count failed: {}", e)))?;
595
596        if !response.status().is_success() {
597            return Ok(0);
598        }
599
600        // Parse collection info to get count
601        let info: serde_json::Value = response.json().await.map_err(|e| {
602            HeliosError::ToolError(format!("Failed to parse collection info: {}", e))
603        })?;
604
605        let count = info
606            .get("result")
607            .and_then(|r| r.get("points_count"))
608            .and_then(|c| c.as_u64())
609            .unwrap_or(0) as usize;
610
611        Ok(count)
612    }
613}
614
615// ============================================================================
616// RAG System
617// ============================================================================
618
619/// Main RAG system that combines embedding provider and vector store
620pub struct RAGSystem {
621    embedding_provider: Box<dyn EmbeddingProvider>,
622    vector_store: Box<dyn VectorStore>,
623    initialized: std::sync::Arc<tokio::sync::RwLock<bool>>,
624}
625
626impl RAGSystem {
627    /// Create a new RAG system
628    pub fn new(
629        embedding_provider: Box<dyn EmbeddingProvider>,
630        vector_store: Box<dyn VectorStore>,
631    ) -> Self {
632        Self {
633            embedding_provider,
634            vector_store,
635            initialized: std::sync::Arc::new(tokio::sync::RwLock::new(false)),
636        }
637    }
638
639    /// Ensure the system is initialized
640    async fn ensure_initialized(&self) -> Result<()> {
641        let is_initialized = *self.initialized.read().await;
642        if !is_initialized {
643            let mut init_guard = self.initialized.write().await;
644            if !*init_guard {
645                let dimension = self.embedding_provider.dimension();
646                self.vector_store.initialize(dimension).await?;
647                *init_guard = true;
648            }
649        }
650        Ok(())
651    }
652
653    /// Add a document to the RAG system
654    pub async fn add_document(
655        &self,
656        text: &str,
657        metadata: Option<HashMap<String, serde_json::Value>>,
658    ) -> Result<String> {
659        self.ensure_initialized().await?;
660
661        let id = Uuid::new_v4().to_string();
662        let embedding = self.embedding_provider.embed(text).await?;
663
664        let mut meta = metadata.unwrap_or_default();
665        meta.insert(
666            "timestamp".to_string(),
667            serde_json::json!(chrono::Utc::now().to_rfc3339()),
668        );
669
670        self.vector_store.add(&id, embedding, text, meta).await?;
671
672        Ok(id)
673    }
674
675    /// Search for similar documents
676    pub async fn search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
677        self.ensure_initialized().await?;
678
679        let query_embedding = self.embedding_provider.embed(query).await?;
680        self.vector_store.search(query_embedding, limit).await
681    }
682
683    /// Delete a document by ID
684    pub async fn delete_document(&self, id: &str) -> Result<()> {
685        self.vector_store.delete(id).await
686    }
687
688    /// Clear all documents
689    pub async fn clear(&self) -> Result<()> {
690        self.vector_store.clear().await
691    }
692
693    /// Get document count
694    pub async fn count(&self) -> Result<usize> {
695        self.vector_store.count().await
696    }
697}