use crate::bft::{BftConsensus, BftMessage};
use crate::network::{NetworkService, RpcMessage};
use crate::{ClusterError, Result};
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tokio::time::{interval, Duration};
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthenticatedMessage {
pub message: BftMessage,
pub sender: String,
pub sequence: u64,
pub signature: Vec<u8>,
pub timestamp: u64,
}
pub struct BftNetworkService {
node_id: String,
consensus: Arc<BftConsensus>,
network: Arc<NetworkService>,
sequence_counter: Arc<RwLock<u64>>,
message_cache: Arc<RwLock<MessageCache>>,
peer_keys: Arc<RwLock<HashMap<String, VerifyingKey>>>,
tx: mpsc::Sender<AuthenticatedMessage>,
rx: Arc<RwLock<mpsc::Receiver<AuthenticatedMessage>>>,
keypair: SigningKey,
}
struct MessageCache {
messages: HashMap<(String, u64), AuthenticatedMessage>,
highest_seq: HashMap<String, u64>,
max_size: usize,
}
impl MessageCache {
fn new(max_size: usize) -> Self {
MessageCache {
messages: HashMap::new(),
highest_seq: HashMap::new(),
max_size,
}
}
fn is_duplicate_or_old(&self, sender: &str, sequence: u64) -> bool {
if let Some(&highest) = self.highest_seq.get(sender) {
sequence <= highest
} else {
false
}
}
fn add_message(&mut self, msg: AuthenticatedMessage) {
let key = (msg.sender.clone(), msg.sequence);
self.messages.insert(key, msg.clone());
self.highest_seq
.entry(msg.sender.clone())
.and_modify(|seq| *seq = (*seq).max(msg.sequence))
.or_insert(msg.sequence);
if self.messages.len() > self.max_size {
self.evict_oldest();
}
}
fn evict_oldest(&mut self) {
let to_remove = self.messages.len() - self.max_size;
let mut entries: Vec<_> = self
.messages
.iter()
.map(|(k, v)| (k.clone(), v.timestamp))
.collect();
entries.sort_by_key(|(_, ts)| *ts);
for (key, _) in entries.iter().take(to_remove) {
self.messages.remove(key);
}
}
}
impl BftNetworkService {
pub fn new(
node_id: String,
consensus: Arc<BftConsensus>,
network: Arc<NetworkService>,
) -> Self {
let (tx, rx) = mpsc::channel(1000);
let seed_bytes: [u8; 32] = rand::random();
let keypair = SigningKey::from_bytes(&seed_bytes);
BftNetworkService {
node_id,
consensus,
network,
sequence_counter: Arc::new(RwLock::new(0)),
message_cache: Arc::new(RwLock::new(MessageCache::new(10000))),
peer_keys: Arc::new(RwLock::new(HashMap::new())),
tx,
rx: Arc::new(RwLock::new(rx)),
keypair,
}
}
pub fn with_keypair(
node_id: String,
consensus: Arc<BftConsensus>,
network: Arc<NetworkService>,
keypair: SigningKey,
) -> Self {
let (tx, rx) = mpsc::channel(1000);
BftNetworkService {
node_id,
consensus,
network,
sequence_counter: Arc::new(RwLock::new(0)),
message_cache: Arc::new(RwLock::new(MessageCache::new(10000))),
peer_keys: Arc::new(RwLock::new(HashMap::new())),
tx,
rx: Arc::new(RwLock::new(rx)),
keypair,
}
}
pub fn public_key(&self) -> VerifyingKey {
self.keypair.verifying_key()
}
pub fn public_key_bytes(&self) -> [u8; 32] {
self.keypair.verifying_key().to_bytes()
}
pub async fn register_peer(&self, peer_id: String, public_key: VerifyingKey) -> Result<()> {
let mut keys = self.peer_keys.write().await;
keys.insert(peer_id.clone(), public_key);
self.consensus.register_node(peer_id, public_key)?;
Ok(())
}
pub async fn start(self: Arc<Self>) -> Result<()> {
let processor = self.clone();
tokio::spawn(async move {
processor.process_messages().await;
});
let heartbeat = self.clone();
tokio::spawn(async move {
heartbeat.send_heartbeats().await;
});
let monitor = self.clone();
tokio::spawn(async move {
monitor.monitor_view_changes().await;
});
Ok(())
}
async fn process_messages(self: Arc<Self>) {
let mut rx = self.rx.write().await;
while let Some(auth_msg) = rx.recv().await {
match self.handle_authenticated_message(auth_msg).await {
Ok(_) => {}
Err(e) => error!("Failed to handle message: {}", e),
}
}
}
async fn handle_authenticated_message(&self, auth_msg: AuthenticatedMessage) -> Result<()> {
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
if current_time > auth_msg.timestamp + 300 {
return Err(ClusterError::Network("Message too old".to_string()));
}
let mut cache = self.message_cache.write().await;
if cache.is_duplicate_or_old(&auth_msg.sender, auth_msg.sequence) {
debug!("Duplicate or old message from {}", auth_msg.sender);
return Ok(());
}
if !self.verify_message_signature(&auth_msg).await? {
warn!("Invalid signature from {}", auth_msg.sender);
return Err(ClusterError::Network("Invalid signature".to_string()));
}
cache.add_message(auth_msg.clone());
drop(cache);
self.consensus
.handle_message(auth_msg.message, &auth_msg.sender)?;
Ok(())
}
pub async fn broadcast(&self, message: BftMessage) -> Result<()> {
let auth_msg = self.create_authenticated_message(message).await?;
let data = serde_json::to_vec(&auth_msg)
.map_err(|e| ClusterError::Network(format!("Serialization error: {e}")))?;
self.network.broadcast(RpcMessage::Bft { data }).await?;
Ok(())
}
pub async fn send_to(&self, peer_id: &str, message: BftMessage) -> Result<()> {
let auth_msg = self.create_authenticated_message(message).await?;
let data = serde_json::to_vec(&auth_msg)
.map_err(|e| ClusterError::Network(format!("Serialization error: {e}")))?;
self.network
.send_to(peer_id, RpcMessage::Bft { data })
.await?;
Ok(())
}
async fn create_authenticated_message(
&self,
message: BftMessage,
) -> Result<AuthenticatedMessage> {
let mut seq = self.sequence_counter.write().await;
*seq += 1;
let sequence = *seq;
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
let mut auth_msg = AuthenticatedMessage {
message,
sender: self.node_id.clone(),
sequence,
signature: vec![],
timestamp,
};
let msg_bytes = serde_json::to_vec(&auth_msg)
.map_err(|e| ClusterError::Network(format!("Serialization error: {e}")))?;
let signature = self.keypair.sign(&msg_bytes);
auth_msg.signature = signature.to_bytes().to_vec();
Ok(auth_msg)
}
async fn verify_message_signature(&self, auth_msg: &AuthenticatedMessage) -> Result<bool> {
let peer_keys = self.peer_keys.read().await;
let public_key = match peer_keys.get(&auth_msg.sender) {
Some(key) => key,
None => {
warn!("No public key found for peer: {}", auth_msg.sender);
return Ok(false);
}
};
let mut msg_for_verification = auth_msg.clone();
msg_for_verification.signature = vec![];
let msg_bytes = match serde_json::to_vec(&msg_for_verification) {
Ok(bytes) => bytes,
Err(e) => {
error!("Failed to serialize message for verification: {}", e);
return Ok(false);
}
};
if auth_msg.signature.len() != 64 {
warn!(
"Invalid signature length from {}: expected 64, got {}",
auth_msg.sender,
auth_msg.signature.len()
);
return Ok(false);
}
let mut signature_bytes = [0u8; 64];
signature_bytes.copy_from_slice(&auth_msg.signature);
let signature = ed25519_dalek::Signature::from_bytes(&signature_bytes);
match public_key.verify(&msg_bytes, &signature) {
Ok(_) => {
debug!("Signature verification successful for {}", auth_msg.sender);
Ok(true)
}
Err(e) => {
warn!(
"Signature verification failed for {}: {}",
auth_msg.sender, e
);
Ok(false)
}
}
}
async fn send_heartbeats(&self) {
let mut interval = interval(Duration::from_secs(1));
loop {
interval.tick().await;
match self.consensus.is_primary() {
Ok(true) => {
let heartbeat = BftMessage::Request {
client_id: format!("{}-heartbeat", self.node_id),
operation: b"HEARTBEAT".to_vec(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
signature: None,
};
if let Err(e) = self.broadcast(heartbeat).await {
warn!("Failed to send heartbeat: {}", e);
}
}
_ => {}
}
}
}
async fn monitor_view_changes(&self) {
let mut interval = interval(Duration::from_secs(5));
loop {
interval.tick().await;
match self.consensus.check_view_timeout() {
Ok(true) => {
info!("View change timeout detected");
if let Err(e) = self.initiate_view_change().await {
error!("Failed to initiate view change: {}", e);
}
}
_ => {}
}
}
}
async fn initiate_view_change(&self) -> Result<()> {
let current_view = self.consensus.current_view()?;
let new_view = current_view + 1;
info!("Initiating view change to view {}", new_view);
let prepared_messages = self.consensus.collect_prepared_messages()?;
info!(
"Collected {} prepared messages for view change",
prepared_messages.len()
);
let view_change = BftMessage::ViewChange {
new_view,
node_id: self.node_id.clone(),
prepared_messages,
signature: vec![],
};
self.broadcast(view_change).await?;
Ok(())
}
pub async fn handle_network_message(&self, data: Vec<u8>) -> Result<()> {
let auth_msg: AuthenticatedMessage = serde_json::from_slice(&data)
.map_err(|e| ClusterError::Network(format!("Deserialization error: {e}")))?;
self.tx
.send(auth_msg)
.await
.map_err(|e| ClusterError::Network(format!("Channel send error: {e}")))?;
Ok(())
}
pub async fn remove_peer(&self, peer_id: &str) -> Result<()> {
let mut keys = self.peer_keys.write().await;
keys.remove(peer_id);
info!("Removed public key for peer: {}", peer_id);
Ok(())
}
pub async fn get_trusted_peers(&self) -> Vec<String> {
let keys = self.peer_keys.read().await;
keys.keys().cloned().collect()
}
pub fn verify_signature(
&self,
message: &[u8],
signature: &[u8],
public_key: &VerifyingKey,
) -> Result<bool> {
if signature.len() != 64 {
return Ok(false);
}
let mut signature_bytes = [0u8; 64];
signature_bytes.copy_from_slice(signature);
let signature = Signature::from_bytes(&signature_bytes);
match public_key.verify(message, &signature) {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
pub fn sign_message(&self, message: &[u8]) -> Vec<u8> {
let signature = self.keypair.sign(message);
signature.to_bytes().to_vec()
}
pub async fn is_peer_trusted(&self, peer_id: &str) -> bool {
let keys = self.peer_keys.read().await;
keys.contains_key(peer_id)
}
}
#[derive(Debug, Clone, Default)]
pub struct BftMetrics {
pub messages_sent: u64,
pub messages_received: u64,
pub invalid_signatures: u64,
pub byzantine_nodes_detected: u64,
pub view_changes: u64,
pub consensus_rounds: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_cache() {
let mut cache = MessageCache::new(100);
let msg = AuthenticatedMessage {
message: BftMessage::Request {
client_id: "test".to_string(),
operation: vec![1, 2, 3],
timestamp: 1000,
signature: None,
},
sender: "node1".to_string(),
sequence: 1,
signature: vec![],
timestamp: 1000,
};
assert!(!cache.is_duplicate_or_old("node1", 1));
cache.add_message(msg.clone());
assert!(cache.is_duplicate_or_old("node1", 1));
assert!(!cache.is_duplicate_or_old("node1", 2));
}
#[test]
fn test_cache_eviction() {
let mut cache = MessageCache::new(2);
for i in 0..3 {
let msg = AuthenticatedMessage {
message: BftMessage::Request {
client_id: format!("test{}", i),
operation: vec![i as u8],
timestamp: i as u64,
signature: None,
},
sender: format!("node{}", i),
sequence: 1,
signature: vec![],
timestamp: i as u64,
};
cache.add_message(msg);
}
assert_eq!(cache.messages.len(), 2);
assert!(!cache.messages.contains_key(&("node0".to_string(), 1)));
}
}