1use async_trait::async_trait;
2use rusqlite::{params, Connection};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::{HashMap, HashSet};
6use std::path::Path;
7use std::sync::Mutex;
8
9#[derive(thiserror::Error, Debug)]
10pub enum VectorStoreError {
11 #[error("db error: {0}")]
12 Db(String),
13 #[error("serialization error: {0}")]
14 Serialization(String),
15 #[error("invalid input: {0}")]
16 InvalidInput(String),
17}
18
19impl From<rusqlite::Error> for VectorStoreError {
20 fn from(err: rusqlite::Error) -> Self {
21 VectorStoreError::Db(err.to_string())
22 }
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, Default)]
26pub struct AccessControl {
27 pub allow_users: Vec<String>,
28 pub allow_groups: Vec<String>,
29 pub deny_users: Vec<String>,
30 pub deny_groups: Vec<String>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct EmbeddedChunk {
35 pub id: String,
36 pub tenant_id: String,
37 pub workspace_id: String,
38 pub source_id: String,
39 pub doc_id: String,
40 pub chunk_index: i64,
41 pub text: String,
42 pub embedding: Vec<f32>,
43 pub metadata: HashMap<String, String>,
44 pub acl: AccessControl,
45}
46
47#[derive(Debug, Clone)]
48pub struct VectorQuery {
49 pub query_embedding: Vec<f32>,
50 pub top_k: usize,
51 pub tenant_id: String,
52 pub workspace_id: String,
53 pub source_ids: Option<Vec<String>>,
54 pub allowed_users: Vec<String>,
55 pub allowed_groups: Vec<String>,
56 pub min_score: f32,
57}
58
59#[derive(Debug, Clone)]
60pub struct VectorSearchResult {
61 pub chunk: EmbeddedChunk,
62 pub score: f32,
63}
64
65#[async_trait]
66pub trait VectorStore: Send + Sync {
67 async fn upsert(&self, chunks: Vec<EmbeddedChunk>) -> Result<(), VectorStoreError>;
68 async fn delete_by_doc(
69 &self,
70 tenant_id: &str,
71 workspace_id: &str,
72 source_id: &str,
73 doc_id: &str,
74 ) -> Result<(), VectorStoreError>;
75 async fn query(
76 &self,
77 request: VectorQuery,
78 ) -> Result<Vec<VectorSearchResult>, VectorStoreError>;
79}
80
81pub struct SqliteVectorStore {
82 conn: Mutex<Connection>,
83}
84
85impl SqliteVectorStore {
86 pub fn open(path: &Path) -> Result<Self, VectorStoreError> {
87 let conn = Connection::open(path)?;
88 let store = Self {
89 conn: Mutex::new(conn),
90 };
91 store.init()?;
92 Ok(store)
93 }
94
95 fn init(&self) -> Result<(), VectorStoreError> {
96 let conn = self
97 .conn
98 .lock()
99 .map_err(|_| VectorStoreError::Db("vector store mutex poisoned".to_string()))?;
100 conn.execute_batch(
101 "CREATE TABLE IF NOT EXISTS rag_chunks (
102 id TEXT PRIMARY KEY,
103 tenant_id TEXT NOT NULL,
104 workspace_id TEXT NOT NULL,
105 source_id TEXT NOT NULL,
106 doc_id TEXT NOT NULL,
107 chunk_index INTEGER NOT NULL,
108 text TEXT NOT NULL,
109 embedding BLOB NOT NULL,
110 metadata_json TEXT,
111 acl_allow_users TEXT,
112 acl_allow_groups TEXT,
113 acl_deny_users TEXT,
114 acl_deny_groups TEXT,
115 updated_at INTEGER
116 );
117 CREATE INDEX IF NOT EXISTS idx_rag_chunks_scope
118 ON rag_chunks (tenant_id, workspace_id, source_id, doc_id);
119 ",
120 )?;
121 Ok(())
122 }
123}
124
125#[async_trait]
126impl VectorStore for SqliteVectorStore {
127 async fn upsert(&self, chunks: Vec<EmbeddedChunk>) -> Result<(), VectorStoreError> {
128 let mut conn = self
129 .conn
130 .lock()
131 .map_err(|_| VectorStoreError::Db("vector store mutex poisoned".to_string()))?;
132 let tx = conn.transaction()?;
133 for chunk in chunks {
134 let metadata_json = serde_json::to_string(&chunk.metadata)
135 .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
136 let acl_allow_users = serde_json::to_string(&chunk.acl.allow_users)
137 .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
138 let acl_allow_groups = serde_json::to_string(&chunk.acl.allow_groups)
139 .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
140 let acl_deny_users = serde_json::to_string(&chunk.acl.deny_users)
141 .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
142 let acl_deny_groups = serde_json::to_string(&chunk.acl.deny_groups)
143 .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
144 let embedding_blob = serialize_embedding(&chunk.embedding);
145 tx.execute(
146 "INSERT OR REPLACE INTO rag_chunks (
147 id, tenant_id, workspace_id, source_id, doc_id, chunk_index,
148 text, embedding, metadata_json, acl_allow_users, acl_allow_groups,
149 acl_deny_users, acl_deny_groups, updated_at
150 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, strftime('%s','now'))",
151 params![
152 chunk.id,
153 chunk.tenant_id,
154 chunk.workspace_id,
155 chunk.source_id,
156 chunk.doc_id,
157 chunk.chunk_index,
158 chunk.text,
159 embedding_blob,
160 metadata_json,
161 acl_allow_users,
162 acl_allow_groups,
163 acl_deny_users,
164 acl_deny_groups,
165 ],
166 )?;
167 }
168 tx.commit()?;
169 Ok(())
170 }
171
172 async fn delete_by_doc(
173 &self,
174 tenant_id: &str,
175 workspace_id: &str,
176 source_id: &str,
177 doc_id: &str,
178 ) -> Result<(), VectorStoreError> {
179 let conn = self
180 .conn
181 .lock()
182 .map_err(|_| VectorStoreError::Db("vector store mutex poisoned".to_string()))?;
183 conn.execute(
184 "DELETE FROM rag_chunks WHERE tenant_id = ?1 AND workspace_id = ?2 AND source_id = ?3 AND doc_id = ?4",
185 params![tenant_id, workspace_id, source_id, doc_id],
186 )?;
187 Ok(())
188 }
189
190 async fn query(
191 &self,
192 request: VectorQuery,
193 ) -> Result<Vec<VectorSearchResult>, VectorStoreError> {
194 if request.query_embedding.is_empty() {
195 return Err(VectorStoreError::InvalidInput(
196 "query embedding is empty".to_string(),
197 ));
198 }
199 let conn = self
200 .conn
201 .lock()
202 .map_err(|_| VectorStoreError::Db("vector store mutex poisoned".to_string()))?;
203
204 let mut sql = String::from(
205 "SELECT id, tenant_id, workspace_id, source_id, doc_id, chunk_index, text, embedding,
206 metadata_json, acl_allow_users, acl_allow_groups, acl_deny_users, acl_deny_groups
207 FROM rag_chunks
208 WHERE tenant_id = ? AND workspace_id = ?",
209 );
210 if let Some(source_ids) = &request.source_ids {
211 if !source_ids.is_empty() {
212 let placeholders = vec!["?"; source_ids.len()].join(", ");
213 sql.push_str(" AND source_id IN (");
214 sql.push_str(&placeholders);
215 sql.push(')');
216 }
217 }
218
219 let mut stmt = conn.prepare(&sql)?;
220 let mut params_vec: Vec<&dyn rusqlite::ToSql> = Vec::new();
221 params_vec.push(&request.tenant_id as &dyn rusqlite::ToSql);
222 params_vec.push(&request.workspace_id as &dyn rusqlite::ToSql);
223 if let Some(source_ids) = &request.source_ids {
224 for source_id in source_ids {
225 params_vec.push(source_id as &dyn rusqlite::ToSql);
226 }
227 }
228
229 let mut rows = stmt.query(params_vec.as_slice())?;
230 let mut results = Vec::new();
231 let allowed_users: HashSet<String> = request.allowed_users.iter().cloned().collect();
232 let allowed_groups: HashSet<String> = request.allowed_groups.iter().cloned().collect();
233
234 while let Some(row) = rows.next()? {
235 let embedding_blob: Vec<u8> = row.get(7)?;
236 let embedding = deserialize_embedding(&embedding_blob);
237 if embedding.len() != request.query_embedding.len() {
238 continue;
239 }
240 let acl_allow_users: String = row.get(9)?;
241 let acl_allow_groups: String = row.get(10)?;
242 let acl_deny_users: String = row.get(11)?;
243 let acl_deny_groups: String = row.get(12)?;
244
245 let allow_users: Vec<String> = parse_json_list(&acl_allow_users);
246 let allow_groups: Vec<String> = parse_json_list(&acl_allow_groups);
247 let deny_users: Vec<String> = parse_json_list(&acl_deny_users);
248 let deny_groups: Vec<String> = parse_json_list(&acl_deny_groups);
249
250 if !is_allowed(&allowed_users, &allowed_groups, &allow_users, &allow_groups) {
251 continue;
252 }
253 if is_denied(&allowed_users, &allowed_groups, &deny_users, &deny_groups) {
254 continue;
255 }
256
257 let score = cosine_similarity(&request.query_embedding, &embedding);
259 if score < request.min_score {
260 continue;
261 }
262
263 let metadata_json: String = row.get(8)?;
264 let metadata: HashMap<String, String> =
265 serde_json::from_str(&metadata_json).unwrap_or_else(|_| HashMap::new());
266
267 let chunk = EmbeddedChunk {
268 id: row.get(0)?,
269 tenant_id: row.get(1)?,
270 workspace_id: row.get(2)?,
271 source_id: row.get(3)?,
272 doc_id: row.get(4)?,
273 chunk_index: row.get(5)?,
274 text: row.get(6)?,
275 embedding,
276 metadata,
277 acl: AccessControl {
278 allow_users,
279 allow_groups,
280 deny_users,
281 deny_groups,
282 },
283 };
284
285 results.push(VectorSearchResult { chunk, score });
286 }
287
288 results.sort_by(|a, b| {
289 b.score
290 .partial_cmp(&a.score)
291 .unwrap_or(std::cmp::Ordering::Equal)
292 });
293 results.truncate(request.top_k);
294 Ok(results)
295 }
296}
297
298fn serialize_embedding(embedding: &[f32]) -> Vec<u8> {
299 let mut bytes = Vec::with_capacity(embedding.len() * 4);
300 for val in embedding {
301 bytes.extend_from_slice(&val.to_le_bytes());
302 }
303 bytes
304}
305
306fn deserialize_embedding(bytes: &[u8]) -> Vec<f32> {
307 bytes
308 .chunks_exact(4)
309 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
310 .collect()
311}
312
313fn parse_json_list(value: &str) -> Vec<String> {
314 match serde_json::from_str::<Value>(value) {
315 Ok(Value::Array(items)) => items
316 .into_iter()
317 .filter_map(|v| v.as_str().map(|s| s.to_string()))
318 .collect(),
319 _ => Vec::new(),
320 }
321}
322
323fn is_allowed(
324 allowed_users: &HashSet<String>,
325 allowed_groups: &HashSet<String>,
326 acl_users: &[String],
327 acl_groups: &[String],
328) -> bool {
329 if acl_users.is_empty() && acl_groups.is_empty() {
330 return true;
331 }
332 acl_users.iter().any(|u| allowed_users.contains(u))
333 || acl_groups.iter().any(|g| allowed_groups.contains(g))
334}
335
336fn is_denied(
337 allowed_users: &HashSet<String>,
338 allowed_groups: &HashSet<String>,
339 deny_users: &[String],
340 deny_groups: &[String],
341) -> bool {
342 deny_users.iter().any(|u| allowed_users.contains(u))
343 || deny_groups.iter().any(|g| allowed_groups.contains(g))
344}
345
346fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
347 let mut dot = 0.0f32;
348 let mut norm_a = 0.0f32;
349 let mut norm_b = 0.0f32;
350 for (x, y) in a.iter().zip(b.iter()) {
351 dot += x * y;
352 norm_a += x * x;
353 norm_b += y * y;
354 }
355 if norm_a == 0.0 || norm_b == 0.0 {
356 return 0.0;
357 }
358 dot / (norm_a.sqrt() * norm_b.sqrt())
359}