use crate::{
message::{Message, MessageError, MessageType},
persistence::{
MemoryBackend, PersistedState, PersistenceError, PersistenceManager, SqliteBackend,
StatePersistence, StateProvider,
},
state::ProtocolStateMachine,
types::{ProtocolError, ProtocolEvent},
};
use qudag_crypto::ml_kem::MlKem768;
use qudag_dag::Consensus;
use qudag_network::Transport;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeConfig {
pub data_dir: PathBuf,
pub network_port: u16,
pub max_peers: usize,
pub initial_peers: Vec<String>,
}
impl Default for NodeConfig {
fn default() -> Self {
Self {
data_dir: PathBuf::from("./data"),
network_port: 8000,
max_peers: 50,
initial_peers: Vec::new(),
}
}
}
pub struct Node {
#[allow(dead_code)]
config: NodeConfig,
state_machine: Arc<RwLock<ProtocolStateMachine>>,
#[allow(dead_code)]
events: NodeEvents,
keys: Option<KeyPair>,
transport: Option<Arc<dyn Transport + Send + Sync>>,
#[allow(dead_code)]
consensus: Option<Arc<dyn Consensus + Send + Sync>>,
persistence: Option<PersistenceManager>,
pub node_id: Vec<u8>,
}
struct NodeEvents {
#[allow(dead_code)]
tx: mpsc::Sender<ProtocolEvent>,
#[allow(dead_code)]
rx: mpsc::Receiver<ProtocolEvent>,
}
struct KeyPair {
#[allow(dead_code)]
public_key: Vec<u8>,
#[allow(dead_code)]
private_key: Vec<u8>,
}
impl Node {
pub async fn new(config: NodeConfig) -> Result<Self, ProtocolError> {
let (tx, rx) = mpsc::channel(1000);
let node_id = Self::generate_node_id();
let state_machine = Arc::new(RwLock::new(ProtocolStateMachine::new(
crate::message::ProtocolVersion::CURRENT,
)));
Ok(Self {
config,
state_machine,
events: NodeEvents { tx, rx },
keys: None,
transport: None,
consensus: None,
persistence: None,
node_id,
})
}
pub async fn with_persistence(config: NodeConfig) -> Result<Self, ProtocolError> {
let mut node = Self::new(config.clone()).await?;
let backend: Arc<dyn StatePersistence> = if config.data_dir.join("state.db").exists() {
let db_path = config.data_dir.join("state.db");
Arc::new(SqliteBackend::new(db_path).await.map_err(|e| {
ProtocolError::Internal(format!("Failed to create SQLite backend: {}", e))
})?)
} else {
Arc::new(MemoryBackend::default())
};
let persistence_manager: PersistenceManager = backend;
if let Some(recovered_state) = persistence_manager
.recover_state()
.await
.map_err(|e| ProtocolError::Internal(format!("Failed to recover state: {}", e)))?
{
info!("Recovered state from persistence");
let _state_machine = node.state_machine.write().await;
debug!(
"Recovered {} peers and {} sessions",
recovered_state.peers.len(),
recovered_state.sessions.len()
);
}
node.persistence = Some(persistence_manager);
Ok(node)
}
fn generate_node_id() -> Vec<u8> {
use rand::RngCore;
let mut rng = rand::thread_rng();
let mut id = vec![0u8; 32];
rng.fill_bytes(&mut id);
id
}
pub async fn start(&mut self) -> Result<(), ProtocolError> {
info!("Starting node...");
self.init_keys().await?;
self.init_transport().await?;
self.init_consensus().await?;
let mut state_machine = self.state_machine.write().await;
state_machine
.transition_to(
crate::state::ProtocolState::Handshake(crate::state::HandshakeState::Waiting),
"Node starting handshake".to_string(),
)
.map_err(|e| ProtocolError::StateError(e.to_string()))?;
state_machine
.transition_to(
crate::state::ProtocolState::Handshake(crate::state::HandshakeState::InProgress),
"Handshake in progress".to_string(),
)
.map_err(|e| ProtocolError::StateError(e.to_string()))?;
state_machine
.transition_to(
crate::state::ProtocolState::Handshake(crate::state::HandshakeState::Processing),
"Processing handshake".to_string(),
)
.map_err(|e| ProtocolError::StateError(e.to_string()))?;
state_machine
.transition_to(
crate::state::ProtocolState::Handshake(crate::state::HandshakeState::Completed),
"Handshake completed".to_string(),
)
.map_err(|e| ProtocolError::StateError(e.to_string()))?;
state_machine
.transition_to(
crate::state::ProtocolState::Active(crate::state::ActiveState::Normal),
"Node started".to_string(),
)
.map_err(|e| ProtocolError::StateError(e.to_string()))?;
drop(state_machine);
info!("Node started successfully");
Ok(())
}
pub async fn stop(&mut self) -> Result<(), ProtocolError> {
info!("Stopping node...");
let mut state_machine = self.state_machine.write().await;
state_machine
.transition_to(
crate::state::ProtocolState::Shutdown,
"Node stopping".to_string(),
)
.map_err(|e| ProtocolError::StateError(e.to_string()))?;
drop(state_machine);
if let Some(persistence) = &self.persistence {
let state = self.get_current_state().await.map_err(|e| {
ProtocolError::Internal(format!("Failed to get state for save: {}", e))
})?;
persistence.save_state(&state).await.map_err(|e| {
ProtocolError::Internal(format!("Failed to save final state: {}", e))
})?;
}
if let Some(_transport) = &self.transport {
}
info!("Node stopped successfully");
Ok(())
}
pub async fn handle_message(&mut self, message: Message) -> Result<(), MessageError> {
debug!("Handling message: {:?}", message.msg_type);
match message.msg_type {
MessageType::Handshake(_) => self.handle_handshake(message).await?,
MessageType::Data(_) => self.handle_data(message).await?,
MessageType::Control(_) => self.handle_control(message).await?,
MessageType::Sync(_) => self.handle_sync(message).await?,
_ => return Err(MessageError::InvalidFormat),
}
Ok(())
}
async fn init_keys(&mut self) -> Result<(), ProtocolError> {
let (pk, sk) = MlKem768::keygen().map_err(|e| ProtocolError::CryptoError(e.to_string()))?;
self.keys = Some(KeyPair {
public_key: pk.as_bytes().to_vec(),
private_key: sk.as_bytes().to_vec(),
});
Ok(())
}
async fn init_transport(&mut self) -> Result<(), ProtocolError> {
Ok(())
}
async fn init_consensus(&mut self) -> Result<(), ProtocolError> {
Ok(())
}
async fn handle_handshake(&mut self, _message: Message) -> Result<(), MessageError> {
Ok(())
}
async fn handle_data(&mut self, _message: Message) -> Result<(), MessageError> {
Ok(())
}
async fn handle_control(&mut self, _message: Message) -> Result<(), MessageError> {
Ok(())
}
async fn handle_sync(&mut self, _message: Message) -> Result<(), MessageError> {
Ok(())
}
pub async fn get_state(&self) -> crate::state::ProtocolState {
self.state_machine.read().await.current_state().clone()
}
pub fn has_persistence(&self) -> bool {
self.persistence.is_some()
}
pub async fn save_state(&self) -> Result<(), ProtocolError> {
if let Some(persistence) = &self.persistence {
let state = self
.get_current_state()
.await
.map_err(|e| ProtocolError::Internal(format!("Failed to get state: {}", e)))?;
persistence
.save_state(&state)
.await
.map_err(|e| ProtocolError::Internal(format!("Failed to save state: {}", e)))?;
info!("State saved successfully");
} else {
warn!("No persistence backend configured");
}
Ok(())
}
pub async fn create_backup(&self, backup_path: PathBuf) -> Result<(), ProtocolError> {
if let Some(persistence) = &self.persistence {
persistence
.create_backup(&backup_path)
.await
.map_err(|e| ProtocolError::Internal(format!("Failed to create backup: {}", e)))?;
info!("Backup created at {:?}", backup_path);
} else {
return Err(ProtocolError::Internal(
"No persistence backend configured".to_string(),
));
}
Ok(())
}
pub async fn restore_backup(&self, backup_path: PathBuf) -> Result<(), ProtocolError> {
if let Some(persistence) = &self.persistence {
persistence
.restore_backup(&backup_path)
.await
.map_err(|e| ProtocolError::Internal(format!("Failed to restore backup: {}", e)))?;
info!("Backup restored from {:?}", backup_path);
} else {
return Err(ProtocolError::Internal(
"No persistence backend configured".to_string(),
));
}
Ok(())
}
}
impl Node {
pub async fn get_current_state(&self) -> Result<PersistedState, PersistenceError> {
let state_machine = self.state_machine.read().await;
let current_state = state_machine.current_state().clone();
let sessions = state_machine.get_sessions().clone();
let metrics = state_machine.get_metrics();
let peers = vec![];
let dag_state = crate::persistence::DagState {
vertices: HashMap::new(),
tips: std::collections::HashSet::new(),
voting_records: HashMap::new(),
last_checkpoint: None,
};
Ok(PersistedState {
version: crate::persistence::CURRENT_STATE_VERSION,
node_id: self.node_id.clone(),
protocol_state: current_state,
sessions,
peers,
dag_state,
metrics,
last_saved: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
})
}
}
pub struct NodeStateProvider {
node: Arc<RwLock<Node>>,
}
impl NodeStateProvider {
pub fn new(node: Arc<RwLock<Node>>) -> Self {
Self { node }
}
}
impl StateProvider for NodeStateProvider {
fn get_state_store(&self) -> Arc<dyn crate::persistence::StateStore + Send + Sync> {
if let Ok(node) = self.node.try_read() {
if let Some(persistence) = &node.persistence {
return persistence.clone();
}
}
Arc::new(crate::persistence::MemoryStateStore::new())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_node_lifecycle() {
let config = NodeConfig::default();
let mut node = Node::new(config).await.unwrap();
assert_eq!(node.get_state().await, crate::state::ProtocolState::Initial);
node.start().await.unwrap();
assert!(matches!(
node.get_state().await,
crate::state::ProtocolState::Active(_)
));
node.stop().await.unwrap();
assert_eq!(
node.get_state().await,
crate::state::ProtocolState::Shutdown
);
}
#[tokio::test]
async fn test_node_persistence() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let config = NodeConfig {
data_dir: temp_dir.path().to_path_buf(),
..Default::default()
};
let mut node = Node::with_persistence(config.clone()).await.unwrap();
node.start().await.unwrap();
node.save_state().await.unwrap();
let backup_path = temp_dir.path().join("backup");
std::fs::create_dir_all(&backup_path).unwrap();
node.create_backup(backup_path.clone()).await.unwrap();
node.stop().await.unwrap();
let node2 = Node::with_persistence(config).await.unwrap();
assert!(node2.persistence.is_some());
}
}