1mod fields;
45mod query;
46mod resolve;
47pub(crate) mod value;
48
49use std::collections::{HashMap, HashSet};
50use std::sync::{Arc, Mutex, PoisonError};
51
52use async_trait::async_trait;
53use schema_core::{
54 ColumnName, DatabaseSchema, Filter, IndexMapping, IndexName, IndexSchema, SoftDelete, TableName,
55};
56use sources_core::document::{Document, DocumentBuilder, DocumentId, IndexScope};
57use sources_core::{Catalog, ColumnInfo, Result, RowKey, SnapshotTable, SourceError, SourceSpec};
58use sqlx::{PgPool, Row};
59
60use fields::find_paths;
61
62type ColTypeCache = HashMap<(String, String, String), ColumnMeta>;
64
65const BUILD_CHUNK: usize = 512;
69
70#[derive(Debug, Clone)]
75struct ColumnMeta {
76 sql_type: String,
77 nullable: bool,
78}
79
80#[derive(Debug, Clone)]
85pub struct PgDocumentBuilder {
86 pool: PgPool,
87 spec: Arc<SourceSpec>,
88 pk_cache: Arc<Mutex<HashMap<(String, String), ColumnName>>>,
90 col_type_cache: Arc<Mutex<ColTypeCache>>,
93}
94
95impl PgDocumentBuilder {
96 pub fn new(pool: PgPool, spec: Arc<SourceSpec>) -> Self {
97 Self {
98 pool,
99 spec,
100 pk_cache: Arc::new(Mutex::new(HashMap::new())),
101 col_type_cache: Arc::new(Mutex::new(HashMap::new())),
102 }
103 }
104
105 #[tracing::instrument(name = "pg.connect", skip_all, err)]
106 pub async fn connect(connection_url: &str, spec: Arc<SourceSpec>) -> Result<Self> {
107 let pool = sqlx::postgres::PgPoolOptions::new()
108 .connect(connection_url)
109 .await
110 .map_err(|e| SourceError::Connection(e.to_string()))?;
111 tracing::info!(indexes = spec.indexes().count(), "connected to Postgres");
112 Ok(Self::new(pool, spec))
113 }
114
115 pub(super) async fn table_primary_key(
119 &self,
120 schema: &DatabaseSchema,
121 table: &TableName,
122 ) -> Result<ColumnName> {
123 let cache_key = (schema.to_string(), table.to_string());
124 {
125 let cache = self.pk_cache.lock().unwrap_or_else(PoisonError::into_inner);
126 if let Some(column) = cache.get(&cache_key) {
127 return Ok(column.clone());
128 }
129 }
130 let column = match self.fetch_primary_key(schema, table).await?.as_slice() {
131 [single] => single.clone(),
132 [] => {
133 return Err(SourceError::Query(format!(
134 "table `{schema}.{table}` has no primary key"
135 )));
136 }
137 _ => {
138 return Err(SourceError::Unsupported(format!(
139 "table `{schema}.{table}` has a composite primary key; relations require a single-column key"
140 )));
141 }
142 };
143 self.pk_cache
144 .lock()
145 .unwrap_or_else(PoisonError::into_inner)
146 .insert(cache_key, column.clone());
147 Ok(column)
148 }
149
150 async fn fetch_primary_key(
151 &self,
152 schema: &DatabaseSchema,
153 table: &TableName,
154 ) -> Result<Vec<ColumnName>> {
155 let names = primary_key_column_names(&self.pool, format!("{schema}.{table}")).await?;
156 names
157 .into_iter()
158 .map(|name| {
159 ColumnName::try_new(name)
160 .map_err(|e| SourceError::Query(format!("invalid primary key column: {e}")))
161 })
162 .collect()
163 }
164
165 async fn relation_pks(
168 &self,
169 schema: &schema_core::IndexSchema,
170 ) -> Result<HashMap<String, ColumnName>> {
171 let mut tables = Vec::new();
172 fields::collect_relation_tables(&schema.fields, &mut tables);
173 let unique: HashSet<&TableName> = tables.iter().collect();
174 let mut pks = HashMap::new();
175 for table in unique {
176 pks.insert(
177 table.to_string(),
178 self.table_primary_key(&schema.db_schema, table).await?,
179 );
180 }
181 Ok(pks)
182 }
183
184 pub(super) async fn column_type(
189 &self,
190 schema: &DatabaseSchema,
191 table: &TableName,
192 column: &ColumnName,
193 ) -> Result<String> {
194 Ok(self.column_meta(schema, table, column).await?.sql_type)
195 }
196
197 async fn column_meta(
201 &self,
202 schema: &DatabaseSchema,
203 table: &TableName,
204 column: &ColumnName,
205 ) -> Result<ColumnMeta> {
206 let cache_key = (schema.to_string(), table.to_string(), column.to_string());
207 {
208 let cache = self
209 .col_type_cache
210 .lock()
211 .unwrap_or_else(PoisonError::into_inner);
212 if let Some(meta) = cache.get(&cache_key) {
213 return Ok(meta.clone());
214 }
215 }
216 let sql = "SELECT format_type(a.atttypid, a.atttypmod) AS sql_type, a.attnotnull AS not_null \
221 FROM pg_attribute a \
222 WHERE a.attrelid = $1::regclass AND a.attname = $2 \
223 AND a.attnum > 0 AND NOT a.attisdropped";
224 let row = sqlx::query(sql)
225 .bind(format!("{schema}.{table}"))
226 .bind(column.as_ref().to_owned())
227 .fetch_optional(&self.pool)
228 .await
229 .map_err(query_err)?;
230 let meta = match row {
231 Some(row) => {
232 let sql_type: String = row.try_get("sql_type").map_err(query_err)?;
233 let not_null: bool = row.try_get("not_null").map_err(query_err)?;
234 ColumnMeta {
235 sql_type,
236 nullable: !not_null,
237 }
238 }
239 None => {
240 return Err(SourceError::Query(format!(
241 "references unknown column `{schema}.{table}.{column}`"
242 )));
243 }
244 };
245 self.col_type_cache
246 .lock()
247 .unwrap_or_else(PoisonError::into_inner)
248 .insert(cache_key, meta.clone());
249 Ok(meta)
250 }
251
252 async fn filter_column_types(
257 &self,
258 schema: &IndexSchema,
259 ) -> Result<HashMap<(String, String), String>> {
260 let mut columns = Vec::new();
261 fields::collect_filter_columns(&schema.fields, &mut columns);
262
263 let when = match &schema.soft_delete {
265 Some(SoftDelete::Column(c)) => c.when.as_deref(),
266 Some(SoftDelete::Field(f)) => f.when.as_deref(),
267 None => None,
268 };
269 let root_filters = schema.filters.as_deref().unwrap_or_default();
270 for filter in when.unwrap_or_default().iter().chain(root_filters) {
271 if let Filter::ValueOp(value_op) = filter {
272 columns.push((&schema.table, &value_op.column));
273 }
274 }
275
276 let mut types = HashMap::new();
277 for (table, column) in columns {
278 let key = (table.to_string(), column.to_string());
279 if types.contains_key(&key) {
280 continue;
281 }
282 let sql_type = self.column_type(&schema.db_schema, table, column).await?;
283 types.insert(key, sql_type);
284 }
285 Ok(types)
286 }
287
288 async fn add_key_column_types(
293 &self,
294 schema: &IndexSchema,
295 columns: &[&ColumnName],
296 types: &mut HashMap<(String, String), String>,
297 ) -> Result<()> {
298 for column in columns {
299 let key = (schema.table.to_string(), column.to_string());
300 if types.contains_key(&key) {
301 continue;
302 }
303 let sql_type = self
304 .column_type(&schema.db_schema, &schema.table, column)
305 .await?;
306 types.insert(key, sql_type);
307 }
308 Ok(())
309 }
310}
311
312#[async_trait]
318impl Catalog for PgDocumentBuilder {
319 async fn column(
320 &self,
321 schema: &DatabaseSchema,
322 table: &TableName,
323 column: &ColumnName,
324 ) -> Result<ColumnInfo> {
325 let meta = self.column_meta(schema, table, column).await?;
326 Ok(ColumnInfo {
327 sql_type: meta.sql_type,
328 nullable: meta.nullable,
329 })
330 }
331}
332
333#[async_trait]
334impl DocumentBuilder for PgDocumentBuilder {
335 #[tracing::instrument(
336 name = "pg.resolve",
337 level = "debug",
338 skip_all,
339 fields(table = table.as_ref()),
340 err,
341 )]
342 async fn resolve(&self, table: &TableName, key: &RowKey) -> Result<Vec<DocumentId>> {
343 let mut ids = Vec::new();
344 for (name, schema) in self.spec.indexes() {
345 if schema.table == *table {
346 ids.push(DocumentId {
347 index: name.clone(),
348 key: key.clone(),
349 });
350 continue;
351 }
352
353 let mut paths = Vec::new();
354 let mut prefix = Vec::new();
355 find_paths(&schema.fields, table, &mut prefix, &mut paths);
356 if paths.is_empty() {
357 continue;
358 }
359 let Some(pk_column) = schema.primary_key.clone() else {
360 tracing::warn!(
361 index = %name, table = %table,
362 "cannot reverse-resolve: index has no primary_key",
363 );
364 continue;
365 };
366
367 let mut seen = HashSet::new();
368 for path in &paths {
369 for root in self.resolve_path(schema, table, key, path).await? {
370 if seen.insert(root.clone()) {
371 ids.push(DocumentId {
372 index: name.clone(),
373 key: RowKey(vec![(pk_column.clone(), root)]),
374 });
375 }
376 }
377 }
378 }
379 tracing::trace!(documents = ids.len(), "resolved affected documents");
380 Ok(ids)
381 }
382
383 #[tracing::instrument(
384 name = "pg.build",
385 level = "debug",
386 skip_all,
387 fields(index = id.index.as_ref()),
388 err,
389 )]
390 async fn build(&self, id: &DocumentId) -> Result<Document> {
391 let schema = self
392 .spec
393 .schema(&id.index)
394 .ok_or_else(|| SourceError::Query(format!("unknown index `{}`", id.index)))?;
395
396 let pks = self.relation_pks(schema).await?;
397 let mut col_types = self.filter_column_types(schema).await?;
398 let key_columns: Vec<&ColumnName> = id.key.0.iter().map(|(column, _)| column).collect();
399 self.add_key_column_types(schema, &key_columns, &mut col_types)
400 .await?;
401 let (sql, params) = query::document_query(schema, &id.key.0, &pks, &col_types)?;
402
403 let mut statement = sqlx::query(sql);
404 for param in ¶ms {
405 statement = query::bind_param(statement, param)?;
406 }
407 let row = statement
408 .fetch_optional(&self.pool)
409 .await
410 .map_err(query_err)?;
411
412 match row {
415 None => Ok(Document::Delete { id: id.clone() }),
416 Some(row) => {
417 let document: serde_json::Value = row.try_get("document").map_err(query_err)?;
418 Ok(Document::Upsert {
419 id: id.clone(),
420 body: value::coerce_document(document, &schema.fields),
421 })
422 }
423 }
424 }
425
426 #[tracing::instrument(name = "pg.build_many", level = "debug", skip_all, fields(ids = ids.len()))]
427 async fn build_many(&self, ids: &[DocumentId]) -> Result<Vec<Document>> {
428 let mut by_index: HashMap<&IndexName, Vec<&DocumentId>> = HashMap::new();
429 for id in ids {
430 by_index.entry(&id.index).or_default().push(id);
431 }
432
433 let mut out = Vec::with_capacity(ids.len());
434 for (index_name, group) in by_index {
435 let schema = self
436 .spec
437 .schema(index_name)
438 .ok_or_else(|| SourceError::Query(format!("unknown index `{index_name}`")))?;
439
440 let keyed: Option<Vec<(&schema_core::GenericValue, &DocumentId)>> = group
447 .iter()
448 .map(|id| match id.key.0.as_slice() {
449 [(_, value)] => Some((value, *id)),
450 _ => None,
451 })
452 .collect();
453 let (Some(pk_column), Some(keyed)) = (schema.primary_key.clone(), keyed) else {
454 for id in group {
455 out.push(self.build(id).await?);
456 }
457 continue;
458 };
459
460 let pks = self.relation_pks(schema).await?;
461 let mut col_types = self.filter_column_types(schema).await?;
462 self.add_key_column_types(schema, &[&pk_column], &mut col_types)
463 .await?;
464
465 for chunk in keyed.chunks(BUILD_CHUNK) {
466 let keys: Vec<schema_core::GenericValue> =
467 chunk.iter().map(|(value, _)| (*value).clone()).collect();
468 let (sql, params) =
469 query::documents_query(schema, &pk_column, &keys, &pks, &col_types)?;
470
471 let mut statement = sqlx::query(sql);
472 for param in ¶ms {
473 statement = query::bind_param(statement, param)?;
474 }
475 let rows = statement.fetch_all(&self.pool).await.map_err(query_err)?;
476
477 let mut bodies: HashMap<schema_core::GenericValue, schema_core::GenericValue> =
481 HashMap::with_capacity(rows.len());
482 for row in &rows {
483 let key = value::first_column_to_generic(row);
484 let document: serde_json::Value = row.try_get("document").map_err(query_err)?;
485 bodies.insert(key, value::coerce_document(document, &schema.fields));
486 }
487
488 for (value, id) in chunk {
492 let document = match bodies.remove(*value) {
493 Some(body) => Document::Upsert {
494 id: (*id).clone(),
495 body,
496 },
497 None => Document::Delete { id: (*id).clone() },
498 };
499 out.push(document);
500 }
501 }
502 }
503 Ok(out)
504 }
505
506 fn backfill_scopes(&self) -> Vec<IndexScope> {
507 self.spec
510 .indexes()
511 .map(|(name, schema)| IndexScope {
512 index: name.clone(),
513 root: SnapshotTable {
514 db_schema: schema.db_schema.clone(),
515 table: schema.table.clone(),
516 },
517 })
518 .collect()
519 }
520
521 async fn index_mappings(&self) -> Result<Vec<IndexMapping>> {
522 Ok(self.spec.index_mappings())
525 }
526}
527
528pub(super) fn query_err(error: sqlx::Error) -> SourceError {
529 SourceError::Query(error.to_string())
530}
531
532pub(crate) const PRIMARY_KEY_SQL: &str = "SELECT a.attname AS name \
535 FROM pg_index i \
536 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) \
537 WHERE i.indrelid = $1::regclass AND i.indisprimary";
538
539pub(crate) async fn primary_key_column_names(
543 pool: &PgPool,
544 qualified: String,
545) -> Result<Vec<String>> {
546 let rows = sqlx::query(PRIMARY_KEY_SQL)
547 .bind(qualified)
548 .fetch_all(pool)
549 .await
550 .map_err(query_err)?;
551 let mut names = Vec::with_capacity(rows.len());
552 for row in &rows {
553 names.push(row.try_get::<String, _>("name").map_err(query_err)?);
554 }
555 Ok(names)
556}