use crate::types::ProtocolError;
use qudag_crypto::ml_kem::{MlKem768, KemError};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use thiserror::Error;
use tracing::{debug, error, info, warn};
#[derive(Debug, Error)]
pub enum RoutingError {
#[error("Route not found")]
RouteNotFound,
#[error("Invalid route")]
InvalidRoute,
#[error("Routing loop detected")]
RoutingLoop,
#[error("Encryption failed")]
EncryptionFailed,
#[error("Decryption failed")]
DecryptionFailed,
#[error("Hop limit exceeded")]
HopLimitExceeded,
#[error("Peer not found")]
PeerNotFound,
#[error("Network error: {0}")]
NetworkError(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PeerInfo {
pub peer_id: [u8; 32],
pub address: SocketAddr,
pub public_key: Vec<u8>,
pub latency: Duration,
pub bandwidth: u64,
pub reliability: f64,
pub last_seen: Instant,
}
impl PeerInfo {
pub fn new(peer_id: [u8; 32], address: SocketAddr, public_key: Vec<u8>) -> Self {
Self {
peer_id,
address,
public_key,
latency: Duration::from_millis(100),
bandwidth: 1_000_000, reliability: 1.0,
last_seen: Instant::now(),
}
}
pub fn update_metrics(&mut self, latency: Duration, bandwidth: u64, success: bool) {
self.latency = latency;
self.bandwidth = bandwidth;
self.last_seen = Instant::now();
let alpha = 0.1;
let new_reliability = if success { 1.0 } else { 0.0 };
self.reliability = alpha * new_reliability + (1.0 - alpha) * self.reliability;
}
pub fn is_alive(&self, timeout: Duration) -> bool {
self.last_seen.elapsed() < timeout
}
pub fn routing_score(&self) -> f64 {
let latency_score = 1.0 / (self.latency.as_secs_f64() + 0.001);
let bandwidth_score = (self.bandwidth as f64).log10() / 10.0;
self.reliability * (0.5 * latency_score + 0.3 * bandwidth_score + 0.2)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnionLayer {
pub next_hop: Vec<u8>,
pub payload: Vec<u8>,
pub auth_tag: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnionHeader {
pub layers: Vec<OnionLayer>,
pub hop_count: u8,
pub message_id: [u8; 16],
}
#[derive(Debug, Clone)]
pub struct RoutingPath {
pub hops: Vec<[u8; 32]>,
pub quality_score: f64,
pub estimated_latency: Duration,
pub estimated_bandwidth: u64,
pub created_at: Instant,
}
impl RoutingPath {
pub fn new(hops: Vec<[u8; 32]>) -> Self {
Self {
hops,
quality_score: 0.0,
estimated_latency: Duration::from_millis(0),
estimated_bandwidth: 0,
created_at: Instant::now(),
}
}
pub fn calculate_quality(&mut self, peer_table: &PeerTable) {
let mut total_latency = Duration::from_millis(0);
let mut min_bandwidth = u64::MAX;
let mut reliability_product = 1.0;
for peer_id in &self.hops {
if let Some(peer) = peer_table.get_peer(peer_id) {
total_latency += peer.latency;
min_bandwidth = min_bandwidth.min(peer.bandwidth);
reliability_product *= peer.reliability;
}
}
self.estimated_latency = total_latency;
self.estimated_bandwidth = min_bandwidth;
let latency_score = 1.0 / (total_latency.as_secs_f64() + 0.001);
let bandwidth_score = (min_bandwidth as f64).log10() / 10.0;
self.quality_score = reliability_product * (0.5 * latency_score + 0.5 * bandwidth_score);
}
pub fn is_valid(&self, peer_table: &PeerTable, max_age: Duration) -> bool {
if self.created_at.elapsed() > max_age {
return false;
}
for peer_id in &self.hops {
if let Some(peer) = peer_table.get_peer(peer_id) {
if !peer.is_alive(Duration::from_secs(60)) {
return false;
}
} else {
return false;
}
}
true
}
}
#[derive(Debug)]
pub struct PeerTable {
peers: HashMap<[u8; 32], PeerInfo>,
connections: HashMap<[u8; 32], Vec<[u8; 32]>>,
}
impl PeerTable {
pub fn new() -> Self {
Self {
peers: HashMap::new(),
connections: HashMap::new(),
}
}
pub fn add_peer(&mut self, peer: PeerInfo) {
let peer_id = peer.peer_id;
self.peers.insert(peer_id, peer);
self.connections.entry(peer_id).or_insert_with(Vec::new);
}
pub fn remove_peer(&mut self, peer_id: &[u8; 32]) {
self.peers.remove(peer_id);
self.connections.remove(peer_id);
for connections in self.connections.values_mut() {
connections.retain(|id| id != peer_id);
}
}
pub fn get_peer(&self, peer_id: &[u8; 32]) -> Option<&PeerInfo> {
self.peers.get(peer_id)
}
pub fn get_all_peers(&self) -> Vec<&PeerInfo> {
self.peers.values().collect()
}
pub fn add_connection(&mut self, peer1: [u8; 32], peer2: [u8; 32]) {
self.connections.entry(peer1).or_default().push(peer2);
self.connections.entry(peer2).or_default().push(peer1);
}
pub fn get_connections(&self, peer_id: &[u8; 32]) -> Vec<[u8; 32]> {
self.connections.get(peer_id).cloned().unwrap_or_default()
}
pub fn find_shortest_path(&self, source: &[u8; 32], destination: &[u8; 32]) -> Option<RoutingPath> {
if source == destination {
return Some(RoutingPath::new(vec![*source]));
}
let mut distances: HashMap<[u8; 32], f64> = HashMap::new();
let mut previous: HashMap<[u8; 32], [u8; 32]> = HashMap::new();
let mut unvisited: Vec<[u8; 32]> = self.peers.keys().copied().collect();
for peer_id in &unvisited {
distances.insert(*peer_id, if peer_id == source { 0.0 } else { f64::INFINITY });
}
while !unvisited.is_empty() {
let current_idx = unvisited
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
distances[a].partial_cmp(&distances[b]).unwrap()
})
.map(|(i, _)| i)?;
let current = unvisited.remove(current_idx);
if current == *destination {
break;
}
if distances[¤t] == f64::INFINITY {
break;
}
for neighbor in self.get_connections(¤t) {
if let Some(neighbor_peer) = self.get_peer(&neighbor) {
let weight = 1.0 / neighbor_peer.routing_score(); let alt_distance = distances[¤t] + weight;
if alt_distance < distances[&neighbor] {
distances.insert(neighbor, alt_distance);
previous.insert(neighbor, current);
}
}
}
}
if !previous.contains_key(destination) {
return None;
}
let mut path = Vec::new();
let mut current = *destination;
while current != *source {
path.push(current);
current = previous[¤t];
}
path.push(*source);
path.reverse();
let mut routing_path = RoutingPath::new(path);
routing_path.calculate_quality(self);
Some(routing_path)
}
pub fn find_multiple_paths(
&self,
source: &[u8; 32],
destination: &[u8; 32],
count: usize,
) -> Vec<RoutingPath> {
let mut paths = Vec::new();
let mut excluded_nodes = Vec::new();
for _ in 0..count {
let mut temp_table = self.clone();
for node in &excluded_nodes {
temp_table.remove_peer(node);
}
if let Some(path) = temp_table.find_shortest_path(source, destination) {
for (i, hop) in path.hops.iter().enumerate() {
if i > 0 && i < path.hops.len() - 1 {
excluded_nodes.push(*hop);
}
}
paths.push(path);
} else {
break;
}
}
paths
}
pub fn cleanup_dead_peers(&mut self, timeout: Duration) {
let dead_peers: Vec<[u8; 32]> = self.peers
.iter()
.filter(|(_, peer)| !peer.is_alive(timeout))
.map(|(id, _)| *id)
.collect();
for peer_id in dead_peers {
warn!("Removing dead peer: {:?}", hex::encode(peer_id));
self.remove_peer(&peer_id);
}
}
}
impl Clone for PeerTable {
fn clone(&self) -> Self {
Self {
peers: self.peers.clone(),
connections: self.connections.clone(),
}
}
}
impl Default for PeerTable {
fn default() -> Self {
Self::new()
}
}
pub struct MessageRouter {
local_peer_id: [u8; 32],
peer_table: PeerTable,
path_cache: HashMap<[u8; 32], Vec<RoutingPath>>,
onion_keys: HashMap<[u8; 32], Vec<u8>>,
message_tracking: HashMap<[u8; 16], RouteTrackingInfo>,
ml_kem: MlKem768,
}
#[derive(Debug)]
struct RouteTrackingInfo {
source: [u8; 32],
destination: [u8; 32],
created_at: Instant,
hop_count: u8,
}
impl MessageRouter {
pub fn new(local_peer_id: [u8; 32]) -> Result<Self, RoutingError> {
Ok(Self {
local_peer_id,
peer_table: PeerTable::new(),
path_cache: HashMap::new(),
onion_keys: HashMap::new(),
message_tracking: HashMap::new(),
ml_kem: MlKem768::new().map_err(|_| RoutingError::EncryptionFailed)?,
})
}
pub fn add_peer(&mut self, peer: PeerInfo) {
let peer_id = peer.peer_id;
self.peer_table.add_peer(peer);
self.path_cache.clear();
}
pub fn remove_peer(&mut self, peer_id: &[u8; 32]) {
self.peer_table.remove_peer(peer_id);
self.path_cache.remove(peer_id);
self.onion_keys.remove(peer_id);
}
pub fn find_route(&mut self, destination: &[u8; 32]) -> Result<RoutingPath, RoutingError> {
if let Some(cached_paths) = self.path_cache.get(destination) {
if let Some(valid_path) = cached_paths.iter().find(|path| path.is_valid(&self.peer_table, Duration::from_secs(300))) {
return Ok(valid_path.clone());
}
}
let path = self.peer_table
.find_shortest_path(&self.local_peer_id, destination)
.ok_or(RoutingError::RouteNotFound)?;
self.path_cache.entry(*destination)
.or_insert_with(Vec::new)
.push(path.clone());
Ok(path)
}
pub fn create_onion_header(
&self,
path: &RoutingPath,
payload: &[u8],
message_id: [u8; 16],
) -> Result<OnionHeader, RoutingError> {
let mut layers = Vec::new();
let mut current_payload = payload.to_vec();
for (i, peer_id) in path.hops.iter().rev().enumerate() {
let peer = self.peer_table.get_peer(peer_id)
.ok_or(RoutingError::PeerNotFound)?;
let (ciphertext, _) = self.ml_kem.encapsulate(&peer.public_key)
.map_err(|_| RoutingError::EncryptionFailed)?;
let next_hop = if i == path.hops.len() - 1 {
vec![] } else {
path.hops[path.hops.len() - i - 2].to_vec()
};
let layer = OnionLayer {
next_hop: next_hop.clone(),
payload: current_payload.clone(),
auth_tag: vec![], };
layers.push(layer);
current_payload = ciphertext;
}
layers.reverse();
Ok(OnionHeader {
layers,
hop_count: path.hops.len() as u8,
message_id,
})
}
pub fn process_onion_layer(
&self,
header: &mut OnionHeader,
secret_key: &[u8],
) -> Result<(Option<[u8; 32]>, Vec<u8>), RoutingError> {
if header.layers.is_empty() {
return Err(RoutingError::InvalidRoute);
}
let layer = header.layers.remove(0);
let decrypted_payload = self.ml_kem.decapsulate(secret_key, &layer.payload)
.map_err(|_| RoutingError::DecryptionFailed)?;
let next_hop = if layer.next_hop.is_empty() {
None
} else if layer.next_hop.len() == 32 {
let mut hop = [0u8; 32];
hop.copy_from_slice(&layer.next_hop);
Some(hop)
} else {
return Err(RoutingError::InvalidRoute);
};
Ok((next_hop, decrypted_payload))
}
pub async fn route_message(
&mut self,
destination: &[u8; 32],
payload: &[u8],
) -> Result<(), RoutingError> {
let path = self.find_route(destination)?;
if path.hops.len() > 10 {
return Err(RoutingError::HopLimitExceeded);
}
let message_id: [u8; 16] = rand::random();
self.message_tracking.insert(message_id, RouteTrackingInfo {
source: self.local_peer_id,
destination: *destination,
created_at: Instant::now(),
hop_count: path.hops.len() as u8,
});
let onion_header = self.create_onion_header(&path, payload, message_id)?;
if let Some(first_hop) = path.hops.get(1) { info!("Routing message {} to {:?} via path of {} hops",
hex::encode(message_id), hex::encode(destination), path.hops.len());
debug!("Sending onion header to first hop: {:?}", hex::encode(first_hop));
}
Ok(())
}
pub fn update_peer_metrics(
&mut self,
peer_id: &[u8; 32],
latency: Duration,
bandwidth: u64,
success: bool,
) {
if let Some(peer) = self.peer_table.peers.get_mut(peer_id) {
peer.update_metrics(latency, bandwidth, success);
if !success || peer.reliability < 0.8 {
self.path_cache.clear();
}
}
}
pub fn cleanup(&mut self) {
self.peer_table.cleanup_dead_peers(Duration::from_secs(300));
let now = Instant::now();
self.message_tracking.retain(|_, info| {
now.duration_since(info.created_at) < Duration::from_secs(3600)
});
self.path_cache.retain(|_, paths| {
paths.retain(|path| path.is_valid(&self.peer_table, Duration::from_secs(300)));
!paths.is_empty()
});
}
pub fn get_stats(&self) -> RoutingStats {
RoutingStats {
peer_count: self.peer_table.peers.len(),
cached_paths: self.path_cache.len(),
tracked_messages: self.message_tracking.len(),
average_path_length: self.path_cache.values()
.flatten()
.map(|path| path.hops.len() as f64)
.sum::<f64>() / self.path_cache.len() as f64,
}
}
}
#[derive(Debug, Clone)]
pub struct RoutingStats {
pub peer_count: usize,
pub cached_paths: usize,
pub tracked_messages: usize,
pub average_path_length: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::SocketAddr;
fn create_test_peer(id: u8, addr: &str) -> PeerInfo {
let mut peer_id = [0u8; 32];
peer_id[0] = id;
PeerInfo::new(
peer_id,
addr.parse::<SocketAddr>().unwrap(),
vec![id; 32],
)
}
#[test]
fn test_peer_info_creation() {
let peer = create_test_peer(1, "127.0.0.1:8000");
assert_eq!(peer.peer_id[0], 1);
assert_eq!(peer.reliability, 1.0);
assert!(peer.is_alive(Duration::from_secs(1)));
}
#[test]
fn test_peer_table_operations() {
let mut table = PeerTable::new();
let peer1 = create_test_peer(1, "127.0.0.1:8001");
let peer2 = create_test_peer(2, "127.0.0.1:8002");
table.add_peer(peer1.clone());
table.add_peer(peer2.clone());
assert_eq!(table.get_all_peers().len(), 2);
assert!(table.get_peer(&peer1.peer_id).is_some());
table.add_connection(peer1.peer_id, peer2.peer_id);
let connections = table.get_connections(&peer1.peer_id);
assert!(connections.contains(&peer2.peer_id));
}
#[test]
fn test_routing_path_creation() {
let hops = vec![[1; 32], [2; 32], [3; 32]];
let path = RoutingPath::new(hops.clone());
assert_eq!(path.hops, hops);
assert_eq!(path.quality_score, 0.0);
}
#[tokio::test]
async fn test_message_router_creation() {
let local_id = [1; 32];
let router = MessageRouter::new(local_id);
assert!(router.is_ok());
}
#[test]
fn test_shortest_path_finding() {
let mut table = PeerTable::new();
let peer1 = create_test_peer(1, "127.0.0.1:8001");
let peer2 = create_test_peer(2, "127.0.0.1:8002");
let peer3 = create_test_peer(3, "127.0.0.1:8003");
table.add_peer(peer1.clone());
table.add_peer(peer2.clone());
table.add_peer(peer3.clone());
table.add_connection(peer1.peer_id, peer2.peer_id);
table.add_connection(peer2.peer_id, peer3.peer_id);
let path = table.find_shortest_path(&peer1.peer_id, &peer3.peer_id);
assert!(path.is_some());
let path = path.unwrap();
assert_eq!(path.hops.len(), 3);
assert_eq!(path.hops[0], peer1.peer_id);
assert_eq!(path.hops[1], peer2.peer_id);
assert_eq!(path.hops[2], peer3.peer_id);
}
#[tokio::test]
async fn test_route_finding() {
let local_id = [1; 32];
let mut router = MessageRouter::new(local_id).unwrap();
let peer2 = create_test_peer(2, "127.0.0.1:8002");
let peer3 = create_test_peer(3, "127.0.0.1:8003");
router.add_peer(peer2.clone());
router.add_peer(peer3.clone());
router.peer_table.add_connection(local_id, peer2.peer_id);
router.peer_table.add_connection(peer2.peer_id, peer3.peer_id);
let route = router.find_route(&peer3.peer_id);
assert!(route.is_ok());
let route = route.unwrap();
assert_eq!(route.hops.len(), 3);
}
}