use std::path::PathBuf;
use std::sync::Arc;
use std::time::SystemTime;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use uni_plugin::PluginRegistry;
use uni_plugin::traits::cdc::{CdcBatch, CdcLsn, CdcStartContext, CdcStream};
use crate::notifications::CommitNotification;
use crate::shutdown::ShutdownHandle;
use uni_sidecar::VecSidecar;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PersistedCheckpoint {
pub name: String,
pub last_lsn: u64,
}
#[derive(Clone, Debug)]
pub struct CdcCheckpointSidecar {
sidecar: VecSidecar<PersistedCheckpoint>,
}
impl CdcCheckpointSidecar {
#[must_use]
pub fn new(data_path: PathBuf) -> Self {
Self {
sidecar: VecSidecar::new(data_path, "cdc_checkpoints.json"),
}
}
#[must_use]
pub fn path(&self) -> &std::path::Path {
self.sidecar.path()
}
pub fn load_all(&self) -> Result<Vec<PersistedCheckpoint>, String> {
self.sidecar.load().map_err(|e| e.to_string())
}
pub fn write_all(&self, rows: &[PersistedCheckpoint]) -> Result<(), String> {
self.sidecar.store(rows).map_err(|e| e.to_string())
}
#[must_use]
pub fn lookup(&self, name: &str) -> Option<CdcLsn> {
self.load_all()
.ok()
.and_then(|rows| rows.into_iter().find(|r| r.name == name))
.map(|r| CdcLsn(r.last_lsn))
}
pub fn write_one(&self, name: &str, lsn: CdcLsn) -> Result<(), String> {
let mut rows = self.load_all()?;
if let Some(row) = rows.iter_mut().find(|r| r.name == name) {
row.last_lsn = lsn.0;
} else {
rows.push(PersistedCheckpoint {
name: name.to_owned(),
last_lsn: lsn.0,
});
}
self.write_all(&rows)
}
}
struct ActiveStream {
name: String,
stream: Box<dyn CdcStream>,
}
fn start_stream(
checkpoint: Option<&CdcCheckpointSidecar>,
name: &str,
provider: &Arc<dyn uni_plugin::traits::cdc::CdcOutputProvider>,
late: bool,
) -> Option<ActiveStream> {
let from_lsn = checkpoint.and_then(|c| c.lookup(name));
match provider.start(CdcStartContext::new(from_lsn)) {
Ok(stream) => {
if late {
tracing::info!(provider = %name, from_lsn = ?from_lsn, "CdcRuntime: late-registered provider started");
} else {
tracing::info!(provider = %name, from_lsn = ?from_lsn, "CdcRuntime: provider started");
}
Some(ActiveStream {
name: name.to_owned(),
stream,
})
}
Err(e) => {
if late {
tracing::warn!(provider = %name, error = %e, "CdcRuntime: late-registered provider start failed");
} else {
tracing::warn!(provider = %name, error = %e, "CdcRuntime: provider start failed; skipping");
}
None
}
}
}
pub struct CdcRuntime {
streams: Arc<Mutex<Vec<ActiveStream>>>,
checkpoint: Option<CdcCheckpointSidecar>,
registry: Arc<PluginRegistry>,
}
impl std::fmt::Debug for CdcRuntime {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let count = self.streams.lock().len();
f.debug_struct("CdcRuntime")
.field("active_streams", &count)
.field(
"checkpoint_path",
&self.checkpoint.as_ref().map(|c| c.path().to_path_buf()),
)
.finish()
}
}
impl CdcRuntime {
#[must_use]
pub fn spawn(
registry: &Arc<PluginRegistry>,
commit_rx: broadcast::Receiver<Arc<CommitNotification>>,
data_path: Option<PathBuf>,
shutdown: &ShutdownHandle,
) -> Arc<Self> {
let checkpoint = data_path.map(CdcCheckpointSidecar::new);
let mut active: Vec<ActiveStream> = Vec::new();
for (name, provider) in registry.cdc_outputs_snapshot() {
if let Some(stream) = start_stream(checkpoint.as_ref(), name.as_str(), &provider, false)
{
active.push(stream);
}
}
let runtime = Arc::new(Self {
streams: Arc::new(Mutex::new(active)),
checkpoint,
registry: Arc::clone(registry),
});
let runtime_clone = Arc::clone(&runtime);
let mut commit_rx = commit_rx;
let mut shutdown_rx = shutdown.subscribe();
let handle = tokio::spawn(async move {
loop {
tokio::select! {
biased;
_ = shutdown_rx.recv() => {
runtime_clone.shutdown_streams();
break;
}
next = commit_rx.recv() => match next {
Ok(notif) => runtime_clone.deliver_commit(¬if),
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(
lagged = n,
"CdcRuntime: commit broadcaster lagged",
);
}
Err(broadcast::error::RecvError::Closed) => break,
}
}
}
});
shutdown.track_task(handle);
runtime
}
#[must_use]
pub fn active_stream_count(&self) -> usize {
self.streams.lock().len()
}
#[must_use]
pub fn checkpoint_sidecar(&self) -> Option<&CdcCheckpointSidecar> {
self.checkpoint.as_ref()
}
fn discover_new_providers(&self) {
let snapshot = self.registry.cdc_outputs_snapshot();
let mut streams = self.streams.lock();
for (name, provider) in snapshot {
if streams.iter().any(|s| s.name == name.as_str()) {
continue;
}
if let Some(stream) =
start_stream(self.checkpoint.as_ref(), name.as_str(), &provider, true)
{
streams.push(stream);
}
}
}
fn deliver_commit(&self, notif: &CommitNotification) {
self.discover_new_providers();
let mutations = notif.mutations.clone().unwrap_or_else(|| {
std::sync::Arc::new(arrow_array::RecordBatch::new_empty(
crate::triggers::event_row_schema(),
))
});
let batch = CdcBatch {
lsn_start: CdcLsn(notif.causal_version),
lsn_end: CdcLsn(notif.version),
mutations,
commit_timestamp: SystemTime::now(),
};
let mut streams = self.streams.lock();
for active in streams.iter_mut() {
if let Err(e) = active.stream.deliver(&batch) {
tracing::warn!(
provider = %active.name,
error = %e,
"CdcRuntime: deliver failed",
);
continue;
}
match active.stream.checkpoint() {
Ok(lsn) => {
if let Some(sidecar) = &self.checkpoint
&& let Err(e) = sidecar.write_one(&active.name, lsn)
{
tracing::debug!(
provider = %active.name,
error = %e,
"CdcRuntime: checkpoint write failed",
);
}
}
Err(e) => tracing::warn!(
provider = %active.name,
error = %e,
"CdcRuntime: checkpoint failed",
),
}
}
}
fn shutdown_streams(&self) {
let mut streams = self.streams.lock();
for active in streams.iter_mut() {
if let Err(e) = active.stream.shutdown() {
tracing::warn!(
provider = %active.name,
error = %e,
"CdcRuntime: shutdown failed",
);
}
}
streams.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn checkpoint_sidecar_round_trip() {
let tmp = TempDir::new().unwrap();
let s = CdcCheckpointSidecar::new(tmp.path().to_path_buf());
assert!(s.load_all().unwrap().is_empty());
s.write_one("kafka", CdcLsn(42)).unwrap();
s.write_one("pulsar", CdcLsn(7)).unwrap();
let rows = s.load_all().unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(s.lookup("kafka"), Some(CdcLsn(42)));
assert_eq!(s.lookup("pulsar"), Some(CdcLsn(7)));
}
#[test]
fn checkpoint_sidecar_survives_close_reopen() {
let tmp = TempDir::new().unwrap();
{
let s = CdcCheckpointSidecar::new(tmp.path().to_path_buf());
s.write_one("kafka", CdcLsn(99)).unwrap();
}
let s2 = CdcCheckpointSidecar::new(tmp.path().to_path_buf());
assert_eq!(s2.lookup("kafka"), Some(CdcLsn(99)));
}
#[test]
fn checkpoint_sidecar_overwrites_existing_provider() {
let tmp = TempDir::new().unwrap();
let s = CdcCheckpointSidecar::new(tmp.path().to_path_buf());
s.write_one("kafka", CdcLsn(1)).unwrap();
s.write_one("kafka", CdcLsn(2)).unwrap();
s.write_one("kafka", CdcLsn(3)).unwrap();
assert_eq!(s.lookup("kafka"), Some(CdcLsn(3)));
assert_eq!(s.load_all().unwrap().len(), 1);
}
}