mod fields;
mod query;
mod resolve;
pub(crate) mod value;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex, PoisonError};
use async_trait::async_trait;
use schema_core::{
ColumnName, DatabaseSchema, Filter, IndexMapping, IndexName, IndexSchema, SoftDelete, TableName,
};
use sources_core::document::{Document, DocumentBuilder, DocumentId, IndexScope};
use sources_core::{Catalog, ColumnInfo, Result, RowKey, SnapshotTable, SourceError, SourceSpec};
use sqlx::{PgPool, Row};
use fields::find_paths;
type ColTypeCache = HashMap<(String, String, String), ColumnMeta>;
const BUILD_CHUNK: usize = 512;
#[derive(Debug, Clone)]
struct ColumnMeta {
sql_type: String,
nullable: bool,
}
#[derive(Debug, Clone)]
pub struct PgDocumentBuilder {
pool: PgPool,
spec: Arc<SourceSpec>,
pk_cache: Arc<Mutex<HashMap<(String, String), ColumnName>>>,
col_type_cache: Arc<Mutex<ColTypeCache>>,
}
impl PgDocumentBuilder {
pub fn new(pool: PgPool, spec: Arc<SourceSpec>) -> Self {
Self {
pool,
spec,
pk_cache: Arc::new(Mutex::new(HashMap::new())),
col_type_cache: Arc::new(Mutex::new(HashMap::new())),
}
}
#[tracing::instrument(name = "pg.connect", skip_all, err)]
pub async fn connect(connection_url: &str, spec: Arc<SourceSpec>) -> Result<Self> {
let pool = sqlx::postgres::PgPoolOptions::new()
.connect(connection_url)
.await
.map_err(|e| SourceError::Connection(e.to_string()))?;
tracing::info!(indexes = spec.indexes().count(), "connected to Postgres");
Ok(Self::new(pool, spec))
}
pub(super) async fn table_primary_key(
&self,
schema: &DatabaseSchema,
table: &TableName,
) -> Result<ColumnName> {
let cache_key = (schema.to_string(), table.to_string());
{
let cache = self.pk_cache.lock().unwrap_or_else(PoisonError::into_inner);
if let Some(column) = cache.get(&cache_key) {
return Ok(column.clone());
}
}
let column = match self.fetch_primary_key(schema, table).await?.as_slice() {
[single] => single.clone(),
[] => {
return Err(SourceError::Query(format!(
"table `{schema}.{table}` has no primary key"
)));
}
_ => {
return Err(SourceError::Unsupported(format!(
"table `{schema}.{table}` has a composite primary key; relations require a single-column key"
)));
}
};
self.pk_cache
.lock()
.unwrap_or_else(PoisonError::into_inner)
.insert(cache_key, column.clone());
Ok(column)
}
async fn fetch_primary_key(
&self,
schema: &DatabaseSchema,
table: &TableName,
) -> Result<Vec<ColumnName>> {
let names = primary_key_column_names(&self.pool, format!("{schema}.{table}")).await?;
names
.into_iter()
.map(|name| {
ColumnName::try_new(name)
.map_err(|e| SourceError::Query(format!("invalid primary key column: {e}")))
})
.collect()
}
async fn relation_pks(
&self,
schema: &schema_core::IndexSchema,
) -> Result<HashMap<String, ColumnName>> {
let mut tables = Vec::new();
fields::collect_relation_tables(&schema.fields, &mut tables);
let unique: HashSet<&TableName> = tables.iter().collect();
let mut pks = HashMap::new();
for table in unique {
pks.insert(
table.to_string(),
self.table_primary_key(&schema.db_schema, table).await?,
);
}
Ok(pks)
}
pub(super) async fn column_type(
&self,
schema: &DatabaseSchema,
table: &TableName,
column: &ColumnName,
) -> Result<String> {
Ok(self.column_meta(schema, table, column).await?.sql_type)
}
async fn column_meta(
&self,
schema: &DatabaseSchema,
table: &TableName,
column: &ColumnName,
) -> Result<ColumnMeta> {
let cache_key = (schema.to_string(), table.to_string(), column.to_string());
{
let cache = self
.col_type_cache
.lock()
.unwrap_or_else(PoisonError::into_inner);
if let Some(meta) = cache.get(&cache_key) {
return Ok(meta.clone());
}
}
let sql = "SELECT format_type(a.atttypid, a.atttypmod) AS sql_type, a.attnotnull AS not_null \
FROM pg_attribute a \
WHERE a.attrelid = $1::regclass AND a.attname = $2 \
AND a.attnum > 0 AND NOT a.attisdropped";
let row = sqlx::query(sql)
.bind(format!("{schema}.{table}"))
.bind(column.as_ref().to_owned())
.fetch_optional(&self.pool)
.await
.map_err(query_err)?;
let meta = match row {
Some(row) => {
let sql_type: String = row.try_get("sql_type").map_err(query_err)?;
let not_null: bool = row.try_get("not_null").map_err(query_err)?;
ColumnMeta {
sql_type,
nullable: !not_null,
}
}
None => {
return Err(SourceError::Query(format!(
"references unknown column `{schema}.{table}.{column}`"
)));
}
};
self.col_type_cache
.lock()
.unwrap_or_else(PoisonError::into_inner)
.insert(cache_key, meta.clone());
Ok(meta)
}
async fn filter_column_types(
&self,
schema: &IndexSchema,
) -> Result<HashMap<(String, String), String>> {
let mut columns = Vec::new();
fields::collect_filter_columns(&schema.fields, &mut columns);
let when = match &schema.soft_delete {
Some(SoftDelete::Column(c)) => c.when.as_deref(),
Some(SoftDelete::Field(f)) => f.when.as_deref(),
None => None,
};
let root_filters = schema.filters.as_deref().unwrap_or_default();
for filter in when.unwrap_or_default().iter().chain(root_filters) {
if let Filter::ValueOp(value_op) = filter {
columns.push((&schema.table, &value_op.column));
}
}
let mut types = HashMap::new();
for (table, column) in columns {
let key = (table.to_string(), column.to_string());
if types.contains_key(&key) {
continue;
}
let sql_type = self.column_type(&schema.db_schema, table, column).await?;
types.insert(key, sql_type);
}
Ok(types)
}
}
#[async_trait]
impl Catalog for PgDocumentBuilder {
async fn column(
&self,
schema: &DatabaseSchema,
table: &TableName,
column: &ColumnName,
) -> Result<ColumnInfo> {
let meta = self.column_meta(schema, table, column).await?;
Ok(ColumnInfo {
sql_type: meta.sql_type,
nullable: meta.nullable,
})
}
}
#[async_trait]
impl DocumentBuilder for PgDocumentBuilder {
#[tracing::instrument(
name = "pg.resolve",
level = "debug",
skip_all,
fields(table = table.as_ref()),
err,
)]
async fn resolve(&self, table: &TableName, key: &RowKey) -> Result<Vec<DocumentId>> {
let mut ids = Vec::new();
for (name, schema) in self.spec.indexes() {
if schema.table == *table {
ids.push(DocumentId {
index: name.clone(),
key: key.clone(),
});
continue;
}
let mut paths = Vec::new();
let mut prefix = Vec::new();
find_paths(&schema.fields, table, &mut prefix, &mut paths);
if paths.is_empty() {
continue;
}
let Some(pk_column) = schema.primary_key.clone() else {
tracing::warn!(
index = %name, table = %table,
"cannot reverse-resolve: index has no primary_key",
);
continue;
};
let mut seen = HashSet::new();
for path in &paths {
for root in self.resolve_path(schema, table, key, path).await? {
if seen.insert(root.clone()) {
ids.push(DocumentId {
index: name.clone(),
key: RowKey(vec![(pk_column.clone(), root)]),
});
}
}
}
}
tracing::trace!(documents = ids.len(), "resolved affected documents");
Ok(ids)
}
#[tracing::instrument(
name = "pg.build",
level = "debug",
skip_all,
fields(index = id.index.as_ref()),
err,
)]
async fn build(&self, id: &DocumentId) -> Result<Document> {
let schema = self
.spec
.schema(&id.index)
.ok_or_else(|| SourceError::Query(format!("unknown index `{}`", id.index)))?;
let pks = self.relation_pks(schema).await?;
let col_types = self.filter_column_types(schema).await?;
let (sql, params) = query::document_query(schema, &id.key.0, &pks, &col_types)?;
let mut statement = sqlx::query(sql);
for param in ¶ms {
statement = query::bind_param(statement, param)?;
}
let row = statement
.fetch_optional(&self.pool)
.await
.map_err(query_err)?;
match row {
None => Ok(Document::Delete { id: id.clone() }),
Some(row) => {
let document: serde_json::Value = row.try_get("document").map_err(query_err)?;
Ok(Document::Upsert {
id: id.clone(),
body: value::json_to_generic(document),
})
}
}
}
#[tracing::instrument(name = "pg.build_many", level = "debug", skip_all, fields(ids = ids.len()))]
async fn build_many(&self, ids: &[DocumentId]) -> Result<Vec<Document>> {
let mut by_index: HashMap<&IndexName, Vec<&DocumentId>> = HashMap::new();
for id in ids {
by_index.entry(&id.index).or_default().push(id);
}
let mut out = Vec::with_capacity(ids.len());
for (index_name, group) in by_index {
let schema = self
.spec
.schema(index_name)
.ok_or_else(|| SourceError::Query(format!("unknown index `{index_name}`")))?;
let keyed: Option<Vec<(&schema_core::GenericValue, &DocumentId)>> = group
.iter()
.map(|id| match id.key.0.as_slice() {
[(_, value)] => Some((value, *id)),
_ => None,
})
.collect();
let (Some(pk_column), Some(keyed)) = (schema.primary_key.clone(), keyed) else {
for id in group {
out.push(self.build(id).await?);
}
continue;
};
let pks = self.relation_pks(schema).await?;
let col_types = self.filter_column_types(schema).await?;
for chunk in keyed.chunks(BUILD_CHUNK) {
let keys: Vec<schema_core::GenericValue> =
chunk.iter().map(|(value, _)| (*value).clone()).collect();
let (sql, params) =
query::documents_query(schema, &pk_column, &keys, &pks, &col_types)?;
let mut statement = sqlx::query(sql);
for param in ¶ms {
statement = query::bind_param(statement, param)?;
}
let rows = statement.fetch_all(&self.pool).await.map_err(query_err)?;
let mut bodies: HashMap<schema_core::GenericValue, schema_core::GenericValue> =
HashMap::with_capacity(rows.len());
for row in &rows {
let key = value::first_column_to_generic(row);
let document: serde_json::Value = row.try_get("document").map_err(query_err)?;
bodies.insert(key, value::json_to_generic(document));
}
for (value, id) in chunk {
let document = match bodies.remove(*value) {
Some(body) => Document::Upsert {
id: (*id).clone(),
body,
},
None => Document::Delete { id: (*id).clone() },
};
out.push(document);
}
}
}
Ok(out)
}
fn backfill_scopes(&self) -> Vec<IndexScope> {
self.spec
.indexes()
.map(|(name, schema)| IndexScope {
index: name.clone(),
root: SnapshotTable {
db_schema: schema.db_schema.clone(),
table: schema.table.clone(),
},
})
.collect()
}
async fn index_mappings(&self) -> Result<Vec<IndexMapping>> {
Ok(self.spec.index_mappings())
}
}
pub(super) fn query_err(error: sqlx::Error) -> SourceError {
SourceError::Query(error.to_string())
}
pub(crate) const PRIMARY_KEY_SQL: &str = "SELECT a.attname AS name \
FROM pg_index i \
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) \
WHERE i.indrelid = $1::regclass AND i.indisprimary";
pub(crate) async fn primary_key_column_names(
pool: &PgPool,
qualified: String,
) -> Result<Vec<String>> {
let rows = sqlx::query(PRIMARY_KEY_SQL)
.bind(qualified)
.fetch_all(pool)
.await
.map_err(query_err)?;
let mut names = Vec::with_capacity(rows.len());
for row in &rows {
names.push(row.try_get::<String, _>("name").map_err(query_err)?);
}
Ok(names)
}