unistore-progress 0.1.0

Progress tracking capability for UniStore
Documentation
//! 【进度追踪器】- 核心实现
//!
//! 职责:
//! - 追踪任务进度
//! - 计算 ETA
//! - 发布进度事件

use crate::config::ProgressConfig;
use crate::deps::*;
#[allow(unused_imports)]
use crate::error::{ProgressError, ProgressResult};
use crate::event::ProgressEvent;

/// 内部状态
struct TrackerState {
    /// 当前进度消息
    message: String,
    /// 最后通知时间(用于去抖动)
    last_notify: Option<Instant>,
    /// 速率样本(用于 ETA 平滑)
    rate_samples: Vec<f64>,
}

/// 进度追踪器
///
/// 线程安全的进度追踪器,支持多个订阅者。
///
/// # Example
///
/// ```rust
/// use unistore_progress::ProgressTracker;
///
/// let tracker = ProgressTracker::new(100);
/// tracker.advance(10);
/// tracker.set_message("处理中...");
/// assert_eq!(tracker.current(), 10);
/// ```
pub struct ProgressTracker {
    /// 总数
    total: u64,
    /// 当前完成数(原子操作)
    current: AtomicU64,
    /// 是否已完成
    finished: AtomicBool,
    /// 开始时间
    started_at: Instant,
    /// 配置
    config: ProgressConfig,
    /// 可变状态
    state: RwLock<TrackerState>,
    /// 事件广播通道
    sender: broadcast::Sender<ProgressEvent>,
}

impl ProgressTracker {
    /// 创建新的进度追踪器
    ///
    /// # Arguments
    /// * `total` - 总任务数
    ///
    /// # Panics
    /// 如果 total 为 0
    pub fn new(total: u64) -> Self {
        Self::with_config(total, ProgressConfig::default())
    }

    /// 使用自定义配置创建进度追踪器
    pub fn with_config(total: u64, config: ProgressConfig) -> Self {
        assert!(total > 0, "total must be greater than 0");

        let (sender, _) = broadcast::channel(config.channel_capacity);

        Self {
            total,
            current: AtomicU64::new(0),
            finished: AtomicBool::new(false),
            started_at: Instant::now(),
            config,
            state: RwLock::new(TrackerState {
                message: String::new(),
                last_notify: None,
                rate_samples: Vec::with_capacity(10),
            }),
            sender,
        }
    }

    /// 获取总数
    pub fn total(&self) -> u64 {
        self.total
    }

    /// 获取当前完成数
    pub fn current(&self) -> u64 {
        self.current.load(Ordering::Relaxed)
    }

    /// 获取完成百分比(0.0 - 100.0)
    pub fn percentage(&self) -> f64 {
        (self.current() as f64 / self.total as f64) * 100.0
    }

    /// 是否已完成
    pub fn is_finished(&self) -> bool {
        self.finished.load(Ordering::Relaxed)
    }

    /// 获取已用时间
    pub fn elapsed(&self) -> Duration {
        self.started_at.elapsed()
    }

    /// 增加进度
    ///
    /// # Arguments
    /// * `delta` - 增加的数量
    pub fn advance(&self, delta: u64) {
        if self.is_finished() {
            return;
        }

        let new_value = self.current.fetch_add(delta, Ordering::Relaxed) + delta;

        // 检查是否达到自动完成阈值
        let ratio = new_value as f64 / self.total as f64;
        if ratio >= self.config.auto_finish_threshold {
            self.finish();
        } else {
            self.maybe_notify();
        }
    }

    /// 设置进度到指定值
    pub fn set(&self, value: u64) {
        if self.is_finished() {
            return;
        }

        let value = value.min(self.total);
        self.current.store(value, Ordering::Relaxed);

        let ratio = value as f64 / self.total as f64;
        if ratio >= self.config.auto_finish_threshold {
            self.finish();
        } else {
            self.maybe_notify();
        }
    }

    /// 设置进度消息
    pub fn set_message(&self, message: impl Into<String>) {
        let mut state = self.state.write();
        state.message = message.into();
        drop(state);
        self.maybe_notify();
    }

    /// 标记为完成
    pub fn finish(&self) {
        if self.finished.swap(true, Ordering::Relaxed) {
            return; // 已经完成过了
        }
        self.current.store(self.total, Ordering::Relaxed);
        self.notify_now();
    }

    /// 订阅进度更新
    pub fn subscribe(&self) -> broadcast::Receiver<ProgressEvent> {
        self.sender.subscribe()
    }

    /// 获取当前快照
    pub fn snapshot(&self) -> ProgressEvent {
        let state = self.state.read();
        let current = self.current();
        let elapsed = self.elapsed();

        ProgressEvent {
            current,
            total: self.total,
            message: state.message.clone(),
            elapsed,
            eta: self.calculate_eta(current, elapsed),
            finished: self.is_finished(),
        }
    }

    /// 检查是否应该通知(去抖动)
    fn maybe_notify(&self) {
        let now = Instant::now();
        let should_notify = {
            let state = self.state.read();
            match state.last_notify {
                Some(last) => now.duration_since(last) >= self.config.debounce_interval,
                None => true,
            }
        };

        if should_notify {
            self.notify_now();
        }
    }

    /// 立即发送通知
    fn notify_now(&self) {
        let event = self.snapshot();

        // 更新最后通知时间和速率样本
        {
            let mut state = self.state.write();
            state.last_notify = Some(Instant::now());

            // 记录速率样本
            let elapsed_secs = event.elapsed.as_secs_f64();
            if elapsed_secs > 0.0 {
                let rate = event.current as f64 / elapsed_secs;
                state.rate_samples.push(rate);
                if state.rate_samples.len() > 10 {
                    state.rate_samples.remove(0);
                }
            }
        }

        // 发送事件(忽略没有订阅者的情况)
        let _ = self.sender.send(event);
    }

    /// 计算 ETA
    fn calculate_eta(&self, current: u64, elapsed: Duration) -> Option<Duration> {
        if current == 0 {
            return None;
        }

        let state = self.state.read();
        if state.rate_samples.len() < self.config.eta_min_samples {
            // 样本不足,使用简单计算
            let rate = current as f64 / elapsed.as_secs_f64();
            if rate > 0.0 {
                let remaining = (self.total - current) as f64;
                return Some(Duration::from_secs_f64(remaining / rate));
            }
            return None;
        }

        // 使用指数移动平均计算平滑速率
        let smoothed_rate = state.rate_samples.iter().rev().fold(0.0, |acc, &rate| {
            acc * (1.0 - self.config.eta_smoothing_factor) + rate * self.config.eta_smoothing_factor
        });

        if smoothed_rate > 0.0 {
            let remaining = (self.total - current) as f64;
            Some(Duration::from_secs_f64(remaining / smoothed_rate))
        } else {
            None
        }
    }
}

impl std::fmt::Debug for ProgressTracker {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ProgressTracker")
            .field("total", &self.total)
            .field("current", &self.current())
            .field("finished", &self.is_finished())
            .field("elapsed", &self.elapsed())
            .finish()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_new() {
        let tracker = ProgressTracker::new(100);
        assert_eq!(tracker.total(), 100);
        assert_eq!(tracker.current(), 0);
        assert!(!tracker.is_finished());
    }

    #[test]
    fn test_advance() {
        let tracker = ProgressTracker::new(100);
        tracker.advance(10);
        assert_eq!(tracker.current(), 10);
        tracker.advance(20);
        assert_eq!(tracker.current(), 30);
    }

    #[test]
    fn test_set() {
        let tracker = ProgressTracker::new(100);
        tracker.set(50);
        assert_eq!(tracker.current(), 50);
    }

    #[test]
    fn test_finish() {
        let tracker = ProgressTracker::new(100);
        tracker.advance(50);
        tracker.finish();
        assert!(tracker.is_finished());
        assert_eq!(tracker.current(), 100);
    }

    #[test]
    fn test_percentage() {
        let tracker = ProgressTracker::new(100);
        tracker.set(25);
        assert!((tracker.percentage() - 25.0).abs() < 0.001);
    }

    #[test]
    fn test_set_message() {
        let tracker = ProgressTracker::new(100);
        tracker.set_message("Processing...");
        let snapshot = tracker.snapshot();
        assert_eq!(snapshot.message, "Processing...");
    }

    #[test]
    fn test_auto_finish() {
        let tracker = ProgressTracker::new(100);
        tracker.set(100);
        assert!(tracker.is_finished());
    }

    #[tokio::test]
    async fn test_subscribe() {
        let tracker = ProgressTracker::with_config(
            100,
            ProgressConfig::default().no_debounce(),
        );
        let mut rx = tracker.subscribe();

        tracker.advance(10);

        // 等待事件
        let event = tokio::time::timeout(Duration::from_millis(100), rx.recv())
            .await
            .expect("timeout")
            .expect("recv error");

        assert_eq!(event.current, 10);
    }
}