1use std::{collections::HashMap, path::PathBuf, time::Instant};
2
3use anyhow::Result;
4use glob::glob;
5use rayon::prelude::*;
6use serde::{Deserialize, Serialize};
7
8use crate::{import, metrics, Embedder};
9
10use super::{Configuration, Document, Embeddings};
11
12#[derive(Serialize, Deserialize)]
13struct Store {
14 documents: HashMap<String, Document>,
15 embeddings: HashMap<String, Embeddings>,
16}
17
18impl Store {
19 fn new() -> Self {
20 let documents = HashMap::new();
21 let embeddings = HashMap::new();
22 Self {
23 documents,
24 embeddings,
25 }
26 }
27
28 fn from_data_path(path: &str) -> Result<Self> {
29 let path = PathBuf::from(path).join("rag.bin");
30 if path.exists() {
31 let raw = std::fs::read(&path)?;
32 Ok(bitcode::deserialize(&raw)?)
33 } else {
34 Ok(Store::new())
35 }
36 }
37
38 fn to_data_path(&self, path: &str) -> Result<()> {
39 let path = PathBuf::from(path).join("rag.bin");
40 let raw = bitcode::serialize(&self)?;
41
42 std::fs::write(path, raw)?;
43
44 Ok(())
45 }
46}
47
48pub struct VectorStore {
49 config: Configuration,
50 embedder: Box<dyn Embedder>,
51 store: Store,
52}
53
54impl VectorStore {
55 pub fn new(embedder: Box<dyn Embedder>, config: Configuration) -> Result<Self> {
56 let store = Store::from_data_path(&config.data_path)?;
57 Ok(Self {
58 config,
59 embedder,
60 store,
61 })
62 }
63
64 pub async fn import_new_documents(&mut self) -> Result<()> {
65 let path = std::fs::canonicalize(&self.config.source_path)?
66 .display()
67 .to_string();
68
69 let expr = format!("{}/**/*.*", path);
70 let start = Instant::now();
71 let mut new = 0;
72
73 for path in (glob(&expr)?).flatten() {
74 match import::import_document_from(&path) {
75 Ok(doc) => {
76 let docs = if let Some(chunk_size) = self.config.chunk_size {
77 doc.chunks(chunk_size)?
78 } else {
79 vec![doc]
80 };
81
82 for doc in docs {
83 match self.add(doc).await {
84 Err(err) => log::error!("storing {}: {}", path.display(), err),
85 Ok(added) => {
86 if added {
87 new += 1
88 }
89 }
90 }
91 }
92 }
93 Err(err) => log::warn!("{} {err}", path.display()),
94 }
95 }
96
97 if new > 0 {
98 log::info!("{} new documents indexed in {:?}\n", new, start.elapsed(),);
99 }
100
101 Ok(())
102 }
103
104 pub async fn add(&mut self, mut document: Document) -> Result<bool> {
105 let doc_id = document.get_ident().to_string();
106 let doc_path = document.get_path().to_string();
107
108 if self.store.documents.contains_key(&doc_id) {
109 log::debug!("document with id '{}' already indexed", &doc_id);
110 return Ok(false);
111 }
112
113 log::info!(
114 "indexing new document '{}' ({} bytes) ...",
115 doc_path,
116 document.get_byte_size()?
117 );
118
119 let start = Instant::now();
120 let embeddings: Vec<f64> = self.embedder.embed(document.get_data()?).await?;
121 let size = embeddings.len();
122
123 document.drop_data();
125
126 self.store.documents.insert(doc_id.to_string(), document);
127 self.store.embeddings.insert(doc_id, embeddings);
128
129 self.store.to_data_path(&self.config.data_path)?;
130
131 log::debug!("time={:?} embedding_size={}", start.elapsed(), size);
132
133 Ok(true)
134 }
135
136 pub async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<(Document, f64)>> {
137 log::debug!("{} (top {})", query, top_k);
138
139 let query_vector = self.embedder.embed(query).await?;
140 let mut results = vec![];
141
142 let distances: Vec<(&String, f64)> = {
143 let mut distances: Vec<(&String, f64)> = self
144 .store
145 .embeddings
146 .par_iter()
147 .map(|(doc_id, doc_embedding)| {
148 (doc_id, metrics::cosine(&query_vector, doc_embedding))
149 })
150 .collect();
151 distances.par_sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap());
152 distances
153 };
154
155 for (doc_id, score) in distances {
156 let document = self.store.documents.get(doc_id).unwrap();
157 results.push((document.clone(), score));
158 if results.len() >= top_k {
159 break;
160 }
161 }
162
163 Ok(results)
164 }
165}