1use std::future::Future;
9
10use anyhow::{Context, Result};
11use sqlx::{postgres::PgPoolOptions, PgPool, Postgres, Transaction};
12
13use crate::backends::base::{BackendConn, BackendDialect, ColSpec};
14
15pub struct PostgresBackend {
16 dsn_env: String,
17 pool: tokio::sync::OnceCell<PgPool>,
18}
19
20impl PostgresBackend {
21 pub fn new(dsn_env: String) -> Self {
22 Self {
23 dsn_env,
24 pool: tokio::sync::OnceCell::new(),
25 }
26 }
27
28 pub async fn pool(&self) -> Result<&PgPool> {
30 self.pool
31 .get_or_try_init(|| async {
32 let dsn = std::env::var(&self.dsn_env)
33 .with_context(|| format!("DSN env var {} not set", self.dsn_env))?;
34 PgPoolOptions::new()
42 .max_connections(1)
43 .connect(&dsn)
44 .await
45 .with_context(|| format!("connecting to {}", self.dsn_env))
46 })
47 .await
48 }
49
50 pub fn vector_metric_sql(metric: &str) -> Result<(&'static str, &'static str)> {
51 match metric {
52 "cosine" => Ok(("<=>", "vector_cosine_ops")),
53 "inner_product" => Ok(("<#>", "vector_ip_ops")),
54 "l2" => Ok(("<->", "vector_l2_ops")),
55 other => anyhow::bail!(
56 "vector_metric must be one of 'cosine', 'inner_product', or 'l2', got {other:?}"
57 ),
58 }
59 }
60}
61
62impl BackendDialect for PostgresBackend {
63 const NAME: &'static str = "postgres";
64 const SUPPORTS_UPSERT: bool = true;
65
66 fn quote_ident(&self, name: &str) -> String {
67 format!("\"{}\"", name.replace('"', "\"\""))
71 }
72
73 fn fq_table(&self, db: &str, table: &str) -> String {
74 format!("{}.{}", self.quote_ident(db), self.quote_ident(table))
75 }
76
77 fn vector_type_ddl(&self, dim: usize) -> String {
80 format!("vector({dim})")
81 }
82 fn json_type_ddl(&self) -> String {
83 "jsonb".to_string()
84 }
85 fn tags_array_type_ddl(&self) -> String {
86 "text[]".to_string()
87 }
88 fn text_pk_type_ddl(&self) -> String {
89 "text".to_string()
90 }
91 fn timestamp_now_default_ddl(&self) -> String {
92 "timestamptz NOT NULL DEFAULT now()".to_string()
93 }
94 fn vector_literal(&self, arr: &[f32]) -> String {
95 let parts: Vec<String> = arr.iter().map(|x| format!("{x:.6}")).collect();
96 format!("[{}]", parts.join(","))
97 }
98
99 fn json_literal(&self, obj: &serde_json::Value) -> String {
100 serde_json::to_string(obj).unwrap_or_else(|_| "null".to_string())
101 }
102 fn json_path_sql(&self, col_expr: &str, dotted_path: &str) -> String {
103 let segs: Vec<&str> = dotted_path.split('.').collect();
104 if segs.len() == 1 {
105 return format!("{col_expr}->>'{}'", segs[0]);
106 }
107 let mut s = String::from(col_expr);
108 for seg in &segs[..segs.len() - 1] {
109 s.push_str(&format!("->'{seg}'"));
110 }
111 s.push_str(&format!("->>'{}'", segs[segs.len() - 1]));
112 s
113 }
114
115 fn upsert_clause(&self, key_cols: &[&str], update_cols: &[&str]) -> String {
116 let keys: Vec<String> = key_cols.iter().map(|c| self.quote_ident(c)).collect();
117 let keys_sql = keys.join(", ");
118 if update_cols.is_empty() {
119 return format!("ON CONFLICT ({keys_sql}) DO NOTHING");
120 }
121 let sets: Vec<String> = update_cols
122 .iter()
123 .map(|c| format!("{q} = EXCLUDED.{q}", q = self.quote_ident(c)))
124 .collect();
125 format!("ON CONFLICT ({keys_sql}) DO UPDATE SET {}", sets.join(", "))
126 }
127 fn create_database_sql(&self, name: &str) -> String {
128 format!("CREATE SCHEMA IF NOT EXISTS {}", self.quote_ident(name))
129 }
130
131 fn add_column_if_not_exists_sql(&self, fq: &str, col: &str, type_ddl: &str) -> String {
132 format!(
133 "ALTER TABLE {fq} ADD COLUMN IF NOT EXISTS {} {type_ddl}",
134 self.quote_ident(col)
135 )
136 }
137
138 fn drop_table_sql(&self, fq: &str) -> String {
139 format!("DROP TABLE {fq}")
140 }
141 fn emit_chunks_table_ddl(
142 &self,
143 fq: &str,
144 cols: &[ColSpec],
145 hnsw: bool,
146 _dim: usize, _engine: Option<&str>, vector_metric: Option<&str>,
149 ) -> Vec<String> {
150 let mut col_lines: Vec<String> = Vec::with_capacity(cols.len());
151 let mut pk_cols: Vec<&str> = Vec::new();
152 for c in cols {
153 let mut line = format!(" {} {}", self.quote_ident(c.name), c.type_ddl);
154 if let Some(default) = c.default {
155 line.push_str(&format!(" DEFAULT {default}"));
156 }
157 if !c.nullable {
158 line.push_str(" NOT NULL");
159 }
160 col_lines.push(line);
161 if c.is_primary_key {
162 pk_cols.push(c.name);
163 }
164 }
165 let mut body = col_lines.join(",\n");
166 if !pk_cols.is_empty() {
167 let pk: Vec<String> = pk_cols.iter().map(|c| self.quote_ident(c)).collect();
168 body.push_str(&format!(",\n PRIMARY KEY ({})", pk.join(", ")));
169 }
170 let create = format!("CREATE TABLE IF NOT EXISTS {fq} (\n{body}\n)");
171
172 let bare = fq
174 .rsplit('.')
175 .next()
176 .unwrap_or(fq)
177 .trim_matches('"')
178 .to_string();
179
180 let mut stmts = vec![create];
181 stmts.push(format!(
182 "CREATE INDEX IF NOT EXISTS {} ON {fq} (\"doc_id\", \"seq_num\")",
183 self.quote_ident(&format!("{bare}_doc_seq_idx"))
184 ));
185 if hnsw {
186 let metric = vector_metric.unwrap_or("cosine");
187 let (_op, opclass) = Self::vector_metric_sql(metric).expect("validated vector_metric");
188 let idx_suffix = if metric == "cosine" {
189 "_emb_hnsw_idx".to_string()
190 } else {
191 format!("_emb_hnsw_{metric}_idx")
192 };
193 stmts.push(format!(
194 "CREATE INDEX IF NOT EXISTS {} ON {fq} USING hnsw (\"embedding\" {opclass})",
195 self.quote_ident(&format!("{bare}{idx_suffix}"))
196 ));
197 }
198 stmts
199 }
200}
201
202impl BackendConn for PostgresBackend {
203 type Db = sqlx::Postgres;
204
205 fn connect(&self) -> impl Future<Output = Result<()>> + Send {
206 async move {
207 let _ = self.pool().await?;
208 Ok(())
209 }
210 }
211
212 fn acquire_create_lock(
213 &self,
214 tx: &mut Transaction<'_, Postgres>,
215 key: &str,
216 ) -> impl Future<Output = Result<()>> + Send {
217 async move {
218 use blake2::{digest::consts::U8, Blake2b, Digest};
221 let mut hasher = Blake2b::<U8>::new();
222 hasher.update(key.as_bytes());
223 let digest = hasher.finalize();
224 let lock_key = i64::from_be_bytes(digest.into());
225 sqlx::query("SELECT pg_advisory_xact_lock($1)")
226 .bind(lock_key)
227 .execute(&mut **tx)
228 .await
229 .with_context(|| format!("acquire advisory lock for {key}"))?;
230 Ok(())
231 }
232 }
233
234 fn table_exists(
235 &self,
236 tx: &mut Transaction<'_, Postgres>,
237 db: &str,
238 table: &str,
239 ) -> impl Future<Output = Result<bool>> + Send {
240 async move {
241 use sqlx::Row;
242 let row = sqlx::query(
243 "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname=$1 AND tablename=$2)",
244 )
245 .bind(db)
246 .bind(table)
247 .fetch_one(&mut **tx)
248 .await?;
249 Ok(row.get::<bool, _>(0))
250 }
251 }
252
253 fn embedding_dim(
254 &self,
255 tx: &mut Transaction<'_, Postgres>,
256 db: &str,
257 table: &str,
258 ) -> impl Future<Output = Result<Option<usize>>> + Send {
259 async move {
260 use sqlx::Row;
261 let row = sqlx::query(
262 r#"
263 SELECT format_type(atttypid, atttypmod) AS t
264 FROM pg_attribute
265 WHERE attrelid = (
266 SELECT c.oid FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace
267 WHERE c.relname = $1 AND n.nspname = $2
268 ) AND attname = 'embedding'
269 "#,
270 )
271 .bind(table)
272 .bind(db)
273 .fetch_optional(&mut **tx)
274 .await?;
275 let Some(r) = row else { return Ok(None) };
276 let s: String = r.get("t");
277 let re = regex::Regex::new(r"^vector\((\d+)\)$").unwrap();
278 Ok(re
279 .captures(&s)
280 .and_then(|c| c.get(1))
281 .and_then(|m| m.as_str().parse().ok()))
282 }
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 fn backend() -> PostgresBackend {
291 PostgresBackend::new("UNUSED_FOR_DIALECT_TESTS".to_string())
292 }
293
294 fn canonical_cols(dim: usize) -> Vec<ColSpec> {
295 vec![
296 ColSpec {
297 name: "id",
298 type_ddl: "text".into(),
299 nullable: false,
300 default: None,
301 is_primary_key: true,
302 },
303 ColSpec {
304 name: "doc_id",
305 type_ddl: "text".into(),
306 nullable: false,
307 default: None,
308 is_primary_key: false,
309 },
310 ColSpec {
311 name: "seq_num",
312 type_ddl: "int".into(),
313 nullable: false,
314 default: None,
315 is_primary_key: false,
316 },
317 ColSpec {
318 name: "embedding",
319 type_ddl: format!("vector({dim})"),
320 nullable: false,
321 default: None,
322 is_primary_key: false,
323 },
324 ]
325 }
326
327 #[test]
328 fn emit_chunks_table_ddl_no_hnsw() {
329 let b = backend();
330 let cols = canonical_cols(384);
331 let stmts = b.emit_chunks_table_ddl("\"db\".\"t\"", &cols, false, 384, None, None);
332 assert_eq!(stmts.len(), 2);
333 assert!(stmts[0].starts_with("CREATE TABLE IF NOT EXISTS \"db\".\"t\""));
334 assert!(stmts[0].contains("\"id\" text NOT NULL"));
335 assert!(stmts[0].contains("PRIMARY KEY (\"id\")"));
336 assert!(stmts[1].contains("CREATE INDEX IF NOT EXISTS \"t_doc_seq_idx\""));
337 assert!(stmts[1].contains("ON \"db\".\"t\" (\"doc_id\", \"seq_num\")"));
338 }
339
340 #[test]
341 fn emit_chunks_table_ddl_with_hnsw() {
342 let b = backend();
343 let cols = canonical_cols(384);
344 let stmts = b.emit_chunks_table_ddl("\"db\".\"t\"", &cols, true, 384, None, None);
345 assert_eq!(stmts.len(), 3);
346 assert!(stmts[2].contains("USING hnsw (\"embedding\" vector_cosine_ops)"));
347 assert!(stmts[2].contains("\"t_emb_hnsw_idx\""));
348 }
349
350 #[test]
351 fn vector_metric_sql_maps_pgvector_operators() {
352 assert_eq!(
353 PostgresBackend::vector_metric_sql("cosine").unwrap(),
354 ("<=>", "vector_cosine_ops")
355 );
356 assert_eq!(
357 PostgresBackend::vector_metric_sql("inner_product").unwrap(),
358 ("<#>", "vector_ip_ops")
359 );
360 assert_eq!(
361 PostgresBackend::vector_metric_sql("l2").unwrap(),
362 ("<->", "vector_l2_ops")
363 );
364 assert!(PostgresBackend::vector_metric_sql("manhattan").is_err());
365 }
366
367 #[test]
368 fn emit_chunks_table_ddl_hnsw_uses_metric_opclass() {
369 let b = backend();
370 let cols = canonical_cols(384);
371 let stmts = b.emit_chunks_table_ddl(
372 "\"db\".\"t\"",
373 &cols,
374 true,
375 384,
376 None,
377 Some("inner_product"),
378 );
379 assert!(stmts[2].contains("USING hnsw (\"embedding\" vector_ip_ops)"));
380 assert!(stmts[2].contains("\"t_emb_hnsw_inner_product_idx\""));
381 }
382
383 #[test]
384 fn quote_ident_wraps_in_double_quotes() {
385 let b = backend();
386 assert_eq!(b.quote_ident("my_table"), "\"my_table\"");
387 }
388
389 #[test]
390 fn quote_ident_doubles_embedded_double_quote() {
391 let b = backend();
392 assert_eq!(b.quote_ident("a\"b"), "\"a\"\"b\"");
395 }
396
397 #[test]
398 fn fq_table_quotes_both_segments() {
399 let b = backend();
400 assert_eq!(b.fq_table("my_db", "my_table"), "\"my_db\".\"my_table\"");
401 }
402
403 #[test]
404 fn vector_type_ddl() {
405 let b = backend();
406 assert_eq!(b.vector_type_ddl(384), "vector(384)");
407 assert_eq!(b.vector_type_ddl(1024), "vector(1024)");
408 }
409
410 #[test]
411 fn json_type_ddl_is_jsonb() {
412 let b = backend();
413 assert_eq!(b.json_type_ddl(), "jsonb");
414 }
415
416 #[test]
417 fn tags_array_type_ddl_is_text_array() {
418 let b = backend();
419 assert_eq!(b.tags_array_type_ddl(), "text[]");
420 }
421
422 #[test]
423 fn text_pk_type_ddl_is_text() {
424 let b = backend();
425 assert_eq!(b.text_pk_type_ddl(), "text");
426 }
427
428 #[test]
429 fn timestamp_now_default_ddl() {
430 let b = backend();
431 assert_eq!(
432 b.timestamp_now_default_ddl(),
433 "timestamptz NOT NULL DEFAULT now()"
434 );
435 }
436
437 #[test]
438 fn vector_literal_format_matches_python() {
439 let b = backend();
440 let v = vec![0.1_f32, 0.2_f32, -0.3_f32];
443 let lit = b.vector_literal(&v);
444 assert_eq!(lit, "[0.100000,0.200000,-0.300000]");
445 }
446
447 #[test]
448 fn vector_literal_empty() {
449 let b = backend();
450 assert_eq!(b.vector_literal(&[]), "[]");
451 }
452
453 #[test]
454 fn json_literal_canonical_form() {
455 let b = backend();
456 let v = serde_json::json!({"k": "v", "n": 1});
457 let lit = b.json_literal(&v);
458 let reparsed: serde_json::Value = serde_json::from_str(&lit).unwrap();
460 assert_eq!(reparsed["k"], "v");
461 assert_eq!(reparsed["n"], 1);
462 }
463
464 #[test]
465 fn json_path_sql_single_segment() {
466 let b = backend();
467 assert_eq!(b.json_path_sql("metadata", "a"), "metadata->>'a'");
468 }
469
470 #[test]
471 fn json_path_sql_two_segments() {
472 let b = backend();
473 assert_eq!(b.json_path_sql("metadata", "a.b"), "metadata->'a'->>'b'");
474 }
475
476 #[test]
477 fn json_path_sql_three_segments() {
478 let b = backend();
479 assert_eq!(
480 b.json_path_sql("metadata", "a.b.c"),
481 "metadata->'a'->'b'->>'c'"
482 );
483 }
484
485 #[test]
486 fn upsert_clause_do_nothing_when_no_update_cols() {
487 let b = backend();
488 let sql = b.upsert_clause(&["id"], &[]);
489 assert_eq!(sql, "ON CONFLICT (\"id\") DO NOTHING");
490 }
491
492 #[test]
493 fn upsert_clause_do_update_set() {
494 let b = backend();
495 let sql = b.upsert_clause(&["id"], &["content", "metadata"]);
496 assert_eq!(
497 sql,
498 "ON CONFLICT (\"id\") DO UPDATE SET \"content\" = EXCLUDED.\"content\", \
499 \"metadata\" = EXCLUDED.\"metadata\""
500 );
501 }
502
503 #[test]
504 fn upsert_clause_composite_key() {
505 let b = backend();
506 let sql = b.upsert_clause(&["a", "b"], &["c"]);
507 assert_eq!(
508 sql,
509 "ON CONFLICT (\"a\", \"b\") DO UPDATE SET \"c\" = EXCLUDED.\"c\""
510 );
511 }
512
513 #[test]
514 fn create_database_sql_uses_schema_for_postgres() {
515 let b = backend();
516 assert_eq!(
517 b.create_database_sql("chunkshop"),
518 "CREATE SCHEMA IF NOT EXISTS \"chunkshop\""
519 );
520 }
521
522 #[test]
523 fn add_column_if_not_exists_sql_format() {
524 let b = backend();
525 let sql = b.add_column_if_not_exists_sql("\"db\".\"t\"", "source", "text");
526 assert_eq!(
527 sql,
528 "ALTER TABLE \"db\".\"t\" ADD COLUMN IF NOT EXISTS \"source\" text"
529 );
530 }
531
532 #[test]
533 fn drop_table_sql_format() {
534 let b = backend();
535 assert_eq!(b.drop_table_sql("\"db\".\"t\""), "DROP TABLE \"db\".\"t\"");
536 }
537}