use crate::{ZoeyError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::sync::mpsc;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeInfo {
pub id: uuid::Uuid,
pub name: String,
pub address: String,
pub status: NodeStatus,
pub agents: Vec<uuid::Uuid>,
pub cpu_usage: f32,
pub memory_usage: f32,
pub last_heartbeat: i64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NodeStatus {
Healthy,
Degraded,
Unhealthy,
Offline,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedMessage {
pub id: uuid::Uuid,
pub from_node: uuid::Uuid,
pub to_node: uuid::Uuid,
pub from_agent: uuid::Uuid,
pub to_agent: uuid::Uuid,
pub payload: serde_json::Value,
pub message_type: String,
pub timestamp: i64,
}
pub struct DistributedRuntime {
node_id: uuid::Uuid,
nodes: Arc<RwLock<HashMap<uuid::Uuid, NodeInfo>>>,
message_tx: mpsc::UnboundedSender<DistributedMessage>,
message_rx: Arc<RwLock<mpsc::UnboundedReceiver<DistributedMessage>>>,
agent_locations: Arc<RwLock<HashMap<uuid::Uuid, uuid::Uuid>>>,
pending_count: Arc<AtomicUsize>,
messages_sent: Arc<AtomicUsize>,
messages_received: Arc<AtomicUsize>,
}
impl DistributedRuntime {
pub fn new(node_id: uuid::Uuid) -> Self {
let (tx, rx) = mpsc::unbounded_channel();
Self {
node_id,
nodes: Arc::new(RwLock::new(HashMap::new())),
message_tx: tx,
message_rx: Arc::new(RwLock::new(rx)),
agent_locations: Arc::new(RwLock::new(HashMap::new())),
pending_count: Arc::new(AtomicUsize::new(0)),
messages_sent: Arc::new(AtomicUsize::new(0)),
messages_received: Arc::new(AtomicUsize::new(0)),
}
}
pub fn register_node(&self, node: NodeInfo) -> Result<()> {
info!("Registering node {} at {}", node.name, node.address);
debug!("Node {} has {} agents", node.id, node.agents.len());
for agent_id in &node.agents {
debug!("Mapping agent {} to node {}", agent_id, node.id);
self.agent_locations
.write()
.unwrap()
.insert(*agent_id, node.id);
}
self.nodes.write().unwrap().insert(node.id, node);
debug!(
"Total nodes in cluster: {}",
self.nodes.read().unwrap().len()
);
Ok(())
}
pub fn unregister_node(&self, node_id: uuid::Uuid) -> Result<()> {
info!("Unregistering node {}", node_id);
if let Some(node) = self.nodes.write().unwrap().remove(&node_id) {
for agent_id in &node.agents {
self.agent_locations.write().unwrap().remove(agent_id);
}
}
Ok(())
}
pub async fn send_to_agent(
&self,
from_agent: uuid::Uuid,
to_agent: uuid::Uuid,
payload: serde_json::Value,
message_type: String,
) -> Result<()> {
debug!(
"Sending {} message from agent {} to agent {}",
message_type, from_agent, to_agent
);
let to_node = self
.agent_locations
.read()
.unwrap()
.get(&to_agent)
.copied()
.ok_or_else(|| {
ZoeyError::not_found(format!("Agent {} not found in cluster", to_agent))
})?;
debug!("Target agent {} is on node {}", to_agent, to_node);
let message = DistributedMessage {
id: uuid::Uuid::new_v4(),
from_node: self.node_id,
to_node,
from_agent,
to_agent,
payload,
message_type: message_type.clone(),
timestamp: chrono::Utc::now().timestamp(),
};
self.message_tx
.send(message)
.map_err(|e| ZoeyError::other(format!("Failed to send message: {}", e)))?;
self.pending_count.fetch_add(1, Ordering::SeqCst);
self.messages_sent.fetch_add(1, Ordering::SeqCst);
debug!(
"Message queued successfully (pending: {})",
self.pending_count.load(Ordering::SeqCst)
);
Ok(())
}
pub fn get_agent_node(&self, agent_id: uuid::Uuid) -> Option<uuid::Uuid> {
self.agent_locations.read().unwrap().get(&agent_id).copied()
}
pub fn get_nodes(&self) -> Vec<NodeInfo> {
self.nodes.read().unwrap().values().cloned().collect()
}
pub fn get_healthy_nodes(&self) -> Vec<NodeInfo> {
self.nodes
.read()
.unwrap()
.values()
.filter(|n| n.status == NodeStatus::Healthy)
.cloned()
.collect()
}
pub fn find_best_node(&self) -> Option<uuid::Uuid> {
let nodes = self.get_healthy_nodes();
if nodes.is_empty() {
warn!("No healthy nodes available for load balancing");
return None;
}
debug!("Finding best node among {} healthy nodes", nodes.len());
let best = nodes
.iter()
.min_by(|a, b| {
let load_a = a.cpu_usage + a.memory_usage;
let load_b = b.cpu_usage + b.memory_usage;
load_a
.partial_cmp(&load_b)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|n| {
let load = n.cpu_usage + n.memory_usage;
debug!("Selected node {} with load {:.2}", n.name, load);
n.id
});
best
}
pub fn heartbeat(&self, node_id: uuid::Uuid, cpu_usage: f32, memory_usage: f32) -> Result<()> {
if let Some(node) = self.nodes.write().unwrap().get_mut(&node_id) {
let old_status = node.status;
node.cpu_usage = cpu_usage;
node.memory_usage = memory_usage;
node.last_heartbeat = chrono::Utc::now().timestamp();
node.status = if cpu_usage > 0.9 || memory_usage > 0.9 {
NodeStatus::Degraded
} else if cpu_usage > 0.95 || memory_usage > 0.95 {
NodeStatus::Unhealthy
} else {
NodeStatus::Healthy
};
if old_status != node.status {
info!(
"Node {} status changed: {:?} -> {:?}",
node.name, old_status, node.status
);
}
debug!(
"Node {} heartbeat: CPU {:.1}%, Memory {:.1}%",
node.name,
cpu_usage * 100.0,
memory_usage * 100.0
);
} else {
warn!("Received heartbeat from unknown node {}", node_id);
}
Ok(())
}
pub fn check_node_health(&self, timeout_seconds: i64) -> Vec<uuid::Uuid> {
let now = chrono::Utc::now().timestamp();
let mut dead_nodes = Vec::new();
for (node_id, node) in self.nodes.read().unwrap().iter() {
if now - node.last_heartbeat > timeout_seconds {
warn!(
"Node {} hasn't sent heartbeat for {} seconds",
node.name,
now - node.last_heartbeat
);
dead_nodes.push(*node_id);
}
}
dead_nodes
}
pub fn try_recv_message(&self) -> Option<DistributedMessage> {
let mut rx = self.message_rx.write().unwrap();
match rx.try_recv() {
Ok(msg) => {
self.pending_count.fetch_sub(1, Ordering::SeqCst);
self.messages_received.fetch_add(1, Ordering::SeqCst);
Some(msg)
}
Err(_) => None,
}
}
pub async fn receive_messages<F>(&self, mut handler: F) -> Result<()>
where
F: FnMut(
DistributedMessage,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
+ Send,
{
loop {
let message = {
let mut rx = self.message_rx.write().unwrap();
match rx.try_recv() {
Ok(msg) => Some(msg),
Err(mpsc::error::TryRecvError::Empty) => None,
Err(mpsc::error::TryRecvError::Disconnected) => {
warn!("Message channel disconnected");
return Err(ZoeyError::other("Message channel disconnected"));
}
}
};
if let Some(msg) = message {
debug!("Received message {} from node {}", msg.id, msg.from_node);
handler(msg).await?;
} else {
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
}
pub async fn process_pending_messages<F>(&self, handler: F) -> Result<usize>
where
F: Fn(&DistributedMessage) -> Result<()>,
{
let mut processed = 0;
loop {
let message = self.try_recv_message();
match message {
Some(msg) => {
debug!("Processing message {} type={}", msg.id, msg.message_type);
if msg.to_node != self.node_id {
warn!(
"Received message for wrong node: expected {}, got {}",
self.node_id, msg.to_node
);
continue;
}
match handler(&msg) {
Ok(_) => {
processed += 1;
debug!("Successfully processed message {}", msg.id);
}
Err(e) => {
warn!("Failed to process message {}: {}", msg.id, e);
}
}
}
None => {
break;
}
}
}
if processed > 0 {
info!("Processed {} distributed message(s)", processed);
}
Ok(processed)
}
pub fn pending_message_count(&self) -> usize {
self.pending_count.load(Ordering::SeqCst)
}
pub fn total_messages_sent(&self) -> usize {
self.messages_sent.load(Ordering::SeqCst)
}
pub fn total_messages_received(&self) -> usize {
self.messages_received.load(Ordering::SeqCst)
}
pub fn get_message_stats(&self) -> MessageStats {
MessageStats {
pending: self.pending_count.load(Ordering::SeqCst),
sent: self.messages_sent.load(Ordering::SeqCst),
received: self.messages_received.load(Ordering::SeqCst),
}
}
pub fn reset_message_stats(&self) {
self.messages_sent.store(0, Ordering::SeqCst);
self.messages_received.store(0, Ordering::SeqCst);
info!("Message statistics reset for node {}", self.node_id);
}
}
#[derive(Debug, Clone, Copy)]
pub struct MessageStats {
pub pending: usize,
pub sent: usize,
pub received: usize,
}
#[derive(Debug, Clone)]
pub struct ClusterConfig {
pub heartbeat_interval: Duration,
pub node_timeout: Duration,
pub auto_rebalance: bool,
pub replication_factor: usize,
}
impl Default for ClusterConfig {
fn default() -> Self {
Self {
heartbeat_interval: Duration::from_secs(5),
node_timeout: Duration::from_secs(30),
auto_rebalance: true,
replication_factor: 1,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distributed_runtime() {
let runtime = DistributedRuntime::new(uuid::Uuid::new_v4());
assert_eq!(runtime.get_nodes().len(), 0);
}
#[test]
fn test_node_registration() {
let runtime = DistributedRuntime::new(uuid::Uuid::new_v4());
let node = NodeInfo {
id: uuid::Uuid::new_v4(),
name: "node1".to_string(),
address: "127.0.0.1:8080".to_string(),
status: NodeStatus::Healthy,
agents: vec![],
cpu_usage: 0.5,
memory_usage: 0.6,
last_heartbeat: chrono::Utc::now().timestamp(),
};
runtime.register_node(node.clone()).unwrap();
assert_eq!(runtime.get_nodes().len(), 1);
assert_eq!(runtime.get_healthy_nodes().len(), 1);
}
#[test]
fn test_load_balancing() {
let runtime = DistributedRuntime::new(uuid::Uuid::new_v4());
let node1 = NodeInfo {
id: uuid::Uuid::new_v4(),
name: "node1".to_string(),
address: "127.0.0.1:8080".to_string(),
status: NodeStatus::Healthy,
agents: vec![],
cpu_usage: 0.8, memory_usage: 0.7,
last_heartbeat: chrono::Utc::now().timestamp(),
};
let node2 = NodeInfo {
id: uuid::Uuid::new_v4(),
name: "node2".to_string(),
address: "127.0.0.1:8081".to_string(),
status: NodeStatus::Healthy,
agents: vec![],
cpu_usage: 0.3, memory_usage: 0.4,
last_heartbeat: chrono::Utc::now().timestamp(),
};
runtime.register_node(node1).unwrap();
runtime.register_node(node2.clone()).unwrap();
let best = runtime.find_best_node().unwrap();
assert_eq!(best, node2.id);
}
#[tokio::test]
async fn test_cross_node_messaging() {
let runtime = DistributedRuntime::new(uuid::Uuid::new_v4());
let agent1 = uuid::Uuid::new_v4();
let agent2 = uuid::Uuid::new_v4();
let node1 = NodeInfo {
id: uuid::Uuid::new_v4(),
name: "node1".to_string(),
address: "127.0.0.1:8080".to_string(),
status: NodeStatus::Healthy,
agents: vec![agent1],
cpu_usage: 0.5,
memory_usage: 0.5,
last_heartbeat: chrono::Utc::now().timestamp(),
};
let node2 = NodeInfo {
id: uuid::Uuid::new_v4(),
name: "node2".to_string(),
address: "127.0.0.1:8081".to_string(),
status: NodeStatus::Healthy,
agents: vec![agent2],
cpu_usage: 0.5,
memory_usage: 0.5,
last_heartbeat: chrono::Utc::now().timestamp(),
};
runtime.register_node(node1).unwrap();
runtime.register_node(node2).unwrap();
let result = runtime
.send_to_agent(
agent1,
agent2,
serde_json::json!({"message": "Hello from another node!"}),
"greeting".to_string(),
)
.await;
assert!(result.is_ok());
}
}