use crate::config::ProgressConfig;
use crate::deps::*;
use crate::event::ProgressEvent;
use crate::tracker::ProgressTracker;
#[derive(Debug, Clone)]
pub struct Stage {
pub name: String,
pub weight: f64,
pub total: u64,
}
pub struct MultiStageTracker {
stages: Vec<Stage>,
trackers: Vec<ProgressTracker>,
current_stage: AtomicU64,
#[allow(dead_code)]
config: ProgressConfig,
sender: broadcast::Sender<ProgressEvent>,
}
impl MultiStageTracker {
pub fn new(stages: Vec<Stage>) -> Self {
Self::with_config(stages, ProgressConfig::default())
}
pub fn with_config(stages: Vec<Stage>, config: ProgressConfig) -> Self {
assert!(!stages.is_empty(), "stages cannot be empty");
let total_weight: f64 = stages.iter().map(|s| s.weight).sum();
assert!(
(total_weight - 1.0).abs() < 0.001,
"stage weights must sum to 1.0, got {}",
total_weight
);
let trackers = stages
.iter()
.map(|s| ProgressTracker::with_config(s.total, config.clone()))
.collect();
let (sender, _) = broadcast::channel(config.channel_capacity);
Self {
stages,
trackers,
current_stage: AtomicU64::new(0),
config,
sender,
}
}
pub fn stage_count(&self) -> usize {
self.stages.len()
}
pub fn current_stage_index(&self) -> usize {
self.current_stage.load(Ordering::Relaxed) as usize
}
pub fn current_stage(&self) -> &Stage {
&self.stages[self.current_stage_index()]
}
pub fn advance(&self, stage_index: usize, delta: u64) {
if stage_index < self.trackers.len() {
self.trackers[stage_index].advance(delta);
self.notify_total();
}
}
pub fn set_stage_message(&self, stage_index: usize, message: impl Into<String>) {
if stage_index < self.trackers.len() {
self.trackers[stage_index].set_message(message);
}
}
pub fn finish_stage(&self, stage_index: usize) {
if stage_index < self.trackers.len() {
self.trackers[stage_index].finish();
if stage_index == self.current_stage_index() && stage_index + 1 < self.stages.len() {
self.current_stage
.store((stage_index + 1) as u64, Ordering::Relaxed);
}
self.notify_total();
}
}
pub fn total_percentage(&self) -> f64 {
self.stages
.iter()
.zip(self.trackers.iter())
.map(|(stage, tracker)| {
let stage_pct = tracker.percentage() / 100.0;
stage_pct * stage.weight * 100.0
})
.sum()
}
pub fn is_finished(&self) -> bool {
self.trackers.iter().all(|t| t.is_finished())
}
pub fn subscribe(&self) -> broadcast::Receiver<ProgressEvent> {
self.sender.subscribe()
}
pub fn stage_tracker(&self, index: usize) -> Option<&ProgressTracker> {
self.trackers.get(index)
}
fn notify_total(&self) {
let total_pct = self.total_percentage();
let current_stage = self.current_stage();
let virtual_total = 10000u64; let virtual_current = (total_pct / 100.0 * virtual_total as f64) as u64;
let elapsed = self.trackers[0].elapsed();
let event = ProgressEvent {
current: virtual_current,
total: virtual_total,
message: format!(
"[{}/{}] {}",
self.current_stage_index() + 1,
self.stages.len(),
current_stage.name
),
elapsed,
eta: None, finished: self.is_finished(),
};
let _ = self.sender.send(event);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multi_stage() {
let tracker = MultiStageTracker::new(vec![
Stage {
name: "阶段1".into(),
weight: 0.5,
total: 100,
},
Stage {
name: "阶段2".into(),
weight: 0.5,
total: 100,
},
]);
assert_eq!(tracker.stage_count(), 2);
assert_eq!(tracker.current_stage_index(), 0);
}
#[test]
fn test_progress_calculation() {
let tracker = MultiStageTracker::new(vec![
Stage {
name: "阶段1".into(),
weight: 0.5,
total: 100,
},
Stage {
name: "阶段2".into(),
weight: 0.5,
total: 100,
},
]);
tracker.advance(0, 50);
assert!((tracker.total_percentage() - 25.0).abs() < 0.1);
tracker.finish_stage(0);
assert!((tracker.total_percentage() - 50.0).abs() < 0.1);
tracker.advance(1, 50);
assert!((tracker.total_percentage() - 75.0).abs() < 0.1);
}
#[test]
fn test_auto_stage_switch() {
let tracker = MultiStageTracker::new(vec![
Stage {
name: "阶段1".into(),
weight: 0.5,
total: 100,
},
Stage {
name: "阶段2".into(),
weight: 0.5,
total: 100,
},
]);
assert_eq!(tracker.current_stage_index(), 0);
tracker.finish_stage(0);
assert_eq!(tracker.current_stage_index(), 1);
}
}