use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, error, info};
use crate::client::KeyComputeClient;
use crate::protocol::types::NodePollRequest;
use crate::runtime::executor::TaskExecutor;
use crate::storage::SessionData;
#[allow(dead_code)] pub async fn poll_loop(
client: &KeyComputeClient,
session: &SessionData,
executor: Arc<TaskExecutor>,
is_excluded: Arc<AtomicBool>,
stop_signal: Arc<AtomicBool>,
excluded_check_interval: Duration,
poll_timeout_secs: u64,
) {
info!("Starting poll loop");
let mut consecutive_failures: u32 = 0;
let max_backoff = Duration::from_secs(16);
let empty_poll_interval = if poll_timeout_secs > 0 {
Duration::from_secs(poll_timeout_secs / 10)
} else {
Duration::from_secs(1) };
info!(
"Poll empty interval: {}s (poll_timeout_secs={})",
empty_poll_interval.as_secs(),
poll_timeout_secs
);
while !stop_signal.load(Ordering::Relaxed) {
if is_excluded.load(Ordering::Relaxed) {
info!("Node excluded, stopping poll (will continue heartbeat only)");
tokio::time::sleep(excluded_check_interval).await;
continue;
}
let req = NodePollRequest {
protocol_version: "node.v1".to_string(),
node_id: session.node_id,
session_id: session.session_id,
};
match client.poll(&req).await {
Ok(resp) => {
consecutive_failures = 0;
if let Some(task) = resp.task {
info!(
"Received task: task_id={}, model={}, deadline_unix_ms={}",
task.task_id, task.model, task.deadline_unix_ms
);
let executor_clone = executor.clone();
tokio::spawn(async move {
executor_clone.execute(task).await;
});
} else if let Some(retry_ms) = resp.retry_after_ms {
debug!("No task available, retry_after={}ms", retry_ms);
tokio::time::sleep(Duration::from_millis(retry_ms)).await;
} else {
tokio::time::sleep(empty_poll_interval).await;
}
}
Err(e) => {
error!("Poll failed: {}", e);
consecutive_failures += 1;
let backoff = std::cmp::min(
Duration::from_secs(2_u64.pow(consecutive_failures.min(4))),
max_backoff,
);
info!(
"Poll retrying after {}s (consecutive_failures={})",
backoff.as_secs(),
consecutive_failures
);
tokio::time::sleep(backoff).await;
}
}
}
info!("Poll loop stopped");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_excluded_flag_check() {
let is_excluded = Arc::new(AtomicBool::new(false));
assert!(!is_excluded.load(Ordering::Relaxed));
is_excluded.store(true, Ordering::Relaxed);
assert!(is_excluded.load(Ordering::Relaxed));
}
#[test]
fn test_stop_signal_check() {
let stop_signal = Arc::new(AtomicBool::new(false));
assert!(!stop_signal.load(Ordering::Relaxed));
stop_signal.store(true, Ordering::Relaxed);
assert!(stop_signal.load(Ordering::Relaxed));
}
#[test]
fn test_poll_backoff_calculation() {
let max_backoff = Duration::from_secs(16);
let backoff_1 = Duration::from_secs(2_u64.pow(1));
assert_eq!(backoff_1, Duration::from_secs(2));
let backoff_2 = Duration::from_secs(2_u64.pow(2));
assert_eq!(backoff_2, Duration::from_secs(4));
let backoff_3 = Duration::from_secs(2_u64.pow(3));
assert_eq!(backoff_3, Duration::from_secs(8));
let backoff_4 = Duration::from_secs(2_u64.pow(4));
assert_eq!(backoff_4, Duration::from_secs(16));
let backoff_5 = std::cmp::min(Duration::from_secs(2_u64.pow(5)), max_backoff);
assert_eq!(backoff_5, Duration::from_secs(16));
let backoff_10 = std::cmp::min(Duration::from_secs(2_u64.pow(10)), max_backoff);
assert_eq!(backoff_10, Duration::from_secs(16));
}
#[test]
fn test_empty_poll_interval_calculation() {
let interval_1 = if 20 > 0 {
Duration::from_secs(20 / 10)
} else {
Duration::from_secs(1) };
assert_eq!(interval_1, Duration::from_secs(2));
let interval_2 = if 30 > 0 {
Duration::from_secs(30 / 10)
} else {
Duration::from_secs(1) };
assert_eq!(interval_2, Duration::from_secs(3));
let interval_3 = if 5 > 0 {
let calculated = 5 / 10; if calculated > 0 {
Duration::from_secs(calculated)
} else {
Duration::from_secs(1) }
} else {
Duration::from_secs(1)
};
assert_eq!(interval_3, Duration::from_secs(1));
}
#[test]
fn test_atomic_bool_concurrent_access() {
let is_excluded = Arc::new(AtomicBool::new(false));
let stop_signal = Arc::new(AtomicBool::new(false));
let mut handles = vec![];
for i in 0..10 {
let excluded = is_excluded.clone();
let stop = stop_signal.clone();
let handle = std::thread::spawn(move || {
if i % 3 == 0 {
excluded.store(true, Ordering::Relaxed);
} else if i % 3 == 1 {
stop.store(true, Ordering::Relaxed);
} else {
let _ = excluded.load(Ordering::Relaxed);
let _ = stop.load(Ordering::Relaxed);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let _ = is_excluded.load(Ordering::Relaxed);
let _ = stop_signal.load(Ordering::Relaxed);
}
}