lumen_rag/stores/
mongo.rs1use crate::store::VectorStore;
2use crate::types::Passage;
3use crate::utils::compute_hash;
4use anyhow::Result;
5use async_trait::async_trait;
6use futures::stream::{FuturesUnordered, StreamExt};
7use mongodb::bson::{doc, oid::ObjectId, Bson};
8use mongodb::{Client, Collection};
9use rayon::prelude::*;
10use std::str::FromStr;
11use std::sync::Arc;
12
13pub struct MongoStore {
14 client: Client,
15 db_name: String,
16 collection_name: String,
17 fetch_limit: i64,
18}
19
20impl MongoStore {
21 pub fn new(client: Client, db_name: String, collection_name: String) -> Self {
22 Self {
23 client,
24 db_name,
25 collection_name,
26 fetch_limit: 2000,
27 }
28 }
29
30 fn get_collection(&self) -> Collection<Passage> {
31 self.client
32 .database(&self.db_name)
33 .collection(&self.collection_name)
34 }
35
36 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
37 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
38 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
39 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
40 if norm_a == 0.0 || norm_b == 0.0 {
41 0.0
42 } else {
43 dot / (norm_a * norm_b)
44 }
45 }
46}
47
48#[async_trait]
49impl VectorStore for MongoStore {
50 async fn add_passages(&self, passages: Vec<Passage>) -> Result<Vec<String>> {
51 let collection = self.get_collection();
52 let mut inserted_ids = Vec::new();
53 let mut tasks = FuturesUnordered::new();
54 let coll_arc = Arc::new(collection);
55
56 for mut p in passages {
57 let coll = Arc::clone(&coll_arc);
58 tasks.push(async move {
59 let hash = compute_hash(&p.text);
60 p.hash = Some(hash as i64);
61
62 if let Ok(Some(existing)) = coll.find_one(doc! { "hash": hash as i64 }).await {
63 if let Some(id_str) = existing.id {
64 return Ok(id_str);
65 }
66 }
67
68 let _id_filter = if let Some(ref id_str) = p.id {
69 if let Ok(oid) = ObjectId::from_str(id_str) {
70 Some(doc! { "_id": oid })
71 } else {
72 None
73 }
74 } else {
75 None
76 };
77
78 match coll.insert_one(p).await {
79 Ok(res) => match res.inserted_id {
80 Bson::ObjectId(oid) => Ok(oid.to_string()),
81 _ => Ok("unknown_id".to_string()),
82 },
83 Err(e) => Err(anyhow::anyhow!("DB Error: {}", e)),
84 }
85 });
86 }
87
88 while let Some(res) = tasks.next().await {
89 if let Ok(id) = res {
90 inserted_ids.push(id);
91 }
92 }
93 Ok(inserted_ids)
94 }
95
96 async fn search(&self, query_embedding: &[f32], limit: usize) -> Result<Vec<Passage>> {
97 let collection = self.get_collection();
98
99 let find_opts = mongodb::options::FindOptions::builder()
100 .limit(self.fetch_limit)
101 .projection(doc! {
102 "text": 1,
103 "embedding": 1,
104 "metadata": 1,
105 "_id": 0
106 })
107 .build();
108
109 let mut cursor = collection.find(doc! {}).with_options(find_opts).await?;
110 let mut candidates = Vec::new();
111
112 while cursor.advance().await? {
113 candidates.push(cursor.deserialize_current()?);
114 }
115
116 let mut scored_passages: Vec<_> = candidates
117 .par_iter()
118 .map(|p| {
119 let sim = Self::cosine_similarity(query_embedding, &p.embedding);
120 (p.clone(), sim)
121 })
122 .collect();
123
124 scored_passages
125 .sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
126
127 let result: Vec<Passage> = scored_passages
128 .into_iter()
129 .take(limit)
130 .map(|(p, _)| p)
131 .collect();
132
133 Ok(result)
134 }
135}