use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tracing::{error, info, warn};
use crate::client::{KeyComputeClient, OllamaClient};
use crate::config::NodeTokenConfig;
use crate::protocol::types::NodeHeartbeatRequest;
use crate::storage::SessionData;
#[allow(dead_code)] pub async fn heartbeat_loop(
client: &KeyComputeClient,
ollama_client: &OllamaClient,
session: &SessionData,
config: &NodeTokenConfig,
is_excluded: Arc<AtomicBool>,
stop_signal: Arc<AtomicBool>,
) {
let base_interval = Duration::from_secs(config.heartbeat_interval_secs);
let mut current_interval = base_interval;
let mut interval = tokio::time::interval(current_interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
info!(
"Starting heartbeat loop: interval={}s",
config.heartbeat_interval_secs
);
let mut consecutive_failures: u32 = 0;
while !stop_signal.load(Ordering::Relaxed) {
interval.tick().await;
let models = match ollama_client.list_models().await {
Ok(m) => m,
Err(e) => {
warn!("Failed to list Ollama models for heartbeat: {}", e);
continue;
}
};
let req = NodeHeartbeatRequest {
protocol_version: "node.v1".to_string(),
node_id: session.node_id,
session_id: session.session_id,
accepted_models: models,
};
match client.heartbeat(&req).await {
Ok(resp) => {
consecutive_failures = 0;
info!(
"Heartbeat: accepted={}, status={}, failure_count={}/{}",
resp.accepted,
resp.node_status,
resp.server_failure_count,
resp.failure_threshold
);
let was_excluded = is_excluded.load(Ordering::Relaxed);
let now_excluded = resp.node_status == "excluded";
is_excluded.store(now_excluded, Ordering::Relaxed);
if now_excluded && !was_excluded {
warn!("Node has been EXCLUDED - will stop poll but continue heartbeat");
current_interval = base_interval * 3;
interval = tokio::time::interval(current_interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
} else if !now_excluded && was_excluded {
info!("Node status changed from excluded to {}, restoring normal heartbeat interval",
resp.node_status);
current_interval = base_interval;
interval = tokio::time::interval(current_interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
}
if !resp.accepted {
warn!(
"Heartbeat not accepted by server, node_status={}",
resp.node_status
);
}
}
Err(e) => {
consecutive_failures += 1;
error!(
"Heartbeat failed (consecutive={}): {}",
consecutive_failures, e
);
}
}
}
info!("Heartbeat loop stopped");
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicBool;
#[test]
fn test_is_excluded_flag_update() {
let is_excluded = Arc::new(AtomicBool::new(false));
assert!(!is_excluded.load(Ordering::Relaxed));
is_excluded.store(true, Ordering::Relaxed);
assert!(is_excluded.load(Ordering::Relaxed));
is_excluded.store(false, Ordering::Relaxed);
assert!(!is_excluded.load(Ordering::Relaxed));
}
#[test]
fn test_heartbeat_interval_calculation() {
let base_interval = Duration::from_secs(30);
let excluded_interval = base_interval * 3;
assert_eq!(excluded_interval, Duration::from_secs(90));
let short_interval = Duration::from_secs(10);
assert_eq!(short_interval * 3, Duration::from_secs(30));
let long_interval = Duration::from_secs(60);
assert_eq!(long_interval * 3, Duration::from_secs(180));
}
#[test]
fn test_heartbeat_interval_edge_cases() {
let min_interval = Duration::from_secs(1);
assert_eq!(min_interval * 3, Duration::from_secs(3));
let zero_interval = Duration::from_secs(0);
assert_eq!(zero_interval * 3, Duration::from_secs(0));
}
#[test]
fn test_atomic_bool_concurrent_access() {
let is_excluded = Arc::new(AtomicBool::new(false));
let mut handles = vec![];
for i in 0..10 {
let flag = is_excluded.clone();
let handle = std::thread::spawn(move || {
if i % 2 == 0 {
flag.store(true, Ordering::Relaxed);
} else {
let _ = flag.load(Ordering::Relaxed);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let _ = is_excluded.load(Ordering::Relaxed);
}
}