unistore-progress 0.1.0

Progress tracking capability for UniStore
Documentation
//! 【多阶段进度】- 支持子任务的进度追踪
//!
//! 职责:
//! - 管理多个阶段/子任务的进度
//! - 聚合子进度到总进度

use crate::config::ProgressConfig;
use crate::deps::*;
use crate::event::ProgressEvent;
use crate::tracker::ProgressTracker;

/// 阶段信息
#[derive(Debug, Clone)]
pub struct Stage {
    /// 阶段名称
    pub name: String,
    /// 阶段权重(占总进度的比例)
    pub weight: f64,
    /// 阶段内部总数
    pub total: u64,
}

/// 多阶段进度追踪器
///
/// 用于追踪由多个阶段组成的复杂任务。
///
/// # Example
///
/// ```rust
/// use unistore_progress::{MultiStageTracker, Stage};
///
/// let tracker = MultiStageTracker::new(vec![
///     Stage { name: "下载".into(), weight: 0.3, total: 100 },
///     Stage { name: "解压".into(), weight: 0.2, total: 50 },
///     Stage { name: "安装".into(), weight: 0.5, total: 200 },
/// ]);
///
/// // 第一阶段
/// tracker.advance(0, 50);  // 下载 50%
/// assert!((tracker.total_percentage() - 15.0).abs() < 0.1);
///
/// // 完成第一阶段,开始第二阶段
/// tracker.finish_stage(0);
/// tracker.advance(1, 25);  // 解压 50%
/// ```
pub struct MultiStageTracker {
    /// 阶段列表
    stages: Vec<Stage>,
    /// 每个阶段的追踪器
    trackers: Vec<ProgressTracker>,
    /// 当前活动阶段索引
    current_stage: AtomicU64,
    /// 配置(预留给未来扩展)
    #[allow(dead_code)]
    config: ProgressConfig,
    /// 总进度广播
    sender: broadcast::Sender<ProgressEvent>,
}

impl MultiStageTracker {
    /// 创建多阶段追踪器
    pub fn new(stages: Vec<Stage>) -> Self {
        Self::with_config(stages, ProgressConfig::default())
    }

    /// 使用自定义配置创建
    pub fn with_config(stages: Vec<Stage>, config: ProgressConfig) -> Self {
        assert!(!stages.is_empty(), "stages cannot be empty");

        // 验证权重总和
        let total_weight: f64 = stages.iter().map(|s| s.weight).sum();
        assert!(
            (total_weight - 1.0).abs() < 0.001,
            "stage weights must sum to 1.0, got {}",
            total_weight
        );

        let trackers = stages
            .iter()
            .map(|s| ProgressTracker::with_config(s.total, config.clone()))
            .collect();

        let (sender, _) = broadcast::channel(config.channel_capacity);

        Self {
            stages,
            trackers,
            current_stage: AtomicU64::new(0),
            config,
            sender,
        }
    }

    /// 获取阶段数量
    pub fn stage_count(&self) -> usize {
        self.stages.len()
    }

    /// 获取当前阶段索引
    pub fn current_stage_index(&self) -> usize {
        self.current_stage.load(Ordering::Relaxed) as usize
    }

    /// 获取当前阶段信息
    pub fn current_stage(&self) -> &Stage {
        &self.stages[self.current_stage_index()]
    }

    /// 增加指定阶段的进度
    pub fn advance(&self, stage_index: usize, delta: u64) {
        if stage_index < self.trackers.len() {
            self.trackers[stage_index].advance(delta);
            self.notify_total();
        }
    }

    /// 设置指定阶段的消息
    pub fn set_stage_message(&self, stage_index: usize, message: impl Into<String>) {
        if stage_index < self.trackers.len() {
            self.trackers[stage_index].set_message(message);
        }
    }

    /// 完成指定阶段
    pub fn finish_stage(&self, stage_index: usize) {
        if stage_index < self.trackers.len() {
            self.trackers[stage_index].finish();

            // 自动切换到下一阶段
            if stage_index == self.current_stage_index() && stage_index + 1 < self.stages.len() {
                self.current_stage
                    .store((stage_index + 1) as u64, Ordering::Relaxed);
            }

            self.notify_total();
        }
    }

    /// 获取总体进度百分比
    pub fn total_percentage(&self) -> f64 {
        self.stages
            .iter()
            .zip(self.trackers.iter())
            .map(|(stage, tracker)| {
                let stage_pct = tracker.percentage() / 100.0;
                stage_pct * stage.weight * 100.0
            })
            .sum()
    }

    /// 是否全部完成
    pub fn is_finished(&self) -> bool {
        self.trackers.iter().all(|t| t.is_finished())
    }

    /// 订阅总进度更新
    pub fn subscribe(&self) -> broadcast::Receiver<ProgressEvent> {
        self.sender.subscribe()
    }

    /// 获取指定阶段的追踪器
    pub fn stage_tracker(&self, index: usize) -> Option<&ProgressTracker> {
        self.trackers.get(index)
    }

    /// 发送总进度通知
    fn notify_total(&self) {
        let total_pct = self.total_percentage();
        let current_stage = self.current_stage();

        // 计算等效的 current/total
        let virtual_total = 10000u64; // 使用 10000 作为虚拟总数提高精度
        let virtual_current = (total_pct / 100.0 * virtual_total as f64) as u64;

        let elapsed = self.trackers[0].elapsed(); // 使用第一个阶段的开始时间

        let event = ProgressEvent {
            current: virtual_current,
            total: virtual_total,
            message: format!(
                "[{}/{}] {}",
                self.current_stage_index() + 1,
                self.stages.len(),
                current_stage.name
            ),
            elapsed,
            eta: None, // 多阶段 ETA 计算复杂,暂不实现
            finished: self.is_finished(),
        };

        let _ = self.sender.send(event);
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_multi_stage() {
        let tracker = MultiStageTracker::new(vec![
            Stage {
                name: "阶段1".into(),
                weight: 0.5,
                total: 100,
            },
            Stage {
                name: "阶段2".into(),
                weight: 0.5,
                total: 100,
            },
        ]);

        assert_eq!(tracker.stage_count(), 2);
        assert_eq!(tracker.current_stage_index(), 0);
    }

    #[test]
    fn test_progress_calculation() {
        let tracker = MultiStageTracker::new(vec![
            Stage {
                name: "阶段1".into(),
                weight: 0.5,
                total: 100,
            },
            Stage {
                name: "阶段2".into(),
                weight: 0.5,
                total: 100,
            },
        ]);

        // 阶段1完成50%,总进度应该是25%
        tracker.advance(0, 50);
        assert!((tracker.total_percentage() - 25.0).abs() < 0.1);

        // 阶段1完成100%,总进度应该是50%
        tracker.finish_stage(0);
        assert!((tracker.total_percentage() - 50.0).abs() < 0.1);

        // 阶段2完成50%,总进度应该是75%
        tracker.advance(1, 50);
        assert!((tracker.total_percentage() - 75.0).abs() < 0.1);
    }

    #[test]
    fn test_auto_stage_switch() {
        let tracker = MultiStageTracker::new(vec![
            Stage {
                name: "阶段1".into(),
                weight: 0.5,
                total: 100,
            },
            Stage {
                name: "阶段2".into(),
                weight: 0.5,
                total: 100,
            },
        ]);

        assert_eq!(tracker.current_stage_index(), 0);
        tracker.finish_stage(0);
        assert_eq!(tracker.current_stage_index(), 1);
    }
}