use crate::crypto::{SecurityContext, ResumptionTicket};
use crate::error::DeonError;
use crate::transport::{SecureTransport, TransportType};
use crate::types::{
ProtocolMessage, WireHeader, FLAG_ENCRYPTED, HEADER_LEN, MAGIC_BYTES, VERSION
};
use log::{info, debug};
use spake2::{Ed25519Group, Identity, Password, Spake2};
use std::time::SystemTime;
use hkdf::Hkdf;
use sha2::Sha256;
#[derive(Debug, PartialEq)]
pub enum ProtocolState {
Searching,
Handshaking,
Streaming,
Idle,
}
pub struct DeonProtocol {
state: ProtocolState,
transport: Box<dyn SecureTransport>,
security_context: Option<SecurityContext>,
resumption_ticket: Option<ResumptionTicket>,
_buffer: Vec<u8>,
}
impl DeonProtocol {
pub fn new(transport: Box<dyn SecureTransport>) -> Self {
Self {
state: ProtocolState::Searching,
transport,
security_context: None,
resumption_ticket: None,
_buffer: Vec::new(),
}
}
pub async fn handshake(&mut self, pin: &str) -> Result<(), DeonError> {
debug!("Starting Handshake with PIN: {}", pin);
self.state = ProtocolState::Handshaking;
if self.transport.get_type() == TransportType::Ble {
let rssi = self.transport.get_rssi().await?;
debug!("RSSI: {}", rssi);
if rssi < -45 {
info!("Handshake rejected due to low RSSI: {}", rssi);
return Err(DeonError::HandshakeError);
}
}
if let Some(ticket) = self.resumption_ticket.clone() {
let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs();
if ticket.expiry == 0 || ticket.expiry > now {
debug!("Attempting Session Resumption...");
let resume_msg = ProtocolMessage::Resume {
session_id: ticket.session_id
};
if let Ok(_) = self.send_cleartext_message(resume_msg).await {
if let Ok(response_bytes) = self.transport.receive().await {
if let Ok(response) = postcard::from_bytes::<ProtocolMessage>(&response_bytes) {
if let ProtocolMessage::ResumeAck = response {
info!("Session Resumed Successfully!");
self.security_context = Some(SecurityContext::new(ticket.key, true));
self.state = ProtocolState::Idle;
return Ok(());
}
}
}
}
debug!("Resumption failed, falling back to full handshake.");
}
}
debug!("Initializing SPAKE2...");
let pwd = Password::new(pin.as_bytes());
let id_a = Identity::new(b"device_a");
let id_b = Identity::new(b"device_b");
let (s1, msg1) = Spake2::<Ed25519Group>::start_a(
&pwd,
&id_a,
&id_b
);
let hello_msg = ProtocolMessage::Hello {
public_key: msg1, device_id: "DeviceA".to_string()
};
debug!("Sending Client Hello...");
self.send_cleartext_message(hello_msg).await?;
debug!("Waiting for Server Hello...");
let response_bytes = self.transport.receive().await?;
debug!("Received {} bytes", response_bytes.len());
let peer_msg: ProtocolMessage = postcard::from_bytes(&response_bytes)?;
if let ProtocolMessage::Hello { public_key: msg2_bytes, .. } = peer_msg {
debug!("Got Server Hello. Finishing SPAKE2...");
let key = s1.finish(&msg2_bytes)
.map_err(|e| {
info!("SPAKE2 Finish failed: {:?}", e);
DeonError::HandshakeError
})?;
debug!("SPAKE2 Handshake Success! Shared Key Derived.");
let mut shared_secret = [0u8; 32];
shared_secret.copy_from_slice(&key[0..32]);
self.security_context = Some(SecurityContext::new(shared_secret, true));
self.send_encrypted_message(&ProtocolMessage::Ping).await?;
let pong_msg = self.read_and_decrypt_message().await?;
if let ProtocolMessage::Pong = pong_msg {
self.state = ProtocolState::Idle;
let mut session_id = [0u8; 32];
let hk = Hkdf::<Sha256>::new(None, &shared_secret);
hk.expand(b"deon_session_id", &mut session_id).unwrap();
self.resumption_ticket = Some(ResumptionTicket {
session_id,
key: shared_secret,
expiry: 0, });
info!("SPAKE2 Handshake Complete. Secure Session Established.");
Ok(())
} else {
Err(DeonError::HandshakeError)
}
} else {
Err(DeonError::ProtocolViolation)
}
}
pub async fn accept_handshake(&mut self, pin: &str) -> Result<(), DeonError> {
debug!("Waiting for Handshake from Client (PIN: {})", pin);
self.state = ProtocolState::Handshaking;
let msg_bytes = self.transport.receive().await?;
let msg: ProtocolMessage = postcard::from_bytes(&msg_bytes)?;
match msg {
ProtocolMessage::Resume { session_id } => {
debug!("Received Resume Request");
if let Some(ticket) = self.resumption_ticket.clone() {
if ticket.session_id == session_id {
let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs();
if ticket.expiry == 0 || ticket.expiry > now {
debug!("Session ID matched. Resuming...");
self.send_cleartext_message(ProtocolMessage::ResumeAck).await?;
self.security_context = Some(SecurityContext::new(ticket.key, false));
self.state = ProtocolState::Idle;
info!("Session Resumed (Responder).");
return Ok(());
}
}
}
debug!("Resumption failed. Rejecting.");
return Err(DeonError::HandshakeError);
}
ProtocolMessage::Hello { public_key: msg1_bytes, .. } => {
debug!("Received Client Hello. Starting SPAKE2 Responder...");
let pwd = Password::new(pin.as_bytes());
let id_a = Identity::new(b"device_a");
let id_b = Identity::new(b"device_b");
let (s2, msg2) = Spake2::<Ed25519Group>::start_b(
&pwd,
&id_a,
&id_b
);
let key = s2.finish(&msg1_bytes)
.map_err(|e| {
info!("SPAKE2 Finish failed: {:?}", e);
DeonError::HandshakeError
})?;
let mut shared_secret = [0u8; 32];
shared_secret.copy_from_slice(&key[0..32]);
let mut session_id = [0u8; 32];
let hk = Hkdf::<Sha256>::new(None, &shared_secret);
hk.expand(b"deon_session_id", &mut session_id).unwrap();
self.resumption_ticket = Some(ResumptionTicket {
session_id,
key: shared_secret,
expiry: 0,
});
let response = ProtocolMessage::Hello {
public_key: msg2,
device_id: "DeviceB".to_string(),
};
self.send_cleartext_message(response).await?;
self.security_context = Some(SecurityContext::new(shared_secret, false));
debug!("SPAKE2 Shared Key Derived.");
let ping_msg = self.read_and_decrypt_message().await?;
if let ProtocolMessage::Ping = ping_msg {
debug!("Received Ping. Sending Pong...");
self.send_encrypted_message(&ProtocolMessage::Pong).await?;
self.state = ProtocolState::Idle;
info!("Secure Session Established (Responder).");
Ok(())
} else {
Err(DeonError::HandshakeError)
}
}
_ => Err(DeonError::ProtocolViolation),
}
}
pub async fn receive_file(&mut self) -> Result<(), DeonError> {
self.state = ProtocolState::Streaming;
let mut file: Option<tokio::fs::File> = None;
let mut expected_size = 0u64;
let mut received_size = 0u64;
let mut current_filename = String::new();
loop {
let msg = match self.read_and_decrypt_message().await {
Ok(m) => m,
Err(e) => {
info!("Connection closed or error: {:?}", e);
break;
}
};
match msg {
ProtocolMessage::FileHeader { filename, size, .. } => {
info!("Receiving File: {} ({} bytes)", filename, size);
current_filename = filename.clone();
expected_size = size;
received_size = 0;
file = Some(tokio::fs::File::create(¤t_filename).await.map_err(|_| DeonError::Io)?);
}
ProtocolMessage::FileChunk { offset, data } => {
if let Some(f) = file.as_mut() {
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
f.seek(tokio::io::SeekFrom::Start(offset)).await.map_err(|_| DeonError::Io)?;
f.write_all(&data).await.map_err(|_| DeonError::Io)?;
received_size += data.len() as u64;
if received_size % (1024 * 1024) == 0 || received_size == expected_size {
info!("Progress: {}/{} bytes", received_size, expected_size);
}
if received_size >= expected_size {
info!("File Transfer Complete: {}", current_filename);
return Ok(());
}
}
}
_ => {
debug!("Received other message: {:?}", msg);
}
}
}
Ok(())
}
pub async fn send_file(&mut self, filename: &str, data: &[u8]) -> Result<(), DeonError> {
self.state = ProtocolState::Streaming;
if data.len() > 64 * 1024 && self.transport.get_type() == TransportType::Ble {
info!("Payload > 64KB. Initiating Wi-Fi Handover...");
self.perform_wifi_handover().await?;
}
let file_header = ProtocolMessage::FileHeader {
filename: filename.to_string(),
size: data.len() as u64,
checksum: crc32fast::hash(data),
};
self.send_encrypted_message(&file_header).await?;
const CHUNK_SIZE: usize = 64 * 1024;
let mut offset = 0;
let mut chunk_hasher = crc32fast::Hasher::new();
for chunk in data.chunks(CHUNK_SIZE) {
chunk_hasher.update(chunk);
let _chunk_checksum = chunk_hasher.clone().finalize();
let chunk_msg = ProtocolMessage::FileChunk {
offset,
data: chunk.to_vec(),
};
self.send_encrypted_message(&chunk_msg).await?;
if self.transport.get_type() == TransportType::Ble {
}
offset += chunk.len() as u64;
}
self.state = ProtocolState::Idle;
Ok(())
}
async fn perform_wifi_handover(&mut self) -> Result<(), DeonError> {
let switch_msg = ProtocolMessage::SwitchToWifi {
ssid: "DeonSecureNet".to_string(),
ip: "192.168.1.50".to_string(),
port: 8080,
};
self.send_encrypted_message(&switch_msg).await?;
let ack = self.read_and_decrypt_message().await?;
if let ProtocolMessage::Ack { .. } = ack {
} else {
return Err(DeonError::HandshakeError);
}
let new_transport = crate::transport::connect_tcp("127.0.0.1:8080").await?;
self.transport = new_transport;
info!("Transport switched to Wi-Fi TCP");
Ok(())
}
async fn send_cleartext_message(&mut self, msg: ProtocolMessage) -> Result<(), DeonError> {
let payload = postcard::to_stdvec(&msg)?;
self.transport.send(&payload).await
}
async fn send_encrypted_message(&mut self, msg: &ProtocolMessage) -> Result<(), DeonError> {
let context = self.security_context.as_ref().ok_or(DeonError::InvalidState)?;
let payload = postcard::to_stdvec(msg)?;
let (ciphertext, nonce) = context.encrypt(&payload, &[])?;
let header = WireHeader {
magic: MAGIC_BYTES,
version: VERSION,
flags: FLAG_ENCRYPTED,
nonce,
};
let mut frame = Vec::with_capacity(HEADER_LEN + ciphertext.len());
frame.extend_from_slice(&header.to_bytes());
frame.extend_from_slice(&ciphertext);
let mut attempts = 0;
let mut delay = 50; let max_retries = 3;
loop {
match self.transport.send(&frame).await {
Ok(_) => return Ok(()),
Err(e) => {
attempts += 1;
if attempts >= max_retries {
debug!("Transport send failed after {} attempts. Giving up.", attempts);
return Err(e);
}
debug!("Transport send failed. Retrying in {}ms (Attempt {}/{})", delay, attempts, max_retries);
tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
delay *= 2; }
}
}
}
async fn read_and_decrypt_message(&mut self) -> Result<ProtocolMessage, DeonError> {
let frame = self.transport.receive().await?;
if frame.len() < HEADER_LEN {
return Err(DeonError::ProtocolViolation);
}
let header = WireHeader::from_bytes(&frame).ok_or(DeonError::ProtocolViolation)?;
if header.flags & FLAG_ENCRYPTED == 0 {
return Err(DeonError::ProtocolViolation);
}
let context = self.security_context.as_ref().ok_or(DeonError::InvalidState)?;
let ciphertext = &frame[HEADER_LEN..];
let plaintext = context.decrypt(ciphertext, &header.nonce, &[])?;
let msg = postcard::from_bytes(&plaintext)?;
Ok(msg)
}
}