use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncWrite, AsyncWriteExt, Stdout};
use tokio::sync::Mutex;
use tracing::{info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MigrationStage {
Validate,
PrepareSnapshot,
Dump,
Restore,
StreamApply,
Lag,
CaughtUp,
Cutover,
Complete,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProgressEvent {
pub stage: MigrationStage,
pub message: String,
pub detail: Option<serde_json::Value>,
}
impl ProgressEvent {
pub fn new(stage: MigrationStage, message: impl Into<String>) -> Self {
Self {
stage,
message: message.into(),
detail: None,
}
}
pub fn with_detail(mut self, detail: serde_json::Value) -> Self {
self.detail = Some(detail);
self
}
}
#[async_trait::async_trait]
pub trait ProgressReporter: Send + Sync + std::fmt::Debug {
async fn report(&self, event: ProgressEvent);
}
#[derive(Debug, Default, Clone)]
pub struct TracingReporter;
#[async_trait::async_trait]
impl ProgressReporter for TracingReporter {
async fn report(&self, event: ProgressEvent) {
info!(stage = ?event.stage, "{}", event.message);
}
}
#[allow(missing_debug_implementations)]
pub struct JsonReporter<W: AsyncWrite + Send + Unpin = Stdout> {
writer: Arc<Mutex<W>>,
}
impl<W: AsyncWrite + Send + Unpin> std::fmt::Debug for JsonReporter<W> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JsonReporter").finish_non_exhaustive()
}
}
impl Default for JsonReporter<Stdout> {
fn default() -> Self {
Self::new(tokio::io::stdout())
}
}
impl<W: AsyncWrite + Send + Unpin> JsonReporter<W> {
pub fn new(writer: W) -> Self {
Self {
writer: Arc::new(Mutex::new(writer)),
}
}
}
#[async_trait::async_trait]
impl<W: AsyncWrite + Send + Unpin + 'static> ProgressReporter for JsonReporter<W> {
async fn report(&self, event: ProgressEvent) {
let mut line = match serde_json::to_string(&event) {
Ok(s) => s,
Err(e) => {
warn!(error = %e, "JsonReporter: failed to serialise event");
return;
}
};
line.push('\n');
let mut w = self.writer.lock().await;
if let Err(e) = w.write_all(line.as_bytes()).await {
warn!(error = %e, "JsonReporter: failed to write event");
return;
}
if let Err(e) = w.flush().await {
warn!(error = %e, "JsonReporter: failed to flush event");
}
}
}
#[derive(Debug, Default, Clone)]
pub struct CollectingReporter {
inner: Arc<Mutex<Vec<ProgressEvent>>>,
}
impl CollectingReporter {
pub fn new() -> Self {
Self::default()
}
pub async fn events(&self) -> Vec<ProgressEvent> {
self.inner.lock().await.clone()
}
pub async fn len(&self) -> usize {
self.inner.lock().await.len()
}
pub async fn is_empty(&self) -> bool {
self.inner.lock().await.is_empty()
}
}
#[async_trait::async_trait]
impl ProgressReporter for CollectingReporter {
async fn report(&self, event: ProgressEvent) {
self.inner.lock().await.push(event);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn collecting_reporter_records_events() {
let r = CollectingReporter::new();
assert!(r.is_empty().await);
r.report(ProgressEvent::new(MigrationStage::Validate, "hello"))
.await;
r.report(
ProgressEvent::new(MigrationStage::Dump, "dump")
.with_detail(serde_json::json!({"jobs": 4})),
)
.await;
assert_eq!(r.len().await, 2);
let events = r.events().await;
assert_eq!(events[0].stage, MigrationStage::Validate);
assert_eq!(events[1].detail.as_ref().unwrap()["jobs"], 4);
}
#[test]
fn progress_event_serializes() {
let ev = ProgressEvent::new(MigrationStage::Complete, "done");
let json = serde_json::to_string(&ev).unwrap();
assert!(json.contains("Complete"));
assert!(json.contains("done"));
}
#[tokio::test]
async fn json_reporter_writes_one_ndjson_record_per_event() {
let buf = Vec::<u8>::new();
let r = JsonReporter::new(buf);
r.report(ProgressEvent::new(MigrationStage::Validate, "ok"))
.await;
r.report(
ProgressEvent::new(MigrationStage::Lag, "lag")
.with_detail(serde_json::json!({"lag_bytes": 42})),
)
.await;
let writer = Arc::try_unwrap(r.writer).unwrap().into_inner();
let out = String::from_utf8(writer).unwrap();
let lines: Vec<&str> = out.lines().collect();
assert_eq!(lines.len(), 2);
let v0: serde_json::Value = serde_json::from_str(lines[0]).unwrap();
assert_eq!(v0["stage"], "Validate");
assert_eq!(v0["message"], "ok");
let v1: serde_json::Value = serde_json::from_str(lines[1]).unwrap();
assert_eq!(v1["detail"]["lag_bytes"], 42);
}
#[tokio::test]
async fn tracing_reporter_does_not_panic() {
let r = TracingReporter;
r.report(ProgressEvent::new(MigrationStage::Validate, "test"))
.await;
r.report(
ProgressEvent::new(MigrationStage::CaughtUp, "caught up")
.with_detail(serde_json::json!({"lag_bytes": 0})),
)
.await;
}
#[test]
fn migration_stage_serde_roundtrip() {
let stages = [
MigrationStage::Validate,
MigrationStage::PrepareSnapshot,
MigrationStage::Dump,
MigrationStage::Restore,
MigrationStage::StreamApply,
MigrationStage::Lag,
MigrationStage::CaughtUp,
MigrationStage::Cutover,
MigrationStage::Complete,
];
for stage in stages {
let json = serde_json::to_string(&stage).unwrap();
let back: MigrationStage = serde_json::from_str(&json).unwrap();
assert_eq!(back, stage);
}
}
#[test]
fn progress_event_without_detail_has_none() {
let ev = ProgressEvent::new(MigrationStage::Dump, "running");
assert!(ev.detail.is_none());
}
#[test]
fn progress_event_with_detail_attaches_json() {
let ev = ProgressEvent::new(MigrationStage::Lag, "lag")
.with_detail(serde_json::json!({"lag_bytes": 1024, "source_lsn": "0/1234"}));
assert!(ev.detail.is_some());
assert_eq!(ev.detail.unwrap()["lag_bytes"], 1024);
}
#[tokio::test]
async fn collecting_reporter_clone_shares_state() {
let r1 = CollectingReporter::new();
let r2 = r1.clone();
r1.report(ProgressEvent::new(MigrationStage::Validate, "a"))
.await;
r2.report(ProgressEvent::new(MigrationStage::Dump, "b"))
.await;
assert_eq!(r1.len().await, 2);
assert_eq!(r2.len().await, 2);
}
#[tokio::test]
async fn json_reporter_debug_does_not_panic() {
let r = JsonReporter::new(Vec::<u8>::new());
let dbg = format!("{:?}", r);
assert!(dbg.contains("JsonReporter"));
}
#[test]
fn progress_event_deserializes_from_json() {
let json = r#"{"stage":"Dump","message":"running","detail":null}"#;
let ev: ProgressEvent = serde_json::from_str(json).unwrap();
assert_eq!(ev.stage, MigrationStage::Dump);
assert_eq!(ev.message, "running");
assert!(ev.detail.is_none());
}
#[test]
fn progress_event_deserializes_with_detail() {
let json = r#"{"stage":"Lag","message":"lag report","detail":{"lag_bytes":1024}}"#;
let ev: ProgressEvent = serde_json::from_str(json).unwrap();
assert_eq!(ev.stage, MigrationStage::Lag);
assert_eq!(ev.detail.unwrap()["lag_bytes"], 1024);
}
#[test]
fn tracing_reporter_debug() {
let r = TracingReporter;
let dbg = format!("{:?}", r);
assert!(dbg.contains("TracingReporter"));
}
#[test]
fn collecting_reporter_debug() {
let r = CollectingReporter::new();
let dbg = format!("{:?}", r);
assert!(dbg.contains("CollectingReporter"));
}
#[tokio::test]
async fn json_reporter_handles_multiple_events() {
let buf = Vec::<u8>::new();
let r = JsonReporter::new(buf);
for i in 0..5 {
r.report(ProgressEvent::new(
MigrationStage::Validate,
format!("event {i}"),
))
.await;
}
let writer = Arc::try_unwrap(r.writer).unwrap().into_inner();
let out = String::from_utf8(writer).unwrap();
let lines: Vec<&str> = out.lines().collect();
assert_eq!(lines.len(), 5);
}
}