Skip to main content

claw_vector/store/
sqlite.rs

1// store/sqlite.rs — SQLite persistence layer for VectorRecord and Collection metadata.
2use std::{collections::HashMap, path::Path};
3
4use sqlx::{
5    sqlite::{SqliteConnectOptions, SqlitePoolOptions},
6    QueryBuilder, Sqlite, SqlitePool,
7};
8use uuid::Uuid;
9
10use crate::{
11    error::{VectorError, VectorResult},
12    types::{Collection, CollectionStats, DistanceMetric, IndexType, VectorRecord},
13};
14
15/// Manages all SQLite read/write operations for collections and vector records.
16pub struct VectorStore {
17    pool: SqlitePool,
18}
19
20/// Backward-compatible alias for [`VectorStore`].
21pub type SqliteStore = VectorStore;
22
23impl VectorStore {
24    /// Open (or create) the SQLite database at `db_path`, applying schema migrations.
25    pub async fn new(db_path: &Path) -> VectorResult<Self> {
26        if let Some(parent) = db_path.parent() {
27            std::fs::create_dir_all(parent)?;
28        }
29
30        let options = SqliteConnectOptions::new()
31            .filename(db_path)
32            .create_if_missing(true)
33            .foreign_keys(true);
34
35        let pool = SqlitePoolOptions::new()
36            .max_connections(8)
37            .connect_with(options)
38            .await?;
39
40        sqlx::query("PRAGMA journal_mode = WAL")
41            .execute(&pool)
42            .await?;
43        sqlx::query("PRAGMA synchronous = NORMAL")
44            .execute(&pool)
45            .await?;
46        sqlx::query("PRAGMA temp_store = MEMORY")
47            .execute(&pool)
48            .await?;
49
50        sqlx::migrate!()
51            .run(&pool)
52            .await
53            .map_err(|err| VectorError::Index(format!("failed to run SQLite migrations: {err}")))?;
54
55        Ok(VectorStore { pool })
56    }
57
58    /// Alias for [`VectorStore::new`].
59    pub async fn open(path: &Path) -> VectorResult<Self> {
60        Self::new(path).await
61    }
62
63    /// Return the underlying SQLx connection pool.
64    pub fn pool(&self) -> &SqlitePool {
65        &self.pool
66    }
67
68    /// Persist a new collection definition (upsert).
69    pub async fn create_collection(&self, workspace_id: &str, col: &Collection) -> VectorResult<()> {
70        sqlx::query(
71            r#"INSERT INTO collections
72               (workspace_id, name, dimensions, distance, index_type, ef_construction, m_connections,
73                created_at, vector_count, metadata)
74               VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"#,
75        )
76        .bind(workspace_id)
77        .bind(&col.name)
78        .bind(col.dimensions as i64)
79        .bind(distance_to_db(col.distance))
80        .bind(index_type_to_db(col.index_type))
81        .bind(col.ef_construction as i64)
82        .bind(col.m_connections as i64)
83        .bind(col.created_at.to_rfc3339())
84        .bind(col.vector_count as i64)
85        .bind(normalize_metadata(&col.metadata)?)
86        .execute(&self.pool)
87        .await?;
88        Ok(())
89    }
90
91    /// Alias for [`VectorStore::create_collection`].
92    pub async fn save_collection(&self, workspace_id: &str, col: &Collection) -> VectorResult<()> {
93        self.create_collection(workspace_id, col).await
94    }
95
96    /// Retrieve a collection by name.
97    pub async fn get_collection(&self, workspace_id: &str, name: &str) -> VectorResult<Collection> {
98        let row = sqlx::query_as::<_, CollectionRow>(
99            "SELECT workspace_id, name, dimensions, distance, index_type, ef_construction, m_connections, \
100             created_at, vector_count, metadata FROM collections WHERE workspace_id = ? AND name = ?",
101        )
102        .bind(workspace_id)
103        .bind(name)
104        .fetch_optional(&self.pool)
105        .await?;
106
107        match row {
108            Some(row) => collection_from_row(row),
109            None => Err(VectorError::NotFound {
110                entity: "collection".into(),
111                id: name.to_string(),
112            }),
113        }
114    }
115
116    /// Delete a collection by name.
117    pub async fn delete_collection(&self, workspace_id: &str, name: &str) -> VectorResult<()> {
118        let mut tx = self.pool.begin().await?;
119        sqlx::query("DELETE FROM vector_records WHERE workspace_id = ? AND collection = ?")
120            .bind(workspace_id)
121            .bind(name)
122            .execute(&mut *tx)
123            .await?;
124        sqlx::query("DELETE FROM collections WHERE workspace_id = ? AND name = ?")
125            .bind(workspace_id)
126            .bind(name)
127            .execute(&mut *tx)
128            .await?;
129        tx.commit().await?;
130        Ok(())
131    }
132
133    /// List all collections.
134    pub async fn list_collections(&self, workspace_id: &str) -> VectorResult<Vec<Collection>> {
135        let rows = sqlx::query_as::<_, CollectionRow>(
136            "SELECT workspace_id, name, dimensions, distance, index_type, ef_construction, m_connections, \
137             created_at, vector_count, metadata FROM collections WHERE workspace_id = ? ORDER BY name",
138        )
139        .bind(workspace_id)
140        .fetch_all(&self.pool)
141        .await?;
142
143        rows.into_iter().map(collection_from_row).collect()
144    }
145
146    /// List all collections across all workspaces.
147    pub async fn list_all_collections(&self) -> VectorResult<Vec<Collection>> {
148        let rows = sqlx::query_as::<_, CollectionRow>(
149            "SELECT workspace_id, name, dimensions, distance, index_type, ef_construction, m_connections, \
150             created_at, vector_count, metadata FROM collections ORDER BY workspace_id, name",
151        )
152        .fetch_all(&self.pool)
153        .await?;
154
155        rows.into_iter().map(collection_from_row).collect()
156    }
157
158    /// Persist a vector record, linking it to the given `internal_id`.
159    pub async fn insert_record(
160        &self,
161        workspace_id: &str,
162        record: &VectorRecord,
163        internal_id: usize,
164    ) -> VectorResult<()> {
165        sqlx::query(
166            r#"INSERT INTO vector_records
167               (id, internal_id, workspace_id, collection, text, metadata, created_at)
168               VALUES (?, ?, ?, ?, ?, ?, ?)"#,
169        )
170        .bind(record.id.to_string())
171        .bind(internal_id as i64)
172        .bind(workspace_id)
173        .bind(&record.collection)
174        .bind(&record.text)
175        .bind(normalize_metadata(&record.metadata)?)
176        .bind(record.created_at.to_rfc3339())
177        .execute(&self.pool)
178        .await?;
179        Ok(())
180    }
181
182    /// Alias for [`VectorStore::insert_record`].
183    pub async fn save_record(
184        &self,
185        workspace_id: &str,
186        record: &VectorRecord,
187        internal_id: usize,
188    ) -> VectorResult<()> {
189        self.insert_record(workspace_id, record, internal_id).await
190    }
191
192    /// Retrieve a record and its internal identifier by UUID.
193    pub async fn get_record(&self, workspace_id: &str, id: Uuid) -> VectorResult<(VectorRecord, usize)> {
194        let row = sqlx::query_as::<_, RecordRow>(
195            "SELECT id, internal_id, workspace_id, collection, text, metadata, created_at \
196             FROM vector_records WHERE workspace_id = ? AND id = ?",
197        )
198        .bind(workspace_id)
199        .bind(id.to_string())
200        .fetch_optional(&self.pool)
201        .await?;
202
203        match row {
204            Some(row) => record_from_row(row),
205            None => Err(VectorError::NotFound {
206                entity: "record".into(),
207                id: id.to_string(),
208            }),
209        }
210    }
211
212    /// Delete a vector record by id and return its previous internal id when found.
213    pub async fn delete_record(&self, workspace_id: &str, id: Uuid) -> VectorResult<Option<usize>> {
214        let mut tx = self.pool.begin().await?;
215        let internal_id =
216            sqlx::query_scalar::<_, i64>("SELECT internal_id FROM vector_records WHERE workspace_id = ? AND id = ?")
217                .bind(workspace_id)
218                .bind(id.to_string())
219                .fetch_optional(&mut *tx)
220                .await?
221                .map(|value| value as usize);
222
223        if internal_id.is_some() {
224            sqlx::query("DELETE FROM vector_records WHERE workspace_id = ? AND id = ?")
225                .bind(workspace_id)
226                .bind(id.to_string())
227                .execute(&mut *tx)
228                .await?;
229        }
230
231        tx.commit().await?;
232        Ok(internal_id)
233    }
234
235    /// Insert multiple vector records in a single transaction.
236    pub async fn batch_insert_records(
237        &self,
238        workspace_id: &str,
239        records: &[(VectorRecord, usize)],
240    ) -> VectorResult<()> {
241        let mut tx = self.pool.begin().await?;
242        for (record, internal_id) in records {
243            sqlx::query(
244                r#"INSERT INTO vector_records
245                   (id, internal_id, workspace_id, collection, text, metadata, created_at)
246                   VALUES (?, ?, ?, ?, ?, ?, ?)"#,
247            )
248            .bind(record.id.to_string())
249            .bind(*internal_id as i64)
250            .bind(workspace_id)
251            .bind(&record.collection)
252            .bind(&record.text)
253            .bind(normalize_metadata(&record.metadata)?)
254            .bind(record.created_at.to_rfc3339())
255            .execute(&mut *tx)
256            .await?;
257        }
258        tx.commit().await?;
259        Ok(())
260    }
261
262    /// Resolve a record UUID to its internal id.
263    pub async fn uuid_to_internal(&self, workspace_id: &str, id: Uuid) -> VectorResult<usize> {
264        let internal_id =
265            sqlx::query_scalar::<_, i64>("SELECT internal_id FROM vector_records WHERE workspace_id = ? AND id = ?")
266                .bind(workspace_id)
267                .bind(id.to_string())
268                .fetch_optional(&self.pool)
269                .await?
270                .ok_or_else(|| VectorError::NotFound {
271                    entity: "record".into(),
272                    id: id.to_string(),
273                })?;
274        Ok(internal_id as usize)
275    }
276
277    /// Resolve a collection-scoped internal id to its UUID.
278    pub async fn internal_to_uuid(
279        &self,
280        workspace_id: &str,
281        collection: &str,
282        internal_id: usize,
283    ) -> VectorResult<Uuid> {
284        let id = sqlx::query_scalar::<_, String>(
285            "SELECT id FROM vector_records WHERE workspace_id = ? AND collection = ? AND internal_id = ?",
286        )
287        .bind(workspace_id)
288        .bind(collection)
289        .bind(internal_id as i64)
290        .fetch_optional(&self.pool)
291        .await?
292        .ok_or_else(|| VectorError::NotFound {
293            entity: "record".into(),
294            id: format!("{collection}:{internal_id}"),
295        })?;
296        Uuid::parse_str(&id)
297            .map_err(|err| VectorError::Index(format!("invalid UUID stored in SQLite: {err}")))
298    }
299
300    /// Bulk-resolve collection-scoped internal ids to stored vector metadata.
301    pub async fn bulk_internal_to_uuid(
302        &self,
303        workspace_id: &str,
304        collection: &str,
305        ids: &[usize],
306    ) -> VectorResult<Vec<(usize, VectorRecord)>> {
307        if ids.is_empty() {
308            return Ok(Vec::new());
309        }
310
311        let mut builder = QueryBuilder::<Sqlite>::new(
312            "SELECT id, internal_id, workspace_id, collection, text, metadata, created_at FROM vector_records WHERE workspace_id = ",
313        );
314        builder.push_bind(workspace_id);
315        builder.push(" AND collection = ");
316        builder.push_bind(collection);
317        builder.push(" AND internal_id IN (");
318        let mut separated = builder.separated(", ");
319        for id in ids {
320            separated.push_bind(*id as i64);
321        }
322        separated.push_unseparated(") ORDER BY internal_id ASC");
323
324        let rows = builder
325            .build_query_as::<RecordRow>()
326            .fetch_all(&self.pool)
327            .await?;
328
329        let resolved = rows
330            .into_iter()
331            .map(record_from_row)
332            .collect::<VectorResult<Vec<_>>>()?;
333
334        let mut by_id = HashMap::with_capacity(resolved.len());
335        for (record, internal_id) in resolved {
336            by_id.insert(internal_id, record);
337        }
338
339        Ok(ids
340            .iter()
341            .filter_map(|id| by_id.remove(id).map(|record| (*id, record)))
342            .collect())
343    }
344
345    /// Increment a collection's stored vector count.
346    pub async fn increment_vector_count(
347        &self,
348        workspace_id: &str,
349        collection: &str,
350        delta: i64,
351    ) -> VectorResult<()> {
352        sqlx::query(
353            "UPDATE collections SET vector_count = MAX(vector_count + ?, 0) WHERE workspace_id = ? AND name = ?",
354        )
355        .bind(delta)
356        .bind(workspace_id)
357        .bind(collection)
358        .execute(&self.pool)
359        .await?;
360        Ok(())
361    }
362
363    /// Update the persisted index type for a collection.
364    pub async fn update_collection_index_type(
365        &self,
366        workspace_id: &str,
367        collection: &str,
368        index_type: IndexType,
369    ) -> VectorResult<()> {
370        sqlx::query("UPDATE collections SET index_type = ? WHERE workspace_id = ? AND name = ?")
371            .bind(index_type_to_db(index_type))
372            .bind(workspace_id)
373            .bind(collection)
374            .execute(&self.pool)
375            .await?;
376        Ok(())
377    }
378
379    /// Return collection storage statistics as tracked in SQLite.
380    pub async fn collection_stats(&self, workspace_id: &str, name: &str) -> VectorResult<CollectionStats> {
381        let vector_count =
382            sqlx::query_scalar::<_, i64>("SELECT vector_count FROM collections WHERE workspace_id = ? AND name = ?")
383                .bind(workspace_id)
384                .bind(name)
385                .fetch_optional(&self.pool)
386                .await?
387                .ok_or_else(|| VectorError::NotFound {
388                    entity: "collection".into(),
389                    id: name.to_string(),
390                })?;
391
392        let record_bytes = sqlx::query_scalar::<_, i64>(
393            "SELECT COALESCE(SUM(LENGTH(id) + LENGTH(IFNULL(text, '')) + LENGTH(metadata) + LENGTH(created_at) + 8), 0) FROM vector_records WHERE workspace_id = ? AND collection = ?",
394        )
395        .bind(workspace_id)
396        .bind(name)
397        .fetch_one(&self.pool)
398        .await?;
399
400        let collection_bytes = sqlx::query_scalar::<_, i64>(
401            "SELECT LENGTH(name) + LENGTH(distance) + LENGTH(index_type) + LENGTH(created_at) + LENGTH(metadata) + 32 FROM collections WHERE workspace_id = ? AND name = ?",
402        )
403        .bind(workspace_id)
404        .bind(name)
405        .fetch_one(&self.pool)
406        .await?;
407
408        Ok(CollectionStats {
409            vector_count: vector_count as u64,
410            size_bytes: (record_bytes + collection_bytes.max(0)) as u64,
411        })
412    }
413
414    /// Return the next available internal id for a collection.
415    pub async fn next_internal_id(&self, workspace_id: &str, collection: &str) -> VectorResult<usize> {
416        let max_internal_id = sqlx::query_scalar::<_, Option<i64>>(
417            "SELECT MAX(internal_id) FROM vector_records WHERE workspace_id = ? AND collection = ?",
418        )
419        .bind(workspace_id)
420        .bind(collection)
421        .fetch_one(&self.pool)
422        .await?;
423        Ok(max_internal_id.map(|value| value as usize + 1).unwrap_or(0))
424    }
425
426    /// Load all persisted records for a collection, ordered by internal id.
427    pub async fn list_records_for_collection(
428        &self,
429        workspace_id: &str,
430        collection: &str,
431    ) -> VectorResult<Vec<(VectorRecord, usize)>> {
432        let rows = sqlx::query_as::<_, RecordRow>(
433            "SELECT id, internal_id, workspace_id, collection, text, metadata, created_at FROM vector_records WHERE workspace_id = ? AND collection = ? ORDER BY internal_id ASC",
434        )
435        .bind(workspace_id)
436        .bind(collection)
437        .fetch_all(&self.pool)
438        .await?;
439
440        rows.into_iter().map(record_from_row).collect()
441    }
442
443    /// Search full-text records for a collection using SQLite FTS5.
444    pub async fn keyword_search(
445        &self,
446        workspace_id: &str,
447        collection: &str,
448        query: &str,
449        limit: usize,
450    ) -> VectorResult<Vec<(usize, VectorRecord, f32)>> {
451        if query.trim().is_empty() || limit == 0 {
452            return Ok(Vec::new());
453        }
454
455        let rows = sqlx::query_as::<_, KeywordRow>(
456            r#"
457                 SELECT vr.id, vr.internal_id, vr.workspace_id, vr.collection, vr.text, vr.metadata, vr.created_at,
458                   CAST(bm25(vector_records_fts) AS REAL) AS rank
459            FROM vector_records_fts
460            JOIN vector_records AS vr ON vr.rowid = vector_records_fts.rowid
461            WHERE vr.workspace_id = ? AND vr.collection = ? AND vector_records_fts MATCH ?
462            ORDER BY rank ASC
463            LIMIT ?
464            "#,
465        )
466        .bind(workspace_id)
467        .bind(collection)
468        .bind(query)
469        .bind(limit as i64)
470        .fetch_all(&self.pool)
471        .await?;
472
473        rows.into_iter()
474            .map(|row| {
475                let rank = row.rank.unwrap_or(0.0);
476                let record_row = RecordRow {
477                    id: row.id,
478                    internal_id: row.internal_id,
479                    workspace_id: row.workspace_id,
480                    collection: row.collection,
481                    text: row.text,
482                    metadata: row.metadata,
483                    created_at: row.created_at,
484                };
485                let (record, internal_id) = record_from_row(record_row)?;
486                Ok((internal_id, record, rank))
487            })
488            .collect()
489    }
490
491    /// Close the underlying SQLx pool.
492    pub async fn close(&self) {
493        self.pool.close().await;
494    }
495}
496
497#[derive(Debug, sqlx::FromRow)]
498struct CollectionRow {
499    workspace_id: String,
500    name: String,
501    dimensions: i64,
502    distance: String,
503    index_type: String,
504    ef_construction: i64,
505    m_connections: i64,
506    created_at: String,
507    vector_count: i64,
508    metadata: String,
509}
510
511#[derive(Debug, sqlx::FromRow)]
512struct RecordRow {
513    id: String,
514    internal_id: i64,
515    #[allow(dead_code)]
516    workspace_id: String,
517    collection: String,
518    text: Option<String>,
519    metadata: String,
520    created_at: String,
521}
522
523#[derive(Debug, sqlx::FromRow)]
524struct KeywordRow {
525    id: String,
526    internal_id: i64,
527    workspace_id: String,
528    collection: String,
529    text: Option<String>,
530    metadata: String,
531    created_at: String,
532    rank: Option<f32>,
533}
534
535/// Convert a raw database row into a [`Collection`], parsing JSON and RFC-3339 fields.
536fn collection_from_row(row: CollectionRow) -> VectorResult<Collection> {
537    Ok(Collection {
538        workspace_id: row.workspace_id,
539        name: row.name,
540        dimensions: row.dimensions as usize,
541        distance: distance_from_db(&row.distance)?,
542        index_type: index_type_from_db(&row.index_type)?,
543        ef_construction: row.ef_construction as usize,
544        m_connections: row.m_connections as usize,
545        created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
546            .map_err(|e| VectorError::Index(format!("invalid timestamp in DB: {e}")))?
547            .with_timezone(&chrono::Utc),
548        vector_count: row.vector_count as u64,
549        metadata: parse_metadata(&row.metadata)?,
550    })
551}
552
553fn record_from_row(row: RecordRow) -> VectorResult<(VectorRecord, usize)> {
554    let id = Uuid::parse_str(&row.id).map_err(|err| {
555        VectorError::Index(format!(
556            "invalid UUID stored in vector_records table: {err}"
557        ))
558    })?;
559    let record = VectorRecord {
560        id,
561        collection: row.collection,
562        vector: Vec::new(),
563        metadata: parse_metadata(&row.metadata)?,
564        text: row.text,
565        created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
566            .map_err(|e| VectorError::Index(format!("invalid timestamp in DB: {e}")))?
567            .with_timezone(&chrono::Utc),
568    };
569    Ok((record, row.internal_id as usize))
570}
571
572fn normalize_metadata(metadata: &serde_json::Value) -> VectorResult<String> {
573    if metadata.is_null() {
574        Ok("{}".to_string())
575    } else {
576        serde_json::to_string(metadata).map_err(Into::into)
577    }
578}
579
580fn parse_metadata(metadata: &str) -> VectorResult<serde_json::Value> {
581    if metadata.trim().is_empty() {
582        Ok(serde_json::json!({}))
583    } else {
584        Ok(serde_json::from_str(metadata)?)
585    }
586}
587
588fn distance_to_db(distance: DistanceMetric) -> &'static str {
589    match distance {
590        DistanceMetric::Cosine => "cosine",
591        DistanceMetric::Euclidean => "euclidean",
592        DistanceMetric::DotProduct => "dot_product",
593    }
594}
595
596fn distance_from_db(distance: &str) -> VectorResult<DistanceMetric> {
597    match distance.trim_matches('"') {
598        "cosine" => Ok(DistanceMetric::Cosine),
599        "euclidean" => Ok(DistanceMetric::Euclidean),
600        "dot_product" => Ok(DistanceMetric::DotProduct),
601        other => Err(VectorError::Index(format!(
602            "unsupported distance metric '{other}'"
603        ))),
604    }
605}
606
607fn index_type_to_db(index_type: IndexType) -> &'static str {
608    match index_type {
609        IndexType::HNSW => "hnsw",
610        IndexType::Flat => "flat",
611    }
612}
613
614fn index_type_from_db(index_type: &str) -> VectorResult<IndexType> {
615    match index_type.trim_matches('"') {
616        "hnsw" => Ok(IndexType::HNSW),
617        "flat" => Ok(IndexType::Flat),
618        other => Err(VectorError::Index(format!(
619            "unsupported index type '{other}'"
620        ))),
621    }
622}