use std::collections::BTreeSet;
use std::sync::Arc;
use async_trait::async_trait;
use futures::stream::BoxStream;
use pgwire_replication::{ReplicationClient, ReplicationConfig};
use sources_core::cdc::{AckSink, Change, ChangeCapture};
use sources_core::{
CaptureProvisioning, CoverageReport, QualifiedTable, Result, SnapshotTable, SourceError,
};
use sqlx::{PgPool, Row};
use tokio::sync::OnceCell;
use super::ack::{AckShared, WalAckSink};
use super::{backfill, publication, stream};
#[derive(Debug, Clone)]
pub struct WalChangeCapture {
config: ReplicationConfig,
connection_url: String,
admin_pool: Arc<OnceCell<PgPool>>,
required_tables: BTreeSet<QualifiedTable>,
manage_publication: bool,
}
impl WalChangeCapture {
pub fn new(config: ReplicationConfig, connection_url: impl Into<String>) -> Self {
Self {
config,
connection_url: connection_url.into(),
admin_pool: Arc::new(OnceCell::new()),
required_tables: BTreeSet::new(),
manage_publication: false,
}
}
pub fn with_publication_management(
mut self,
required: BTreeSet<QualifiedTable>,
manage: bool,
) -> Self {
self.required_tables = required;
self.manage_publication = manage;
self
}
async fn admin_pool(&self) -> Result<&PgPool> {
self.admin_pool
.get_or_try_init(|| async {
sqlx::postgres::PgPoolOptions::new()
.max_connections(2)
.connect(&self.connection_url)
.await
.map_err(|e| SourceError::Connection(e.to_string()))
})
.await
}
async fn ensure_slot(&self) -> Result<()> {
let pool = self.admin_pool().await?;
let row = sqlx::query("SELECT plugin FROM pg_replication_slots WHERE slot_name = $1")
.bind(&self.config.slot)
.fetch_optional(pool)
.await
.map_err(|e| SourceError::Query(e.to_string()))?;
match row {
Some(row) => {
let plugin: String = row
.try_get("plugin")
.map_err(|e| SourceError::Query(e.to_string()))?;
if plugin != "pgoutput" {
return Err(SourceError::Connection(format!(
"replication slot '{}' exists but uses plugin '{}', expected 'pgoutput'",
self.config.slot, plugin,
)));
}
tracing::debug!(slot = %self.config.slot, "replication slot already exists");
}
None => {
sqlx::query("SELECT pg_create_logical_replication_slot($1, 'pgoutput')")
.bind(&self.config.slot)
.execute(pool)
.await
.map_err(|e| {
SourceError::Connection(format!(
"failed to create replication slot '{}': {e}",
self.config.slot,
))
})?;
tracing::info!(slot = %self.config.slot, "created replication slot");
}
}
Ok(())
}
}
#[async_trait]
impl ChangeCapture for WalChangeCapture {
#[tracing::instrument(name = "wal.live", skip_all, err)]
async fn live(&self) -> Result<BoxStream<'static, Result<Change>>> {
self.ensure_slot().await?;
self.ensure_coverage(&self.required_tables, self.manage_publication)
.await?;
let client = ReplicationClient::connect(self.config.clone())
.await
.map_err(|e| SourceError::Connection(e.to_string()))?;
let ack = Arc::new(AckShared::new(self.config.start_lsn.as_u64()));
let sink: Arc<dyn AckSink> = Arc::new(WalAckSink::new(Arc::clone(&ack)));
tracing::info!(
start_lsn = self.config.start_lsn.as_u64(),
"opened replication stream"
);
Ok(stream::build(client, ack, sink))
}
#[tracing::instrument(name = "wal.snapshot", skip_all, fields(tables = tables.len()), err)]
async fn snapshot(
&self,
tables: &[SnapshotTable],
) -> Result<BoxStream<'static, Result<Change>>> {
tracing::info!(tables = tables.len(), "starting snapshot");
backfill::snapshot(&self.connection_url, tables).await
}
#[tracing::instrument(name = "wal.lag", skip_all, err)]
async fn lag(&self) -> Result<Option<u64>> {
let pool = self.admin_pool().await?;
let row = sqlx::query(
"SELECT pg_wal_lsn_diff(pg_current_wal_lsn(), confirmed_flush_lsn)::bigint AS lag \
FROM pg_replication_slots WHERE slot_name = $1",
)
.bind(&self.config.slot)
.fetch_optional(pool)
.await
.map_err(|e| SourceError::Query(e.to_string()))?;
let lag = match row {
Some(row) => {
let bytes: i64 = row
.try_get("lag")
.map_err(|e| SourceError::Query(e.to_string()))?;
Some(bytes.max(0) as u64)
}
None => None,
};
Ok(lag)
}
}
#[async_trait]
impl CaptureProvisioning for WalChangeCapture {
async fn inspect_coverage(
&self,
required: &BTreeSet<QualifiedTable>,
) -> Result<CoverageReport> {
let pool = self.admin_pool().await?;
publication::inspect_publication(pool, &self.config.publication, required).await
}
#[tracing::instrument(name = "wal.ensure_coverage", skip_all, err)]
async fn ensure_coverage(
&self,
required: &BTreeSet<QualifiedTable>,
manage: bool,
) -> Result<CoverageReport> {
let pool = self.admin_pool().await?;
let report =
publication::inspect_publication(pool, &self.config.publication, required).await?;
if report.satisfied {
tracing::debug!(
publication = %self.config.publication,
"publication covers every required table",
);
return Ok(report);
}
let missing = report
.missing
.iter()
.map(|table| table.to_string())
.collect::<Vec<_>>()
.join(", ");
if manage && report.manageable {
publication::apply_publication(pool, &self.config.publication, &report.missing).await?;
tracing::info!(
publication = %self.config.publication,
tables = %missing,
"provisioned publication for missing tables",
);
} else {
let reason = if !manage {
"automatic publication management is disabled".to_owned()
} else {
report.blockers.join("; ")
};
tracing::warn!(
publication = %self.config.publication,
missing = %missing,
reason = %reason,
remediation = %report.remediation.join(" "),
"publication is missing tables and flusso will not create them automatically; \
run the printed SQL to stream every table (changes to missing tables are dropped)",
);
}
Ok(report)
}
}