rain_engine_cognition/
rag.rs1use rain_engine_core::{EmbeddingProvider, ProviderError};
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4use tokio::sync::RwLock;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct DocumentChunk {
8 pub chunk_id: String,
9 pub text: String,
10 pub embedding: Vec<f32>,
11 pub metadata: serde_json::Value,
12}
13
14pub struct CognitiveStore {
15 embedding_provider: Arc<dyn EmbeddingProvider>,
16 chunks: RwLock<Vec<DocumentChunk>>,
17}
18
19impl CognitiveStore {
20 pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
21 Self {
22 embedding_provider,
23 chunks: RwLock::new(Vec::new()),
24 }
25 }
26
27 pub async fn ingest(
28 &self,
29 text: String,
30 metadata: serde_json::Value,
31 ) -> Result<(), ProviderError> {
32 let chunks = chunk_text(&text, 1000);
33 let embeddings = self
34 .embedding_provider
35 .generate_embeddings(chunks.clone())
36 .await?;
37
38 let mut store = self.chunks.write().await;
39 for (chunk_text, embedding) in chunks.into_iter().zip(embeddings) {
40 store.push(DocumentChunk {
41 chunk_id: uuid::Uuid::new_v4().to_string(),
42 text: chunk_text,
43 embedding,
44 metadata: metadata.clone(),
45 });
46 }
47 Ok(())
48 }
49
50 pub async fn search(
51 &self,
52 query: String,
53 limit: usize,
54 ) -> Result<Vec<(DocumentChunk, f32)>, ProviderError> {
55 let query_embedding = self
56 .embedding_provider
57 .generate_embeddings(vec![query])
58 .await?
59 .pop()
60 .ok_or_else(|| ProviderError::internal("no embedding generated for query"))?;
61
62 let chunks = self.chunks.read().await;
63 let mut results: Vec<(DocumentChunk, f32)> = chunks
64 .iter()
65 .map(|chunk| {
66 let score = cosine_similarity(&query_embedding, &chunk.embedding);
67 (chunk.clone(), score)
68 })
69 .collect();
70
71 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
72 results.truncate(limit);
73
74 Ok(results)
75 }
76}
77
78fn chunk_text(text: &str, size: usize) -> Vec<String> {
79 text.chars()
80 .collect::<Vec<_>>()
81 .chunks(size)
82 .map(|c| c.iter().collect::<String>())
83 .collect()
84}
85
86fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
87 if a.len() != b.len() {
88 return 0.0;
89 }
90 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
91 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
92 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
93 if norm_a == 0.0 || norm_b == 0.0 {
94 return 0.0;
95 }
96 dot / (norm_a * norm_b)
97}