Skip to main content

rain_engine_cognition/
rag.rs

1use rain_engine_core::{EmbeddingProvider, ProviderError};
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4use tokio::sync::RwLock;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct DocumentChunk {
8    pub chunk_id: String,
9    pub text: String,
10    pub embedding: Vec<f32>,
11    pub metadata: serde_json::Value,
12}
13
14pub struct CognitiveStore {
15    embedding_provider: Arc<dyn EmbeddingProvider>,
16    chunks: RwLock<Vec<DocumentChunk>>,
17}
18
19impl CognitiveStore {
20    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
21        Self {
22            embedding_provider,
23            chunks: RwLock::new(Vec::new()),
24        }
25    }
26
27    pub async fn ingest(
28        &self,
29        text: String,
30        metadata: serde_json::Value,
31    ) -> Result<(), ProviderError> {
32        let chunks = chunk_text(&text, 1000);
33        let embeddings = self
34            .embedding_provider
35            .generate_embeddings(chunks.clone())
36            .await?;
37
38        let mut store = self.chunks.write().await;
39        for (chunk_text, embedding) in chunks.into_iter().zip(embeddings) {
40            store.push(DocumentChunk {
41                chunk_id: uuid::Uuid::new_v4().to_string(),
42                text: chunk_text,
43                embedding,
44                metadata: metadata.clone(),
45            });
46        }
47        Ok(())
48    }
49
50    pub async fn search(
51        &self,
52        query: String,
53        limit: usize,
54    ) -> Result<Vec<(DocumentChunk, f32)>, ProviderError> {
55        let query_embedding = self
56            .embedding_provider
57            .generate_embeddings(vec![query])
58            .await?
59            .pop()
60            .ok_or_else(|| ProviderError::internal("no embedding generated for query"))?;
61
62        let chunks = self.chunks.read().await;
63        let mut results: Vec<(DocumentChunk, f32)> = chunks
64            .iter()
65            .map(|chunk| {
66                let score = cosine_similarity(&query_embedding, &chunk.embedding);
67                (chunk.clone(), score)
68            })
69            .collect();
70
71        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
72        results.truncate(limit);
73
74        Ok(results)
75    }
76}
77
78fn chunk_text(text: &str, size: usize) -> Vec<String> {
79    text.chars()
80        .collect::<Vec<_>>()
81        .chunks(size)
82        .map(|c| c.iter().collect::<String>())
83        .collect()
84}
85
86fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
87    if a.len() != b.len() {
88        return 0.0;
89    }
90    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
91    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
92    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
93    if norm_a == 0.0 || norm_b == 0.0 {
94        return 0.0;
95    }
96    dot / (norm_a * norm_b)
97}