use crate::cell::messaging::{CellMessage, CellMessageBus, CellMessageType};
use crate::models::{Capability, CapabilityExt, CapabilityType};
use crate::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tracing::{debug, info, instrument, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ElectionState {
Candidate,
Leader,
Follower,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LeadershipScore {
pub compute: f64,
pub communication: f64,
pub sensors: f64,
pub power: f64,
pub reliability: f64,
pub total: f64,
}
impl LeadershipScore {
pub fn from_capabilities(capabilities: &[Capability]) -> Self {
let mut compute = 0.0;
let mut communication = 0.0;
let mut sensors: f64 = 0.0;
let power = 1.0; let reliability = 1.0;
for cap in capabilities {
match cap.get_capability_type() {
CapabilityType::Compute => compute = cap.confidence as f64,
CapabilityType::Communication => communication = cap.confidence as f64,
CapabilityType::Sensor => sensors += 0.25, _ => {}
}
}
sensors = sensors.min(1.0);
let total = (compute * 0.30)
+ (communication * 0.25)
+ (sensors * 0.20)
+ (power * 0.15)
+ (reliability * 0.10);
Self {
compute,
communication,
sensors,
power,
reliability,
total,
}
}
pub fn compare(&self, other: &Self, my_id: &str, other_id: &str) -> std::cmp::Ordering {
match self.total.partial_cmp(&other.total) {
Some(std::cmp::Ordering::Equal) => {
my_id.cmp(other_id)
}
Some(ordering) => ordering,
None => my_id.cmp(other_id), }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ElectionRound {
pub round: u32,
pub started_at: u64,
pub candidates: HashMap<String, LeadershipScore>,
}
impl ElectionRound {
pub fn new(round: u32) -> Self {
let started_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
Self {
round,
started_at,
candidates: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
struct LeaderHeartbeat {
leader_id: String,
last_heartbeat: Instant,
missed_count: u32,
}
impl LeaderHeartbeat {
fn new(leader_id: String) -> Self {
Self {
leader_id,
last_heartbeat: Instant::now(),
missed_count: 0,
}
}
fn update(&mut self) {
self.last_heartbeat = Instant::now();
self.missed_count = 0;
}
fn is_failed(&self, heartbeat_interval: Duration, max_missed: u32) -> bool {
let elapsed = self.last_heartbeat.elapsed();
elapsed > heartbeat_interval * max_missed
}
}
pub struct LeaderElectionManager {
squad_id: String,
platform_id: String,
message_bus: Arc<CellMessageBus>,
state: Arc<Mutex<ElectionState>>,
current_round: Arc<Mutex<ElectionRound>>,
my_score: Arc<Mutex<LeadershipScore>>,
current_leader: Arc<Mutex<Option<String>>>,
leader_heartbeat: Arc<Mutex<Option<LeaderHeartbeat>>>,
#[allow(dead_code)]
election_timeout: Duration,
heartbeat_interval: Duration,
max_missed_heartbeats: u32,
}
impl LeaderElectionManager {
pub fn new(
squad_id: String,
platform_id: String,
message_bus: Arc<CellMessageBus>,
capabilities: Vec<Capability>,
) -> Self {
let my_score = LeadershipScore::from_capabilities(&capabilities);
let current_round = ElectionRound::new(1);
Self {
squad_id,
platform_id,
message_bus,
state: Arc::new(Mutex::new(ElectionState::Candidate)),
current_round: Arc::new(Mutex::new(current_round)),
my_score: Arc::new(Mutex::new(my_score)),
current_leader: Arc::new(Mutex::new(None)),
leader_heartbeat: Arc::new(Mutex::new(None)),
election_timeout: Duration::from_secs(5),
heartbeat_interval: Duration::from_secs(2),
max_missed_heartbeats: 3,
}
}
#[instrument(skip(self))]
pub fn start_election(&self) -> Result<()> {
info!("Starting leader election for squad {}", self.squad_id);
let score = self.my_score.lock().unwrap().clone();
let round = {
let mut current = self.current_round.lock().unwrap();
current.candidates.insert(self.platform_id.clone(), score);
current.round
};
self.announce_candidacy(round)?;
Ok(())
}
fn announce_candidacy(&self, round: u32) -> Result<()> {
debug!(
"Node {} announcing candidacy for round {}",
self.platform_id, round
);
let payload = CellMessageType::LeaderAnnounce {
leader_id: self.platform_id.clone(),
election_round: round,
};
self.message_bus.publish(payload)?;
Ok(())
}
#[instrument(skip(self, message))]
pub fn process_election_message(&self, message: &CellMessage) -> Result<()> {
match &message.payload {
CellMessageType::LeaderAnnounce {
leader_id,
election_round,
} => {
self.handle_leader_announce(leader_id, *election_round)?;
}
CellMessageType::Heartbeat { platform_id } => {
self.handle_heartbeat(platform_id)?;
}
_ => {
}
}
Ok(())
}
fn handle_leader_announce(&self, leader_id: &str, round: u32) -> Result<()> {
let current_round = {
let guard = self.current_round.lock().unwrap();
guard.round
};
if round < current_round {
debug!("Ignoring stale announcement from round {}", round);
return Ok(());
}
let my_state = *self.state.lock().unwrap();
if my_state != ElectionState::Candidate {
return Ok(());
}
let should_follow = self.should_follow_leader(leader_id)?;
if should_follow {
info!("Following leader: {}", leader_id);
*self.state.lock().unwrap() = ElectionState::Follower;
*self.current_leader.lock().unwrap() = Some(leader_id.to_string());
*self.leader_heartbeat.lock().unwrap() =
Some(LeaderHeartbeat::new(leader_id.to_string()));
} else {
debug!("My score is higher than {}, remaining candidate", leader_id);
}
Ok(())
}
fn should_follow_leader(&self, leader_id: &str) -> Result<bool> {
Ok(leader_id > self.platform_id.as_str())
}
fn handle_heartbeat(&self, platform_id: &str) -> Result<()> {
let current_leader = self.current_leader.lock().unwrap().clone();
if let Some(leader_id) = current_leader {
if platform_id == leader_id {
if let Some(ref mut heartbeat) = *self.leader_heartbeat.lock().unwrap() {
heartbeat.update();
debug!("Received heartbeat from leader {}", leader_id);
}
}
}
Ok(())
}
pub fn check_leader_failure(&self) -> Result<bool> {
let heartbeat = self.leader_heartbeat.lock().unwrap().clone();
if let Some(hb) = heartbeat {
if hb.is_failed(self.heartbeat_interval, self.max_missed_heartbeats) {
warn!(
"Leader {} has failed (no heartbeat for {:?})",
hb.leader_id,
hb.last_heartbeat.elapsed()
);
self.trigger_reelection()?;
return Ok(true);
}
}
Ok(false)
}
#[instrument(skip(self))]
fn trigger_reelection(&self) -> Result<()> {
info!("Triggering re-election for squad {}", self.squad_id);
*self.state.lock().unwrap() = ElectionState::Candidate;
*self.current_leader.lock().unwrap() = None;
*self.leader_heartbeat.lock().unwrap() = None;
let new_round = {
let mut round = self.current_round.lock().unwrap();
round.round += 1;
*round = ElectionRound::new(round.round);
round.round
};
self.announce_candidacy(new_round)?;
Ok(())
}
pub fn send_heartbeat_if_leader(&self) -> Result<()> {
let state = *self.state.lock().unwrap();
if state == ElectionState::Leader {
let payload = CellMessageType::Heartbeat {
platform_id: self.platform_id.clone(),
};
self.message_bus.publish(payload)?;
debug!("Sent leader heartbeat");
}
Ok(())
}
pub fn get_state(&self) -> ElectionState {
*self.state.lock().unwrap()
}
pub fn get_leader(&self) -> Option<String> {
self.current_leader.lock().unwrap().clone()
}
pub fn get_round(&self) -> u32 {
self.current_round.lock().unwrap().round
}
pub fn set_as_leader(&self) {
info!("Node {} set as leader", self.platform_id);
*self.state.lock().unwrap() = ElectionState::Leader;
*self.current_leader.lock().unwrap() = Some(self.platform_id.clone());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_leadership_score_computation() {
let capabilities = vec![
Capability::new(
"cap1".to_string(),
"compute".to_string(),
CapabilityType::Compute,
0.8,
),
Capability::new(
"cap2".to_string(),
"communication".to_string(),
CapabilityType::Communication,
0.6,
),
Capability::new(
"cap3".to_string(),
"sensor".to_string(),
CapabilityType::Sensor,
1.0,
),
];
let score = LeadershipScore::from_capabilities(&capabilities);
assert!((score.compute - 0.8).abs() < 0.001);
assert!((score.communication - 0.6).abs() < 0.001);
assert_eq!(score.sensors, 0.25); assert_eq!(score.power, 1.0); assert_eq!(score.reliability, 1.0);
assert!(score.total > 0.6 && score.total < 0.7);
}
#[test]
fn test_leadership_score_comparison() {
let score1 = LeadershipScore {
compute: 0.8,
communication: 0.6,
sensors: 0.5,
power: 0.9,
reliability: 1.0,
total: 0.75,
};
let score2 = LeadershipScore {
compute: 0.6,
communication: 0.5,
sensors: 0.4,
power: 0.8,
reliability: 0.9,
total: 0.65,
};
assert_eq!(
score1.compare(&score2, "platform_a", "platform_b"),
std::cmp::Ordering::Greater
);
assert_eq!(
score2.compare(&score1, "platform_b", "platform_a"),
std::cmp::Ordering::Less
);
}
#[test]
fn test_leadership_score_tie_breaking() {
let score1 = LeadershipScore {
compute: 0.8,
communication: 0.6,
sensors: 0.5,
power: 0.9,
reliability: 1.0,
total: 0.75,
};
let score2 = LeadershipScore {
compute: 0.8,
communication: 0.6,
sensors: 0.5,
power: 0.9,
reliability: 1.0,
total: 0.75,
};
assert_eq!(
score1.compare(&score2, "platform_a", "platform_b"),
std::cmp::Ordering::Less
);
assert_eq!(
score1.compare(&score2, "platform_b", "platform_a"),
std::cmp::Ordering::Greater
);
}
#[test]
fn test_election_round_creation() {
let round = ElectionRound::new(1);
assert_eq!(round.round, 1);
assert!(round.candidates.is_empty());
assert!(round.started_at > 0);
}
#[test]
fn test_leader_heartbeat_tracking() {
let mut heartbeat = LeaderHeartbeat::new("leader_1".to_string());
assert!(!heartbeat.is_failed(Duration::from_secs(2), 3));
heartbeat.update();
assert_eq!(heartbeat.missed_count, 0);
assert_eq!(heartbeat.leader_id, "leader_1");
}
#[test]
fn test_election_manager_creation() {
let message_bus = Arc::new(CellMessageBus::new(
"squad_1".to_string(),
"node_1".to_string(),
));
let capabilities = vec![Capability::new(
"cap1".to_string(),
"compute".to_string(),
CapabilityType::Compute,
0.8,
)];
let manager = LeaderElectionManager::new(
"squad_1".to_string(),
"node_1".to_string(),
message_bus,
capabilities,
);
assert_eq!(manager.get_state(), ElectionState::Candidate);
assert_eq!(manager.get_leader(), None);
assert_eq!(manager.get_round(), 1);
}
#[test]
fn test_set_as_leader() {
let message_bus = Arc::new(CellMessageBus::new(
"squad_1".to_string(),
"node_1".to_string(),
));
let manager = LeaderElectionManager::new(
"squad_1".to_string(),
"node_1".to_string(),
message_bus,
vec![],
);
manager.set_as_leader();
assert_eq!(manager.get_state(), ElectionState::Leader);
assert_eq!(manager.get_leader(), Some("node_1".to_string()));
}
#[test]
fn test_election_state_transitions() {
let message_bus = Arc::new(CellMessageBus::new(
"squad_1".to_string(),
"node_1".to_string(),
));
let manager = LeaderElectionManager::new(
"squad_1".to_string(),
"node_1".to_string(),
message_bus,
vec![],
);
assert_eq!(manager.get_state(), ElectionState::Candidate);
manager.set_as_leader();
assert_eq!(manager.get_state(), ElectionState::Leader);
let result = manager.send_heartbeat_if_leader();
assert!(result.is_ok());
}
#[test]
fn test_multiple_sensors_score() {
let capabilities = vec![
Capability::new(
"sensor1".to_string(),
"sensor".to_string(),
CapabilityType::Sensor,
1.0,
),
Capability::new(
"sensor2".to_string(),
"sensor".to_string(),
CapabilityType::Sensor,
1.0,
),
Capability::new(
"sensor3".to_string(),
"sensor".to_string(),
CapabilityType::Sensor,
1.0,
),
Capability::new(
"sensor4".to_string(),
"sensor".to_string(),
CapabilityType::Sensor,
1.0,
),
];
let score = LeadershipScore::from_capabilities(&capabilities);
assert_eq!(score.sensors, 1.0);
}
#[test]
fn test_start_election() {
let message_bus = Arc::new(CellMessageBus::new(
"squad_1".to_string(),
"node_1".to_string(),
));
let capabilities = vec![Capability::new(
"cap1".to_string(),
"compute".to_string(),
CapabilityType::Compute,
0.8,
)];
let manager = LeaderElectionManager::new(
"squad_1".to_string(),
"node_1".to_string(),
message_bus,
capabilities,
);
let result = manager.start_election();
assert!(result.is_ok());
let round = manager.current_round.lock().unwrap();
assert!(round.candidates.contains_key("node_1"));
}
}