use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use tokio_postgres::{Client, NoTls, SimpleQueryMessage};
use vil_log::{mq_log, types::MqPayload};
use vil_log::dict::register_str;
use vil_trigger_core::traits::{TriggerSource, EventCallback};
use vil_trigger_core::types::{TriggerEvent, TriggerFault};
use crate::config::CdcConfig;
use crate::error::CdcFault;
const PG_INSERT: u8 = b'I';
const PG_UPDATE: u8 = b'U';
const PG_DELETE: u8 = b'D';
pub struct CdcTrigger {
config: CdcConfig,
paused: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
slot_hash: u32,
kind_hash: u32,
}
impl CdcTrigger {
pub fn new(config: CdcConfig) -> Self {
let slot_hash = register_str(&config.slot_name);
let kind_hash = register_str("cdc");
Self {
config,
paused: Arc::new(AtomicBool::new(false)),
sequence: Arc::new(AtomicU64::new(0)),
slot_hash,
kind_hash,
}
}
fn map_fault(f: CdcFault, kind_hash: u32) -> TriggerFault {
TriggerFault::SourceUnavailable {
kind_hash,
reason_code: f.as_error_code(),
}
}
async fn connect_replication(&self) -> Result<Client, CdcFault> {
let conn_hash = register_str(&self.config.conn_string);
let repl_conn = format!("{} replication=database", self.config.conn_string);
let (client, connection) = tokio_postgres::connect(&repl_conn, NoTls)
.await
.map_err(|e| CdcFault::ConnectionFailed {
conn_hash,
reason_code: e.as_db_error().map(|d| d.code().code().len() as u32).unwrap_or(0),
})?;
tokio::spawn(async move {
let _ = connection.await;
});
Ok(client)
}
async fn consume_stream(&self, client: &Client, on_event: &EventCallback) -> Result<(), CdcFault> {
let pub_hash = register_str(&self.config.publication);
let slot_hash = self.slot_hash;
let kind_hash = self.kind_hash;
let start_sql = format!(
"START_REPLICATION SLOT {} LOGICAL 0/0 (proto_version '1', publication_names '{}')",
self.config.slot_name, self.config.publication
);
let rows = client
.simple_query(&start_sql)
.await
.map_err(|_| CdcFault::ReplicationStartFailed {
slot_hash,
pg_error_code: 0,
})?;
for msg in rows {
if self.paused.load(Ordering::Relaxed) {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
continue;
}
if let SimpleQueryMessage::Row(row) = msg {
let start = std::time::Instant::now();
let seq = self.sequence.fetch_add(1, Ordering::Relaxed);
let op_byte: u8 = row
.get(0)
.and_then(|s| s.as_bytes().first().copied())
.unwrap_or(0);
let mq_op: u8 = match op_byte {
PG_INSERT => 0, PG_UPDATE => 1, PG_DELETE => 2, _ => 0,
};
let table_str = row.get(1).unwrap_or("unknown");
let table_hash = register_str(table_str);
let elapsed = start.elapsed();
mq_log!(Info, MqPayload {
broker_hash: slot_hash,
topic_hash: table_hash,
group_hash: pub_hash,
offset: seq,
message_bytes: 0,
e2e_latency_us: elapsed.as_micros() as u32,
op_type: mq_op,
partition: 0,
retries: 0,
compression: 0,
..MqPayload::default()
});
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
on_event(TriggerEvent {
kind_hash,
source_hash: slot_hash,
sequence: seq,
timestamp_ns: ts,
payload_bytes: 0,
op: 0,
_pad: [0; 3],
});
}
}
Ok(())
}
}
#[async_trait]
impl TriggerSource for CdcTrigger {
fn kind(&self) -> &'static str {
"cdc"
}
async fn start(&self, on_event: EventCallback) -> Result<(), TriggerFault> {
let client = self
.connect_replication()
.await
.map_err(|e| Self::map_fault(e, self.kind_hash))?;
self.consume_stream(&client, &on_event)
.await
.map_err(|e| Self::map_fault(e, self.kind_hash))
}
async fn pause(&self) -> Result<(), TriggerFault> {
self.paused.store(true, Ordering::Relaxed);
Ok(())
}
async fn resume(&self) -> Result<(), TriggerFault> {
self.paused.store(false, Ordering::Relaxed);
Ok(())
}
async fn stop(&self) -> Result<(), TriggerFault> {
self.paused.store(true, Ordering::Relaxed);
Ok(())
}
}