rain-engine-cognition 0.1.0

Planning and task orchestration for RainEngine
Documentation
use rain_engine_core::{EmbeddingProvider, ProviderError};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentChunk {
    pub chunk_id: String,
    pub text: String,
    pub embedding: Vec<f32>,
    pub metadata: serde_json::Value,
}

pub struct CognitiveStore {
    embedding_provider: Arc<dyn EmbeddingProvider>,
    chunks: RwLock<Vec<DocumentChunk>>,
}

impl CognitiveStore {
    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
        Self {
            embedding_provider,
            chunks: RwLock::new(Vec::new()),
        }
    }

    pub async fn ingest(
        &self,
        text: String,
        metadata: serde_json::Value,
    ) -> Result<(), ProviderError> {
        let chunks = chunk_text(&text, 1000);
        let embeddings = self
            .embedding_provider
            .generate_embeddings(chunks.clone())
            .await?;

        let mut store = self.chunks.write().await;
        for (chunk_text, embedding) in chunks.into_iter().zip(embeddings) {
            store.push(DocumentChunk {
                chunk_id: uuid::Uuid::new_v4().to_string(),
                text: chunk_text,
                embedding,
                metadata: metadata.clone(),
            });
        }
        Ok(())
    }

    pub async fn search(
        &self,
        query: String,
        limit: usize,
    ) -> Result<Vec<(DocumentChunk, f32)>, ProviderError> {
        let query_embedding = self
            .embedding_provider
            .generate_embeddings(vec![query])
            .await?
            .pop()
            .ok_or_else(|| ProviderError::internal("no embedding generated for query"))?;

        let chunks = self.chunks.read().await;
        let mut results: Vec<(DocumentChunk, f32)> = chunks
            .iter()
            .map(|chunk| {
                let score = cosine_similarity(&query_embedding, &chunk.embedding);
                (chunk.clone(), score)
            })
            .collect();

        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
        results.truncate(limit);

        Ok(results)
    }
}

fn chunk_text(text: &str, size: usize) -> Vec<String> {
    text.chars()
        .collect::<Vec<_>>()
        .chunks(size)
        .map(|c| c.iter().collect::<String>())
        .collect()
}

fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
    if a.len() != b.len() {
        return 0.0;
    }
    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
    if norm_a == 0.0 || norm_b == 0.0 {
        return 0.0;
    }
    dot / (norm_a * norm_b)
}