lumen_rag/stores/
mongo.rs

1use crate::store::VectorStore;
2use crate::types::Passage;
3use crate::utils::compute_hash;
4use anyhow::Result;
5use async_trait::async_trait;
6use futures::stream::{FuturesUnordered, StreamExt};
7use mongodb::bson::{doc, oid::ObjectId, Bson};
8use mongodb::{Client, Collection};
9use rayon::prelude::*;
10use std::str::FromStr;
11use std::sync::Arc;
12
13pub struct MongoStore {
14    client: Client,
15    db_name: String,
16    collection_name: String,
17    fetch_limit: i64,
18}
19
20impl MongoStore {
21    pub fn new(client: Client, db_name: String, collection_name: String) -> Self {
22        Self {
23            client,
24            db_name,
25            collection_name,
26            fetch_limit: 2000,
27        }
28    }
29
30    fn get_collection(&self) -> Collection<Passage> {
31        self.client
32            .database(&self.db_name)
33            .collection(&self.collection_name)
34    }
35
36    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
37        let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
38        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
39        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
40        if norm_a == 0.0 || norm_b == 0.0 {
41            0.0
42        } else {
43            dot / (norm_a * norm_b)
44        }
45    }
46}
47
48#[async_trait]
49impl VectorStore for MongoStore {
50    async fn add_passages(&self, passages: Vec<Passage>) -> Result<Vec<String>> {
51        let collection = self.get_collection();
52        let mut inserted_ids = Vec::new();
53        let mut tasks = FuturesUnordered::new();
54        let coll_arc = Arc::new(collection);
55
56        for mut p in passages {
57            let coll = Arc::clone(&coll_arc);
58            tasks.push(async move {
59                let hash = compute_hash(&p.text);
60                p.hash = Some(hash as i64);
61
62                if let Ok(Some(existing)) = coll.find_one(doc! { "hash": hash as i64 }).await {
63                    if let Some(id_str) = existing.id {
64                        return Ok(id_str);
65                    }
66                }
67
68                let _id_filter = if let Some(ref id_str) = p.id {
69                    if let Ok(oid) = ObjectId::from_str(id_str) {
70                        Some(doc! { "_id": oid })
71                    } else {
72                        None
73                    }
74                } else {
75                    None
76                };
77
78                match coll.insert_one(p).await {
79                    Ok(res) => match res.inserted_id {
80                        Bson::ObjectId(oid) => Ok(oid.to_string()),
81                        _ => Ok("unknown_id".to_string()),
82                    },
83                    Err(e) => Err(anyhow::anyhow!("DB Error: {}", e)),
84                }
85            });
86        }
87
88        while let Some(res) = tasks.next().await {
89            if let Ok(id) = res {
90                inserted_ids.push(id);
91            }
92        }
93        Ok(inserted_ids)
94    }
95
96    async fn search(&self, query_embedding: &[f32], limit: usize) -> Result<Vec<Passage>> {
97        let collection = self.get_collection();
98
99        let find_opts = mongodb::options::FindOptions::builder()
100            .limit(self.fetch_limit)
101            .projection(doc! {
102                "text": 1,
103                "embedding": 1,
104                "metadata": 1,
105                "_id": 0
106            })
107            .build();
108
109        let mut cursor = collection.find(doc! {}).with_options(find_opts).await?;
110        let mut candidates = Vec::new();
111
112        while cursor.advance().await? {
113            candidates.push(cursor.deserialize_current()?);
114        }
115
116        let mut scored_passages: Vec<_> = candidates
117            .par_iter()
118            .map(|p| {
119                let sim = Self::cosine_similarity(query_embedding, &p.embedding);
120                (p.clone(), sim)
121            })
122            .collect();
123
124        scored_passages
125            .sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
126
127        let result: Vec<Passage> = scored_passages
128            .into_iter()
129            .take(limit)
130            .map(|(p, _)| p)
131            .collect();
132
133        Ok(result)
134    }
135}