Skip to main content

kapsl_rag/vector/
mod.rs

1use async_trait::async_trait;
2use rusqlite::{params, Connection};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::{HashMap, HashSet};
6use std::path::Path;
7use std::sync::Mutex;
8
9#[derive(thiserror::Error, Debug)]
10pub enum VectorStoreError {
11    #[error("db error: {0}")]
12    Db(String),
13    #[error("serialization error: {0}")]
14    Serialization(String),
15    #[error("invalid input: {0}")]
16    InvalidInput(String),
17}
18
19impl From<rusqlite::Error> for VectorStoreError {
20    fn from(err: rusqlite::Error) -> Self {
21        VectorStoreError::Db(err.to_string())
22    }
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, Default)]
26pub struct AccessControl {
27    pub allow_users: Vec<String>,
28    pub allow_groups: Vec<String>,
29    pub deny_users: Vec<String>,
30    pub deny_groups: Vec<String>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct EmbeddedChunk {
35    pub id: String,
36    pub tenant_id: String,
37    pub workspace_id: String,
38    pub source_id: String,
39    pub doc_id: String,
40    pub chunk_index: i64,
41    pub text: String,
42    pub embedding: Vec<f32>,
43    pub metadata: HashMap<String, String>,
44    pub acl: AccessControl,
45}
46
47#[derive(Debug, Clone)]
48pub struct VectorQuery {
49    pub query_embedding: Vec<f32>,
50    pub top_k: usize,
51    pub tenant_id: String,
52    pub workspace_id: String,
53    pub source_ids: Option<Vec<String>>,
54    pub allowed_users: Vec<String>,
55    pub allowed_groups: Vec<String>,
56    pub min_score: f32,
57}
58
59#[derive(Debug, Clone)]
60pub struct VectorSearchResult {
61    pub chunk: EmbeddedChunk,
62    pub score: f32,
63}
64
65#[async_trait]
66pub trait VectorStore: Send + Sync {
67    async fn upsert(&self, chunks: Vec<EmbeddedChunk>) -> Result<(), VectorStoreError>;
68    async fn delete_by_doc(
69        &self,
70        tenant_id: &str,
71        workspace_id: &str,
72        source_id: &str,
73        doc_id: &str,
74    ) -> Result<(), VectorStoreError>;
75    async fn query(
76        &self,
77        request: VectorQuery,
78    ) -> Result<Vec<VectorSearchResult>, VectorStoreError>;
79}
80
81pub struct SqliteVectorStore {
82    conn: Mutex<Connection>,
83}
84
85impl SqliteVectorStore {
86    pub fn open(path: &Path) -> Result<Self, VectorStoreError> {
87        let conn = Connection::open(path)?;
88        let store = Self {
89            conn: Mutex::new(conn),
90        };
91        store.init()?;
92        Ok(store)
93    }
94
95    fn init(&self) -> Result<(), VectorStoreError> {
96        let conn = self
97            .conn
98            .lock()
99            .map_err(|_| VectorStoreError::Db("vector store mutex poisoned".to_string()))?;
100        conn.execute_batch(
101            "CREATE TABLE IF NOT EXISTS rag_chunks (
102                id TEXT PRIMARY KEY,
103                tenant_id TEXT NOT NULL,
104                workspace_id TEXT NOT NULL,
105                source_id TEXT NOT NULL,
106                doc_id TEXT NOT NULL,
107                chunk_index INTEGER NOT NULL,
108                text TEXT NOT NULL,
109                embedding BLOB NOT NULL,
110                metadata_json TEXT,
111                acl_allow_users TEXT,
112                acl_allow_groups TEXT,
113                acl_deny_users TEXT,
114                acl_deny_groups TEXT,
115                updated_at INTEGER
116            );
117            CREATE INDEX IF NOT EXISTS idx_rag_chunks_scope
118                ON rag_chunks (tenant_id, workspace_id, source_id, doc_id);
119            ",
120        )?;
121        Ok(())
122    }
123}
124
125#[async_trait]
126impl VectorStore for SqliteVectorStore {
127    async fn upsert(&self, chunks: Vec<EmbeddedChunk>) -> Result<(), VectorStoreError> {
128        let mut conn = self
129            .conn
130            .lock()
131            .map_err(|_| VectorStoreError::Db("vector store mutex poisoned".to_string()))?;
132        let tx = conn.transaction()?;
133        for chunk in chunks {
134            let metadata_json = serde_json::to_string(&chunk.metadata)
135                .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
136            let acl_allow_users = serde_json::to_string(&chunk.acl.allow_users)
137                .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
138            let acl_allow_groups = serde_json::to_string(&chunk.acl.allow_groups)
139                .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
140            let acl_deny_users = serde_json::to_string(&chunk.acl.deny_users)
141                .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
142            let acl_deny_groups = serde_json::to_string(&chunk.acl.deny_groups)
143                .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
144            let embedding_blob = serialize_embedding(&chunk.embedding);
145            tx.execute(
146                "INSERT OR REPLACE INTO rag_chunks (
147                    id, tenant_id, workspace_id, source_id, doc_id, chunk_index,
148                    text, embedding, metadata_json, acl_allow_users, acl_allow_groups,
149                    acl_deny_users, acl_deny_groups, updated_at
150                ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, strftime('%s','now'))",
151                params![
152                    chunk.id,
153                    chunk.tenant_id,
154                    chunk.workspace_id,
155                    chunk.source_id,
156                    chunk.doc_id,
157                    chunk.chunk_index,
158                    chunk.text,
159                    embedding_blob,
160                    metadata_json,
161                    acl_allow_users,
162                    acl_allow_groups,
163                    acl_deny_users,
164                    acl_deny_groups,
165                ],
166            )?;
167        }
168        tx.commit()?;
169        Ok(())
170    }
171
172    async fn delete_by_doc(
173        &self,
174        tenant_id: &str,
175        workspace_id: &str,
176        source_id: &str,
177        doc_id: &str,
178    ) -> Result<(), VectorStoreError> {
179        let conn = self
180            .conn
181            .lock()
182            .map_err(|_| VectorStoreError::Db("vector store mutex poisoned".to_string()))?;
183        conn.execute(
184            "DELETE FROM rag_chunks WHERE tenant_id = ?1 AND workspace_id = ?2 AND source_id = ?3 AND doc_id = ?4",
185            params![tenant_id, workspace_id, source_id, doc_id],
186        )?;
187        Ok(())
188    }
189
190    async fn query(
191        &self,
192        request: VectorQuery,
193    ) -> Result<Vec<VectorSearchResult>, VectorStoreError> {
194        if request.query_embedding.is_empty() {
195            return Err(VectorStoreError::InvalidInput(
196                "query embedding is empty".to_string(),
197            ));
198        }
199        let conn = self
200            .conn
201            .lock()
202            .map_err(|_| VectorStoreError::Db("vector store mutex poisoned".to_string()))?;
203
204        let mut sql = String::from(
205            "SELECT id, tenant_id, workspace_id, source_id, doc_id, chunk_index, text, embedding,
206                    metadata_json, acl_allow_users, acl_allow_groups, acl_deny_users, acl_deny_groups
207             FROM rag_chunks
208             WHERE tenant_id = ? AND workspace_id = ?",
209        );
210        if let Some(source_ids) = &request.source_ids {
211            if !source_ids.is_empty() {
212                let placeholders = vec!["?"; source_ids.len()].join(", ");
213                sql.push_str(" AND source_id IN (");
214                sql.push_str(&placeholders);
215                sql.push(')');
216            }
217        }
218
219        let mut stmt = conn.prepare(&sql)?;
220        let mut params_vec: Vec<&dyn rusqlite::ToSql> = Vec::new();
221        params_vec.push(&request.tenant_id as &dyn rusqlite::ToSql);
222        params_vec.push(&request.workspace_id as &dyn rusqlite::ToSql);
223        if let Some(source_ids) = &request.source_ids {
224            for source_id in source_ids {
225                params_vec.push(source_id as &dyn rusqlite::ToSql);
226            }
227        }
228
229        let mut rows = stmt.query(params_vec.as_slice())?;
230        let mut results = Vec::new();
231        let allowed_users: HashSet<String> = request.allowed_users.iter().cloned().collect();
232        let allowed_groups: HashSet<String> = request.allowed_groups.iter().cloned().collect();
233
234        while let Some(row) = rows.next()? {
235            let embedding_blob: Vec<u8> = row.get(7)?;
236            let embedding = deserialize_embedding(&embedding_blob);
237            if embedding.len() != request.query_embedding.len() {
238                continue;
239            }
240            let acl_allow_users: String = row.get(9)?;
241            let acl_allow_groups: String = row.get(10)?;
242            let acl_deny_users: String = row.get(11)?;
243            let acl_deny_groups: String = row.get(12)?;
244
245            let allow_users: Vec<String> = parse_json_list(&acl_allow_users);
246            let allow_groups: Vec<String> = parse_json_list(&acl_allow_groups);
247            let deny_users: Vec<String> = parse_json_list(&acl_deny_users);
248            let deny_groups: Vec<String> = parse_json_list(&acl_deny_groups);
249
250            if !is_allowed(&allowed_users, &allowed_groups, &allow_users, &allow_groups) {
251                continue;
252            }
253            if is_denied(&allowed_users, &allowed_groups, &deny_users, &deny_groups) {
254                continue;
255            }
256
257            // TODO: replace brute-force scoring with an ANN index (HNSW) for scale.
258            let score = cosine_similarity(&request.query_embedding, &embedding);
259            if score < request.min_score {
260                continue;
261            }
262
263            let metadata_json: String = row.get(8)?;
264            let metadata: HashMap<String, String> =
265                serde_json::from_str(&metadata_json).unwrap_or_else(|_| HashMap::new());
266
267            let chunk = EmbeddedChunk {
268                id: row.get(0)?,
269                tenant_id: row.get(1)?,
270                workspace_id: row.get(2)?,
271                source_id: row.get(3)?,
272                doc_id: row.get(4)?,
273                chunk_index: row.get(5)?,
274                text: row.get(6)?,
275                embedding,
276                metadata,
277                acl: AccessControl {
278                    allow_users,
279                    allow_groups,
280                    deny_users,
281                    deny_groups,
282                },
283            };
284
285            results.push(VectorSearchResult { chunk, score });
286        }
287
288        results.sort_by(|a, b| {
289            b.score
290                .partial_cmp(&a.score)
291                .unwrap_or(std::cmp::Ordering::Equal)
292        });
293        results.truncate(request.top_k);
294        Ok(results)
295    }
296}
297
298fn serialize_embedding(embedding: &[f32]) -> Vec<u8> {
299    let mut bytes = Vec::with_capacity(embedding.len() * 4);
300    for val in embedding {
301        bytes.extend_from_slice(&val.to_le_bytes());
302    }
303    bytes
304}
305
306fn deserialize_embedding(bytes: &[u8]) -> Vec<f32> {
307    bytes
308        .chunks_exact(4)
309        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
310        .collect()
311}
312
313fn parse_json_list(value: &str) -> Vec<String> {
314    match serde_json::from_str::<Value>(value) {
315        Ok(Value::Array(items)) => items
316            .into_iter()
317            .filter_map(|v| v.as_str().map(|s| s.to_string()))
318            .collect(),
319        _ => Vec::new(),
320    }
321}
322
323fn is_allowed(
324    allowed_users: &HashSet<String>,
325    allowed_groups: &HashSet<String>,
326    acl_users: &[String],
327    acl_groups: &[String],
328) -> bool {
329    if acl_users.is_empty() && acl_groups.is_empty() {
330        return true;
331    }
332    acl_users.iter().any(|u| allowed_users.contains(u))
333        || acl_groups.iter().any(|g| allowed_groups.contains(g))
334}
335
336fn is_denied(
337    allowed_users: &HashSet<String>,
338    allowed_groups: &HashSet<String>,
339    deny_users: &[String],
340    deny_groups: &[String],
341) -> bool {
342    deny_users.iter().any(|u| allowed_users.contains(u))
343        || deny_groups.iter().any(|g| allowed_groups.contains(g))
344}
345
346fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
347    let mut dot = 0.0f32;
348    let mut norm_a = 0.0f32;
349    let mut norm_b = 0.0f32;
350    for (x, y) in a.iter().zip(b.iter()) {
351        dot += x * y;
352        norm_a += x * x;
353        norm_b += y * y;
354    }
355    if norm_a == 0.0 || norm_b == 0.0 {
356        return 0.0;
357    }
358    dot / (norm_a.sqrt() * norm_b.sqrt())
359}