use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use futures::stream::{self, BoxStream};
use pgwire_replication::{Lsn, ReplicationClient, ReplicationEvent};
use sources_core::cdc::{Ack, AckSink, Change, ChangeEvent};
use sources_core::{Result, SourceError};
use super::ack::AckShared;
use super::pgoutput::{self, Decoded, Relation};
struct State {
client: ReplicationClient,
relations: HashMap<u32, Relation>,
open_txn: Vec<ChangeEvent>,
pending: VecDeque<(ChangeEvent, u64)>,
ack: Arc<AckShared>,
sink: Arc<dyn AckSink>,
done: bool,
}
pub(crate) fn build(
client: ReplicationClient,
ack: Arc<AckShared>,
sink: Arc<dyn AckSink>,
) -> BoxStream<'static, Result<Change>> {
let state = State {
client,
relations: HashMap::new(),
open_txn: Vec::new(),
pending: VecDeque::new(),
ack,
sink,
done: false,
};
Box::pin(stream::unfold(state, |mut state| async move {
loop {
state
.client
.update_applied_lsn(Lsn::from_u64(state.ack.confirmed_lsn()));
if let Some((event, lsn)) = state.pending.pop_front() {
let seq = state.ack.register(lsn);
let ack = Ack::new(seq, Arc::clone(&state.sink));
return Some((Ok(Change { event, ack }), state));
}
if state.done {
return None;
}
match state.client.recv().await {
Ok(Some(event)) => {
if let Err(e) = handle(&mut state, event) {
state.done = true;
return Some((Err(e), state));
}
}
Ok(None) => state.done = true,
Err(e) => {
state.done = true;
return Some((Err(map_pgwire(e)), state));
}
}
}
}))
}
fn handle(state: &mut State, event: ReplicationEvent) -> std::result::Result<(), SourceError> {
match event {
ReplicationEvent::KeepAlive { .. } | ReplicationEvent::Message { .. } => {}
ReplicationEvent::Begin { .. } => state.open_txn.clear(),
ReplicationEvent::Commit { end_lsn, .. } => {
let lsn = end_lsn.as_u64();
for change in state.open_txn.drain(..) {
state.pending.push_back((change, lsn));
}
}
ReplicationEvent::StoppedAt { .. } => state.done = true,
ReplicationEvent::XLogData { data, .. } => handle_xlog(state, data.as_ref())?,
}
Ok(())
}
fn handle_xlog(state: &mut State, data: &[u8]) -> std::result::Result<(), SourceError> {
match pgoutput::decode(data)? {
Decoded::Relation(relation) => {
state.relations.insert(relation.oid, relation);
}
Decoded::Insert { rel, new } => {
let relation = lookup_relation(state, rel)?;
let table = relation.table.clone();
let key = pgoutput::row_key(relation, &new)?;
state.open_txn.push(ChangeEvent::Upsert { table, key });
}
Decoded::Update { rel, old, new } => {
let relation = lookup_relation(state, rel)?;
let table = relation.table.clone();
let new_key = pgoutput::row_key(relation, &new)?;
let old_key = match &old {
Some(old) => Some(pgoutput::row_key(relation, old)?),
None => None,
};
if let Some(old_key) = old_key
&& old_key.0 != new_key.0
{
state.open_txn.push(ChangeEvent::Delete {
table: table.clone(),
key: old_key,
});
}
state.open_txn.push(ChangeEvent::Upsert {
table,
key: new_key,
});
}
Decoded::Delete { rel, old } => {
let relation = lookup_relation(state, rel)?;
let table = relation.table.clone();
let key = pgoutput::row_key(relation, &old)?;
state.open_txn.push(ChangeEvent::Delete { table, key });
}
Decoded::Truncate { rels } => {
for oid in rels {
let table = state
.relations
.get(&oid)
.map(|r| r.table.to_string())
.unwrap_or_else(|| format!("oid {oid}"));
tracing::warn!(%table, "TRUNCATE received but not propagated; index may be stale");
}
}
Decoded::Other => {}
}
Ok(())
}
fn lookup_relation(state: &State, oid: u32) -> std::result::Result<&Relation, SourceError> {
state.relations.get(&oid).ok_or_else(|| {
SourceError::Decode(format!("pgoutput: change for unknown relation oid {oid}"))
})
}
fn map_pgwire(error: pgwire_replication::PgWireError) -> SourceError {
use pgwire_replication::PgWireError;
if error.is_transient() {
return SourceError::Connection(error.to_string());
}
match error {
PgWireError::Server(_) | PgWireError::Auth(_) | PgWireError::Tls(_) => {
SourceError::Setup(error.to_string())
}
PgWireError::Protocol(_) => SourceError::Decode(error.to_string()),
other => SourceError::Connection(other.to_string()),
}
}