Skip to main content

entelix_memory_pgvector/
store.rs

1//! `PgVectorStore` — concrete `VectorStore` over Postgres + pgvector.
2//!
3//! Single-table design: `(namespace_key, doc_id)` composite primary
4//! key + `embedding VECTOR(N)` + `metadata JSONB`. The composite PK
5//! doubles as the namespace anchor index, so every read / write /
6//! count / list rides a B-tree probe before the vector / GIN index
7//! ever sees a row. Cross-tenant data leak is structurally
8//! impossible — the namespace anchor is mandatory in every query.
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use pgvector::Vector;
14use serde_json::Value;
15use sqlx::{PgPool, Postgres, QueryBuilder, Row};
16use uuid::Uuid;
17
18use entelix_core::context::ExecutionContext;
19use entelix_core::error::{Error, Result};
20use entelix_memory::{Document, Namespace, VectorFilter, VectorStore};
21
22use crate::error::{PgVectorStoreError, PgVectorStoreResult};
23use crate::filter::append_where;
24use crate::migration;
25use crate::tenant::set_tenant_session;
26
27/// Distance metric used for vector similarity. Mirrors pgvector's
28/// own taxonomy 1:1 — operators familiar with `<=>` / `<->` /
29/// `<#>` pick the metric they would have picked there.
30#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
31#[non_exhaustive]
32pub enum DistanceMetric {
33    /// Cosine similarity (`<=>` operator). The right default for
34    /// normalized embeddings (`text-embedding-3-*`, etc.).
35    #[default]
36    Cosine,
37    /// Euclidean / L2 distance (`<->` operator).
38    L2,
39    /// Inner product (`<#>` operator). Note: pgvector's `<#>`
40    /// returns the *negative* inner product so smaller is "more
41    /// similar"; the store inverts it on read so caller-facing
42    /// scores stay "higher = better".
43    InnerProduct,
44}
45
46/// ANN index kind. HNSW is the production default; IVFFlat is
47/// selected when build time matters more than query latency.
48#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
49#[non_exhaustive]
50pub enum IndexKind {
51    /// Hierarchical Navigable Small World — pgvector's HNSW. Best
52    /// recall / throughput trade-off for ≤10M vectors per
53    /// namespace.
54    #[default]
55    Hnsw,
56    /// IVF-Flat — fast build, lower memory at the cost of recall.
57    /// Operators must `SET ivfflat.probes = N` per session for
58    /// query-time recall tuning.
59    IvfFlat,
60}
61
62/// Concrete [`VectorStore`] backed by Postgres + pgvector.
63///
64/// Cloning is cheap — the pool is `Arc`-shared internally.
65#[derive(Clone)]
66pub struct PgVectorStore {
67    pool: PgPool,
68    table: Arc<str>,
69    dimension: usize,
70    distance: DistanceMetric,
71}
72
73impl std::fmt::Debug for PgVectorStore {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        f.debug_struct("PgVectorStore")
76            .field("table", &self.table)
77            .field("dimension", &self.dimension)
78            .field("distance", &self.distance)
79            .finish_non_exhaustive()
80    }
81}
82
83impl PgVectorStore {
84    /// Begin building a [`PgVectorStore`].
85    pub fn builder(dimension: usize) -> PgVectorStoreBuilder {
86        PgVectorStoreBuilder::new(dimension)
87    }
88
89    fn distance_op(&self) -> &'static str {
90        match self.distance {
91            DistanceMetric::Cosine => "<=>",
92            DistanceMetric::L2 => "<->",
93            DistanceMetric::InnerProduct => "<#>",
94        }
95    }
96
97    /// Convert pgvector's distance into a "higher = better"
98    /// similarity score in `[0.0, 1.0]` for cosine / L2 (best
99    /// effort) and the negated inner product for ip metric.
100    /// Comparable only within a single query result set.
101    fn distance_to_score(&self, distance: f64) -> f32 {
102        let s = match self.distance {
103            DistanceMetric::Cosine => 1.0 - distance,
104            DistanceMetric::L2 => 1.0 / (1.0 + distance),
105            // pgvector's `<#>` returns negative inner product;
106            // `-distance` recovers the operator-facing similarity.
107            DistanceMetric::InnerProduct => -distance,
108        };
109        s as f32
110    }
111}
112
113/// Builder for [`PgVectorStore`].
114#[must_use]
115pub struct PgVectorStoreBuilder {
116    table: String,
117    dimension: usize,
118    distance: DistanceMetric,
119    index_kind: IndexKind,
120    auto_migrate: bool,
121    connection_string: Option<String>,
122    pool: Option<PgPool>,
123    max_connections: u32,
124}
125
126impl PgVectorStoreBuilder {
127    fn new(dimension: usize) -> Self {
128        Self {
129            table: "entelix_vectors".into(),
130            dimension,
131            distance: DistanceMetric::default(),
132            index_kind: IndexKind::default(),
133            auto_migrate: true,
134            connection_string: None,
135            pool: None,
136            max_connections: 10,
137        }
138    }
139
140    /// Override the table name. Defaults to `entelix_vectors`.
141    /// Must satisfy SQL-identifier rules
142    /// (`[a-zA-Z_][a-zA-Z0-9_]{0,62}`).
143    pub fn with_table(mut self, table: impl Into<String>) -> Self {
144        self.table = table.into();
145        self
146    }
147
148    /// Override the distance metric. Defaults to
149    /// [`DistanceMetric::Cosine`].
150    pub const fn with_distance(mut self, distance: DistanceMetric) -> Self {
151        self.distance = distance;
152        self
153    }
154
155    /// Override the ANN index kind. Defaults to
156    /// [`IndexKind::Hnsw`].
157    pub const fn with_index_kind(mut self, kind: IndexKind) -> Self {
158        self.index_kind = kind;
159        self
160    }
161
162    /// Disable the automatic schema bootstrap.
163    ///
164    /// Pass `false` when the table + extension + indexes are
165    /// provisioned externally (DBA-managed, IaC, migration
166    /// pipeline) and the store should consume an existing schema.
167    /// Defaults to `true`.
168    pub const fn with_auto_migrate(mut self, auto: bool) -> Self {
169        self.auto_migrate = auto;
170        self
171    }
172
173    /// Connect with a libpq-style connection string. Mutually
174    /// exclusive with [`Self::with_pool`].
175    pub fn with_connection_string(mut self, url: impl Into<String>) -> Self {
176        self.connection_string = Some(url.into());
177        self
178    }
179
180    /// Reuse an existing `PgPool`. Mutually exclusive with
181    /// [`Self::with_connection_string`].
182    pub fn with_pool(mut self, pool: PgPool) -> Self {
183        self.pool = Some(pool);
184        self
185    }
186
187    /// Override the pool's `max_connections` (when the builder
188    /// constructs the pool). Ignored when [`Self::with_pool`]
189    /// supplies a pre-built pool.
190    pub const fn with_max_connections(mut self, max: u32) -> Self {
191        self.max_connections = max;
192        self
193    }
194
195    /// Finalize the builder. Connects (or adopts the supplied
196    /// pool) and runs the schema bootstrap when
197    /// `auto_migrate=true`.
198    pub async fn build(self) -> PgVectorStoreResult<PgVectorStore> {
199        let pool = match (self.pool, self.connection_string) {
200            (Some(p), None) => p,
201            (None, Some(url)) => {
202                sqlx::postgres::PgPoolOptions::new()
203                    .max_connections(self.max_connections)
204                    .connect(&url)
205                    .await?
206            }
207            (None, None) => {
208                return Err(PgVectorStoreError::Config(
209                    "either with_pool or with_connection_string is required".into(),
210                ));
211            }
212            (Some(_), Some(_)) => {
213                return Err(PgVectorStoreError::Config(
214                    "with_pool and with_connection_string are mutually exclusive".into(),
215                ));
216            }
217        };
218
219        if self.auto_migrate {
220            migration::bootstrap(
221                &pool,
222                &self.table,
223                self.dimension,
224                self.distance,
225                self.index_kind,
226            )
227            .await?;
228        }
229
230        Ok(PgVectorStore {
231            pool,
232            table: self.table.into(),
233            dimension: self.dimension,
234            distance: self.distance,
235        })
236    }
237}
238
239#[async_trait]
240impl VectorStore for PgVectorStore {
241    fn dimension(&self) -> usize {
242        self.dimension
243    }
244
245    async fn add(
246        &self,
247        ctx: &ExecutionContext,
248        ns: &Namespace,
249        document: Document,
250        vector: Vec<f32>,
251    ) -> Result<()> {
252        if ctx.is_cancelled() {
253            return Err(Error::Cancelled);
254        }
255        if vector.len() != self.dimension {
256            return Err(Error::invalid_request(format!(
257                "PgVectorStore: vector dimension {} does not match \
258                 index dimension {}",
259                vector.len(),
260                self.dimension
261            )));
262        }
263        let ns_key = ns.render();
264        let doc_id = document
265            .doc_id
266            .clone()
267            .unwrap_or_else(|| Uuid::new_v4().to_string());
268        let metadata = if document.metadata.is_null() {
269            Value::Object(serde_json::Map::new())
270        } else {
271            document.metadata
272        };
273        let stmt = format!(
274            "INSERT INTO {table} (tenant_id, namespace_key, doc_id, content, metadata, embedding) \
275             VALUES ($1, $2, $3, $4, $5, $6) \
276             ON CONFLICT (namespace_key, doc_id) DO UPDATE SET \
277                 content = EXCLUDED.content, \
278                 metadata = EXCLUDED.metadata, \
279                 embedding = EXCLUDED.embedding",
280            table = self.table
281        );
282        let mut tx = self
283            .pool
284            .begin()
285            .await
286            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
287        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
288        sqlx::query(&stmt)
289            .bind(ns.tenant_id().as_str())
290            .bind(ns_key)
291            .bind(doc_id)
292            .bind(document.content)
293            .bind(sqlx::types::Json(metadata))
294            .bind(Vector::from(vector))
295            .execute(&mut *tx)
296            .await
297            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
298        tx.commit()
299            .await
300            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
301        Ok(())
302    }
303
304    async fn add_batch(
305        &self,
306        ctx: &ExecutionContext,
307        ns: &Namespace,
308        items: Vec<(Document, Vec<f32>)>,
309    ) -> Result<()> {
310        if ctx.is_cancelled() {
311            return Err(Error::Cancelled);
312        }
313        if items.is_empty() {
314            return Ok(());
315        }
316        let ns_key = ns.render();
317        for (_, vector) in &items {
318            if vector.len() != self.dimension {
319                return Err(Error::invalid_request(format!(
320                    "PgVectorStore: vector dimension {} does not match \
321                     index dimension {}",
322                    vector.len(),
323                    self.dimension
324                )));
325            }
326        }
327        // Bulk insert via QueryBuilder::push_values — single round-trip.
328        let tenant_id = ns.tenant_id().as_str().to_owned();
329        let mut qb: QueryBuilder<'_, Postgres> = QueryBuilder::new(format!(
330            "INSERT INTO {table} \
331             (tenant_id, namespace_key, doc_id, content, metadata, embedding) ",
332            table = self.table
333        ));
334        qb.push_values(items, |mut b, (mut document, vector)| {
335            let doc_id = document
336                .doc_id
337                .take()
338                .unwrap_or_else(|| Uuid::new_v4().to_string());
339            let metadata = if document.metadata.is_null() {
340                Value::Object(serde_json::Map::new())
341            } else {
342                document.metadata
343            };
344            b.push_bind(tenant_id.clone())
345                .push_bind(ns_key.clone())
346                .push_bind(doc_id)
347                .push_bind(document.content)
348                .push_bind(sqlx::types::Json(metadata))
349                .push_bind(Vector::from(vector));
350        });
351        qb.push(
352            " ON CONFLICT (namespace_key, doc_id) DO UPDATE SET \
353                 content = EXCLUDED.content, \
354                 metadata = EXCLUDED.metadata, \
355                 embedding = EXCLUDED.embedding",
356        );
357        let mut tx = self
358            .pool
359            .begin()
360            .await
361            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
362        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
363        qb.build()
364            .execute(&mut *tx)
365            .await
366            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
367        tx.commit()
368            .await
369            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
370        Ok(())
371    }
372
373    async fn search(
374        &self,
375        ctx: &ExecutionContext,
376        ns: &Namespace,
377        query_vector: &[f32],
378        top_k: usize,
379    ) -> Result<Vec<Document>> {
380        self.search_filtered(ctx, ns, query_vector, top_k, &VectorFilter::All)
381            .await
382    }
383
384    async fn search_filtered(
385        &self,
386        ctx: &ExecutionContext,
387        ns: &Namespace,
388        query_vector: &[f32],
389        top_k: usize,
390        filter: &VectorFilter,
391    ) -> Result<Vec<Document>> {
392        if ctx.is_cancelled() {
393            return Err(Error::Cancelled);
394        }
395        if query_vector.len() != self.dimension {
396            return Err(Error::invalid_request(format!(
397                "PgVectorStore: query dimension {} does not match \
398                 index dimension {}",
399                query_vector.len(),
400                self.dimension
401            )));
402        }
403        let ns_key = ns.render();
404
405        // Postgres lets `ORDER BY <alias>` reference the SELECT
406        // alias directly, so the query vector binds exactly once
407        // — emitted into the SELECT distance expression and
408        // reused by the ORDER BY through the `distance` alias.
409        let mut qb: QueryBuilder<'_, Postgres> = QueryBuilder::new(format!(
410            "SELECT doc_id, content, metadata, embedding {op} ",
411            op = self.distance_op(),
412        ));
413        qb.push_bind(Vector::from(query_vector.to_vec()));
414        qb.push(format!(" AS distance FROM {table}", table = self.table));
415        append_where(&mut qb, &ns_key, Some(filter)).map_err(Error::from)?;
416        qb.push(" ORDER BY distance LIMIT ");
417        qb.push_bind(top_k as i64);
418
419        let mut tx = self
420            .pool
421            .begin()
422            .await
423            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
424        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
425        let rows = qb
426            .build()
427            .fetch_all(&mut *tx)
428            .await
429            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
430        tx.commit()
431            .await
432            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
433        rows.into_iter()
434            .map(|row| self.row_to_document(&row, true))
435            .collect()
436    }
437
438    async fn delete(&self, ctx: &ExecutionContext, ns: &Namespace, doc_id: &str) -> Result<()> {
439        if ctx.is_cancelled() {
440            return Err(Error::Cancelled);
441        }
442        let stmt = format!(
443            "DELETE FROM {table} WHERE namespace_key = $1 AND doc_id = $2",
444            table = self.table
445        );
446        let mut tx = self
447            .pool
448            .begin()
449            .await
450            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
451        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
452        sqlx::query(&stmt)
453            .bind(ns.render())
454            .bind(doc_id.to_owned())
455            .execute(&mut *tx)
456            .await
457            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
458        tx.commit()
459            .await
460            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
461        Ok(())
462    }
463
464    async fn update(
465        &self,
466        ctx: &ExecutionContext,
467        ns: &Namespace,
468        doc_id: &str,
469        document: Document,
470        vector: Vec<f32>,
471    ) -> Result<()> {
472        // `INSERT … ON CONFLICT … DO UPDATE` is atomic per-row, so
473        // we override the trait's non-atomic delete-then-add
474        // default via the same code path as `add`.
475        let stored = Document {
476            doc_id: Some(doc_id.to_owned()),
477            ..document
478        };
479        self.add(ctx, ns, stored, vector).await
480    }
481
482    async fn count(
483        &self,
484        ctx: &ExecutionContext,
485        ns: &Namespace,
486        filter: Option<&VectorFilter>,
487    ) -> Result<usize> {
488        if ctx.is_cancelled() {
489            return Err(Error::Cancelled);
490        }
491        let ns_key = ns.render();
492        let mut qb: QueryBuilder<'_, Postgres> =
493            QueryBuilder::new(format!("SELECT COUNT(*) FROM {table}", table = self.table));
494        append_where(&mut qb, &ns_key, filter).map_err(Error::from)?;
495        let mut tx = self
496            .pool
497            .begin()
498            .await
499            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
500        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
501        let row = qb
502            .build()
503            .fetch_one(&mut *tx)
504            .await
505            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
506        tx.commit()
507            .await
508            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
509        let count: i64 = row.try_get::<i64, _>(0).map_err(|e| {
510            Error::from(PgVectorStoreError::Malformed(format!(
511                "COUNT(*) row missing expected column: {e}"
512            )))
513        })?;
514        Ok(count.max(0) as usize)
515    }
516
517    async fn list(
518        &self,
519        ctx: &ExecutionContext,
520        ns: &Namespace,
521        filter: Option<&VectorFilter>,
522        limit: usize,
523        offset: usize,
524    ) -> Result<Vec<Document>> {
525        if ctx.is_cancelled() {
526            return Err(Error::Cancelled);
527        }
528        let ns_key = ns.render();
529        let mut qb: QueryBuilder<'_, Postgres> = QueryBuilder::new(format!(
530            "SELECT doc_id, content, metadata FROM {table}",
531            table = self.table
532        ));
533        append_where(&mut qb, &ns_key, filter).map_err(Error::from)?;
534        // Stable iteration order — `(namespace_key, doc_id)` is
535        // the PK so the ordering is deterministic across calls.
536        qb.push(" ORDER BY doc_id");
537        qb.push(" LIMIT ");
538        qb.push_bind(limit as i64);
539        qb.push(" OFFSET ");
540        qb.push_bind(offset as i64);
541        let mut tx = self
542            .pool
543            .begin()
544            .await
545            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
546        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
547        let rows = qb
548            .build()
549            .fetch_all(&mut *tx)
550            .await
551            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
552        tx.commit()
553            .await
554            .map_err(|e| Error::from(PgVectorStoreError::from(e)))?;
555        rows.into_iter()
556            .map(|row| self.row_to_document(&row, false))
557            .collect()
558    }
559}
560
561impl PgVectorStore {
562    fn row_to_document(
563        &self,
564        row: &sqlx::postgres::PgRow,
565        with_distance: bool,
566    ) -> Result<Document> {
567        let doc_id: String = row.try_get("doc_id").map_err(|e| {
568            Error::from(PgVectorStoreError::Malformed(format!(
569                "row missing doc_id: {e}"
570            )))
571        })?;
572        let content: String = row.try_get("content").map_err(|e| {
573            Error::from(PgVectorStoreError::Malformed(format!(
574                "row missing content: {e}"
575            )))
576        })?;
577        let metadata: sqlx::types::Json<Value> = row.try_get("metadata").map_err(|e| {
578            Error::from(PgVectorStoreError::Malformed(format!(
579                "row missing metadata: {e}"
580            )))
581        })?;
582        let score = if with_distance {
583            let distance: f64 = row.try_get("distance").map_err(|e| {
584                Error::from(PgVectorStoreError::Malformed(format!(
585                    "row missing distance: {e}"
586                )))
587            })?;
588            Some(self.distance_to_score(distance))
589        } else {
590            None
591        };
592        Ok(Document {
593            doc_id: Some(doc_id),
594            content,
595            metadata: metadata.0,
596            score,
597        })
598    }
599}