1use std::sync::Arc;
11
12use async_trait::async_trait;
13use pgvector::Vector;
14use serde_json::Value;
15use sqlx::{PgPool, Postgres, QueryBuilder, Row};
16use uuid::Uuid;
17
18use entelix_core::context::ExecutionContext;
19use entelix_core::error::{Error, Result};
20use entelix_memory::{Document, Namespace, VectorFilter, VectorStore};
21
22use crate::error::{PgVectorStoreError, PgVectorStoreResult};
23use crate::filter::append_where;
24use crate::migration;
25use crate::tenant::set_tenant_session;
26
27#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
31#[non_exhaustive]
32pub enum DistanceMetric {
33 #[default]
36 Cosine,
37 L2,
39 InnerProduct,
44}
45
46#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
49#[non_exhaustive]
50pub enum IndexKind {
51 #[default]
55 Hnsw,
56 IvfFlat,
60}
61
62#[derive(Clone)]
66pub struct PgVectorStore {
67 pool: PgPool,
68 table: Arc<str>,
69 dimension: usize,
70 distance: DistanceMetric,
71}
72
73impl std::fmt::Debug for PgVectorStore {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 f.debug_struct("PgVectorStore")
76 .field("table", &self.table)
77 .field("dimension", &self.dimension)
78 .field("distance", &self.distance)
79 .finish_non_exhaustive()
80 }
81}
82
83impl PgVectorStore {
84 pub fn builder(dimension: usize) -> PgVectorStoreBuilder {
86 PgVectorStoreBuilder::new(dimension)
87 }
88
89 fn distance_op(&self) -> &'static str {
90 match self.distance {
91 DistanceMetric::Cosine => "<=>",
92 DistanceMetric::L2 => "<->",
93 DistanceMetric::InnerProduct => "<#>",
94 }
95 }
96
97 fn distance_to_score(&self, distance: f64) -> f32 {
102 let s = match self.distance {
103 DistanceMetric::Cosine => 1.0 - distance,
104 DistanceMetric::L2 => 1.0 / (1.0 + distance),
105 DistanceMetric::InnerProduct => -distance,
108 };
109 s as f32
110 }
111}
112
113#[must_use]
115pub struct PgVectorStoreBuilder {
116 table: String,
117 dimension: usize,
118 distance: DistanceMetric,
119 index_kind: IndexKind,
120 auto_migrate: bool,
121 connection_string: Option<String>,
122 pool: Option<PgPool>,
123 max_connections: u32,
124}
125
126impl PgVectorStoreBuilder {
127 fn new(dimension: usize) -> Self {
128 Self {
129 table: "entelix_vectors".into(),
130 dimension,
131 distance: DistanceMetric::default(),
132 index_kind: IndexKind::default(),
133 auto_migrate: true,
134 connection_string: None,
135 pool: None,
136 max_connections: 10,
137 }
138 }
139
140 pub fn with_table(mut self, table: impl Into<String>) -> Self {
144 self.table = table.into();
145 self
146 }
147
148 pub const fn with_distance(mut self, distance: DistanceMetric) -> Self {
151 self.distance = distance;
152 self
153 }
154
155 pub const fn with_index_kind(mut self, kind: IndexKind) -> Self {
158 self.index_kind = kind;
159 self
160 }
161
162 pub const fn with_auto_migrate(mut self, auto: bool) -> Self {
169 self.auto_migrate = auto;
170 self
171 }
172
173 pub fn with_connection_string(mut self, url: impl Into<String>) -> Self {
176 self.connection_string = Some(url.into());
177 self
178 }
179
180 pub fn with_pool(mut self, pool: PgPool) -> Self {
183 self.pool = Some(pool);
184 self
185 }
186
187 pub const fn with_max_connections(mut self, max: u32) -> Self {
191 self.max_connections = max;
192 self
193 }
194
195 pub async fn build(self) -> PgVectorStoreResult<PgVectorStore> {
199 let pool = match (self.pool, self.connection_string) {
200 (Some(p), None) => p,
201 (None, Some(url)) => {
202 sqlx::postgres::PgPoolOptions::new()
203 .max_connections(self.max_connections)
204 .connect(&url)
205 .await?
206 }
207 (None, None) => {
208 return Err(PgVectorStoreError::Config(
209 "either with_pool or with_connection_string is required".into(),
210 ));
211 }
212 (Some(_), Some(_)) => {
213 return Err(PgVectorStoreError::Config(
214 "with_pool and with_connection_string are mutually exclusive".into(),
215 ));
216 }
217 };
218
219 if self.auto_migrate {
220 migration::bootstrap(
221 &pool,
222 &self.table,
223 self.dimension,
224 self.distance,
225 self.index_kind,
226 )
227 .await?;
228 }
229
230 Ok(PgVectorStore {
231 pool,
232 table: self.table.into(),
233 dimension: self.dimension,
234 distance: self.distance,
235 })
236 }
237}
238
239#[async_trait]
240impl VectorStore for PgVectorStore {
241 fn dimension(&self) -> usize {
242 self.dimension
243 }
244
245 async fn add(
246 &self,
247 ctx: &ExecutionContext,
248 ns: &Namespace,
249 document: Document,
250 vector: Vec<f32>,
251 ) -> Result<()> {
252 if ctx.is_cancelled() {
253 return Err(Error::Cancelled);
254 }
255 if vector.len() != self.dimension {
256 return Err(Error::invalid_request(format!(
257 "PgVectorStore: vector dimension {} does not match \
258 index dimension {}",
259 vector.len(),
260 self.dimension
261 )));
262 }
263 let ns_key = ns.render();
264 let doc_id = document
265 .doc_id
266 .clone()
267 .unwrap_or_else(|| Uuid::new_v4().to_string());
268 let metadata = if document.metadata.is_null() {
269 Value::Object(serde_json::Map::new())
270 } else {
271 document.metadata
272 };
273 let stmt = format!(
274 "INSERT INTO {table} (tenant_id, namespace_key, doc_id, content, metadata, embedding) \
275 VALUES ($1, $2, $3, $4, $5, $6) \
276 ON CONFLICT (namespace_key, doc_id) DO UPDATE SET \
277 content = EXCLUDED.content, \
278 metadata = EXCLUDED.metadata, \
279 embedding = EXCLUDED.embedding",
280 table = self.table
281 );
282 let mut tx = self
283 .pool
284 .begin()
285 .await
286 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
287 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
288 sqlx::query(&stmt)
289 .bind(ns.tenant_id().as_str())
290 .bind(ns_key)
291 .bind(doc_id)
292 .bind(document.content)
293 .bind(sqlx::types::Json(metadata))
294 .bind(Vector::from(vector))
295 .execute(&mut *tx)
296 .await
297 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
298 tx.commit()
299 .await
300 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
301 Ok(())
302 }
303
304 async fn add_batch(
305 &self,
306 ctx: &ExecutionContext,
307 ns: &Namespace,
308 items: Vec<(Document, Vec<f32>)>,
309 ) -> Result<()> {
310 if ctx.is_cancelled() {
311 return Err(Error::Cancelled);
312 }
313 if items.is_empty() {
314 return Ok(());
315 }
316 let ns_key = ns.render();
317 for (_, vector) in &items {
318 if vector.len() != self.dimension {
319 return Err(Error::invalid_request(format!(
320 "PgVectorStore: vector dimension {} does not match \
321 index dimension {}",
322 vector.len(),
323 self.dimension
324 )));
325 }
326 }
327 let tenant_id = ns.tenant_id().as_str().to_owned();
329 let mut qb: QueryBuilder<'_, Postgres> = QueryBuilder::new(format!(
330 "INSERT INTO {table} \
331 (tenant_id, namespace_key, doc_id, content, metadata, embedding) ",
332 table = self.table
333 ));
334 qb.push_values(items, |mut b, (mut document, vector)| {
335 let doc_id = document
336 .doc_id
337 .take()
338 .unwrap_or_else(|| Uuid::new_v4().to_string());
339 let metadata = if document.metadata.is_null() {
340 Value::Object(serde_json::Map::new())
341 } else {
342 document.metadata
343 };
344 b.push_bind(tenant_id.clone())
345 .push_bind(ns_key.clone())
346 .push_bind(doc_id)
347 .push_bind(document.content)
348 .push_bind(sqlx::types::Json(metadata))
349 .push_bind(Vector::from(vector));
350 });
351 qb.push(
352 " ON CONFLICT (namespace_key, doc_id) DO UPDATE SET \
353 content = EXCLUDED.content, \
354 metadata = EXCLUDED.metadata, \
355 embedding = EXCLUDED.embedding",
356 );
357 let mut tx = self
358 .pool
359 .begin()
360 .await
361 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
362 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
363 qb.build()
364 .execute(&mut *tx)
365 .await
366 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
367 tx.commit()
368 .await
369 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
370 Ok(())
371 }
372
373 async fn search(
374 &self,
375 ctx: &ExecutionContext,
376 ns: &Namespace,
377 query_vector: &[f32],
378 top_k: usize,
379 ) -> Result<Vec<Document>> {
380 self.search_filtered(ctx, ns, query_vector, top_k, &VectorFilter::All)
381 .await
382 }
383
384 async fn search_filtered(
385 &self,
386 ctx: &ExecutionContext,
387 ns: &Namespace,
388 query_vector: &[f32],
389 top_k: usize,
390 filter: &VectorFilter,
391 ) -> Result<Vec<Document>> {
392 if ctx.is_cancelled() {
393 return Err(Error::Cancelled);
394 }
395 if query_vector.len() != self.dimension {
396 return Err(Error::invalid_request(format!(
397 "PgVectorStore: query dimension {} does not match \
398 index dimension {}",
399 query_vector.len(),
400 self.dimension
401 )));
402 }
403 let ns_key = ns.render();
404
405 let mut qb: QueryBuilder<'_, Postgres> = QueryBuilder::new(format!(
410 "SELECT doc_id, content, metadata, embedding {op} ",
411 op = self.distance_op(),
412 ));
413 qb.push_bind(Vector::from(query_vector.to_vec()));
414 qb.push(format!(" AS distance FROM {table}", table = self.table));
415 append_where(&mut qb, &ns_key, Some(filter)).map_err(Error::from)?;
416 qb.push(" ORDER BY distance LIMIT ");
417 qb.push_bind(top_k as i64);
418
419 let mut tx = self
420 .pool
421 .begin()
422 .await
423 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
424 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
425 let rows = qb
426 .build()
427 .fetch_all(&mut *tx)
428 .await
429 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
430 tx.commit()
431 .await
432 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
433 rows.into_iter()
434 .map(|row| self.row_to_document(&row, true))
435 .collect()
436 }
437
438 async fn delete(&self, ctx: &ExecutionContext, ns: &Namespace, doc_id: &str) -> Result<()> {
439 if ctx.is_cancelled() {
440 return Err(Error::Cancelled);
441 }
442 let stmt = format!(
443 "DELETE FROM {table} WHERE namespace_key = $1 AND doc_id = $2",
444 table = self.table
445 );
446 let mut tx = self
447 .pool
448 .begin()
449 .await
450 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
451 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
452 sqlx::query(&stmt)
453 .bind(ns.render())
454 .bind(doc_id.to_owned())
455 .execute(&mut *tx)
456 .await
457 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
458 tx.commit()
459 .await
460 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
461 Ok(())
462 }
463
464 async fn update(
465 &self,
466 ctx: &ExecutionContext,
467 ns: &Namespace,
468 doc_id: &str,
469 document: Document,
470 vector: Vec<f32>,
471 ) -> Result<()> {
472 let stored = Document {
476 doc_id: Some(doc_id.to_owned()),
477 ..document
478 };
479 self.add(ctx, ns, stored, vector).await
480 }
481
482 async fn count(
483 &self,
484 ctx: &ExecutionContext,
485 ns: &Namespace,
486 filter: Option<&VectorFilter>,
487 ) -> Result<usize> {
488 if ctx.is_cancelled() {
489 return Err(Error::Cancelled);
490 }
491 let ns_key = ns.render();
492 let mut qb: QueryBuilder<'_, Postgres> =
493 QueryBuilder::new(format!("SELECT COUNT(*) FROM {table}", table = self.table));
494 append_where(&mut qb, &ns_key, filter).map_err(Error::from)?;
495 let mut tx = self
496 .pool
497 .begin()
498 .await
499 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
500 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
501 let row = qb
502 .build()
503 .fetch_one(&mut *tx)
504 .await
505 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
506 tx.commit()
507 .await
508 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
509 let count: i64 = row.try_get::<i64, _>(0).map_err(|e| {
510 Error::from(PgVectorStoreError::Malformed(format!(
511 "COUNT(*) row missing expected column: {e}"
512 )))
513 })?;
514 Ok(count.max(0) as usize)
515 }
516
517 async fn list(
518 &self,
519 ctx: &ExecutionContext,
520 ns: &Namespace,
521 filter: Option<&VectorFilter>,
522 limit: usize,
523 offset: usize,
524 ) -> Result<Vec<Document>> {
525 if ctx.is_cancelled() {
526 return Err(Error::Cancelled);
527 }
528 let ns_key = ns.render();
529 let mut qb: QueryBuilder<'_, Postgres> = QueryBuilder::new(format!(
530 "SELECT doc_id, content, metadata FROM {table}",
531 table = self.table
532 ));
533 append_where(&mut qb, &ns_key, filter).map_err(Error::from)?;
534 qb.push(" ORDER BY doc_id");
537 qb.push(" LIMIT ");
538 qb.push_bind(limit as i64);
539 qb.push(" OFFSET ");
540 qb.push_bind(offset as i64);
541 let mut tx = self
542 .pool
543 .begin()
544 .await
545 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
546 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
547 let rows = qb
548 .build()
549 .fetch_all(&mut *tx)
550 .await
551 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
552 tx.commit()
553 .await
554 .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
555 rows.into_iter()
556 .map(|row| self.row_to_document(&row, false))
557 .collect()
558 }
559}
560
561impl PgVectorStore {
562 fn row_to_document(
563 &self,
564 row: &sqlx::postgres::PgRow,
565 with_distance: bool,
566 ) -> Result<Document> {
567 let doc_id: String = row.try_get("doc_id").map_err(|e| {
568 Error::from(PgVectorStoreError::Malformed(format!(
569 "row missing doc_id: {e}"
570 )))
571 })?;
572 let content: String = row.try_get("content").map_err(|e| {
573 Error::from(PgVectorStoreError::Malformed(format!(
574 "row missing content: {e}"
575 )))
576 })?;
577 let metadata: sqlx::types::Json<Value> = row.try_get("metadata").map_err(|e| {
578 Error::from(PgVectorStoreError::Malformed(format!(
579 "row missing metadata: {e}"
580 )))
581 })?;
582 let score = if with_distance {
583 let distance: f64 = row.try_get("distance").map_err(|e| {
584 Error::from(PgVectorStoreError::Malformed(format!(
585 "row missing distance: {e}"
586 )))
587 })?;
588 Some(self.distance_to_score(distance))
589 } else {
590 None
591 };
592 Ok(Document {
593 doc_id: Some(doc_id),
594 content,
595 metadata: metadata.0,
596 score,
597 })
598 }
599}