mini_rag/
naive.rs

1use std::{collections::HashMap, path::PathBuf, time::Instant};
2
3use anyhow::Result;
4use glob::glob;
5use rayon::prelude::*;
6use serde::{Deserialize, Serialize};
7
8use crate::{import, metrics, Embedder};
9
10use super::{Configuration, Document, Embeddings};
11
12#[derive(Serialize, Deserialize)]
13struct Store {
14    documents: HashMap<String, Document>,
15    embeddings: HashMap<String, Embeddings>,
16}
17
18impl Store {
19    fn new() -> Self {
20        let documents = HashMap::new();
21        let embeddings = HashMap::new();
22        Self {
23            documents,
24            embeddings,
25        }
26    }
27
28    fn from_data_path(path: &str) -> Result<Self> {
29        let path = PathBuf::from(path).join("rag.bin");
30        if path.exists() {
31            let raw = std::fs::read(&path)?;
32            Ok(bitcode::deserialize(&raw)?)
33        } else {
34            Ok(Store::new())
35        }
36    }
37
38    fn to_data_path(&self, path: &str) -> Result<()> {
39        let path = PathBuf::from(path).join("rag.bin");
40        let raw = bitcode::serialize(&self)?;
41
42        std::fs::write(path, raw)?;
43
44        Ok(())
45    }
46}
47
48pub struct VectorStore {
49    config: Configuration,
50    embedder: Box<dyn Embedder>,
51    store: Store,
52}
53
54impl VectorStore {
55    pub fn new(embedder: Box<dyn Embedder>, config: Configuration) -> Result<Self> {
56        let store = Store::from_data_path(&config.data_path)?;
57        Ok(Self {
58            config,
59            embedder,
60            store,
61        })
62    }
63
64    pub async fn import_new_documents(&mut self) -> Result<()> {
65        let path = std::fs::canonicalize(&self.config.source_path)?
66            .display()
67            .to_string();
68
69        let expr = format!("{}/**/*.*", path);
70        let start = Instant::now();
71        let mut new = 0;
72
73        for path in (glob(&expr)?).flatten() {
74            match import::import_document_from(&path) {
75                Ok(doc) => {
76                    let docs = if let Some(chunk_size) = self.config.chunk_size {
77                        doc.chunks(chunk_size)?
78                    } else {
79                        vec![doc]
80                    };
81
82                    for doc in docs {
83                        match self.add(doc).await {
84                            Err(err) => log::error!("storing {}: {}", path.display(), err),
85                            Ok(added) => {
86                                if added {
87                                    new += 1
88                                }
89                            }
90                        }
91                    }
92                }
93                Err(err) => log::warn!("{} {err}", path.display()),
94            }
95        }
96
97        if new > 0 {
98            log::info!("{} new documents indexed in {:?}\n", new, start.elapsed(),);
99        }
100
101        Ok(())
102    }
103
104    pub async fn add(&mut self, mut document: Document) -> Result<bool> {
105        let doc_id = document.get_ident().to_string();
106        let doc_path = document.get_path().to_string();
107
108        if self.store.documents.contains_key(&doc_id) {
109            log::debug!("document with id '{}' already indexed", &doc_id);
110            return Ok(false);
111        }
112
113        log::info!(
114            "indexing new document '{}' ({} bytes) ...",
115            doc_path,
116            document.get_byte_size()?
117        );
118
119        let start = Instant::now();
120        let embeddings: Vec<f64> = self.embedder.embed(document.get_data()?).await?;
121        let size = embeddings.len();
122
123        // get rid of the contents once indexed
124        document.drop_data();
125
126        self.store.documents.insert(doc_id.to_string(), document);
127        self.store.embeddings.insert(doc_id, embeddings);
128
129        self.store.to_data_path(&self.config.data_path)?;
130
131        log::debug!("time={:?} embedding_size={}", start.elapsed(), size);
132
133        Ok(true)
134    }
135
136    pub async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<(Document, f64)>> {
137        log::debug!("{} (top {})", query, top_k);
138
139        let query_vector = self.embedder.embed(query).await?;
140        let mut results = vec![];
141
142        let distances: Vec<(&String, f64)> = {
143            let mut distances: Vec<(&String, f64)> = self
144                .store
145                .embeddings
146                .par_iter()
147                .map(|(doc_id, doc_embedding)| {
148                    (doc_id, metrics::cosine(&query_vector, doc_embedding))
149                })
150                .collect();
151            distances.par_sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap());
152            distances
153        };
154
155        for (doc_id, score) in distances {
156            let document = self.store.documents.get(doc_id).unwrap();
157            results.push((document.clone(), score));
158            if results.len() >= top_k {
159                break;
160            }
161        }
162
163        Ok(results)
164    }
165}