use anyhow::Result;
use chrono::{DateTime, Utc};
use crossbeam_channel::{Receiver, Sender};
use indicatif::{ProgressBar, ProgressStyle};
use serde::{Deserialize, Serialize};
use std::sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc,
};
use std::time::{Duration, Instant};
use tokio::sync::broadcast;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ProgressUpdate {
pub operation_id: Uuid,
pub operation_name: String,
pub current: u64,
pub total: Option<u64>,
pub message: Option<String>,
pub timestamp: DateTime<Utc>,
pub status: ProgressStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ProgressStatus {
Started,
InProgress,
Completed,
Failed,
Cancelled,
}
#[derive(Debug)]
pub struct ProgressTracker {
operation_id: Uuid,
operation_name: String,
current: Arc<AtomicU64>,
total: Option<u64>,
is_cancelled: Arc<AtomicBool>,
progress_bar: Option<ProgressBar>,
sender: Sender<ProgressUpdate>,
start_time: Instant,
}
impl ProgressTracker {
#[must_use]
pub fn new(
operation_name: &str,
total: Option<u64>,
sender: Sender<ProgressUpdate>,
show_progress_bar: bool,
) -> Self {
let operation_id = Uuid::new_v4();
let current = Arc::new(AtomicU64::new(0));
let is_cancelled = Arc::new(AtomicBool::new(false));
let progress_bar = if show_progress_bar {
let pb = if let Some(total) = total {
ProgressBar::new(total)
} else {
ProgressBar::new_spinner()
};
let style = if total.is_some() {
ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} {msg}")
.unwrap()
.progress_chars("#>-")
} else {
ProgressStyle::default_spinner()
.template("{spinner:.green} [{elapsed_precise}] {msg}")
.unwrap()
};
pb.set_style(style);
Some(pb)
} else {
None
};
let tracker = Self {
operation_id,
operation_name: operation_name.to_string(),
current,
total,
is_cancelled,
progress_bar,
sender,
start_time: Instant::now(),
};
tracker.send_update(ProgressStatus::Started, None);
tracker
}
pub fn inc(&self, amount: u64) {
if self.is_cancelled.load(Ordering::Relaxed) {
return;
}
let _new_current = self.current.fetch_add(amount, Ordering::Relaxed) + amount;
if let Some(pb) = &self.progress_bar {
pb.inc(amount);
}
self.send_update(ProgressStatus::InProgress, None);
}
pub fn set_current(&self, current: u64) {
if self.is_cancelled.load(Ordering::Relaxed) {
return;
}
self.current.store(current, Ordering::Relaxed);
if let Some(pb) = &self.progress_bar {
pb.set_position(current);
}
self.send_update(ProgressStatus::InProgress, None);
}
pub fn set_message(&self, message: String) {
if self.is_cancelled.load(Ordering::Relaxed) {
return;
}
if let Some(pb) = &self.progress_bar {
pb.set_message(message.clone());
}
self.send_update(ProgressStatus::InProgress, Some(message));
}
pub fn complete(&self) {
if self.is_cancelled.load(Ordering::Relaxed) {
return;
}
if let Some(pb) = &self.progress_bar {
pb.finish_with_message("Completed");
}
self.send_update(ProgressStatus::Completed, None);
}
pub fn fail(&self, error_message: String) {
if self.is_cancelled.load(Ordering::Relaxed) {
return;
}
self.is_cancelled.store(true, Ordering::Relaxed);
if let Some(pb) = &self.progress_bar {
pb.finish();
}
self.send_update(ProgressStatus::Failed, Some(error_message));
}
pub fn cancel(&self) {
self.is_cancelled.store(true, Ordering::Relaxed);
if let Some(pb) = &self.progress_bar {
pb.finish_with_message("Cancelled");
}
self.send_update(ProgressStatus::Cancelled, None);
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.is_cancelled.load(Ordering::Relaxed)
}
#[must_use]
pub fn current(&self) -> u64 {
self.current.load(Ordering::Relaxed)
}
#[must_use]
pub fn total(&self) -> Option<u64> {
self.total
}
#[must_use]
pub fn operation_id(&self) -> Uuid {
self.operation_id
}
#[must_use]
pub fn operation_name(&self) -> &str {
&self.operation_name
}
#[must_use]
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
fn send_update(&self, status: ProgressStatus, message: Option<String>) {
let update = ProgressUpdate {
operation_id: self.operation_id,
operation_name: self.operation_name.clone(),
current: self.current.load(Ordering::Relaxed),
total: self.total,
message,
timestamp: Utc::now(),
status,
};
let _ = self.sender.try_send(update);
}
}
#[derive(Clone, Debug)]
pub struct ProgressManager {
sender: Sender<ProgressUpdate>,
receiver: Receiver<ProgressUpdate>,
broadcast_sender: broadcast::Sender<ProgressUpdate>,
}
impl ProgressManager {
#[must_use]
pub fn new() -> Self {
let (sender, receiver) = crossbeam_channel::unbounded();
let (broadcast_sender, _) = broadcast::channel(1000);
Self {
sender,
receiver,
broadcast_sender,
}
}
#[must_use]
pub fn create_tracker(
&self,
operation_name: &str,
total: Option<u64>,
show_progress_bar: bool,
) -> ProgressTracker {
ProgressTracker::new(
operation_name,
total,
self.sender.clone(),
show_progress_bar,
)
}
#[must_use]
pub fn subscribe(&self) -> broadcast::Receiver<ProgressUpdate> {
self.broadcast_sender.subscribe()
}
pub fn run(&self) -> Result<()> {
while let Ok(update) = self.receiver.recv() {
let _ = self.broadcast_sender.send(update);
}
Ok(())
}
#[must_use]
pub fn sender(&self) -> Sender<ProgressUpdate> {
self.sender.clone()
}
}
impl Default for ProgressManager {
fn default() -> Self {
Self::new()
}
}
pub trait TrackableOperation {
fn execute_with_progress(&self, tracker: &ProgressTracker) -> Result<()>;
}
#[macro_export]
macro_rules! trackable_operation {
($name:expr, $total:expr, $operation:block) => {{
let progress_manager = ProgressManager::new();
let tracker = progress_manager.create_tracker($name, $total, true);
let manager = progress_manager.clone();
tokio::spawn(async move {
let _ = manager.run();
});
let result = $operation;
if result.is_ok() {
tracker.complete();
} else {
tracker.fail(format!("{:?}", result.as_ref().unwrap_err()));
}
result
}};
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration as StdDuration;
#[test]
fn test_progress_tracker_creation() {
let (sender, _receiver) = crossbeam_channel::unbounded();
let tracker = ProgressTracker::new("test_operation", Some(100), sender, false);
assert_eq!(tracker.operation_name(), "test_operation");
assert_eq!(tracker.total(), Some(100));
assert_eq!(tracker.current(), 0);
assert!(!tracker.is_cancelled());
}
#[test]
fn test_progress_tracker_increment() {
let (sender, _receiver) = crossbeam_channel::unbounded();
let tracker = ProgressTracker::new("test_operation", Some(100), sender, false);
tracker.inc(10);
assert_eq!(tracker.current(), 10);
tracker.inc(5);
assert_eq!(tracker.current(), 15);
}
#[test]
fn test_progress_tracker_set_current() {
let (sender, _receiver) = crossbeam_channel::unbounded();
let tracker = ProgressTracker::new("test_operation", Some(100), sender, false);
tracker.set_current(50);
assert_eq!(tracker.current(), 50);
}
#[test]
fn test_progress_tracker_cancellation() {
let (sender, _receiver) = crossbeam_channel::unbounded();
let tracker = ProgressTracker::new("test_operation", Some(100), sender, false);
assert!(!tracker.is_cancelled());
tracker.cancel();
assert!(tracker.is_cancelled());
}
#[test]
fn test_progress_manager() {
let manager = ProgressManager::new();
let tracker = manager.create_tracker("test_operation", Some(100), false);
assert_eq!(tracker.operation_name(), "test_operation");
assert_eq!(tracker.total(), Some(100));
}
#[tokio::test]
#[ignore = "This test is flaky due to async timing issues"]
async fn test_progress_manager_subscription() {
let manager = ProgressManager::new();
let mut subscriber = manager.subscribe();
let tracker = manager.create_tracker("test_operation", Some(100), false);
let manager_clone = manager.clone();
let manager_handle = tokio::spawn(async move {
let _ = manager_clone.run();
});
tokio::time::sleep(StdDuration::from_millis(10)).await;
tracker.inc(10);
tokio::time::sleep(StdDuration::from_millis(10)).await;
let _update_result =
tokio::time::timeout(StdDuration::from_millis(50), subscriber.recv()).await;
manager_handle.abort();
}
#[test]
fn test_trackable_operation_macro() {
let manager = ProgressManager::new();
let tracker = manager.create_tracker("test", Some(10), false);
tracker.inc(5);
assert_eq!(tracker.current(), 5);
tracker.complete();
}
#[test]
fn test_progress_tracker_edge_cases() {
let manager = ProgressManager::new();
let tracker = manager.create_tracker("edge_case_test", Some(100), false);
tracker.inc(0);
assert_eq!(tracker.current(), 0);
tracker.inc(1000);
assert_eq!(tracker.current(), 1000);
tracker.set_current(50);
assert_eq!(tracker.current(), 50);
tracker.set_current(0);
assert_eq!(tracker.current(), 0);
tracker.set_current(100);
assert_eq!(tracker.current(), 100);
}
#[test]
fn test_progress_tracker_without_total() {
let manager = ProgressManager::new();
let tracker = manager.create_tracker("no_total_test", None, false);
tracker.inc(10);
assert_eq!(tracker.current(), 10);
assert_eq!(tracker.total(), None);
tracker.set_current(50);
assert_eq!(tracker.current(), 50);
tracker.complete();
}
#[test]
fn test_progress_tracker_failure() {
let manager = ProgressManager::new();
let tracker = manager.create_tracker("failure_test", Some(100), false);
tracker.fail("Test failure message".to_string());
assert!(tracker.is_cancelled());
}
#[test]
fn test_progress_tracker_elapsed_time() {
let manager = ProgressManager::new();
let tracker = manager.create_tracker("elapsed_test", Some(100), false);
let elapsed = tracker.elapsed();
std::thread::sleep(std::time::Duration::from_millis(10));
let elapsed_after = tracker.elapsed();
assert!(elapsed_after >= elapsed);
}
#[test]
fn test_progress_tracker_operation_info() {
let manager = ProgressManager::new();
let tracker = manager.create_tracker("info_test", Some(100), false);
assert_eq!(tracker.operation_id(), tracker.operation_id());
assert_eq!(tracker.operation_name(), "info_test");
assert_eq!(tracker.total(), Some(100));
}
#[test]
fn test_progress_manager_multiple_trackers() {
let manager = ProgressManager::new();
let tracker1 = manager.create_tracker("operation1", Some(100), false);
let tracker2 = manager.create_tracker("operation2", Some(200), false);
let tracker3 = manager.create_tracker("operation3", None, false);
assert_ne!(tracker1.operation_id(), tracker2.operation_id());
assert_ne!(tracker1.operation_id(), tracker3.operation_id());
assert_ne!(tracker2.operation_id(), tracker3.operation_id());
tracker1.inc(10);
tracker2.inc(20);
tracker3.inc(30);
assert_eq!(tracker1.current(), 10);
assert_eq!(tracker2.current(), 20);
assert_eq!(tracker3.current(), 30);
}
#[test]
fn test_progress_tracker_completion() {
let manager = ProgressManager::new();
let tracker = manager.create_tracker("completion_test", Some(100), false);
tracker.set_current(100);
tracker.complete();
assert_eq!(tracker.current(), 100);
assert_eq!(tracker.total(), Some(100));
}
#[test]
fn test_progress_tracker_large_values() {
let manager = ProgressManager::new();
let tracker = manager.create_tracker("large_values_test", Some(u64::MAX), false);
tracker.set_current(u64::MAX / 2);
assert_eq!(tracker.current(), u64::MAX / 2);
tracker.inc(1000);
assert_eq!(tracker.current(), u64::MAX / 2 + 1000);
}
#[test]
fn test_progress_tracker_negative_operations() {
let manager = ProgressManager::new();
let tracker = manager.create_tracker("negative_test", Some(100), false);
tracker.inc(50);
assert_eq!(tracker.current(), 50);
tracker.set_current(25);
assert_eq!(tracker.current(), 25);
}
#[test]
fn test_progress_manager_sender_access() {
let manager = ProgressManager::new();
let _sender = manager.sender();
}
#[test]
fn test_progress_tracker_debug_formatting() {
let manager = ProgressManager::new();
let tracker = manager.create_tracker("debug_test", Some(100), false);
let debug_str = format!("{tracker:?}");
assert!(debug_str.contains("debug_test"));
assert!(debug_str.contains("ProgressTracker"));
}
#[test]
fn test_progress_manager_debug_formatting() {
let manager = ProgressManager::new();
let debug_str = format!("{manager:?}");
assert!(debug_str.contains("ProgressManager"));
}
#[test]
fn test_progress_update_creation() {
let update = ProgressUpdate {
operation_id: Uuid::new_v4(),
operation_name: "test_operation".to_string(),
current: 50,
total: Some(100),
message: Some("Test message".to_string()),
timestamp: Utc::now(),
status: ProgressStatus::InProgress,
};
assert_eq!(update.operation_name, "test_operation");
assert_eq!(update.current, 50);
assert_eq!(update.total, Some(100));
assert_eq!(update.message, Some("Test message".to_string()));
}
#[test]
fn test_progress_update_serialization() {
let update = ProgressUpdate {
operation_id: Uuid::new_v4(),
operation_name: "serialization_test".to_string(),
current: 75,
total: Some(150),
message: Some("Serialization test".to_string()),
timestamp: Utc::now(),
status: ProgressStatus::InProgress,
};
let json = serde_json::to_string(&update).unwrap();
let deserialized: ProgressUpdate = serde_json::from_str(&json).unwrap();
assert_eq!(update.operation_id, deserialized.operation_id);
assert_eq!(update.operation_name, deserialized.operation_name);
assert_eq!(update.current, deserialized.current);
assert_eq!(update.total, deserialized.total);
assert_eq!(update.message, deserialized.message);
}
#[test]
fn test_progress_update_edge_cases() {
let update_none = ProgressUpdate {
operation_id: Uuid::new_v4(),
operation_name: String::new(),
current: 0,
total: None,
message: None,
timestamp: Utc::now(),
status: ProgressStatus::Started,
};
assert_eq!(update_none.operation_name, "");
assert_eq!(update_none.current, 0);
assert_eq!(update_none.total, None);
assert_eq!(update_none.message, None);
let update_max = ProgressUpdate {
operation_id: Uuid::new_v4(),
operation_name: "A".repeat(1000),
current: u64::MAX,
total: Some(u64::MAX),
message: Some("B".repeat(1000)),
timestamp: Utc::now(),
status: ProgressStatus::Completed,
};
assert_eq!(update_max.current, u64::MAX);
assert_eq!(update_max.total, Some(u64::MAX));
}
}