use crate::raft::{OxirsNodeId, RdfCommand, RdfResponse};
use crate::tls::{TlsConfig, TlsManager};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use std::time::SystemTime;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::RwLock;
use tokio::time::timeout;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RpcMessage {
RequestVote {
term: u64,
candidate_id: OxirsNodeId,
last_log_index: u64,
last_log_term: u64,
},
VoteResponse { term: u64, vote_granted: bool },
AppendEntries {
term: u64,
leader_id: OxirsNodeId,
prev_log_index: u64,
prev_log_term: u64,
entries: Vec<LogEntry>,
leader_commit: u64,
},
AppendEntriesResponse {
term: u64,
success: bool,
last_log_index: u64,
},
ClientRequest { command: RdfCommand },
ClientResponse { response: RdfResponse },
Heartbeat { term: u64, leader_id: OxirsNodeId },
HeartbeatResponse { term: u64 },
#[cfg(feature = "bft")]
Bft { data: Vec<u8> },
ShardOperation(crate::shard_manager::ShardOperation),
StoreTriple {
shard_id: crate::shard::ShardId,
triple: oxirs_core::model::Triple,
},
ReplicateTriple {
shard_id: crate::shard::ShardId,
triple: oxirs_core::model::Triple,
},
QueryShard {
shard_id: crate::shard::ShardId,
subject: Option<String>,
predicate: Option<String>,
object: Option<String>,
},
QueryShardResponse {
shard_id: crate::shard::ShardId,
results: Vec<oxirs_core::model::Triple>,
},
TransactionPrepare {
tx_id: String,
shard_id: crate::shard::ShardId,
operations: Vec<crate::transaction::TransactionOp>,
},
TransactionVote {
tx_id: String,
shard_id: crate::shard::ShardId,
vote: bool,
},
TransactionCommit {
tx_id: String,
shard_id: crate::shard::ShardId,
},
TransactionAbort {
tx_id: String,
shard_id: crate::shard::ShardId,
},
TransactionAck {
tx_id: String,
shard_id: crate::shard::ShardId,
},
MigrationBatch {
migration_id: String,
batch: crate::shard_migration::MigrationBatch,
},
ShardTransfer {
shard_id: crate::shard::ShardId,
triples: Vec<oxirs_core::model::Triple>,
source_node: crate::raft::OxirsNodeId,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogEntry {
pub index: u64,
pub term: u64,
pub command: RdfCommand,
}
impl LogEntry {
pub fn new(index: u64, term: u64, command: RdfCommand) -> Self {
Self {
index,
term,
command,
}
}
}
#[derive(Debug, Clone)]
pub struct NetworkConfig {
pub local_address: SocketAddr,
pub connection_timeout: Duration,
pub request_timeout: Duration,
pub max_connections: usize,
pub keep_alive_interval: Duration,
pub tls_config: TlsConfig,
pub enable_compression: bool,
pub max_message_size: usize,
}
impl Default for NetworkConfig {
fn default() -> Self {
Self {
local_address: "127.0.0.1:8080"
.parse()
.expect("localhost address is valid"),
connection_timeout: Duration::from_secs(5),
request_timeout: Duration::from_secs(10),
max_connections: 100,
keep_alive_interval: Duration::from_secs(30),
tls_config: TlsConfig::default(),
enable_compression: true,
max_message_size: 16 * 1024 * 1024, }
}
}
#[derive(Debug)]
pub struct NetworkManager {
config: NetworkConfig,
node_id: OxirsNodeId,
connections: Arc<RwLock<HashMap<OxirsNodeId, Connection>>>,
listener: Option<TcpListener>,
running: Arc<RwLock<bool>>,
tls_manager: Option<Arc<TlsManager>>,
message_stats: Arc<RwLock<MessageStats>>,
}
#[derive(Debug, Clone, Default)]
pub struct MessageStats {
pub messages_sent: u64,
pub messages_received: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub connections_established: u64,
pub connections_failed: u64,
pub tls_handshakes_completed: u64,
pub tls_handshakes_failed: u64,
}
#[derive(Debug, Clone)]
pub struct Connection {
pub peer_id: OxirsNodeId,
pub address: SocketAddr,
pub last_activity: std::time::Instant,
pub is_connected: bool,
}
impl Connection {
pub fn new(peer_id: OxirsNodeId, address: SocketAddr) -> Self {
Self {
peer_id,
address,
last_activity: std::time::Instant::now(),
is_connected: false,
}
}
pub fn is_stale(&self, timeout: Duration) -> bool {
self.last_activity.elapsed() > timeout
}
pub fn update_activity(&mut self) {
self.last_activity = std::time::Instant::now();
}
}
impl NetworkManager {
pub fn new(node_id: OxirsNodeId, config: NetworkConfig) -> Self {
Self {
config,
node_id,
connections: Arc::new(RwLock::new(HashMap::new())),
listener: None,
running: Arc::new(RwLock::new(false)),
tls_manager: None,
message_stats: Arc::new(RwLock::new(MessageStats::default())),
}
}
pub async fn with_tls(node_id: OxirsNodeId, config: NetworkConfig) -> Result<Self> {
let tls_manager = if config.tls_config.enabled {
let tls_mgr = TlsManager::new(config.tls_config.clone(), node_id);
tls_mgr.initialize().await?;
Some(Arc::new(tls_mgr))
} else {
None
};
Ok(Self {
config,
node_id,
connections: Arc::new(RwLock::new(HashMap::new())),
listener: None,
running: Arc::new(RwLock::new(false)),
tls_manager,
message_stats: Arc::new(RwLock::new(MessageStats::default())),
})
}
pub async fn start(&mut self) -> Result<()> {
{
let mut running = self.running.write().await;
if *running {
return Ok(());
}
*running = true;
}
let listener = TcpListener::bind(self.config.local_address).await?;
tracing::info!(
"Network manager for node {} listening on {}",
self.node_id,
self.config.local_address
);
self.listener = Some(listener);
self.start_background_tasks().await;
Ok(())
}
pub async fn stop(&mut self) -> Result<()> {
let mut running = self.running.write().await;
if !*running {
return Ok(());
}
tracing::info!("Stopping network manager for node {}", self.node_id);
*running = false;
let mut connections = self.connections.write().await;
connections.clear();
Ok(())
}
pub async fn send_rpc(
&self,
peer_id: OxirsNodeId,
peer_address: SocketAddr,
message: RpcMessage,
) -> Result<RpcMessage> {
let connection = self.get_or_create_connection(peer_id, peer_address).await?;
let response = timeout(
self.config.request_timeout,
self.send_message_to_connection(connection, message),
)
.await??;
Ok(response)
}
pub async fn send_request_vote(
&self,
peer_id: OxirsNodeId,
peer_address: SocketAddr,
term: u64,
last_log_index: u64,
last_log_term: u64,
) -> Result<(u64, bool)> {
let message = RpcMessage::RequestVote {
term,
candidate_id: self.node_id,
last_log_index,
last_log_term,
};
let response = self.send_rpc(peer_id, peer_address, message).await?;
match response {
RpcMessage::VoteResponse { term, vote_granted } => Ok((term, vote_granted)),
_ => Err(anyhow::anyhow!("Unexpected response type")),
}
}
#[allow(clippy::too_many_arguments)]
pub async fn send_append_entries(
&self,
peer_id: OxirsNodeId,
peer_address: SocketAddr,
term: u64,
prev_log_index: u64,
prev_log_term: u64,
entries: Vec<LogEntry>,
leader_commit: u64,
) -> Result<(u64, bool, u64)> {
let message = RpcMessage::AppendEntries {
term,
leader_id: self.node_id,
prev_log_index,
prev_log_term,
entries,
leader_commit,
};
let response = self.send_rpc(peer_id, peer_address, message).await?;
match response {
RpcMessage::AppendEntriesResponse {
term,
success,
last_log_index,
} => Ok((term, success, last_log_index)),
_ => Err(anyhow::anyhow!("Unexpected response type")),
}
}
pub async fn send_heartbeat(&self, term: u64, peers: &[(OxirsNodeId, SocketAddr)]) {
let message = RpcMessage::Heartbeat {
term,
leader_id: self.node_id,
};
for &(peer_id, peer_address) in peers {
if peer_id != self.node_id {
let manager = self.clone();
let message = message.clone();
tokio::spawn(async move {
if let Err(e) = manager.send_rpc(peer_id, peer_address, message).await {
tracing::warn!("Failed to send heartbeat to peer {}: {}", peer_id, e);
}
});
}
}
}
async fn get_or_create_connection(
&self,
peer_id: OxirsNodeId,
peer_address: SocketAddr,
) -> Result<Connection> {
{
let connections = self.connections.read().await;
if let Some(connection) = connections.get(&peer_id) {
if connection.is_connected && !connection.is_stale(self.config.connection_timeout) {
return Ok(connection.clone());
}
}
}
let _stream = timeout(
self.config.connection_timeout,
TcpStream::connect(peer_address),
)
.await??;
let mut connection = Connection::new(peer_id, peer_address);
connection.is_connected = true;
connection.update_activity();
{
let mut connections = self.connections.write().await;
connections.insert(peer_id, connection.clone());
}
Ok(connection)
}
async fn send_message_to_connection(
&self,
mut connection: Connection,
message: RpcMessage,
) -> Result<RpcMessage> {
connection.update_activity();
tokio::time::sleep(Duration::from_millis(1)).await;
let response = match message {
RpcMessage::RequestVote { term, .. } => RpcMessage::VoteResponse {
term,
vote_granted: false, },
RpcMessage::AppendEntries { term, .. } => RpcMessage::AppendEntriesResponse {
term,
success: true, last_log_index: 0,
},
RpcMessage::Heartbeat { term, .. } => RpcMessage::HeartbeatResponse { term },
RpcMessage::ClientRequest { .. } => RpcMessage::ClientResponse {
response: RdfResponse::Success,
},
_ => return Err(anyhow::anyhow!("Unexpected message type")),
};
Ok(response)
}
async fn start_background_tasks(&self) {
let running = Arc::clone(&self.running);
let connections = Arc::clone(&self.connections);
let connection_timeout = self.config.connection_timeout;
tokio::spawn(async move {
while *running.read().await {
{
let mut connections = connections.write().await;
let stale_connections: Vec<_> = connections
.iter()
.filter(|(_, conn)| conn.is_stale(connection_timeout))
.map(|(&id, _)| id)
.collect();
for peer_id in stale_connections {
connections.remove(&peer_id);
tracing::debug!("Removed stale connection to peer {}", peer_id);
}
}
tokio::time::sleep(Duration::from_secs(30)).await;
}
});
}
pub async fn get_stats(&self) -> NetworkStats {
let connections = self.connections.read().await;
let total_connections = connections.len();
let active_connections = connections.values().filter(|c| c.is_connected).count();
NetworkStats {
total_connections,
active_connections,
local_address: self.config.local_address,
node_id: self.node_id,
}
}
pub async fn send_secure_rpc(
&self,
peer_id: OxirsNodeId,
peer_address: SocketAddr,
message: RpcMessage,
) -> Result<RpcMessage> {
if let Some(tls_manager) = &self.tls_manager {
let connector = tls_manager.get_connector().await?;
let tcp_stream = timeout(
self.config.connection_timeout,
TcpStream::connect(peer_address),
)
.await??;
let server_name = rustls::pki_types::ServerName::try_from(format!("node-{peer_id}"))?;
let _tls_stream = connector.connect(server_name, tcp_stream).await?;
{
let mut stats = self.message_stats.write().await;
stats.tls_handshakes_completed += 1;
stats.connections_established += 1;
}
self.simulate_secure_communication(message).await
} else {
self.send_rpc(peer_id, peer_address, message).await
}
}
async fn simulate_secure_communication(&self, message: RpcMessage) -> Result<RpcMessage> {
{
let mut stats = self.message_stats.write().await;
stats.messages_sent += 1;
stats.bytes_sent += self.estimate_message_size(&message);
}
tokio::time::sleep(Duration::from_millis(5)).await;
let response = match message {
RpcMessage::RequestVote { term, .. } => RpcMessage::VoteResponse {
term,
vote_granted: true,
},
RpcMessage::AppendEntries { term, .. } => RpcMessage::AppendEntriesResponse {
term,
success: true,
last_log_index: 0,
},
RpcMessage::Heartbeat { term, .. } => RpcMessage::HeartbeatResponse { term },
RpcMessage::ClientRequest { .. } => RpcMessage::ClientResponse {
response: RdfResponse::Success,
},
_ => return Err(anyhow::anyhow!("Unsupported message type")),
};
{
let mut stats = self.message_stats.write().await;
stats.messages_received += 1;
stats.bytes_received += self.estimate_message_size(&response);
}
Ok(response)
}
fn estimate_message_size(&self, message: &RpcMessage) -> u64 {
match message {
RpcMessage::RequestVote { .. } => 64,
RpcMessage::VoteResponse { .. } => 32,
RpcMessage::AppendEntries { entries, .. } => 128 + entries.len() as u64 * 256,
RpcMessage::AppendEntriesResponse { .. } => 48,
RpcMessage::Heartbeat { .. } => 24,
RpcMessage::HeartbeatResponse { .. } => 16,
RpcMessage::ClientRequest { .. } => 512,
RpcMessage::ClientResponse { .. } => 256,
_ => 128,
}
}
pub async fn get_message_stats(&self) -> MessageStats {
self.message_stats.read().await.clone()
}
pub async fn reset_stats(&self) {
let mut stats = self.message_stats.write().await;
*stats = MessageStats::default();
}
pub async fn get_tls_status(&self) -> Result<TlsStatus> {
if let Some(tls_manager) = &self.tls_manager {
let certificates = tls_manager.list_certificates().await;
let server_cert = certificates.get("server");
Ok(TlsStatus {
enabled: true,
certificates_count: certificates.len(),
server_cert_expires: server_cert.map(|c| c.not_after),
handshakes_completed: self.message_stats.read().await.tls_handshakes_completed,
handshakes_failed: self.message_stats.read().await.tls_handshakes_failed,
})
} else {
Ok(TlsStatus {
enabled: false,
certificates_count: 0,
server_cert_expires: None,
handshakes_completed: 0,
handshakes_failed: 0,
})
}
}
pub async fn encrypt_data(&self, data: &[u8]) -> Result<Vec<u8>> {
if let Some(_tls_manager) = &self.tls_manager {
let mut encrypted = Vec::with_capacity(data.len() + 32);
encrypted.extend_from_slice(b"ENCRYPTED:");
encrypted.extend_from_slice(data);
Ok(encrypted)
} else {
Ok(data.to_vec())
}
}
pub async fn decrypt_data(&self, encrypted_data: &[u8]) -> Result<Vec<u8>> {
if let Some(_tls_manager) = &self.tls_manager {
if encrypted_data.starts_with(b"ENCRYPTED:") {
Ok(encrypted_data[10..].to_vec())
} else {
Err(anyhow::anyhow!("Invalid encrypted data format"))
}
} else {
Ok(encrypted_data.to_vec())
}
}
}
impl Clone for NetworkManager {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
node_id: self.node_id,
connections: Arc::clone(&self.connections),
listener: None, running: Arc::clone(&self.running),
tls_manager: self.tls_manager.clone(),
message_stats: Arc::clone(&self.message_stats),
}
}
}
#[derive(Debug, Clone)]
pub struct NetworkStats {
pub total_connections: usize,
pub active_connections: usize,
pub local_address: SocketAddr,
pub node_id: OxirsNodeId,
}
#[derive(Debug, Clone)]
pub struct TlsStatus {
pub enabled: bool,
pub certificates_count: usize,
pub server_cert_expires: Option<SystemTime>,
pub handshakes_completed: u64,
pub handshakes_failed: u64,
}
#[derive(Debug, Clone)]
pub struct NetworkService {
manager: NetworkManager,
}
impl NetworkService {
pub fn new(node_id: OxirsNodeId, config: NetworkConfig) -> Self {
Self {
manager: NetworkManager::new(node_id, config),
}
}
pub async fn start(&mut self) -> Result<()> {
self.manager.start().await
}
pub async fn stop(&mut self) -> Result<()> {
self.manager.stop().await
}
pub async fn send_to(&self, peer_id: &str, message: RpcMessage) -> Result<()> {
let peer_id: OxirsNodeId = peer_id
.parse()
.map_err(|_| anyhow::anyhow!("Invalid peer ID"))?;
self.send_message(peer_id, message).await?;
Ok(())
}
pub async fn broadcast(&self, message: RpcMessage) -> Result<()> {
let connections = self.manager.connections.read().await;
for peer_id in connections.keys() {
let _ = self.send_message(*peer_id, message.clone()).await;
}
Ok(())
}
pub async fn send_message(&self, node_id: OxirsNodeId, message: RpcMessage) -> Result<()> {
tracing::debug!("Sending message to node {}: {:?}", node_id, message);
Ok(())
}
pub async fn handle_message(&self, message: RpcMessage) -> Result<RpcMessage> {
match message {
#[cfg(feature = "bft")]
RpcMessage::Bft { .. } => {
Err(anyhow::anyhow!(
"BFT messages should be handled by BFT network service"
))
}
_ => {
Err(anyhow::anyhow!("Message handling not yet implemented"))
}
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum NetworkError {
#[error("Connection failed to {peer_id} at {address}: {message}")]
ConnectionFailed {
peer_id: OxirsNodeId,
address: SocketAddr,
message: String,
},
#[error("Timeout: {operation} timed out after {duration:?}")]
Timeout {
operation: String,
duration: Duration,
},
#[error("Serialization error: {message}")]
Serialization { message: String },
#[error("Protocol error: {message}")]
Protocol { message: String },
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
#[test]
fn test_network_config_default() {
let config = NetworkConfig::default();
assert_eq!(config.connection_timeout, Duration::from_secs(5));
assert_eq!(config.request_timeout, Duration::from_secs(10));
assert_eq!(config.max_connections, 100);
assert_eq!(config.keep_alive_interval, Duration::from_secs(30));
}
#[test]
fn test_log_entry_creation() {
let command = RdfCommand::Insert {
subject: "s".to_string(),
predicate: "p".to_string(),
object: "o".to_string(),
};
let entry = LogEntry::new(1, 1, command.clone());
assert_eq!(entry.index, 1);
assert_eq!(entry.term, 1);
assert_eq!(entry.command, command);
}
#[test]
fn test_connection_creation() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let connection = Connection::new(1, addr);
assert_eq!(connection.peer_id, 1);
assert_eq!(connection.address, addr);
assert!(!connection.is_connected);
}
#[test]
fn test_connection_staleness() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let connection = Connection::new(1, addr);
assert!(!connection.is_stale(Duration::from_secs(10)));
std::thread::sleep(Duration::from_millis(1));
assert!(connection.is_stale(Duration::from_nanos(1)));
}
#[test]
fn test_connection_activity_update() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let mut connection = Connection::new(1, addr);
let old_activity = connection.last_activity;
std::thread::sleep(Duration::from_millis(1));
connection.update_activity();
assert!(connection.last_activity > old_activity);
}
#[tokio::test]
async fn test_network_manager_creation() {
let config = NetworkConfig::default();
let manager = NetworkManager::new(1, config);
assert_eq!(manager.node_id, 1);
assert!(!*manager.running.read().await);
let stats = manager.get_stats().await;
assert_eq!(stats.node_id, 1);
assert_eq!(stats.total_connections, 0);
assert_eq!(stats.active_connections, 0);
}
#[tokio::test]
async fn test_rpc_message_serialization() {
let message = RpcMessage::RequestVote {
term: 1,
candidate_id: 1,
last_log_index: 0,
last_log_term: 0,
};
let serialized = serde_json::to_string(&message).unwrap();
let deserialized: RpcMessage = serde_json::from_str(&serialized).unwrap();
match deserialized {
RpcMessage::RequestVote {
term,
candidate_id,
last_log_index,
last_log_term,
} => {
assert_eq!(term, 1);
assert_eq!(candidate_id, 1);
assert_eq!(last_log_index, 0);
assert_eq!(last_log_term, 0);
}
_ => panic!("Unexpected message type"),
}
}
#[test]
fn test_network_error_display() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let err = NetworkError::ConnectionFailed {
peer_id: 1,
address: addr,
message: "refused".to_string(),
};
assert!(err
.to_string()
.contains("Connection failed to 1 at 127.0.0.1:8080: refused"));
let err = NetworkError::Timeout {
operation: "connect".to_string(),
duration: Duration::from_secs(5),
};
assert!(err
.to_string()
.contains("Timeout: connect timed out after 5s"));
let err = NetworkError::Serialization {
message: "invalid json".to_string(),
};
assert!(err
.to_string()
.contains("Serialization error: invalid json"));
let err = NetworkError::Protocol {
message: "unknown message".to_string(),
};
assert!(err.to_string().contains("Protocol error: unknown message"));
}
}