use qudag_crypto::{Ciphertext, MlDsaKeyPair, MlDsaPublicKey, MlKem768, PublicKey, SecretKey};
use qudag_dag::vertex::VertexId;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
use uuid::Uuid;
#[derive(Debug, Error)]
pub enum MessageError {
#[error("Invalid message format")]
InvalidFormat,
#[error("Message too large: {0} bytes")]
MessageTooLarge(usize),
#[error("Invalid signature")]
InvalidSignature,
#[error("Missing signature")]
MissingSignature,
#[error("Message signing failed")]
SigningFailed,
#[error("Signature verification failed")]
VerificationFailed,
#[error("Encryption failed")]
EncryptionFailed,
#[error("Decryption failed")]
DecryptionFailed,
#[error("Message serialization failed")]
SerializationFailed,
#[error("Message deserialization failed")]
DeserializationFailed,
#[error("Message has expired")]
MessageExpired,
#[error("Invalid message timestamp")]
InvalidTimestamp,
#[error("Incompatible protocol version: {0:?}")]
IncompatibleVersion(ProtocolVersion),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ProtocolVersion {
pub major: u16,
pub minor: u16,
pub patch: u16,
pub features: Vec<String>,
}
impl ProtocolVersion {
pub const CURRENT: ProtocolVersion = ProtocolVersion {
major: 1,
minor: 0,
patch: 0,
features: vec![],
};
pub fn is_compatible(&self, other: &ProtocolVersion) -> bool {
self.major == other.major && self.minor <= other.minor
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum MessageType {
Handshake(HandshakeType),
Consensus(ConsensusMessageType),
Routing(RoutingMessageType),
Anonymous(AnonymousMessageType),
Control(ControlMessageType),
Sync(SyncMessageType),
Data(Vec<u8>),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum HandshakeType {
Init,
Response,
Complete,
VersionNegotiation,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ConsensusMessageType {
VertexProposal,
Vote,
Finality,
Query,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum RoutingMessageType {
OnionRouted,
Direct,
Broadcast,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum AnonymousMessageType {
Data,
Mix,
Cover,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ControlMessageType {
Ping,
Pong,
Disconnect,
KeepAlive,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum SyncMessageType {
StateRequest,
StateResponse,
DeltaSync,
CheckpointSync,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub id: Uuid,
pub version: ProtocolVersion,
pub msg_type: MessageType,
pub payload: Vec<u8>,
pub timestamp: u64,
pub signature: Option<Vec<u8>>,
pub headers: HashMap<String, String>,
pub sender_key_hash: Option<Vec<u8>>,
pub sequence: u64,
pub ttl: Option<u64>,
}
impl Message {
pub fn new(msg_type: MessageType, payload: Vec<u8>) -> Self {
Self {
id: Uuid::new_v4(),
version: ProtocolVersion::CURRENT,
msg_type,
payload,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
signature: None,
headers: HashMap::new(),
sender_key_hash: None,
sequence: 0,
ttl: None,
}
}
pub fn new_with_version(
version: ProtocolVersion,
msg_type: MessageType,
payload: Vec<u8>,
) -> Self {
Self {
id: Uuid::new_v4(),
version,
msg_type,
payload,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
signature: None,
headers: HashMap::new(),
sender_key_hash: None,
sequence: 0,
ttl: None,
}
}
pub fn with_sequence(mut self, sequence: u64) -> Self {
self.sequence = sequence;
self
}
pub fn with_ttl(mut self, ttl: u64) -> Self {
self.ttl = Some(ttl);
self
}
pub fn with_header(mut self, key: String, value: String) -> Self {
self.headers.insert(key, value);
self
}
pub fn is_expired(&self) -> bool {
if let Some(ttl) = self.ttl {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
now > self.timestamp + ttl
} else {
false
}
}
fn get_signable_data(&self) -> Result<Vec<u8>, MessageError> {
let mut msg_copy = self.clone();
msg_copy.signature = None;
bincode::serialize(&msg_copy).map_err(|_| MessageError::SerializationFailed)
}
pub fn sign(&mut self, keypair: &MlDsaKeyPair) -> Result<(), MessageError> {
let signable_data = self.get_signable_data()?;
let signature = keypair
.sign(&signable_data, &mut rand::thread_rng())
.map_err(|_| MessageError::SigningFailed)?;
self.signature = Some(signature);
let public_key_bytes = keypair.public_key();
self.sender_key_hash = Some(blake3::hash(public_key_bytes).as_bytes().to_vec());
Ok(())
}
pub fn verify(&self, public_key: &MlDsaPublicKey) -> Result<bool, MessageError> {
let signature = self
.signature
.as_ref()
.ok_or(MessageError::MissingSignature)?;
if let Some(sender_hash) = &self.sender_key_hash {
let public_key_bytes = public_key.as_bytes();
let expected_hash = blake3::hash(public_key_bytes).as_bytes().to_vec();
if sender_hash != &expected_hash {
return Ok(false);
}
}
let signable_data = self.get_signable_data()?;
public_key
.verify(&signable_data, signature)
.map_err(|_| MessageError::VerificationFailed)
.map(|_| true)
}
pub fn to_bytes(&self) -> Result<Vec<u8>, MessageError> {
bincode::serialize(self).map_err(|_| MessageError::SerializationFailed)
}
pub fn from_bytes(data: &[u8]) -> Result<Self, MessageError> {
bincode::deserialize(data).map_err(|_| MessageError::DeserializationFailed)
}
pub fn validate(&self) -> Result<(), MessageError> {
if self.is_expired() {
return Err(MessageError::MessageExpired);
}
if self.payload.len() > 1024 * 1024 {
return Err(MessageError::MessageTooLarge(self.payload.len()));
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
if self.timestamp > now + (5 * 60 * 1000) {
return Err(MessageError::InvalidTimestamp);
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedMessage {
pub ciphertext: Vec<u8>,
pub encapsulation: Vec<u8>,
pub headers: HashMap<String, String>,
pub timestamp: u64,
}
impl EncryptedMessage {
pub fn encrypt(
message: &Message,
recipient_public_key: &PublicKey,
) -> Result<Self, MessageError> {
let message_bytes = message.to_bytes()?;
let (ciphertext, shared_secret) = MlKem768::encapsulate(recipient_public_key)
.map_err(|_| MessageError::EncryptionFailed)?;
let _aes_key = &shared_secret.as_bytes()[..32];
let encrypted_data = message_bytes;
Ok(Self {
ciphertext: encrypted_data,
encapsulation: ciphertext.as_bytes().to_vec(),
headers: message.headers.clone(),
timestamp: message.timestamp,
})
}
pub fn decrypt(&self, recipient_secret_key: &SecretKey) -> Result<Message, MessageError> {
let encapsulation = Ciphertext::from_bytes(&self.encapsulation)
.map_err(|_| MessageError::DecryptionFailed)?;
let shared_secret = MlKem768::decapsulate(recipient_secret_key, &encapsulation)
.map_err(|_| MessageError::DecryptionFailed)?;
let _aes_key = &shared_secret.as_bytes()[..32];
let message_bytes = &self.ciphertext;
Message::from_bytes(message_bytes)
}
}
pub struct MessageFactory;
impl MessageFactory {
pub fn create_handshake_init(
protocol_version: ProtocolVersion,
public_key: &MlDsaPublicKey,
kem_public_key: &PublicKey,
) -> Result<Message, MessageError> {
let payload = HandshakePayload {
protocol_version: protocol_version.clone(),
public_key: public_key.as_bytes().to_vec(),
kem_public_key: kem_public_key.as_bytes().to_vec(),
capabilities: vec!["anonymous-routing".to_string(), "dag-consensus".to_string()],
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
};
let payload_bytes =
bincode::serialize(&payload).map_err(|_| MessageError::SerializationFailed)?;
Ok(Message::new_with_version(
protocol_version,
MessageType::Handshake(HandshakeType::Init),
payload_bytes,
))
}
pub fn create_vertex_proposal(
vertex_id: VertexId,
vertex_data: Vec<u8>,
parent_vertices: Vec<VertexId>,
) -> Result<Message, MessageError> {
let payload = ConsensusPayload::VertexProposal {
vertex_id,
vertex_data,
parent_vertices,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
};
let payload_bytes =
bincode::serialize(&payload).map_err(|_| MessageError::SerializationFailed)?;
Ok(Message::new(
MessageType::Consensus(ConsensusMessageType::VertexProposal),
payload_bytes,
))
}
pub fn create_ping() -> Result<Message, MessageError> {
let payload = ControlPayload::Ping {
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
nonce: rand::random::<u64>(),
};
let payload_bytes =
bincode::serialize(&payload).map_err(|_| MessageError::SerializationFailed)?;
Ok(Message::new(
MessageType::Control(ControlMessageType::Ping),
payload_bytes,
)
.with_ttl(30000)) }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandshakePayload {
pub protocol_version: ProtocolVersion,
pub public_key: Vec<u8>,
pub kem_public_key: Vec<u8>,
pub capabilities: Vec<String>,
pub timestamp: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ConsensusPayload {
VertexProposal {
vertex_id: VertexId,
vertex_data: Vec<u8>,
parent_vertices: Vec<VertexId>,
timestamp: u64,
},
Vote {
vertex_id: VertexId,
vote: bool,
timestamp: u64,
},
Finality {
vertex_ids: Vec<VertexId>,
timestamp: u64,
},
Query {
requested_vertices: Vec<VertexId>,
timestamp: u64,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ControlPayload {
Ping { timestamp: u64, nonce: u64 },
Pong { timestamp: u64, nonce: u64 },
Disconnect { reason: String, timestamp: u64 },
KeepAlive { timestamp: u64 },
}