node-token 0.1.0

KeyCompute node-token client for personal PC nodes
/// 轮询循环逻辑
///
/// 定期从服务端领取任务,提交到执行器执行。
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;

use tracing::{debug, error, info};

use crate::client::KeyComputeClient;
use crate::protocol::types::NodePollRequest;
use crate::runtime::executor::TaskExecutor;
use crate::storage::SessionData;

/// 轮询循环
///
/// # 参数
/// - `client`: KeyCompute HTTP 客户端
/// - `session`: 当前 session 信息
/// - `executor`: 任务执行器
/// - `is_excluded`: 节点排除标志(由 heartbeat 循环更新)
/// - `stop_signal`: 退出信号
/// - `excluded_check_interval`: Excluded 节点 poll 检查间隔
/// - `poll_timeout_secs`: 服务端轮询超时(秒),用于计算无任务时的等待间隔
///
/// # 行为
/// - 定期调用 poll API 领取任务
/// - 如果节点被 excluded,停止 poll
/// - 服务端返回 retry_after_ms 时等待指定时间
/// - 网络错误指数退避(AGENTS.md 第 724 行)
/// - 无任务时等待间隔 = poll_timeout_secs / 10(默认 1 秒)
#[allow(dead_code)] // 在阶段五使用
pub async fn poll_loop(
    client: &KeyComputeClient,
    session: &SessionData,
    executor: Arc<TaskExecutor>,
    is_excluded: Arc<AtomicBool>,
    stop_signal: Arc<AtomicBool>,
    excluded_check_interval: Duration,
    poll_timeout_secs: u64,
) {
    info!("Starting poll loop");

    // 连续失败计数,用于指数退避
    let mut consecutive_failures: u32 = 0;
    let max_backoff = Duration::from_secs(16);

    // 计算无任务时的等待间隔:poll_timeout_secs / 10,默认 1 秒
    let empty_poll_interval = if poll_timeout_secs > 0 {
        Duration::from_secs(poll_timeout_secs / 10)
    } else {
        Duration::from_secs(1) // 默认 1 秒
    };

    info!(
        "Poll empty interval: {}s (poll_timeout_secs={})",
        empty_poll_interval.as_secs(),
        poll_timeout_secs
    );

    while !stop_signal.load(Ordering::Relaxed) {
        // 如果节点被 excluded,停止 poll
        if is_excluded.load(Ordering::Relaxed) {
            info!("Node excluded, stopping poll (will continue heartbeat only)");
            // 按配置的检查间隔等待后,检查是否恢复
            tokio::time::sleep(excluded_check_interval).await;
            continue;
        }

        let req = NodePollRequest {
            protocol_version: "node.v1".to_string(),
            node_id: session.node_id,
            session_id: session.session_id,
        };

        match client.poll(&req).await {
            Ok(resp) => {
                // 成功后重置失败计数
                consecutive_failures = 0;

                if let Some(task) = resp.task {
                    info!(
                        "Received task: task_id={}, model={}, deadline_unix_ms={}",
                        task.task_id, task.model, task.deadline_unix_ms
                    );

                    // 收到任务,提交到执行器
                    // 使用 Arc 克隆 executor,让 executor 在后台执行
                    let executor_clone = executor.clone();
                    tokio::spawn(async move {
                        executor_clone.execute(task).await;
                    });
                } else if let Some(retry_ms) = resp.retry_after_ms {
                    debug!("No task available, retry_after={}ms", retry_ms);
                    tokio::time::sleep(Duration::from_millis(retry_ms)).await;
                } else {
                    // 没有任务也没有 retry_after,使用计算的间隔等待后继续
                    tokio::time::sleep(empty_poll_interval).await;
                }
            }
            Err(e) => {
                error!("Poll failed: {}", e);
                // 网络错误指数退避(AGENTS.md 第 724 行)
                consecutive_failures += 1;
                let backoff = std::cmp::min(
                    Duration::from_secs(2_u64.pow(consecutive_failures.min(4))),
                    max_backoff,
                );
                info!(
                    "Poll retrying after {}s (consecutive_failures={})",
                    backoff.as_secs(),
                    consecutive_failures
                );
                tokio::time::sleep(backoff).await;
            }
        }
    }

    info!("Poll loop stopped");
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    /// 验证 poll 循环的 excluded 标志检查逻辑。
    /// excluded 节点应该停止 poll。
    fn test_excluded_flag_check() {
        let is_excluded = Arc::new(AtomicBool::new(false));

        // 非 excluded 时可以 poll
        assert!(!is_excluded.load(Ordering::Relaxed));

        // excluded 时停止 poll
        is_excluded.store(true, Ordering::Relaxed);
        assert!(is_excluded.load(Ordering::Relaxed));
    }

    #[test]
    /// 验证 poll 循环的停止信号检查逻辑。
    fn test_stop_signal_check() {
        let stop_signal = Arc::new(AtomicBool::new(false));

        // 未停止
        assert!(!stop_signal.load(Ordering::Relaxed));

        // 发送停止信号
        stop_signal.store(true, Ordering::Relaxed);
        assert!(stop_signal.load(Ordering::Relaxed));
    }

    #[test]
    /// 验证轮询指数退避间隔计算。
    /// 网络错误后应该指数级增加等待时间。
    fn test_poll_backoff_calculation() {
        let max_backoff = Duration::from_secs(16);

        // 第 1 次失败:2^1 = 2 秒
        let backoff_1 = Duration::from_secs(2_u64.pow(1));
        assert_eq!(backoff_1, Duration::from_secs(2));

        // 第 2 次失败:2^2 = 4 秒
        let backoff_2 = Duration::from_secs(2_u64.pow(2));
        assert_eq!(backoff_2, Duration::from_secs(4));

        // 第 3 次失败:2^3 = 8 秒
        let backoff_3 = Duration::from_secs(2_u64.pow(3));
        assert_eq!(backoff_3, Duration::from_secs(8));

        // 第 4 次失败:2^4 = 16 秒
        let backoff_4 = Duration::from_secs(2_u64.pow(4));
        assert_eq!(backoff_4, Duration::from_secs(16));

        // 第 5 次失败:min(2^5, 16) = 16 秒(达到最大值)
        let backoff_5 = std::cmp::min(Duration::from_secs(2_u64.pow(5)), max_backoff);
        assert_eq!(backoff_5, Duration::from_secs(16));

        // 第 10 次失败:min(2^10, 16) = 16 秒(仍然最大值)
        let backoff_10 = std::cmp::min(Duration::from_secs(2_u64.pow(10)), max_backoff);
        assert_eq!(backoff_10, Duration::from_secs(16));
    }

    #[test]
    /// 验证空轮询间隔计算。
    /// 无任务时的等待间隔 = poll_timeout_secs / 10
    fn test_empty_poll_interval_calculation() {
        // poll_timeout_secs = 20,间隔 = 2 秒
        let interval_1 = if 20 > 0 {
            Duration::from_secs(20 / 10)
        } else {
            Duration::from_secs(1) // 默认值
        };
        assert_eq!(interval_1, Duration::from_secs(2));

        // poll_timeout_secs = 30,间隔 = 3 秒
        let interval_2 = if 30 > 0 {
            Duration::from_secs(30 / 10)
        } else {
            Duration::from_secs(1) // 默认值
        };
        assert_eq!(interval_2, Duration::from_secs(3));

        // poll_timeout_secs = 5,间隔 = 0 秒(整数除法),使用默认 1 秒
        let interval_3 = if 5 > 0 {
            let calculated = 5 / 10; // = 0
            if calculated > 0 {
                Duration::from_secs(calculated)
            } else {
                Duration::from_secs(1) // 使用默认值
            }
        } else {
            Duration::from_secs(1)
        };
        assert_eq!(interval_3, Duration::from_secs(1));
    }

    #[test]
    /// 验证多个 AtomicBool 并发访问的安全性。
    fn test_atomic_bool_concurrent_access() {
        let is_excluded = Arc::new(AtomicBool::new(false));
        let stop_signal = Arc::new(AtomicBool::new(false));
        let mut handles = vec![];

        // 创建多个线程同时读写
        for i in 0..10 {
            let excluded = is_excluded.clone();
            let stop = stop_signal.clone();
            let handle = std::thread::spawn(move || {
                if i % 3 == 0 {
                    excluded.store(true, Ordering::Relaxed);
                } else if i % 3 == 1 {
                    stop.store(true, Ordering::Relaxed);
                } else {
                    let _ = excluded.load(Ordering::Relaxed);
                    let _ = stop.load(Ordering::Relaxed);
                }
            });
            handles.push(handle);
        }

        // 等待所有线程完成
        for handle in handles {
            handle.join().unwrap();
        }

        // 验证没有 panic
        let _ = is_excluded.load(Ordering::Relaxed);
        let _ = stop_signal.load(Ordering::Relaxed);
    }
}