Skip to main content

chunkshop/backends/
postgres.rs

1//! Postgres backend — sqlx-based connection pool + dialect helpers.
2//!
3//! Mirrors `python/src/chunkshop/backends/postgres.py`. Identifier safety
4//! is two-layer: regex allowlist enforced at config-load (in config.rs)
5//! plus quote-doubling here (defense-in-depth — even if the regex were
6//! widened, embedded `"` characters can't break out).
7
8use std::future::Future;
9
10use anyhow::{Context, Result};
11use sqlx::{postgres::PgPoolOptions, PgPool, Postgres, Transaction};
12
13use crate::backends::base::{BackendConn, BackendDialect, ColSpec};
14
15pub struct PostgresBackend {
16    dsn_env: String,
17    pool: tokio::sync::OnceCell<PgPool>,
18}
19
20impl PostgresBackend {
21    pub fn new(dsn_env: String) -> Self {
22        Self {
23            dsn_env,
24            pool: tokio::sync::OnceCell::new(),
25        }
26    }
27
28    /// Lazily-initialized pool. Idempotent.
29    pub async fn pool(&self) -> Result<&PgPool> {
30        self.pool
31            .get_or_try_init(|| async {
32                let dsn = std::env::var(&self.dsn_env)
33                    .with_context(|| format!("DSN env var {} not set", self.dsn_env))?;
34                // max_connections(1) mirrors the Python implementation's
35                // short-lived per-document connection discipline (see
36                // CLAUDE.md). PgSink opens one short transaction per
37                // write_document, so concurrent throughput comes from
38                // running multiple cells as separate processes (orchestrator),
39                // not from pooling within a single process. Revisit if the
40                // sink layer ever wants intra-process write concurrency.
41                PgPoolOptions::new()
42                    .max_connections(1)
43                    .connect(&dsn)
44                    .await
45                    .with_context(|| format!("connecting to {}", self.dsn_env))
46            })
47            .await
48    }
49
50    pub fn vector_metric_sql(metric: &str) -> Result<(&'static str, &'static str)> {
51        match metric {
52            "cosine" => Ok(("<=>", "vector_cosine_ops")),
53            "inner_product" => Ok(("<#>", "vector_ip_ops")),
54            "l2" => Ok(("<->", "vector_l2_ops")),
55            other => anyhow::bail!(
56                "vector_metric must be one of 'cosine', 'inner_product', or 'l2', got {other:?}"
57            ),
58        }
59    }
60}
61
62impl BackendDialect for PostgresBackend {
63    const NAME: &'static str = "postgres";
64    const SUPPORTS_UPSERT: bool = true;
65
66    fn quote_ident(&self, name: &str) -> String {
67        // Defense-in-depth: even with the regex allowlist at config-load
68        // refusing characters outside [a-z0-9_], we still double-quote any
69        // embedded `"` in case the regex is ever widened.
70        format!("\"{}\"", name.replace('"', "\"\""))
71    }
72
73    fn fq_table(&self, db: &str, table: &str) -> String {
74        format!("{}.{}", self.quote_ident(db), self.quote_ident(table))
75    }
76
77    // --- remaining methods land in Tasks 5–9. Stubs below to keep crate compiling. ---
78
79    fn vector_type_ddl(&self, dim: usize) -> String {
80        format!("vector({dim})")
81    }
82    fn json_type_ddl(&self) -> String {
83        "jsonb".to_string()
84    }
85    fn tags_array_type_ddl(&self) -> String {
86        "text[]".to_string()
87    }
88    fn text_pk_type_ddl(&self) -> String {
89        "text".to_string()
90    }
91    fn timestamp_now_default_ddl(&self) -> String {
92        "timestamptz NOT NULL DEFAULT now()".to_string()
93    }
94    fn vector_literal(&self, arr: &[f32]) -> String {
95        let parts: Vec<String> = arr.iter().map(|x| format!("{x:.6}")).collect();
96        format!("[{}]", parts.join(","))
97    }
98
99    fn json_literal(&self, obj: &serde_json::Value) -> String {
100        serde_json::to_string(obj).unwrap_or_else(|_| "null".to_string())
101    }
102    fn json_path_sql(&self, col_expr: &str, dotted_path: &str) -> String {
103        let segs: Vec<&str> = dotted_path.split('.').collect();
104        if segs.len() == 1 {
105            return format!("{col_expr}->>'{}'", segs[0]);
106        }
107        let mut s = String::from(col_expr);
108        for seg in &segs[..segs.len() - 1] {
109            s.push_str(&format!("->'{seg}'"));
110        }
111        s.push_str(&format!("->>'{}'", segs[segs.len() - 1]));
112        s
113    }
114
115    fn upsert_clause(&self, key_cols: &[&str], update_cols: &[&str]) -> String {
116        let keys: Vec<String> = key_cols.iter().map(|c| self.quote_ident(c)).collect();
117        let keys_sql = keys.join(", ");
118        if update_cols.is_empty() {
119            return format!("ON CONFLICT ({keys_sql}) DO NOTHING");
120        }
121        let sets: Vec<String> = update_cols
122            .iter()
123            .map(|c| format!("{q} = EXCLUDED.{q}", q = self.quote_ident(c)))
124            .collect();
125        format!("ON CONFLICT ({keys_sql}) DO UPDATE SET {}", sets.join(", "))
126    }
127    fn create_database_sql(&self, name: &str) -> String {
128        format!("CREATE SCHEMA IF NOT EXISTS {}", self.quote_ident(name))
129    }
130
131    fn add_column_if_not_exists_sql(&self, fq: &str, col: &str, type_ddl: &str) -> String {
132        format!(
133            "ALTER TABLE {fq} ADD COLUMN IF NOT EXISTS {} {type_ddl}",
134            self.quote_ident(col)
135        )
136    }
137
138    fn drop_table_sql(&self, fq: &str) -> String {
139        format!("DROP TABLE {fq}")
140    }
141    fn emit_chunks_table_ddl(
142        &self,
143        fq: &str,
144        cols: &[ColSpec],
145        hnsw: bool,
146        _dim: usize,           // dim is encoded in the embedding column's type_ddl
147        _engine: Option<&str>, // engine clause is not applicable on PG
148        vector_metric: Option<&str>,
149    ) -> Vec<String> {
150        let mut col_lines: Vec<String> = Vec::with_capacity(cols.len());
151        let mut pk_cols: Vec<&str> = Vec::new();
152        for c in cols {
153            let mut line = format!("  {} {}", self.quote_ident(c.name), c.type_ddl);
154            if let Some(default) = c.default {
155                line.push_str(&format!(" DEFAULT {default}"));
156            }
157            if !c.nullable {
158                line.push_str(" NOT NULL");
159            }
160            col_lines.push(line);
161            if c.is_primary_key {
162                pk_cols.push(c.name);
163            }
164        }
165        let mut body = col_lines.join(",\n");
166        if !pk_cols.is_empty() {
167            let pk: Vec<String> = pk_cols.iter().map(|c| self.quote_ident(c)).collect();
168            body.push_str(&format!(",\n  PRIMARY KEY ({})", pk.join(", ")));
169        }
170        let create = format!("CREATE TABLE IF NOT EXISTS {fq} (\n{body}\n)");
171
172        // Strip schema prefix from fq for index naming: "db"."t" → t
173        let bare = fq
174            .rsplit('.')
175            .next()
176            .unwrap_or(fq)
177            .trim_matches('"')
178            .to_string();
179
180        let mut stmts = vec![create];
181        stmts.push(format!(
182            "CREATE INDEX IF NOT EXISTS {} ON {fq} (\"doc_id\", \"seq_num\")",
183            self.quote_ident(&format!("{bare}_doc_seq_idx"))
184        ));
185        if hnsw {
186            let metric = vector_metric.unwrap_or("cosine");
187            let (_op, opclass) = Self::vector_metric_sql(metric).expect("validated vector_metric");
188            let idx_suffix = if metric == "cosine" {
189                "_emb_hnsw_idx".to_string()
190            } else {
191                format!("_emb_hnsw_{metric}_idx")
192            };
193            stmts.push(format!(
194                "CREATE INDEX IF NOT EXISTS {} ON {fq} USING hnsw (\"embedding\" {opclass})",
195                self.quote_ident(&format!("{bare}{idx_suffix}"))
196            ));
197        }
198        stmts
199    }
200}
201
202impl BackendConn for PostgresBackend {
203    type Db = sqlx::Postgres;
204
205    fn connect(&self) -> impl Future<Output = Result<()>> + Send {
206        async move {
207            let _ = self.pool().await?;
208            Ok(())
209        }
210    }
211
212    fn acquire_create_lock(
213        &self,
214        tx: &mut Transaction<'_, Postgres>,
215        key: &str,
216    ) -> impl Future<Output = Result<()>> + Send {
217        async move {
218            // Deterministic 64-bit signed int from BLAKE2b-8 of the schema name.
219            // Mirrors Python's PostgresBackend._advisory_lock_key.
220            use blake2::{digest::consts::U8, Blake2b, Digest};
221            let mut hasher = Blake2b::<U8>::new();
222            hasher.update(key.as_bytes());
223            let digest = hasher.finalize();
224            let lock_key = i64::from_be_bytes(digest.into());
225            sqlx::query("SELECT pg_advisory_xact_lock($1)")
226                .bind(lock_key)
227                .execute(&mut **tx)
228                .await
229                .with_context(|| format!("acquire advisory lock for {key}"))?;
230            Ok(())
231        }
232    }
233
234    fn table_exists(
235        &self,
236        tx: &mut Transaction<'_, Postgres>,
237        db: &str,
238        table: &str,
239    ) -> impl Future<Output = Result<bool>> + Send {
240        async move {
241            use sqlx::Row;
242            let row = sqlx::query(
243                "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname=$1 AND tablename=$2)",
244            )
245            .bind(db)
246            .bind(table)
247            .fetch_one(&mut **tx)
248            .await?;
249            Ok(row.get::<bool, _>(0))
250        }
251    }
252
253    fn embedding_dim(
254        &self,
255        tx: &mut Transaction<'_, Postgres>,
256        db: &str,
257        table: &str,
258    ) -> impl Future<Output = Result<Option<usize>>> + Send {
259        async move {
260            use sqlx::Row;
261            let row = sqlx::query(
262                r#"
263                SELECT format_type(atttypid, atttypmod) AS t
264                FROM pg_attribute
265                WHERE attrelid = (
266                    SELECT c.oid FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace
267                    WHERE c.relname = $1 AND n.nspname = $2
268                ) AND attname = 'embedding'
269                "#,
270            )
271            .bind(table)
272            .bind(db)
273            .fetch_optional(&mut **tx)
274            .await?;
275            let Some(r) = row else { return Ok(None) };
276            let s: String = r.get("t");
277            let re = regex::Regex::new(r"^vector\((\d+)\)$").unwrap();
278            Ok(re
279                .captures(&s)
280                .and_then(|c| c.get(1))
281                .and_then(|m| m.as_str().parse().ok()))
282        }
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    fn backend() -> PostgresBackend {
291        PostgresBackend::new("UNUSED_FOR_DIALECT_TESTS".to_string())
292    }
293
294    fn canonical_cols(dim: usize) -> Vec<ColSpec> {
295        vec![
296            ColSpec {
297                name: "id",
298                type_ddl: "text".into(),
299                nullable: false,
300                default: None,
301                is_primary_key: true,
302            },
303            ColSpec {
304                name: "doc_id",
305                type_ddl: "text".into(),
306                nullable: false,
307                default: None,
308                is_primary_key: false,
309            },
310            ColSpec {
311                name: "seq_num",
312                type_ddl: "int".into(),
313                nullable: false,
314                default: None,
315                is_primary_key: false,
316            },
317            ColSpec {
318                name: "embedding",
319                type_ddl: format!("vector({dim})"),
320                nullable: false,
321                default: None,
322                is_primary_key: false,
323            },
324        ]
325    }
326
327    #[test]
328    fn emit_chunks_table_ddl_no_hnsw() {
329        let b = backend();
330        let cols = canonical_cols(384);
331        let stmts = b.emit_chunks_table_ddl("\"db\".\"t\"", &cols, false, 384, None, None);
332        assert_eq!(stmts.len(), 2);
333        assert!(stmts[0].starts_with("CREATE TABLE IF NOT EXISTS \"db\".\"t\""));
334        assert!(stmts[0].contains("\"id\" text NOT NULL"));
335        assert!(stmts[0].contains("PRIMARY KEY (\"id\")"));
336        assert!(stmts[1].contains("CREATE INDEX IF NOT EXISTS \"t_doc_seq_idx\""));
337        assert!(stmts[1].contains("ON \"db\".\"t\" (\"doc_id\", \"seq_num\")"));
338    }
339
340    #[test]
341    fn emit_chunks_table_ddl_with_hnsw() {
342        let b = backend();
343        let cols = canonical_cols(384);
344        let stmts = b.emit_chunks_table_ddl("\"db\".\"t\"", &cols, true, 384, None, None);
345        assert_eq!(stmts.len(), 3);
346        assert!(stmts[2].contains("USING hnsw (\"embedding\" vector_cosine_ops)"));
347        assert!(stmts[2].contains("\"t_emb_hnsw_idx\""));
348    }
349
350    #[test]
351    fn vector_metric_sql_maps_pgvector_operators() {
352        assert_eq!(
353            PostgresBackend::vector_metric_sql("cosine").unwrap(),
354            ("<=>", "vector_cosine_ops")
355        );
356        assert_eq!(
357            PostgresBackend::vector_metric_sql("inner_product").unwrap(),
358            ("<#>", "vector_ip_ops")
359        );
360        assert_eq!(
361            PostgresBackend::vector_metric_sql("l2").unwrap(),
362            ("<->", "vector_l2_ops")
363        );
364        assert!(PostgresBackend::vector_metric_sql("manhattan").is_err());
365    }
366
367    #[test]
368    fn emit_chunks_table_ddl_hnsw_uses_metric_opclass() {
369        let b = backend();
370        let cols = canonical_cols(384);
371        let stmts = b.emit_chunks_table_ddl(
372            "\"db\".\"t\"",
373            &cols,
374            true,
375            384,
376            None,
377            Some("inner_product"),
378        );
379        assert!(stmts[2].contains("USING hnsw (\"embedding\" vector_ip_ops)"));
380        assert!(stmts[2].contains("\"t_emb_hnsw_inner_product_idx\""));
381    }
382
383    #[test]
384    fn quote_ident_wraps_in_double_quotes() {
385        let b = backend();
386        assert_eq!(b.quote_ident("my_table"), "\"my_table\"");
387    }
388
389    #[test]
390    fn quote_ident_doubles_embedded_double_quote() {
391        let b = backend();
392        // Defense-in-depth: even though the config-load regex disallows `"`,
393        // we still escape it here.
394        assert_eq!(b.quote_ident("a\"b"), "\"a\"\"b\"");
395    }
396
397    #[test]
398    fn fq_table_quotes_both_segments() {
399        let b = backend();
400        assert_eq!(b.fq_table("my_db", "my_table"), "\"my_db\".\"my_table\"");
401    }
402
403    #[test]
404    fn vector_type_ddl() {
405        let b = backend();
406        assert_eq!(b.vector_type_ddl(384), "vector(384)");
407        assert_eq!(b.vector_type_ddl(1024), "vector(1024)");
408    }
409
410    #[test]
411    fn json_type_ddl_is_jsonb() {
412        let b = backend();
413        assert_eq!(b.json_type_ddl(), "jsonb");
414    }
415
416    #[test]
417    fn tags_array_type_ddl_is_text_array() {
418        let b = backend();
419        assert_eq!(b.tags_array_type_ddl(), "text[]");
420    }
421
422    #[test]
423    fn text_pk_type_ddl_is_text() {
424        let b = backend();
425        assert_eq!(b.text_pk_type_ddl(), "text");
426    }
427
428    #[test]
429    fn timestamp_now_default_ddl() {
430        let b = backend();
431        assert_eq!(
432            b.timestamp_now_default_ddl(),
433            "timestamptz NOT NULL DEFAULT now()"
434        );
435    }
436
437    #[test]
438    fn vector_literal_format_matches_python() {
439        let b = backend();
440        // Mirrors Python's PostgresBackend.vector_literal:
441        //   "[" + ",".join(f"{x:.6f}" for x in arr) + "]"
442        let v = vec![0.1_f32, 0.2_f32, -0.3_f32];
443        let lit = b.vector_literal(&v);
444        assert_eq!(lit, "[0.100000,0.200000,-0.300000]");
445    }
446
447    #[test]
448    fn vector_literal_empty() {
449        let b = backend();
450        assert_eq!(b.vector_literal(&[]), "[]");
451    }
452
453    #[test]
454    fn json_literal_canonical_form() {
455        let b = backend();
456        let v = serde_json::json!({"k": "v", "n": 1});
457        let lit = b.json_literal(&v);
458        // Order is implementation-defined; assert structure via re-parse.
459        let reparsed: serde_json::Value = serde_json::from_str(&lit).unwrap();
460        assert_eq!(reparsed["k"], "v");
461        assert_eq!(reparsed["n"], 1);
462    }
463
464    #[test]
465    fn json_path_sql_single_segment() {
466        let b = backend();
467        assert_eq!(b.json_path_sql("metadata", "a"), "metadata->>'a'");
468    }
469
470    #[test]
471    fn json_path_sql_two_segments() {
472        let b = backend();
473        assert_eq!(b.json_path_sql("metadata", "a.b"), "metadata->'a'->>'b'");
474    }
475
476    #[test]
477    fn json_path_sql_three_segments() {
478        let b = backend();
479        assert_eq!(
480            b.json_path_sql("metadata", "a.b.c"),
481            "metadata->'a'->'b'->>'c'"
482        );
483    }
484
485    #[test]
486    fn upsert_clause_do_nothing_when_no_update_cols() {
487        let b = backend();
488        let sql = b.upsert_clause(&["id"], &[]);
489        assert_eq!(sql, "ON CONFLICT (\"id\") DO NOTHING");
490    }
491
492    #[test]
493    fn upsert_clause_do_update_set() {
494        let b = backend();
495        let sql = b.upsert_clause(&["id"], &["content", "metadata"]);
496        assert_eq!(
497            sql,
498            "ON CONFLICT (\"id\") DO UPDATE SET \"content\" = EXCLUDED.\"content\", \
499             \"metadata\" = EXCLUDED.\"metadata\""
500        );
501    }
502
503    #[test]
504    fn upsert_clause_composite_key() {
505        let b = backend();
506        let sql = b.upsert_clause(&["a", "b"], &["c"]);
507        assert_eq!(
508            sql,
509            "ON CONFLICT (\"a\", \"b\") DO UPDATE SET \"c\" = EXCLUDED.\"c\""
510        );
511    }
512
513    #[test]
514    fn create_database_sql_uses_schema_for_postgres() {
515        let b = backend();
516        assert_eq!(
517            b.create_database_sql("chunkshop"),
518            "CREATE SCHEMA IF NOT EXISTS \"chunkshop\""
519        );
520    }
521
522    #[test]
523    fn add_column_if_not_exists_sql_format() {
524        let b = backend();
525        let sql = b.add_column_if_not_exists_sql("\"db\".\"t\"", "source", "text");
526        assert_eq!(
527            sql,
528            "ALTER TABLE \"db\".\"t\" ADD COLUMN IF NOT EXISTS \"source\" text"
529        );
530    }
531
532    #[test]
533    fn drop_table_sql_format() {
534        let b = backend();
535        assert_eq!(b.drop_table_sql("\"db\".\"t\""), "DROP TABLE \"db\".\"t\"");
536    }
537}