use std::sync::Arc;
use async_trait::async_trait;
use futures::stream::BoxStream;
use pgwire_replication::{ReplicationClient, ReplicationConfig};
use sources_core::cdc::{AckSink, Change, ChangeCapture};
use sources_core::{Result, SnapshotTable, SourceError};
use sqlx::{PgPool, Row};
use tokio::sync::OnceCell;
use super::ack::{AckShared, WalAckSink};
use super::{backfill, stream};
#[derive(Debug, Clone)]
pub struct WalChangeCapture {
config: ReplicationConfig,
connection_url: String,
admin_pool: Arc<OnceCell<PgPool>>,
}
impl WalChangeCapture {
pub fn new(config: ReplicationConfig, connection_url: impl Into<String>) -> Self {
Self {
config,
connection_url: connection_url.into(),
admin_pool: Arc::new(OnceCell::new()),
}
}
async fn admin_pool(&self) -> Result<&PgPool> {
self.admin_pool
.get_or_try_init(|| async {
sqlx::postgres::PgPoolOptions::new()
.max_connections(2)
.connect(&self.connection_url)
.await
.map_err(|e| SourceError::Connection(e.to_string()))
})
.await
}
async fn ensure_slot(&self) -> Result<()> {
let pool = self.admin_pool().await?;
let row = sqlx::query("SELECT plugin FROM pg_replication_slots WHERE slot_name = $1")
.bind(&self.config.slot)
.fetch_optional(pool)
.await
.map_err(|e| SourceError::Query(e.to_string()))?;
match row {
Some(row) => {
let plugin: String = row
.try_get("plugin")
.map_err(|e| SourceError::Query(e.to_string()))?;
if plugin != "pgoutput" {
return Err(SourceError::Connection(format!(
"replication slot '{}' exists but uses plugin '{}', expected 'pgoutput'",
self.config.slot, plugin,
)));
}
tracing::debug!(slot = %self.config.slot, "replication slot already exists");
}
None => {
sqlx::query("SELECT pg_create_logical_replication_slot($1, 'pgoutput')")
.bind(&self.config.slot)
.execute(pool)
.await
.map_err(|e| {
SourceError::Connection(format!(
"failed to create replication slot '{}': {e}",
self.config.slot,
))
})?;
tracing::info!(slot = %self.config.slot, "created replication slot");
}
}
Ok(())
}
}
#[async_trait]
impl ChangeCapture for WalChangeCapture {
#[tracing::instrument(name = "wal.live", skip_all, err)]
async fn live(&self) -> Result<BoxStream<'static, Result<Change>>> {
self.ensure_slot().await?;
let client = ReplicationClient::connect(self.config.clone())
.await
.map_err(|e| SourceError::Connection(e.to_string()))?;
let ack = Arc::new(AckShared::new(self.config.start_lsn.as_u64()));
let sink: Arc<dyn AckSink> = Arc::new(WalAckSink::new(Arc::clone(&ack)));
tracing::info!(
start_lsn = self.config.start_lsn.as_u64(),
"opened replication stream"
);
Ok(stream::build(client, ack, sink))
}
#[tracing::instrument(name = "wal.snapshot", skip_all, fields(tables = tables.len()), err)]
async fn snapshot(
&self,
tables: &[SnapshotTable],
) -> Result<BoxStream<'static, Result<Change>>> {
tracing::info!(tables = tables.len(), "starting snapshot");
backfill::snapshot(&self.connection_url, tables).await
}
#[tracing::instrument(name = "wal.lag", skip_all, err)]
async fn lag(&self) -> Result<Option<u64>> {
let pool = self.admin_pool().await?;
let row = sqlx::query(
"SELECT pg_wal_lsn_diff(pg_current_wal_lsn(), confirmed_flush_lsn)::bigint AS lag \
FROM pg_replication_slots WHERE slot_name = $1",
)
.bind(&self.config.slot)
.fetch_optional(pool)
.await
.map_err(|e| SourceError::Query(e.to_string()))?;
let lag = match row {
Some(row) => {
let bytes: i64 = row
.try_get("lag")
.map_err(|e| SourceError::Query(e.to_string()))?;
Some(bytes.max(0) as u64)
}
None => None,
};
Ok(lag)
}
}