use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::{broadcast, mpsc};
use async_trait::async_trait;
use crate::{Result, OdinError};
use crate::message::{OdinMessage, MessageType, MessagePriority};
use crate::config::OdinConfig;
use crate::metrics::MetricsCollector;
#[derive(Debug)]
pub struct OdinProtocol {
config: OdinConfig,
node_id: String,
connections: HashMap<String, Connection>,
message_tx: broadcast::Sender<OdinMessage>,
message_rx: broadcast::Receiver<OdinMessage>,
control_tx: mpsc::Sender<ControlMessage>,
metrics: MetricsCollector,
is_running: bool,
}
impl OdinProtocol {
pub fn new(config: OdinConfig) -> Result<Self> {
config.validate()?;
let (message_tx, message_rx) = broadcast::channel(1024);
let (control_tx, _control_rx) = mpsc::channel(256);
let metrics = MetricsCollector::new();
Ok(Self {
node_id: config.node_id.clone(),
config,
connections: HashMap::new(),
message_tx,
message_rx,
control_tx,
metrics,
is_running: false,
})
}
pub async fn start(&mut self) -> Result<()> {
if self.is_running {
return Err(OdinError::Protocol("Protocol already running".to_string()));
}
self.is_running = true;
self.metrics.record_startup();
self.initialize_network().await?;
self.start_heartbeat().await?;
Ok(())
}
pub async fn stop(&mut self) -> Result<()> {
if !self.is_running {
return Ok(());
}
self.is_running = false;
for (_, mut connection) in self.connections.drain() {
connection.close().await?;
}
self.metrics.record_shutdown();
Ok(())
}
pub async fn send_message(&self, target_node: &str, content: &str, priority: MessagePriority) -> Result<String> {
if !self.is_running {
return Err(OdinError::Protocol("Protocol not running".to_string()));
}
let message = OdinMessage::new(
MessageType::Standard,
&self.node_id,
target_node,
content,
priority,
);
self.metrics.record_message_sent();
self.message_tx.send(message.clone()).map_err(|e| {
OdinError::Network(format!("Failed to send message: {}", e))
})?;
Ok(message.id.clone())
}
pub async fn broadcast_message(&self, content: &str, priority: MessagePriority) -> Result<String> {
if !self.is_running {
return Err(OdinError::Protocol("Protocol not running".to_string()));
}
let message = OdinMessage::new(
MessageType::Broadcast,
&self.node_id,
"all",
content,
priority,
);
self.metrics.record_broadcast_sent();
self.message_tx.send(message.clone()).map_err(|e| {
OdinError::Network(format!("Failed to broadcast message: {}", e))
})?;
Ok(message.id.clone())
}
pub fn subscribe_to_messages(&self) -> broadcast::Receiver<OdinMessage> {
self.message_tx.subscribe()
}
pub fn get_metrics(&self) -> HashMap<String, f64> {
self.metrics.get_metrics()
}
pub fn get_status(&self) -> NodeStatus {
NodeStatus {
node_id: self.node_id.clone(),
is_running: self.is_running,
connection_count: self.connections.len(),
uptime: self.metrics.get_uptime(),
messages_sent: self.metrics.get_messages_sent(),
messages_received: self.metrics.get_messages_received(),
}
}
async fn initialize_network(&mut self) -> Result<()> {
let connection = Connection::new(&self.config.network_endpoint).await?;
self.connections.insert("default".to_string(), connection);
Ok(())
}
async fn start_heartbeat(&self) -> Result<()> {
let interval = self.config.heartbeat_interval;
let tx = self.message_tx.clone();
let node_id = self.node_id.clone();
tokio::spawn(async move {
let mut interval_timer = tokio::time::interval(interval);
loop {
interval_timer.tick().await;
let heartbeat = OdinMessage::new(
MessageType::Heartbeat,
&node_id,
"all",
"heartbeat",
MessagePriority::Low,
);
if tx.send(heartbeat).is_err() {
break; }
}
});
Ok(())
}
}
#[derive(Debug)]
pub struct Connection {
endpoint: String,
is_connected: bool,
}
impl Connection {
pub async fn new(endpoint: &str) -> Result<Self> {
Ok(Self {
endpoint: endpoint.to_string(),
is_connected: true,
})
}
pub async fn close(&mut self) -> Result<()> {
self.is_connected = false;
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum ControlMessage {
Connect(String),
Disconnect(String),
Shutdown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeStatus {
pub node_id: String,
pub is_running: bool,
pub connection_count: usize,
pub uptime: f64,
pub messages_sent: u64,
pub messages_received: u64,
}
#[async_trait]
pub trait ProtocolHandler: Send + Sync {
async fn handle_message(&self, message: &OdinMessage) -> Result<Option<OdinMessage>>;
async fn handle_connection_event(&self, event: ConnectionEvent) -> Result<()>;
}
#[derive(Debug, Clone)]
pub enum ConnectionEvent {
Connected(String),
Disconnected(String),
Error(String, String),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::OdinConfig;
use tokio::time::{timeout, Duration};
#[tokio::test]
async fn test_protocol_creation() {
let config = OdinConfig::default();
let protocol = OdinProtocol::new(config).unwrap();
assert!(!protocol.is_running);
assert_eq!(protocol.connections.len(), 0);
}
#[tokio::test]
async fn test_message_sending() {
let config = OdinConfig::default();
let mut protocol = OdinProtocol::new(config).unwrap();
protocol.start().await.unwrap();
let message_id = protocol
.send_message("target-node", "test message", MessagePriority::Normal)
.await
.unwrap();
assert!(!message_id.is_empty());
protocol.stop().await.unwrap();
}
#[tokio::test]
async fn test_message_subscription() {
let config = OdinConfig::default();
let mut protocol = OdinProtocol::new(config).unwrap();
protocol.start().await.unwrap();
let mut rx = protocol.subscribe_to_messages();
protocol
.send_message("target-node", "test message", MessagePriority::Normal)
.await
.unwrap();
let received = timeout(Duration::from_millis(100), rx.recv()).await;
assert!(received.is_ok());
let message = received.unwrap().unwrap();
assert_eq!(message.content, "test message");
assert_eq!(message.target_node, "target-node");
protocol.stop().await.unwrap();
}
#[tokio::test]
async fn test_protocol_status() {
let config = OdinConfig::default();
let protocol = OdinProtocol::new(config).unwrap();
let status = protocol.get_status();
assert!(!status.is_running);
assert_eq!(status.connection_count, 0);
assert_eq!(status.messages_sent, 0);
}
}