use crate::cook::execution::errors::MapReduceResult;
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::{Html, Json, Response, Sse},
routing::get,
Router,
};
use chrono::{DateTime, Utc};
use futures_util::stream::Stream;
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::fs;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio::time::interval;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tower_http::cors::CorsLayer;
use tracing::{error, info, warn};
use uuid::Uuid;
#[derive(Clone)]
pub struct EnhancedProgressTracker {
pub job_id: String,
pub total_items: usize,
pub start_time: Instant,
pub agents: Arc<RwLock<HashMap<String, AgentProgress>>>,
pub metrics: Arc<RwLock<ProgressMetrics>>,
pub event_sender: mpsc::UnboundedSender<ProgressUpdate>,
pub event_receiver: Arc<Mutex<mpsc::UnboundedReceiver<ProgressUpdate>>>,
pub web_server: Option<Arc<ProgressWebServer>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentProgress {
pub agent_id: String,
pub item_id: String,
pub state: AgentState,
pub current_step: String,
pub steps_completed: usize,
pub total_steps: usize,
pub progress_percentage: f32,
pub started_at: DateTime<Utc>,
pub last_update: DateTime<Utc>,
pub estimated_completion: Option<DateTime<Utc>>,
pub error_count: usize,
pub retry_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum AgentState {
Queued,
Initializing,
Running { step: String, progress: f32 },
Merging,
Completed,
Failed { error: String },
Retrying { attempt: u32 },
DeadLettered,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProgressMetrics {
pub completed_items: usize,
pub failed_items: usize,
pub pending_items: usize,
pub active_agents: usize,
pub throughput_current: f64, pub throughput_average: f64,
pub success_rate: f64,
pub average_duration_ms: u64,
pub estimated_completion: Option<DateTime<Utc>>,
pub memory_usage_mb: usize,
pub cpu_usage_percent: f32,
}
impl Default for ProgressMetrics {
fn default() -> Self {
Self {
completed_items: 0,
failed_items: 0,
pending_items: 0,
active_agents: 0,
throughput_current: 0.0,
throughput_average: 0.0,
success_rate: 100.0,
average_duration_ms: 0,
estimated_completion: None,
memory_usage_mb: 0,
cpu_usage_percent: 0.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProgressUpdate {
pub update_type: UpdateType,
pub timestamp: DateTime<Utc>,
pub data: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum UpdateType {
AgentProgress,
MetricsUpdate,
JobCompleted,
Error,
}
pub struct ProgressWebServer {
port: u16,
tracker: Arc<EnhancedProgressTracker>,
connections: Arc<RwLock<HashMap<Uuid, mpsc::UnboundedSender<String>>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProgressSnapshot {
pub timestamp: DateTime<Utc>,
pub job_id: String,
pub metrics: ProgressMetrics,
pub agent_states: HashMap<String, AgentState>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProgressHistory {
pub snapshots: Vec<ProgressSnapshot>,
pub interval_seconds: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ExportFormat {
Json,
Csv,
Html,
}
pub struct ProgressPersistence {
base_path: PathBuf,
job_id: String,
save_interval: Duration,
}
impl ProgressPersistence {
pub fn new(job_id: String) -> Self {
let base_path = PathBuf::from(".prodigy/progress");
Self {
base_path,
job_id,
save_interval: Duration::from_secs(5),
}
}
fn snapshot_path(&self) -> PathBuf {
self.base_path.join(format!("{}.json", self.job_id))
}
fn history_path(&self) -> PathBuf {
self.base_path.join(format!("{}_history.json", self.job_id))
}
pub async fn save_snapshot(&self, snapshot: &ProgressSnapshot) -> MapReduceResult<()> {
fs::create_dir_all(&self.base_path).await?;
let path = self.snapshot_path();
let json = serde_json::to_string_pretty(snapshot)?;
fs::write(&path, json).await?;
info!("Saved progress snapshot to {:?}", path);
Ok(())
}
pub async fn load_snapshot(&self) -> MapReduceResult<Option<ProgressSnapshot>> {
let path = self.snapshot_path();
if !path.exists() {
return Ok(None);
}
let json = fs::read_to_string(&path).await?;
let snapshot: ProgressSnapshot = serde_json::from_str(&json)?;
info!("Loaded progress snapshot from {:?}", path);
Ok(Some(snapshot))
}
pub async fn append_to_history(&self, snapshot: &ProgressSnapshot) -> MapReduceResult<()> {
fs::create_dir_all(&self.base_path).await?;
let path = self.history_path();
let mut history = if path.exists() {
let json = fs::read_to_string(&path).await?;
serde_json::from_str::<ProgressHistory>(&json).unwrap_or_else(|_| ProgressHistory {
snapshots: Vec::new(),
interval_seconds: self.save_interval.as_secs() as u32,
})
} else {
ProgressHistory {
snapshots: Vec::new(),
interval_seconds: self.save_interval.as_secs() as u32,
}
};
if history.snapshots.len() >= 1000 {
history.snapshots.remove(0);
}
history.snapshots.push(snapshot.clone());
let json = serde_json::to_string_pretty(&history)?;
fs::write(&path, json).await?;
Ok(())
}
pub async fn cleanup(&self) -> MapReduceResult<()> {
let snapshot_path = self.snapshot_path();
if snapshot_path.exists() {
fs::remove_file(&snapshot_path).await?;
}
let history_path = self.history_path();
if history_path.exists() {
fs::remove_file(&history_path).await?;
}
info!(
"Cleaned up progress persistence files for job {}",
self.job_id
);
Ok(())
}
}
pub struct ProgressSampler {
sample_rate: Duration,
cache: Arc<RwLock<ProgressCache>>,
}
#[derive(Clone, Debug)]
struct ProgressCache {
last_update: Instant,
snapshot: Option<ProgressSnapshot>,
metrics: ProgressMetrics,
}
impl ProgressSampler {
pub fn new(sample_rate: Duration) -> Self {
Self {
sample_rate,
cache: Arc::new(RwLock::new(ProgressCache {
last_update: Instant::now(),
snapshot: None,
metrics: ProgressMetrics::default(),
})),
}
}
pub async fn should_sample(&self) -> bool {
let cache = self.cache.read().await;
cache.last_update.elapsed() >= self.sample_rate
}
pub async fn update_cache(&self, snapshot: ProgressSnapshot, metrics: ProgressMetrics) {
let mut cache = self.cache.write().await;
cache.last_update = Instant::now();
cache.snapshot = Some(snapshot);
cache.metrics = metrics;
}
pub async fn get_cached(&self) -> Option<(ProgressSnapshot, ProgressMetrics)> {
let cache = self.cache.read().await;
cache
.snapshot
.as_ref()
.map(|s| (s.clone(), cache.metrics.clone()))
}
}
impl EnhancedProgressTracker {
pub fn new(job_id: String, total_items: usize) -> Self {
let (event_sender, event_receiver) = mpsc::unbounded_channel();
Self {
job_id,
total_items,
start_time: Instant::now(),
agents: Arc::new(RwLock::new(HashMap::new())),
metrics: Arc::new(RwLock::new(ProgressMetrics {
pending_items: total_items,
..Default::default()
})),
event_sender,
event_receiver: Arc::new(Mutex::new(event_receiver)),
web_server: None,
}
}
pub async fn start_persistence(&self) -> MapReduceResult<()> {
let persistence = ProgressPersistence::new(self.job_id.clone());
let tracker = self.clone();
tokio::spawn(async move {
let mut interval = interval(persistence.save_interval);
loop {
interval.tick().await;
let snapshot = tracker.create_snapshot().await;
if let Err(e) = persistence.save_snapshot(&snapshot).await {
warn!("Failed to save progress snapshot: {}", e);
}
if let Err(e) = persistence.append_to_history(&snapshot).await {
warn!("Failed to append to progress history: {}", e);
}
}
});
info!("Started progress persistence for job {}", self.job_id);
Ok(())
}
pub async fn restore_from_disk(&mut self) -> MapReduceResult<bool> {
let persistence = ProgressPersistence::new(self.job_id.clone());
if let Some(snapshot) = persistence.load_snapshot().await? {
let mut metrics = self.metrics.write().await;
*metrics = snapshot.metrics;
let mut agents = self.agents.write().await;
let restored_agents = snapshot.agent_states.into_iter().map(|(agent_id, state)| {
let progress = AgentProgress {
agent_id: agent_id.clone(),
item_id: String::new(),
state,
current_step: String::new(),
steps_completed: 0,
total_steps: 0,
progress_percentage: 0.0,
started_at: snapshot.timestamp,
last_update: snapshot.timestamp,
estimated_completion: None,
error_count: 0,
retry_count: 0,
};
(agent_id, progress)
});
agents.extend(restored_agents);
info!("Restored progress from disk for job {}", self.job_id);
Ok(true)
} else {
Ok(false)
}
}
pub async fn start_web_server(&mut self, port: u16) -> MapReduceResult<()> {
let server = Arc::new(ProgressWebServer {
port,
tracker: Arc::new(self.clone()),
connections: Arc::new(RwLock::new(HashMap::new())),
});
self.web_server = Some(server.clone());
let server_clone = server.clone();
tokio::spawn(async move {
if let Err(e) = server_clone.start().await {
error!("Failed to start progress web server: {}", e);
}
});
info!("Progress dashboard available at http://localhost:{}", port);
Ok(())
}
async fn update_agent_progress_impl(
&self,
agent_id: &str,
progress: AgentProgress,
) -> MapReduceResult<()> {
{
let mut agents = self.agents.write().await;
agents.insert(agent_id.to_string(), progress.clone());
}
let update = ProgressUpdate {
update_type: UpdateType::AgentProgress,
timestamp: Utc::now(),
data: json!({
"agent_id": agent_id,
"progress": progress,
}),
};
let _ = self.event_sender.send(update);
self.recalculate_metrics().await?;
Ok(())
}
pub async fn update_agent_progress(
&self,
agent_id: &str,
progress: AgentProgress,
) -> MapReduceResult<()> {
self.update_agent_progress_impl(agent_id, progress).await
}
pub async fn update_agent_state(
&self,
agent_id: &str,
state: AgentState,
) -> MapReduceResult<()> {
{
let mut agents = self.agents.write().await;
if let Some(agent) = agents.get_mut(agent_id) {
agent.state = state;
agent.last_update = Utc::now();
} else {
let new_agent = AgentProgress {
agent_id: agent_id.to_string(),
item_id: String::new(),
state,
current_step: String::new(),
steps_completed: 0,
total_steps: 0,
progress_percentage: 0.0,
started_at: Utc::now(),
last_update: Utc::now(),
estimated_completion: None,
error_count: 0,
retry_count: 0,
};
agents.insert(agent_id.to_string(), new_agent);
}
}
self.recalculate_metrics().await?;
Ok(())
}
pub async fn mark_item_completed(&self, agent_id: &str) -> MapReduceResult<()> {
{
let mut agents = self.agents.write().await;
if let Some(agent) = agents.get_mut(agent_id) {
agent.state = AgentState::Completed;
agent.last_update = Utc::now();
}
}
{
let mut metrics = self.metrics.write().await;
metrics.completed_items += 1;
metrics.pending_items = metrics.pending_items.saturating_sub(1);
}
self.recalculate_metrics().await?;
Ok(())
}
pub async fn mark_item_failed(&self, agent_id: &str, error: String) -> MapReduceResult<()> {
{
let mut agents = self.agents.write().await;
if let Some(agent) = agents.get_mut(agent_id) {
agent.state = AgentState::Failed { error };
agent.last_update = Utc::now();
}
}
{
let mut metrics = self.metrics.write().await;
metrics.failed_items += 1;
metrics.pending_items = metrics.pending_items.saturating_sub(1);
}
self.recalculate_metrics().await?;
Ok(())
}
async fn recalculate_metrics(&self) -> MapReduceResult<()> {
let agents_data = {
let agents = self.agents.read().await;
agents.clone()
};
let mut metrics = self.metrics.write().await;
metrics.active_agents = agents_data
.values()
.filter(|a| {
matches!(
a.state,
AgentState::Running { .. } | AgentState::Initializing
)
})
.count();
let elapsed = self.start_time.elapsed().as_secs_f64();
if elapsed > 0.0 {
metrics.throughput_average = metrics.completed_items as f64 / elapsed;
}
let total_processed = metrics.completed_items + metrics.failed_items;
if total_processed > 0 {
metrics.success_rate =
(metrics.completed_items as f64 / total_processed as f64) * 100.0;
}
if metrics.throughput_average > 0.0 && metrics.pending_items > 0 {
let remaining_seconds = metrics.pending_items as f64 / metrics.throughput_average;
metrics.estimated_completion =
Some(Utc::now() + chrono::Duration::seconds(remaining_seconds as i64));
}
let update = ProgressUpdate {
update_type: UpdateType::MetricsUpdate,
timestamp: Utc::now(),
data: serde_json::to_value(&*metrics).unwrap_or(json!({})),
};
let _ = self.event_sender.send(update);
Ok(())
}
pub async fn get_overall_progress(&self) -> f32 {
let metrics = self.metrics.read().await;
let processed = metrics.completed_items + metrics.failed_items;
if self.total_items > 0 {
(processed as f32 / self.total_items as f32) * 100.0
} else {
0.0
}
}
pub async fn get_estimated_completion(&self) -> Option<DateTime<Utc>> {
let metrics = self.metrics.read().await;
metrics.estimated_completion
}
pub async fn export_progress(&self, format: ExportFormat) -> MapReduceResult<Vec<u8>> {
let agents = self.agents.read().await;
let metrics = self.metrics.read().await;
let snapshot = ProgressSnapshot {
timestamp: Utc::now(),
job_id: self.job_id.clone(),
metrics: metrics.clone(),
agent_states: agents
.iter()
.map(|(id, agent)| (id.clone(), agent.state.clone()))
.collect(),
};
match format {
ExportFormat::Json => {
let json = serde_json::to_vec_pretty(&snapshot)?;
Ok(json)
}
ExportFormat::Csv => {
use std::fmt::Write;
let mut csv_data = String::new();
let _ = writeln!(
&mut csv_data,
"timestamp,job_id,completed_items,failed_items,pending_items,success_rate,throughput_average"
);
let _ = writeln!(
&mut csv_data,
"{},{},{},{},{},{:.2},{:.2}",
snapshot.timestamp.to_rfc3339(),
snapshot.job_id,
metrics.completed_items,
metrics.failed_items,
metrics.pending_items,
metrics.success_rate,
metrics.throughput_average,
);
Ok(csv_data.into_bytes())
}
ExportFormat::Html => {
let html = format!(
r#"<!DOCTYPE html>
<html>
<head>
<title>Progress Report - {}</title>
<style>
body {{ font-family: sans-serif; margin: 20px; }}
h1 {{ color: #333; }}
.metrics {{ background: #f5f5f5; padding: 15px; border-radius: 5px; }}
.metric {{ margin: 10px 0; }}
.label {{ font-weight: bold; }}
</style>
</head>
<body>
<h1>MapReduce Job Progress Report</h1>
<div class="metrics">
<div class="metric"><span class="label">Job ID:</span> {}</div>
<div class="metric"><span class="label">Timestamp:</span> {}</div>
<div class="metric"><span class="label">Completed:</span> {}/{}</div>
<div class="metric"><span class="label">Failed:</span> {}</div>
<div class="metric"><span class="label">Success Rate:</span> {:.1}%</div>
<div class="metric"><span class="label">Throughput:</span> {:.2} items/sec</div>
</div>
</body>
</html>"#,
snapshot.job_id,
snapshot.job_id,
snapshot.timestamp.to_rfc3339(),
metrics.completed_items,
self.total_items,
metrics.failed_items,
metrics.success_rate,
metrics.throughput_average,
);
Ok(html.into_bytes())
}
}
}
pub async fn create_snapshot(&self) -> ProgressSnapshot {
let agents = self.agents.read().await;
let metrics = self.metrics.read().await;
ProgressSnapshot {
timestamp: Utc::now(),
job_id: self.job_id.clone(),
metrics: metrics.clone(),
agent_states: agents
.iter()
.map(|(id, agent)| (id.clone(), agent.state.clone()))
.collect(),
}
}
}
impl ProgressWebServer {
pub async fn start(self: Arc<Self>) -> MapReduceResult<()> {
let app = Router::new()
.route("/", get(Self::dashboard_html))
.route("/api/progress", get(Self::get_progress))
.route("/api/agents", get(Self::get_agents))
.route("/api/metrics", get(Self::get_metrics))
.route("/ws", get(Self::websocket_handler))
.route("/sse", get(Self::sse_handler))
.route("/api/prometheus", get(Self::prometheus_metrics))
.layer(CorsLayer::permissive())
.with_state(self.clone());
let addr = SocketAddr::from(([127, 0, 0, 1], self.port));
let broadcaster_self = self.clone();
tokio::spawn(async move {
broadcaster_self.broadcast_events().await;
});
info!("Starting progress web server on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn broadcast_events(self: Arc<Self>) {
let mut receiver = self.tracker.event_receiver.lock().await;
while let Some(update) = receiver.recv().await {
let connections = self.connections.read().await;
let message = serde_json::to_string(&update).unwrap_or_default();
connections.values().for_each(|sender| {
let _ = sender.send(message.clone());
});
}
}
async fn dashboard_html() -> Html<&'static str> {
Html(include_str!("progress_dashboard.html"))
}
async fn get_progress(State(server): State<Arc<ProgressWebServer>>) -> Json<Value> {
let progress = server.tracker.get_overall_progress().await;
let metrics = server.tracker.metrics.read().await;
Json(json!({
"job_id": server.tracker.job_id,
"progress": progress,
"total_items": server.tracker.total_items,
"metrics": *metrics,
}))
}
async fn get_agents(State(server): State<Arc<ProgressWebServer>>) -> Json<Value> {
let agents = server.tracker.agents.read().await;
Json(json!({
"agents": agents.clone(),
}))
}
async fn get_metrics(State(server): State<Arc<ProgressWebServer>>) -> Json<ProgressMetrics> {
let metrics = server.tracker.metrics.read().await;
Json(metrics.clone())
}
async fn websocket_handler(
ws: WebSocketUpgrade,
State(server): State<Arc<ProgressWebServer>>,
) -> Response {
ws.on_upgrade(move |socket| Self::handle_socket(socket, server))
}
async fn sse_handler(
State(server): State<Arc<ProgressWebServer>>,
) -> Sse<impl Stream<Item = Result<axum::response::sse::Event, Infallible>>> {
let (tx, rx) = mpsc::unbounded_channel();
let client_id = Uuid::new_v4();
{
let mut connections = server.connections.write().await;
connections.insert(client_id, tx);
}
let stream = UnboundedReceiverStream::new(rx)
.map(|msg| Ok(axum::response::sse::Event::default().data(msg)));
let server_clone = server.clone();
tokio::spawn(async move {
let mut connections = server_clone.connections.write().await;
connections.remove(&client_id);
});
Sse::new(stream).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(30))
.text("keep-alive"),
)
}
async fn prometheus_metrics(State(server): State<Arc<ProgressWebServer>>) -> String {
let metrics = server.tracker.metrics.read().await;
let agents = server.tracker.agents.read().await;
let mut output = String::new();
output.push_str("# HELP mapreduce_items_completed Number of completed items\n");
output.push_str("# TYPE mapreduce_items_completed counter\n");
output.push_str(&format!(
"mapreduce_items_completed{{job_id=\"{}\"}} {}\n",
server.tracker.job_id, metrics.completed_items
));
output.push_str("# HELP mapreduce_items_failed Number of failed items\n");
output.push_str("# TYPE mapreduce_items_failed counter\n");
output.push_str(&format!(
"mapreduce_items_failed{{job_id=\"{}\"}} {}\n",
server.tracker.job_id, metrics.failed_items
));
output.push_str("# HELP mapreduce_items_pending Number of pending items\n");
output.push_str("# TYPE mapreduce_items_pending gauge\n");
output.push_str(&format!(
"mapreduce_items_pending{{job_id=\"{}\"}} {}\n",
server.tracker.job_id, metrics.pending_items
));
output.push_str("# HELP mapreduce_active_agents Number of active agents\n");
output.push_str("# TYPE mapreduce_active_agents gauge\n");
output.push_str(&format!(
"mapreduce_active_agents{{job_id=\"{}\"}} {}\n",
server.tracker.job_id, metrics.active_agents
));
output.push_str("# HELP mapreduce_throughput_average Average throughput in items/sec\n");
output.push_str("# TYPE mapreduce_throughput_average gauge\n");
output.push_str(&format!(
"mapreduce_throughput_average{{job_id=\"{}\"}} {}\n",
server.tracker.job_id, metrics.throughput_average
));
output.push_str("# HELP mapreduce_success_rate Success rate percentage\n");
output.push_str("# TYPE mapreduce_success_rate gauge\n");
output.push_str(&format!(
"mapreduce_success_rate{{job_id=\"{}\"}} {}\n",
server.tracker.job_id, metrics.success_rate
));
let state_counts: HashMap<String, usize> =
agents.values().fold(HashMap::new(), |mut counts, agent| {
let state_name = match &agent.state {
AgentState::Queued => "queued",
AgentState::Initializing => "initializing",
AgentState::Running { .. } => "running",
AgentState::Merging => "merging",
AgentState::Completed => "completed",
AgentState::Failed { .. } => "failed",
AgentState::Retrying { .. } => "retrying",
AgentState::DeadLettered => "dead_lettered",
};
*counts.entry(state_name.to_string()).or_insert(0) += 1;
counts
});
output.push_str("# HELP mapreduce_agent_states Count of agents by state\n");
output.push_str("# TYPE mapreduce_agent_states gauge\n");
state_counts.iter().for_each(|(state, count)| {
output.push_str(&format!(
"mapreduce_agent_states{{job_id=\"{}\",state=\"{}\"}} {}\n",
server.tracker.job_id, state, count
));
});
let duration = server.tracker.start_time.elapsed().as_secs();
output.push_str("# HELP mapreduce_job_duration_seconds Job duration in seconds\n");
output.push_str("# TYPE mapreduce_job_duration_seconds gauge\n");
output.push_str(&format!(
"mapreduce_job_duration_seconds{{job_id=\"{}\"}} {}\n",
server.tracker.job_id, duration
));
output
}
async fn handle_socket(socket: WebSocket, server: Arc<ProgressWebServer>) {
use futures_util::{SinkExt, StreamExt};
let (mut sender, mut receiver) = socket.split();
let (tx, mut rx) = mpsc::unbounded_channel::<String>();
let client_id = Uuid::new_v4();
{
let mut connections = server.connections.write().await;
connections.insert(client_id, tx);
}
let mut send_task = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if sender.send(Message::Text(msg.into())).await.is_err() {
break;
}
}
});
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(_msg)) = receiver.next().await {
}
});
tokio::select! {
_ = &mut send_task => recv_task.abort(),
_ = &mut recv_task => send_task.abort(),
}
let mut connections = server.connections.write().await;
connections.remove(&client_id);
}
}
#[derive(Debug, Clone)]
pub(crate) enum RenderStrategy {
Full,
Cached(ProgressMetrics),
Skip,
}
impl PartialEq for RenderStrategy {
fn eq(&self, other: &Self) -> bool {
matches!(
(self, other),
(RenderStrategy::Full, RenderStrategy::Full)
| (RenderStrategy::Skip, RenderStrategy::Skip)
| (RenderStrategy::Cached(_), RenderStrategy::Cached(_))
)
}
}
pub struct CLIProgressViewer {
tracker: Arc<EnhancedProgressTracker>,
update_interval: Duration,
sampler: Option<ProgressSampler>,
}
impl CLIProgressViewer {
pub fn new(tracker: Arc<EnhancedProgressTracker>) -> Self {
Self {
tracker,
update_interval: Duration::from_millis(500),
sampler: None,
}
}
pub fn with_sampling(tracker: Arc<EnhancedProgressTracker>, sample_rate: Duration) -> Self {
Self {
tracker,
update_interval: Duration::from_millis(500),
sampler: Some(ProgressSampler::new(sample_rate)),
}
}
pub(crate) fn is_job_complete(metrics: &ProgressMetrics) -> bool {
metrics.pending_items == 0 && metrics.active_agents == 0
}
pub(crate) async fn should_use_cached_render(sampler: &ProgressSampler) -> bool {
!sampler.should_sample().await
}
pub(crate) async fn determine_render_strategy(
sampler: Option<&ProgressSampler>,
) -> RenderStrategy {
if let Some(sampler) = sampler {
if Self::should_use_cached_render(sampler).await {
if let Some((_, metrics)) = sampler.get_cached().await {
RenderStrategy::Cached(metrics)
} else {
RenderStrategy::Skip
}
} else {
RenderStrategy::Full
}
} else {
RenderStrategy::Full
}
}
pub async fn display(&self) -> MapReduceResult<()> {
let mut interval = interval(self.update_interval);
loop {
interval.tick().await;
let strategy = Self::determine_render_strategy(self.sampler.as_ref()).await;
match strategy {
RenderStrategy::Full => {
if let Some(ref sampler) = self.sampler {
let snapshot = self.tracker.create_snapshot().await;
let metrics = self.tracker.metrics.read().await;
sampler.update_cache(snapshot, metrics.clone()).await;
}
self.clear_screen();
self.render_header().await?;
self.render_metrics().await?;
self.render_agents().await?;
}
RenderStrategy::Cached(metrics) => {
self.clear_screen();
self.render_header_with_metrics(&metrics).await?;
self.render_cached_agents().await?;
}
RenderStrategy::Skip => {
}
}
let metrics = self.tracker.metrics.read().await;
if Self::is_job_complete(&metrics) {
println!("\n✅ Job completed!");
break;
}
}
Ok(())
}
async fn render_header_with_metrics(&self, metrics: &ProgressMetrics) -> MapReduceResult<()> {
let total = metrics.completed_items + metrics.failed_items + metrics.pending_items;
let progress = if total > 0 {
((metrics.completed_items + metrics.failed_items) as f32 / total as f32) * 100.0
} else {
0.0
};
let elapsed = self.tracker.start_time.elapsed();
println!("╔════════════════════════════════════════════════════════════════╗");
println!("║ MapReduce Job: {} ║", self.tracker.job_id);
println!("║ Progress: {:.1}% | Elapsed: {:?} ║", progress, elapsed);
println!("╚════════════════════════════════════════════════════════════════╝");
Ok(())
}
async fn render_cached_agents(&self) -> MapReduceResult<()> {
if let Some(ref sampler) = self.sampler {
if let Some((snapshot, _)) = sampler.get_cached().await {
println!("\n👥 Agent Status (cached):");
println!("{}", "─".repeat(60));
for (id, state) in snapshot.agent_states.iter().take(10) {
let state_str = match state {
AgentState::Running { step, .. } => format!("🔄 {}", step),
AgentState::Completed => "✅ Completed".to_string(),
AgentState::Failed { error } => format!("❌ Failed: {}", error),
_ => format!("{:?}", state),
};
println!(" {}: {}", &id[..8.min(id.len())], state_str);
}
if snapshot.agent_states.len() > 10 {
println!(" ... and {} more agents", snapshot.agent_states.len() - 10);
}
}
}
Ok(())
}
fn clear_screen(&self) {
print!("\x1B[2J\x1B[1;1H");
}
async fn render_header(&self) -> MapReduceResult<()> {
let progress = self.tracker.get_overall_progress().await;
let elapsed = self.tracker.start_time.elapsed();
println!("╔════════════════════════════════════════════════════════════════╗");
println!("║ MapReduce Job: {} ║", self.tracker.job_id);
println!("║ Progress: {:.1}% | Elapsed: {:?} ║", progress, elapsed);
println!("╚════════════════════════════════════════════════════════════════╝");
Ok(())
}
async fn render_metrics(&self) -> MapReduceResult<()> {
let metrics = self.tracker.metrics.read().await;
println!("\n📊 Metrics:");
println!("{}", "─".repeat(60));
println!(
" Completed: {} | Failed: {} | Pending: {}",
metrics.completed_items, metrics.failed_items, metrics.pending_items
);
println!(
" Active Agents: {} | Success Rate: {:.1}%",
metrics.active_agents, metrics.success_rate
);
println!(
" Throughput: {:.2} items/sec (avg)",
metrics.throughput_average
);
if let Some(etc) = metrics.estimated_completion {
let remaining = etc.signed_duration_since(Utc::now());
println!(
" ETC: {} ({} remaining)",
etc.format("%H:%M:%S"),
format_duration(remaining.to_std().unwrap_or_default())
);
}
Ok(())
}
async fn render_agents(&self) -> MapReduceResult<()> {
let agents = self.tracker.agents.read().await;
println!("\n👥 Agent Status:");
println!("{}", "─".repeat(60));
for (id, progress) in agents.iter().take(10) {
let bar = self.create_progress_bar(progress.progress_percentage);
let state_str = match &progress.state {
AgentState::Running { step, .. } => format!("🔄 {}", step),
AgentState::Completed => "✅ Completed".to_string(),
AgentState::Failed { error } => format!("❌ Failed: {}", error),
_ => format!("{:?}", progress.state),
};
println!(
" {}: {} [{}] {:.1}%",
&id[..8],
state_str,
bar,
progress.progress_percentage
);
}
if agents.len() > 10 {
println!(" ... and {} more agents", agents.len() - 10);
}
Ok(())
}
pub fn create_progress_bar(&self, percentage: f32) -> String {
let width = 20;
let filled = ((percentage / 100.0) * width as f32) as usize;
let empty = width - filled;
format!("{}{}", "█".repeat(filled), "░".repeat(empty))
}
}
pub fn format_duration(duration: Duration) -> String {
let secs = duration.as_secs();
let hours = secs / 3600;
let minutes = (secs % 3600) / 60;
let seconds = secs % 60;
if hours > 0 {
format!("{}h {}m {}s", hours, minutes, seconds)
} else if minutes > 0 {
format!("{}m {}s", minutes, seconds)
} else {
format!("{}s", seconds)
}
}
#[async_trait::async_trait]
pub trait ProgressReporter: Send + Sync {
async fn update_agent_progress(
&self,
agent_id: &str,
progress: AgentProgress,
) -> MapReduceResult<()>;
async fn get_overall_progress(&self) -> MapReduceResult<f32>;
async fn get_estimated_completion(&self) -> MapReduceResult<Option<DateTime<Utc>>>;
async fn export_progress(&self, format: ExportFormat) -> MapReduceResult<Vec<u8>>;
}
#[async_trait::async_trait]
impl ProgressReporter for EnhancedProgressTracker {
async fn update_agent_progress(
&self,
agent_id: &str,
progress: AgentProgress,
) -> MapReduceResult<()> {
self.update_agent_progress_impl(agent_id, progress).await
}
async fn get_overall_progress(&self) -> MapReduceResult<f32> {
Ok(EnhancedProgressTracker::get_overall_progress(self).await)
}
async fn get_estimated_completion(&self) -> MapReduceResult<Option<DateTime<Utc>>> {
Ok(EnhancedProgressTracker::get_estimated_completion(self).await)
}
async fn export_progress(&self, format: ExportFormat) -> MapReduceResult<Vec<u8>> {
EnhancedProgressTracker::export_progress(self, format).await
}
}