use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncWrite, AsyncWriteExt, Stdout};
use tokio::sync::Mutex;
use tracing::{info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MigrationStage {
Validate,
PrepareSnapshot,
Dump,
Restore,
StreamApply,
Lag,
CaughtUp,
Cutover,
Complete,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProgressEvent {
pub stage: MigrationStage,
pub message: String,
pub detail: Option<serde_json::Value>,
}
impl ProgressEvent {
pub fn new(stage: MigrationStage, message: impl Into<String>) -> Self {
Self {
stage,
message: message.into(),
detail: None,
}
}
pub fn with_detail(mut self, detail: serde_json::Value) -> Self {
self.detail = Some(detail);
self
}
}
#[async_trait::async_trait]
pub trait ProgressReporter: Send + Sync + std::fmt::Debug {
async fn report(&self, event: ProgressEvent);
}
#[derive(Debug, Default, Clone)]
pub struct TracingReporter;
#[async_trait::async_trait]
impl ProgressReporter for TracingReporter {
async fn report(&self, event: ProgressEvent) {
info!(stage = ?event.stage, "{}", event.message);
}
}
#[allow(missing_debug_implementations)]
pub struct JsonReporter<W: AsyncWrite + Send + Unpin = Stdout> {
writer: Arc<Mutex<W>>,
}
impl<W: AsyncWrite + Send + Unpin> std::fmt::Debug for JsonReporter<W> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JsonReporter").finish_non_exhaustive()
}
}
impl Default for JsonReporter<Stdout> {
fn default() -> Self {
Self::new(tokio::io::stdout())
}
}
impl<W: AsyncWrite + Send + Unpin> JsonReporter<W> {
pub fn new(writer: W) -> Self {
Self {
writer: Arc::new(Mutex::new(writer)),
}
}
}
#[async_trait::async_trait]
impl<W: AsyncWrite + Send + Unpin + 'static> ProgressReporter for JsonReporter<W> {
async fn report(&self, event: ProgressEvent) {
let mut line = match serde_json::to_string(&event) {
Ok(s) => s,
Err(e) => {
warn!(error = %e, "JsonReporter: failed to serialise event");
return;
}
};
line.push('\n');
let mut w = self.writer.lock().await;
if let Err(e) = w.write_all(line.as_bytes()).await {
warn!(error = %e, "JsonReporter: failed to write event");
return;
}
if let Err(e) = w.flush().await {
warn!(error = %e, "JsonReporter: failed to flush event");
}
}
}
#[derive(Debug, Default, Clone)]
pub struct CollectingReporter {
inner: Arc<Mutex<Vec<ProgressEvent>>>,
}
impl CollectingReporter {
pub fn new() -> Self {
Self::default()
}
pub async fn events(&self) -> Vec<ProgressEvent> {
self.inner.lock().await.clone()
}
pub async fn len(&self) -> usize {
self.inner.lock().await.len()
}
pub async fn is_empty(&self) -> bool {
self.inner.lock().await.is_empty()
}
}
#[async_trait::async_trait]
impl ProgressReporter for CollectingReporter {
async fn report(&self, event: ProgressEvent) {
self.inner.lock().await.push(event);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn collecting_reporter_records_events() {
let r = CollectingReporter::new();
assert!(r.is_empty().await);
r.report(ProgressEvent::new(MigrationStage::Validate, "hello"))
.await;
r.report(
ProgressEvent::new(MigrationStage::Dump, "dump")
.with_detail(serde_json::json!({"jobs": 4})),
)
.await;
assert_eq!(r.len().await, 2);
let events = r.events().await;
assert_eq!(events[0].stage, MigrationStage::Validate);
assert_eq!(events[1].detail.as_ref().unwrap()["jobs"], 4);
}
#[test]
fn progress_event_serializes() {
let ev = ProgressEvent::new(MigrationStage::Complete, "done");
let json = serde_json::to_string(&ev).unwrap();
assert!(json.contains("Complete"));
assert!(json.contains("done"));
}
#[tokio::test]
async fn json_reporter_writes_one_ndjson_record_per_event() {
let buf = Vec::<u8>::new();
let r = JsonReporter::new(buf);
r.report(ProgressEvent::new(MigrationStage::Validate, "ok"))
.await;
r.report(
ProgressEvent::new(MigrationStage::Lag, "lag")
.with_detail(serde_json::json!({"lag_bytes": 42})),
)
.await;
let writer = Arc::try_unwrap(r.writer).unwrap().into_inner();
let out = String::from_utf8(writer).unwrap();
let lines: Vec<&str> = out.lines().collect();
assert_eq!(lines.len(), 2);
let v0: serde_json::Value = serde_json::from_str(lines[0]).unwrap();
assert_eq!(v0["stage"], "Validate");
assert_eq!(v0["message"], "ok");
let v1: serde_json::Value = serde_json::from_str(lines[1]).unwrap();
assert_eq!(v1["detail"]["lag_bytes"], 42);
}
}