1use std::{collections::HashMap, path::Path};
3
4use sqlx::{
5 sqlite::{SqliteConnectOptions, SqlitePoolOptions},
6 QueryBuilder, Sqlite, SqlitePool,
7};
8use uuid::Uuid;
9
10use crate::{
11 error::{VectorError, VectorResult},
12 types::{Collection, CollectionStats, DistanceMetric, IndexType, VectorRecord},
13};
14
15pub struct VectorStore {
17 pool: SqlitePool,
18}
19
20pub type SqliteStore = VectorStore;
22
23impl VectorStore {
24 pub async fn new(db_path: &Path) -> VectorResult<Self> {
26 if let Some(parent) = db_path.parent() {
27 std::fs::create_dir_all(parent)?;
28 }
29
30 let options = SqliteConnectOptions::new()
31 .filename(db_path)
32 .create_if_missing(true)
33 .foreign_keys(true);
34
35 let pool = SqlitePoolOptions::new()
36 .max_connections(8)
37 .connect_with(options)
38 .await?;
39
40 sqlx::query("PRAGMA journal_mode = WAL")
41 .execute(&pool)
42 .await?;
43 sqlx::query("PRAGMA synchronous = NORMAL")
44 .execute(&pool)
45 .await?;
46 sqlx::query("PRAGMA temp_store = MEMORY")
47 .execute(&pool)
48 .await?;
49
50 sqlx::migrate!()
51 .run(&pool)
52 .await
53 .map_err(|err| VectorError::Index(format!("failed to run SQLite migrations: {err}")))?;
54
55 Ok(VectorStore { pool })
56 }
57
58 pub async fn open(path: &Path) -> VectorResult<Self> {
60 Self::new(path).await
61 }
62
63 pub fn pool(&self) -> &SqlitePool {
65 &self.pool
66 }
67
68 pub async fn create_collection(&self, workspace_id: &str, col: &Collection) -> VectorResult<()> {
70 sqlx::query(
71 r#"INSERT INTO collections
72 (workspace_id, name, dimensions, distance, index_type, ef_construction, m_connections,
73 created_at, vector_count, metadata)
74 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"#,
75 )
76 .bind(workspace_id)
77 .bind(&col.name)
78 .bind(col.dimensions as i64)
79 .bind(distance_to_db(col.distance))
80 .bind(index_type_to_db(col.index_type))
81 .bind(col.ef_construction as i64)
82 .bind(col.m_connections as i64)
83 .bind(col.created_at.to_rfc3339())
84 .bind(col.vector_count as i64)
85 .bind(normalize_metadata(&col.metadata)?)
86 .execute(&self.pool)
87 .await?;
88 Ok(())
89 }
90
91 pub async fn save_collection(&self, workspace_id: &str, col: &Collection) -> VectorResult<()> {
93 self.create_collection(workspace_id, col).await
94 }
95
96 pub async fn get_collection(&self, workspace_id: &str, name: &str) -> VectorResult<Collection> {
98 let row = sqlx::query_as::<_, CollectionRow>(
99 "SELECT workspace_id, name, dimensions, distance, index_type, ef_construction, m_connections, \
100 created_at, vector_count, metadata FROM collections WHERE workspace_id = ? AND name = ?",
101 )
102 .bind(workspace_id)
103 .bind(name)
104 .fetch_optional(&self.pool)
105 .await?;
106
107 match row {
108 Some(row) => collection_from_row(row),
109 None => Err(VectorError::NotFound {
110 entity: "collection".into(),
111 id: name.to_string(),
112 }),
113 }
114 }
115
116 pub async fn delete_collection(&self, workspace_id: &str, name: &str) -> VectorResult<()> {
118 let mut tx = self.pool.begin().await?;
119 sqlx::query("DELETE FROM vector_records WHERE workspace_id = ? AND collection = ?")
120 .bind(workspace_id)
121 .bind(name)
122 .execute(&mut *tx)
123 .await?;
124 sqlx::query("DELETE FROM collections WHERE workspace_id = ? AND name = ?")
125 .bind(workspace_id)
126 .bind(name)
127 .execute(&mut *tx)
128 .await?;
129 tx.commit().await?;
130 Ok(())
131 }
132
133 pub async fn list_collections(&self, workspace_id: &str) -> VectorResult<Vec<Collection>> {
135 let rows = sqlx::query_as::<_, CollectionRow>(
136 "SELECT workspace_id, name, dimensions, distance, index_type, ef_construction, m_connections, \
137 created_at, vector_count, metadata FROM collections WHERE workspace_id = ? ORDER BY name",
138 )
139 .bind(workspace_id)
140 .fetch_all(&self.pool)
141 .await?;
142
143 rows.into_iter().map(collection_from_row).collect()
144 }
145
146 pub async fn list_all_collections(&self) -> VectorResult<Vec<Collection>> {
148 let rows = sqlx::query_as::<_, CollectionRow>(
149 "SELECT workspace_id, name, dimensions, distance, index_type, ef_construction, m_connections, \
150 created_at, vector_count, metadata FROM collections ORDER BY workspace_id, name",
151 )
152 .fetch_all(&self.pool)
153 .await?;
154
155 rows.into_iter().map(collection_from_row).collect()
156 }
157
158 pub async fn insert_record(
160 &self,
161 workspace_id: &str,
162 record: &VectorRecord,
163 internal_id: usize,
164 ) -> VectorResult<()> {
165 sqlx::query(
166 r#"INSERT INTO vector_records
167 (id, internal_id, workspace_id, collection, text, metadata, created_at)
168 VALUES (?, ?, ?, ?, ?, ?, ?)"#,
169 )
170 .bind(record.id.to_string())
171 .bind(internal_id as i64)
172 .bind(workspace_id)
173 .bind(&record.collection)
174 .bind(&record.text)
175 .bind(normalize_metadata(&record.metadata)?)
176 .bind(record.created_at.to_rfc3339())
177 .execute(&self.pool)
178 .await?;
179 Ok(())
180 }
181
182 pub async fn save_record(
184 &self,
185 workspace_id: &str,
186 record: &VectorRecord,
187 internal_id: usize,
188 ) -> VectorResult<()> {
189 self.insert_record(workspace_id, record, internal_id).await
190 }
191
192 pub async fn get_record(&self, workspace_id: &str, id: Uuid) -> VectorResult<(VectorRecord, usize)> {
194 let row = sqlx::query_as::<_, RecordRow>(
195 "SELECT id, internal_id, workspace_id, collection, text, metadata, created_at \
196 FROM vector_records WHERE workspace_id = ? AND id = ?",
197 )
198 .bind(workspace_id)
199 .bind(id.to_string())
200 .fetch_optional(&self.pool)
201 .await?;
202
203 match row {
204 Some(row) => record_from_row(row),
205 None => Err(VectorError::NotFound {
206 entity: "record".into(),
207 id: id.to_string(),
208 }),
209 }
210 }
211
212 pub async fn delete_record(&self, workspace_id: &str, id: Uuid) -> VectorResult<Option<usize>> {
214 let mut tx = self.pool.begin().await?;
215 let internal_id =
216 sqlx::query_scalar::<_, i64>("SELECT internal_id FROM vector_records WHERE workspace_id = ? AND id = ?")
217 .bind(workspace_id)
218 .bind(id.to_string())
219 .fetch_optional(&mut *tx)
220 .await?
221 .map(|value| value as usize);
222
223 if internal_id.is_some() {
224 sqlx::query("DELETE FROM vector_records WHERE workspace_id = ? AND id = ?")
225 .bind(workspace_id)
226 .bind(id.to_string())
227 .execute(&mut *tx)
228 .await?;
229 }
230
231 tx.commit().await?;
232 Ok(internal_id)
233 }
234
235 pub async fn batch_insert_records(
237 &self,
238 workspace_id: &str,
239 records: &[(VectorRecord, usize)],
240 ) -> VectorResult<()> {
241 let mut tx = self.pool.begin().await?;
242 for (record, internal_id) in records {
243 sqlx::query(
244 r#"INSERT INTO vector_records
245 (id, internal_id, workspace_id, collection, text, metadata, created_at)
246 VALUES (?, ?, ?, ?, ?, ?, ?)"#,
247 )
248 .bind(record.id.to_string())
249 .bind(*internal_id as i64)
250 .bind(workspace_id)
251 .bind(&record.collection)
252 .bind(&record.text)
253 .bind(normalize_metadata(&record.metadata)?)
254 .bind(record.created_at.to_rfc3339())
255 .execute(&mut *tx)
256 .await?;
257 }
258 tx.commit().await?;
259 Ok(())
260 }
261
262 pub async fn uuid_to_internal(&self, workspace_id: &str, id: Uuid) -> VectorResult<usize> {
264 let internal_id =
265 sqlx::query_scalar::<_, i64>("SELECT internal_id FROM vector_records WHERE workspace_id = ? AND id = ?")
266 .bind(workspace_id)
267 .bind(id.to_string())
268 .fetch_optional(&self.pool)
269 .await?
270 .ok_or_else(|| VectorError::NotFound {
271 entity: "record".into(),
272 id: id.to_string(),
273 })?;
274 Ok(internal_id as usize)
275 }
276
277 pub async fn internal_to_uuid(
279 &self,
280 workspace_id: &str,
281 collection: &str,
282 internal_id: usize,
283 ) -> VectorResult<Uuid> {
284 let id = sqlx::query_scalar::<_, String>(
285 "SELECT id FROM vector_records WHERE workspace_id = ? AND collection = ? AND internal_id = ?",
286 )
287 .bind(workspace_id)
288 .bind(collection)
289 .bind(internal_id as i64)
290 .fetch_optional(&self.pool)
291 .await?
292 .ok_or_else(|| VectorError::NotFound {
293 entity: "record".into(),
294 id: format!("{collection}:{internal_id}"),
295 })?;
296 Uuid::parse_str(&id)
297 .map_err(|err| VectorError::Index(format!("invalid UUID stored in SQLite: {err}")))
298 }
299
300 pub async fn bulk_internal_to_uuid(
302 &self,
303 workspace_id: &str,
304 collection: &str,
305 ids: &[usize],
306 ) -> VectorResult<Vec<(usize, VectorRecord)>> {
307 if ids.is_empty() {
308 return Ok(Vec::new());
309 }
310
311 let mut builder = QueryBuilder::<Sqlite>::new(
312 "SELECT id, internal_id, workspace_id, collection, text, metadata, created_at FROM vector_records WHERE workspace_id = ",
313 );
314 builder.push_bind(workspace_id);
315 builder.push(" AND collection = ");
316 builder.push_bind(collection);
317 builder.push(" AND internal_id IN (");
318 let mut separated = builder.separated(", ");
319 for id in ids {
320 separated.push_bind(*id as i64);
321 }
322 separated.push_unseparated(") ORDER BY internal_id ASC");
323
324 let rows = builder
325 .build_query_as::<RecordRow>()
326 .fetch_all(&self.pool)
327 .await?;
328
329 let resolved = rows
330 .into_iter()
331 .map(record_from_row)
332 .collect::<VectorResult<Vec<_>>>()?;
333
334 let mut by_id = HashMap::with_capacity(resolved.len());
335 for (record, internal_id) in resolved {
336 by_id.insert(internal_id, record);
337 }
338
339 Ok(ids
340 .iter()
341 .filter_map(|id| by_id.remove(id).map(|record| (*id, record)))
342 .collect())
343 }
344
345 pub async fn increment_vector_count(
347 &self,
348 workspace_id: &str,
349 collection: &str,
350 delta: i64,
351 ) -> VectorResult<()> {
352 sqlx::query(
353 "UPDATE collections SET vector_count = MAX(vector_count + ?, 0) WHERE workspace_id = ? AND name = ?",
354 )
355 .bind(delta)
356 .bind(workspace_id)
357 .bind(collection)
358 .execute(&self.pool)
359 .await?;
360 Ok(())
361 }
362
363 pub async fn update_collection_index_type(
365 &self,
366 workspace_id: &str,
367 collection: &str,
368 index_type: IndexType,
369 ) -> VectorResult<()> {
370 sqlx::query("UPDATE collections SET index_type = ? WHERE workspace_id = ? AND name = ?")
371 .bind(index_type_to_db(index_type))
372 .bind(workspace_id)
373 .bind(collection)
374 .execute(&self.pool)
375 .await?;
376 Ok(())
377 }
378
379 pub async fn collection_stats(&self, workspace_id: &str, name: &str) -> VectorResult<CollectionStats> {
381 let vector_count =
382 sqlx::query_scalar::<_, i64>("SELECT vector_count FROM collections WHERE workspace_id = ? AND name = ?")
383 .bind(workspace_id)
384 .bind(name)
385 .fetch_optional(&self.pool)
386 .await?
387 .ok_or_else(|| VectorError::NotFound {
388 entity: "collection".into(),
389 id: name.to_string(),
390 })?;
391
392 let record_bytes = sqlx::query_scalar::<_, i64>(
393 "SELECT COALESCE(SUM(LENGTH(id) + LENGTH(IFNULL(text, '')) + LENGTH(metadata) + LENGTH(created_at) + 8), 0) FROM vector_records WHERE workspace_id = ? AND collection = ?",
394 )
395 .bind(workspace_id)
396 .bind(name)
397 .fetch_one(&self.pool)
398 .await?;
399
400 let collection_bytes = sqlx::query_scalar::<_, i64>(
401 "SELECT LENGTH(name) + LENGTH(distance) + LENGTH(index_type) + LENGTH(created_at) + LENGTH(metadata) + 32 FROM collections WHERE workspace_id = ? AND name = ?",
402 )
403 .bind(workspace_id)
404 .bind(name)
405 .fetch_one(&self.pool)
406 .await?;
407
408 Ok(CollectionStats {
409 vector_count: vector_count as u64,
410 size_bytes: (record_bytes + collection_bytes.max(0)) as u64,
411 })
412 }
413
414 pub async fn next_internal_id(&self, workspace_id: &str, collection: &str) -> VectorResult<usize> {
416 let max_internal_id = sqlx::query_scalar::<_, Option<i64>>(
417 "SELECT MAX(internal_id) FROM vector_records WHERE workspace_id = ? AND collection = ?",
418 )
419 .bind(workspace_id)
420 .bind(collection)
421 .fetch_one(&self.pool)
422 .await?;
423 Ok(max_internal_id.map(|value| value as usize + 1).unwrap_or(0))
424 }
425
426 pub async fn list_records_for_collection(
428 &self,
429 workspace_id: &str,
430 collection: &str,
431 ) -> VectorResult<Vec<(VectorRecord, usize)>> {
432 let rows = sqlx::query_as::<_, RecordRow>(
433 "SELECT id, internal_id, workspace_id, collection, text, metadata, created_at FROM vector_records WHERE workspace_id = ? AND collection = ? ORDER BY internal_id ASC",
434 )
435 .bind(workspace_id)
436 .bind(collection)
437 .fetch_all(&self.pool)
438 .await?;
439
440 rows.into_iter().map(record_from_row).collect()
441 }
442
443 pub async fn keyword_search(
445 &self,
446 workspace_id: &str,
447 collection: &str,
448 query: &str,
449 limit: usize,
450 ) -> VectorResult<Vec<(usize, VectorRecord, f32)>> {
451 if query.trim().is_empty() || limit == 0 {
452 return Ok(Vec::new());
453 }
454
455 let rows = sqlx::query_as::<_, KeywordRow>(
456 r#"
457 SELECT vr.id, vr.internal_id, vr.workspace_id, vr.collection, vr.text, vr.metadata, vr.created_at,
458 CAST(bm25(vector_records_fts) AS REAL) AS rank
459 FROM vector_records_fts
460 JOIN vector_records AS vr ON vr.rowid = vector_records_fts.rowid
461 WHERE vr.workspace_id = ? AND vr.collection = ? AND vector_records_fts MATCH ?
462 ORDER BY rank ASC
463 LIMIT ?
464 "#,
465 )
466 .bind(workspace_id)
467 .bind(collection)
468 .bind(query)
469 .bind(limit as i64)
470 .fetch_all(&self.pool)
471 .await?;
472
473 rows.into_iter()
474 .map(|row| {
475 let rank = row.rank.unwrap_or(0.0);
476 let record_row = RecordRow {
477 id: row.id,
478 internal_id: row.internal_id,
479 workspace_id: row.workspace_id,
480 collection: row.collection,
481 text: row.text,
482 metadata: row.metadata,
483 created_at: row.created_at,
484 };
485 let (record, internal_id) = record_from_row(record_row)?;
486 Ok((internal_id, record, rank))
487 })
488 .collect()
489 }
490
491 pub async fn close(&self) {
493 self.pool.close().await;
494 }
495}
496
497#[derive(Debug, sqlx::FromRow)]
498struct CollectionRow {
499 workspace_id: String,
500 name: String,
501 dimensions: i64,
502 distance: String,
503 index_type: String,
504 ef_construction: i64,
505 m_connections: i64,
506 created_at: String,
507 vector_count: i64,
508 metadata: String,
509}
510
511#[derive(Debug, sqlx::FromRow)]
512struct RecordRow {
513 id: String,
514 internal_id: i64,
515 #[allow(dead_code)]
516 workspace_id: String,
517 collection: String,
518 text: Option<String>,
519 metadata: String,
520 created_at: String,
521}
522
523#[derive(Debug, sqlx::FromRow)]
524struct KeywordRow {
525 id: String,
526 internal_id: i64,
527 workspace_id: String,
528 collection: String,
529 text: Option<String>,
530 metadata: String,
531 created_at: String,
532 rank: Option<f32>,
533}
534
535fn collection_from_row(row: CollectionRow) -> VectorResult<Collection> {
537 Ok(Collection {
538 workspace_id: row.workspace_id,
539 name: row.name,
540 dimensions: row.dimensions as usize,
541 distance: distance_from_db(&row.distance)?,
542 index_type: index_type_from_db(&row.index_type)?,
543 ef_construction: row.ef_construction as usize,
544 m_connections: row.m_connections as usize,
545 created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
546 .map_err(|e| VectorError::Index(format!("invalid timestamp in DB: {e}")))?
547 .with_timezone(&chrono::Utc),
548 vector_count: row.vector_count as u64,
549 metadata: parse_metadata(&row.metadata)?,
550 })
551}
552
553fn record_from_row(row: RecordRow) -> VectorResult<(VectorRecord, usize)> {
554 let id = Uuid::parse_str(&row.id).map_err(|err| {
555 VectorError::Index(format!(
556 "invalid UUID stored in vector_records table: {err}"
557 ))
558 })?;
559 let record = VectorRecord {
560 id,
561 collection: row.collection,
562 vector: Vec::new(),
563 metadata: parse_metadata(&row.metadata)?,
564 text: row.text,
565 created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
566 .map_err(|e| VectorError::Index(format!("invalid timestamp in DB: {e}")))?
567 .with_timezone(&chrono::Utc),
568 };
569 Ok((record, row.internal_id as usize))
570}
571
572fn normalize_metadata(metadata: &serde_json::Value) -> VectorResult<String> {
573 if metadata.is_null() {
574 Ok("{}".to_string())
575 } else {
576 serde_json::to_string(metadata).map_err(Into::into)
577 }
578}
579
580fn parse_metadata(metadata: &str) -> VectorResult<serde_json::Value> {
581 if metadata.trim().is_empty() {
582 Ok(serde_json::json!({}))
583 } else {
584 Ok(serde_json::from_str(metadata)?)
585 }
586}
587
588fn distance_to_db(distance: DistanceMetric) -> &'static str {
589 match distance {
590 DistanceMetric::Cosine => "cosine",
591 DistanceMetric::Euclidean => "euclidean",
592 DistanceMetric::DotProduct => "dot_product",
593 }
594}
595
596fn distance_from_db(distance: &str) -> VectorResult<DistanceMetric> {
597 match distance.trim_matches('"') {
598 "cosine" => Ok(DistanceMetric::Cosine),
599 "euclidean" => Ok(DistanceMetric::Euclidean),
600 "dot_product" => Ok(DistanceMetric::DotProduct),
601 other => Err(VectorError::Index(format!(
602 "unsupported distance metric '{other}'"
603 ))),
604 }
605}
606
607fn index_type_to_db(index_type: IndexType) -> &'static str {
608 match index_type {
609 IndexType::HNSW => "hnsw",
610 IndexType::Flat => "flat",
611 }
612}
613
614fn index_type_from_db(index_type: &str) -> VectorResult<IndexType> {
615 match index_type.trim_matches('"') {
616 "hnsw" => Ok(IndexType::HNSW),
617 "flat" => Ok(IndexType::Flat),
618 other => Err(VectorError::Index(format!(
619 "unsupported index type '{other}'"
620 ))),
621 }
622}