1use std::{collections::HashMap, path::Path};
3
4use sqlx::{
5 sqlite::{SqliteConnectOptions, SqlitePoolOptions},
6 QueryBuilder, Sqlite, SqlitePool,
7};
8use uuid::Uuid;
9
10use crate::{
11 error::{VectorError, VectorResult},
12 types::{Collection, CollectionStats, DistanceMetric, IndexType, VectorRecord},
13};
14
15pub struct VectorStore {
17 pool: SqlitePool,
18}
19
20pub type SqliteStore = VectorStore;
22
23impl VectorStore {
24 pub async fn new(db_path: &Path) -> VectorResult<Self> {
26 if let Some(parent) = db_path.parent() {
27 std::fs::create_dir_all(parent)?;
28 }
29
30 let options = SqliteConnectOptions::new()
31 .filename(db_path)
32 .create_if_missing(true)
33 .foreign_keys(true);
34
35 let pool = SqlitePoolOptions::new()
36 .max_connections(8)
37 .connect_with(options)
38 .await?;
39
40 sqlx::query("PRAGMA journal_mode = WAL")
41 .execute(&pool)
42 .await?;
43 sqlx::query("PRAGMA synchronous = NORMAL")
44 .execute(&pool)
45 .await?;
46 sqlx::query("PRAGMA temp_store = MEMORY")
47 .execute(&pool)
48 .await?;
49
50 sqlx::migrate!()
51 .run(&pool)
52 .await
53 .map_err(|err| VectorError::Index(format!("failed to run SQLite migrations: {err}")))?;
54
55 Ok(VectorStore { pool })
56 }
57
58 pub async fn open(path: &Path) -> VectorResult<Self> {
60 Self::new(path).await
61 }
62
63 pub fn pool(&self) -> &SqlitePool {
65 &self.pool
66 }
67
68 pub async fn create_collection(
70 &self,
71 workspace_id: &str,
72 col: &Collection,
73 ) -> VectorResult<()> {
74 sqlx::query(
75 r#"INSERT INTO collections
76 (workspace_id, name, dimensions, distance, index_type, ef_construction, m_connections,
77 created_at, vector_count, metadata)
78 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"#,
79 )
80 .bind(workspace_id)
81 .bind(&col.name)
82 .bind(col.dimensions as i64)
83 .bind(distance_to_db(col.distance))
84 .bind(index_type_to_db(col.index_type))
85 .bind(col.ef_construction as i64)
86 .bind(col.m_connections as i64)
87 .bind(col.created_at.to_rfc3339())
88 .bind(col.vector_count as i64)
89 .bind(normalize_metadata(&col.metadata)?)
90 .execute(&self.pool)
91 .await?;
92 Ok(())
93 }
94
95 pub async fn save_collection(&self, workspace_id: &str, col: &Collection) -> VectorResult<()> {
97 self.create_collection(workspace_id, col).await
98 }
99
100 pub async fn get_collection(&self, workspace_id: &str, name: &str) -> VectorResult<Collection> {
102 let row = sqlx::query_as::<_, CollectionRow>(
103 "SELECT workspace_id, name, dimensions, distance, index_type, ef_construction, m_connections, \
104 created_at, vector_count, metadata FROM collections WHERE workspace_id = ? AND name = ?",
105 )
106 .bind(workspace_id)
107 .bind(name)
108 .fetch_optional(&self.pool)
109 .await?;
110
111 match row {
112 Some(row) => collection_from_row(row),
113 None => Err(VectorError::NotFound {
114 entity: "collection".into(),
115 id: name.to_string(),
116 }),
117 }
118 }
119
120 pub async fn delete_collection(&self, workspace_id: &str, name: &str) -> VectorResult<()> {
122 let mut tx = self.pool.begin().await?;
123 sqlx::query("DELETE FROM vector_records WHERE workspace_id = ? AND collection = ?")
124 .bind(workspace_id)
125 .bind(name)
126 .execute(&mut *tx)
127 .await?;
128 sqlx::query("DELETE FROM collections WHERE workspace_id = ? AND name = ?")
129 .bind(workspace_id)
130 .bind(name)
131 .execute(&mut *tx)
132 .await?;
133 tx.commit().await?;
134 Ok(())
135 }
136
137 pub async fn list_collections(&self, workspace_id: &str) -> VectorResult<Vec<Collection>> {
139 let rows = sqlx::query_as::<_, CollectionRow>(
140 "SELECT workspace_id, name, dimensions, distance, index_type, ef_construction, m_connections, \
141 created_at, vector_count, metadata FROM collections WHERE workspace_id = ? ORDER BY name",
142 )
143 .bind(workspace_id)
144 .fetch_all(&self.pool)
145 .await?;
146
147 rows.into_iter().map(collection_from_row).collect()
148 }
149
150 pub async fn list_all_collections(&self) -> VectorResult<Vec<Collection>> {
152 let rows = sqlx::query_as::<_, CollectionRow>(
153 "SELECT workspace_id, name, dimensions, distance, index_type, ef_construction, m_connections, \
154 created_at, vector_count, metadata FROM collections ORDER BY workspace_id, name",
155 )
156 .fetch_all(&self.pool)
157 .await?;
158
159 rows.into_iter().map(collection_from_row).collect()
160 }
161
162 pub async fn insert_record(
164 &self,
165 workspace_id: &str,
166 record: &VectorRecord,
167 internal_id: usize,
168 ) -> VectorResult<()> {
169 sqlx::query(
170 r#"INSERT INTO vector_records
171 (id, internal_id, workspace_id, collection, text, metadata, created_at)
172 VALUES (?, ?, ?, ?, ?, ?, ?)"#,
173 )
174 .bind(record.id.to_string())
175 .bind(internal_id as i64)
176 .bind(workspace_id)
177 .bind(&record.collection)
178 .bind(&record.text)
179 .bind(normalize_metadata(&record.metadata)?)
180 .bind(record.created_at.to_rfc3339())
181 .execute(&self.pool)
182 .await?;
183 Ok(())
184 }
185
186 pub async fn save_record(
188 &self,
189 workspace_id: &str,
190 record: &VectorRecord,
191 internal_id: usize,
192 ) -> VectorResult<()> {
193 self.insert_record(workspace_id, record, internal_id).await
194 }
195
196 pub async fn get_record(
198 &self,
199 workspace_id: &str,
200 id: Uuid,
201 ) -> VectorResult<(VectorRecord, usize)> {
202 let row = sqlx::query_as::<_, RecordRow>(
203 "SELECT id, internal_id, workspace_id, collection, text, metadata, created_at \
204 FROM vector_records WHERE workspace_id = ? AND id = ?",
205 )
206 .bind(workspace_id)
207 .bind(id.to_string())
208 .fetch_optional(&self.pool)
209 .await?;
210
211 match row {
212 Some(row) => record_from_row(row),
213 None => Err(VectorError::NotFound {
214 entity: "record".into(),
215 id: id.to_string(),
216 }),
217 }
218 }
219
220 pub async fn delete_record(&self, workspace_id: &str, id: Uuid) -> VectorResult<Option<usize>> {
222 let mut tx = self.pool.begin().await?;
223 let internal_id = sqlx::query_scalar::<_, i64>(
224 "SELECT internal_id FROM vector_records WHERE workspace_id = ? AND id = ?",
225 )
226 .bind(workspace_id)
227 .bind(id.to_string())
228 .fetch_optional(&mut *tx)
229 .await?
230 .map(|value| value as usize);
231
232 if internal_id.is_some() {
233 sqlx::query("DELETE FROM vector_records WHERE workspace_id = ? AND id = ?")
234 .bind(workspace_id)
235 .bind(id.to_string())
236 .execute(&mut *tx)
237 .await?;
238 }
239
240 tx.commit().await?;
241 Ok(internal_id)
242 }
243
244 pub async fn batch_insert_records(
246 &self,
247 workspace_id: &str,
248 records: &[(VectorRecord, usize)],
249 ) -> VectorResult<()> {
250 let mut tx = self.pool.begin().await?;
251 for (record, internal_id) in records {
252 sqlx::query(
253 r#"INSERT INTO vector_records
254 (id, internal_id, workspace_id, collection, text, metadata, created_at)
255 VALUES (?, ?, ?, ?, ?, ?, ?)"#,
256 )
257 .bind(record.id.to_string())
258 .bind(*internal_id as i64)
259 .bind(workspace_id)
260 .bind(&record.collection)
261 .bind(&record.text)
262 .bind(normalize_metadata(&record.metadata)?)
263 .bind(record.created_at.to_rfc3339())
264 .execute(&mut *tx)
265 .await?;
266 }
267 tx.commit().await?;
268 Ok(())
269 }
270
271 pub async fn uuid_to_internal(&self, workspace_id: &str, id: Uuid) -> VectorResult<usize> {
273 let internal_id = sqlx::query_scalar::<_, i64>(
274 "SELECT internal_id FROM vector_records WHERE workspace_id = ? AND id = ?",
275 )
276 .bind(workspace_id)
277 .bind(id.to_string())
278 .fetch_optional(&self.pool)
279 .await?
280 .ok_or_else(|| VectorError::NotFound {
281 entity: "record".into(),
282 id: id.to_string(),
283 })?;
284 Ok(internal_id as usize)
285 }
286
287 pub async fn internal_to_uuid(
289 &self,
290 workspace_id: &str,
291 collection: &str,
292 internal_id: usize,
293 ) -> VectorResult<Uuid> {
294 let id = sqlx::query_scalar::<_, String>(
295 "SELECT id FROM vector_records WHERE workspace_id = ? AND collection = ? AND internal_id = ?",
296 )
297 .bind(workspace_id)
298 .bind(collection)
299 .bind(internal_id as i64)
300 .fetch_optional(&self.pool)
301 .await?
302 .ok_or_else(|| VectorError::NotFound {
303 entity: "record".into(),
304 id: format!("{collection}:{internal_id}"),
305 })?;
306 Uuid::parse_str(&id)
307 .map_err(|err| VectorError::Index(format!("invalid UUID stored in SQLite: {err}")))
308 }
309
310 pub async fn bulk_internal_to_uuid(
312 &self,
313 workspace_id: &str,
314 collection: &str,
315 ids: &[usize],
316 ) -> VectorResult<Vec<(usize, VectorRecord)>> {
317 if ids.is_empty() {
318 return Ok(Vec::new());
319 }
320
321 let mut builder = QueryBuilder::<Sqlite>::new(
322 "SELECT id, internal_id, workspace_id, collection, text, metadata, created_at FROM vector_records WHERE workspace_id = ",
323 );
324 builder.push_bind(workspace_id);
325 builder.push(" AND collection = ");
326 builder.push_bind(collection);
327 builder.push(" AND internal_id IN (");
328 let mut separated = builder.separated(", ");
329 for id in ids {
330 separated.push_bind(*id as i64);
331 }
332 separated.push_unseparated(") ORDER BY internal_id ASC");
333
334 let rows = builder
335 .build_query_as::<RecordRow>()
336 .fetch_all(&self.pool)
337 .await?;
338
339 let resolved = rows
340 .into_iter()
341 .map(record_from_row)
342 .collect::<VectorResult<Vec<_>>>()?;
343
344 let mut by_id = HashMap::with_capacity(resolved.len());
345 for (record, internal_id) in resolved {
346 by_id.insert(internal_id, record);
347 }
348
349 Ok(ids
350 .iter()
351 .filter_map(|id| by_id.remove(id).map(|record| (*id, record)))
352 .collect())
353 }
354
355 pub async fn increment_vector_count(
357 &self,
358 workspace_id: &str,
359 collection: &str,
360 delta: i64,
361 ) -> VectorResult<()> {
362 sqlx::query(
363 "UPDATE collections SET vector_count = MAX(vector_count + ?, 0) WHERE workspace_id = ? AND name = ?",
364 )
365 .bind(delta)
366 .bind(workspace_id)
367 .bind(collection)
368 .execute(&self.pool)
369 .await?;
370 Ok(())
371 }
372
373 pub async fn update_collection_index_type(
375 &self,
376 workspace_id: &str,
377 collection: &str,
378 index_type: IndexType,
379 ) -> VectorResult<()> {
380 sqlx::query("UPDATE collections SET index_type = ? WHERE workspace_id = ? AND name = ?")
381 .bind(index_type_to_db(index_type))
382 .bind(workspace_id)
383 .bind(collection)
384 .execute(&self.pool)
385 .await?;
386 Ok(())
387 }
388
389 pub async fn collection_stats(
391 &self,
392 workspace_id: &str,
393 name: &str,
394 ) -> VectorResult<CollectionStats> {
395 let vector_count = sqlx::query_scalar::<_, i64>(
396 "SELECT vector_count FROM collections WHERE workspace_id = ? AND name = ?",
397 )
398 .bind(workspace_id)
399 .bind(name)
400 .fetch_optional(&self.pool)
401 .await?
402 .ok_or_else(|| VectorError::NotFound {
403 entity: "collection".into(),
404 id: name.to_string(),
405 })?;
406
407 let record_bytes = sqlx::query_scalar::<_, i64>(
408 "SELECT COALESCE(SUM(LENGTH(id) + LENGTH(IFNULL(text, '')) + LENGTH(metadata) + LENGTH(created_at) + 8), 0) FROM vector_records WHERE workspace_id = ? AND collection = ?",
409 )
410 .bind(workspace_id)
411 .bind(name)
412 .fetch_one(&self.pool)
413 .await?;
414
415 let collection_bytes = sqlx::query_scalar::<_, i64>(
416 "SELECT LENGTH(name) + LENGTH(distance) + LENGTH(index_type) + LENGTH(created_at) + LENGTH(metadata) + 32 FROM collections WHERE workspace_id = ? AND name = ?",
417 )
418 .bind(workspace_id)
419 .bind(name)
420 .fetch_one(&self.pool)
421 .await?;
422
423 Ok(CollectionStats {
424 vector_count: vector_count as u64,
425 size_bytes: (record_bytes + collection_bytes.max(0)) as u64,
426 })
427 }
428
429 pub async fn next_internal_id(
431 &self,
432 workspace_id: &str,
433 collection: &str,
434 ) -> VectorResult<usize> {
435 let max_internal_id = sqlx::query_scalar::<_, Option<i64>>(
436 "SELECT MAX(internal_id) FROM vector_records WHERE workspace_id = ? AND collection = ?",
437 )
438 .bind(workspace_id)
439 .bind(collection)
440 .fetch_one(&self.pool)
441 .await?;
442 Ok(max_internal_id.map(|value| value as usize + 1).unwrap_or(0))
443 }
444
445 pub async fn list_records_for_collection(
447 &self,
448 workspace_id: &str,
449 collection: &str,
450 ) -> VectorResult<Vec<(VectorRecord, usize)>> {
451 let rows = sqlx::query_as::<_, RecordRow>(
452 "SELECT id, internal_id, workspace_id, collection, text, metadata, created_at FROM vector_records WHERE workspace_id = ? AND collection = ? ORDER BY internal_id ASC",
453 )
454 .bind(workspace_id)
455 .bind(collection)
456 .fetch_all(&self.pool)
457 .await?;
458
459 rows.into_iter().map(record_from_row).collect()
460 }
461
462 pub async fn keyword_search(
464 &self,
465 workspace_id: &str,
466 collection: &str,
467 query: &str,
468 limit: usize,
469 ) -> VectorResult<Vec<(usize, VectorRecord, f32)>> {
470 if query.trim().is_empty() || limit == 0 {
471 return Ok(Vec::new());
472 }
473
474 let rows = sqlx::query_as::<_, KeywordRow>(
475 r#"
476 SELECT vr.id, vr.internal_id, vr.workspace_id, vr.collection, vr.text, vr.metadata, vr.created_at,
477 CAST(bm25(vector_records_fts) AS REAL) AS rank
478 FROM vector_records_fts
479 JOIN vector_records AS vr ON vr.rowid = vector_records_fts.rowid
480 WHERE vr.workspace_id = ? AND vr.collection = ? AND vector_records_fts MATCH ?
481 ORDER BY rank ASC
482 LIMIT ?
483 "#,
484 )
485 .bind(workspace_id)
486 .bind(collection)
487 .bind(query)
488 .bind(limit as i64)
489 .fetch_all(&self.pool)
490 .await?;
491
492 rows.into_iter()
493 .map(|row| {
494 let rank = row.rank.unwrap_or(0.0);
495 let record_row = RecordRow {
496 id: row.id,
497 internal_id: row.internal_id,
498 workspace_id: row.workspace_id,
499 collection: row.collection,
500 text: row.text,
501 metadata: row.metadata,
502 created_at: row.created_at,
503 };
504 let (record, internal_id) = record_from_row(record_row)?;
505 Ok((internal_id, record, rank))
506 })
507 .collect()
508 }
509
510 pub async fn close(&self) {
512 self.pool.close().await;
513 }
514}
515
516#[derive(Debug, sqlx::FromRow)]
517struct CollectionRow {
518 workspace_id: String,
519 name: String,
520 dimensions: i64,
521 distance: String,
522 index_type: String,
523 ef_construction: i64,
524 m_connections: i64,
525 created_at: String,
526 vector_count: i64,
527 metadata: String,
528}
529
530#[derive(Debug, sqlx::FromRow)]
531struct RecordRow {
532 id: String,
533 internal_id: i64,
534 #[allow(dead_code)]
535 workspace_id: String,
536 collection: String,
537 text: Option<String>,
538 metadata: String,
539 created_at: String,
540}
541
542#[derive(Debug, sqlx::FromRow)]
543struct KeywordRow {
544 id: String,
545 internal_id: i64,
546 workspace_id: String,
547 collection: String,
548 text: Option<String>,
549 metadata: String,
550 created_at: String,
551 rank: Option<f32>,
552}
553
554fn collection_from_row(row: CollectionRow) -> VectorResult<Collection> {
556 Ok(Collection {
557 workspace_id: row.workspace_id,
558 name: row.name,
559 dimensions: row.dimensions as usize,
560 distance: distance_from_db(&row.distance)?,
561 index_type: index_type_from_db(&row.index_type)?,
562 ef_construction: row.ef_construction as usize,
563 m_connections: row.m_connections as usize,
564 created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
565 .map_err(|e| VectorError::Index(format!("invalid timestamp in DB: {e}")))?
566 .with_timezone(&chrono::Utc),
567 vector_count: row.vector_count as u64,
568 metadata: parse_metadata(&row.metadata)?,
569 })
570}
571
572fn record_from_row(row: RecordRow) -> VectorResult<(VectorRecord, usize)> {
573 let id = Uuid::parse_str(&row.id).map_err(|err| {
574 VectorError::Index(format!(
575 "invalid UUID stored in vector_records table: {err}"
576 ))
577 })?;
578 let record = VectorRecord {
579 id,
580 collection: row.collection,
581 vector: Vec::new(),
582 metadata: parse_metadata(&row.metadata)?,
583 text: row.text,
584 created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
585 .map_err(|e| VectorError::Index(format!("invalid timestamp in DB: {e}")))?
586 .with_timezone(&chrono::Utc),
587 };
588 Ok((record, row.internal_id as usize))
589}
590
591fn normalize_metadata(metadata: &serde_json::Value) -> VectorResult<String> {
592 if metadata.is_null() {
593 Ok("{}".to_string())
594 } else {
595 serde_json::to_string(metadata).map_err(Into::into)
596 }
597}
598
599fn parse_metadata(metadata: &str) -> VectorResult<serde_json::Value> {
600 if metadata.trim().is_empty() {
601 Ok(serde_json::json!({}))
602 } else {
603 Ok(serde_json::from_str(metadata)?)
604 }
605}
606
607fn distance_to_db(distance: DistanceMetric) -> &'static str {
608 match distance {
609 DistanceMetric::Cosine => "cosine",
610 DistanceMetric::Euclidean => "euclidean",
611 DistanceMetric::DotProduct => "dot_product",
612 }
613}
614
615fn distance_from_db(distance: &str) -> VectorResult<DistanceMetric> {
616 match distance.trim_matches('"') {
617 "cosine" => Ok(DistanceMetric::Cosine),
618 "euclidean" => Ok(DistanceMetric::Euclidean),
619 "dot_product" => Ok(DistanceMetric::DotProduct),
620 other => Err(VectorError::Index(format!(
621 "unsupported distance metric '{other}'"
622 ))),
623 }
624}
625
626fn index_type_to_db(index_type: IndexType) -> &'static str {
627 match index_type {
628 IndexType::HNSW => "hnsw",
629 IndexType::Flat => "flat",
630 }
631}
632
633fn index_type_from_db(index_type: &str) -> VectorResult<IndexType> {
634 match index_type.trim_matches('"') {
635 "hnsw" => Ok(IndexType::HNSW),
636 "flat" => Ok(IndexType::Flat),
637 other => Err(VectorError::Index(format!(
638 "unsupported index type '{other}'"
639 ))),
640 }
641}