use std::io::{self, BufReader, Write};
use std::net::TcpStream;
use super::crypto;
use super::{MessageType, RelayMessage, epoch_ms, gen_msg_id};
const MAX_LINE_SIZE: usize = 1_048_576;
pub fn write_message(stream: &mut TcpStream, msg: &RelayMessage) -> io::Result<()> {
let json =
serde_json::to_string(msg).map_err(|e| io::Error::other(format!("serialize: {e}")))?;
let line = format!("{json}\n");
stream.write_all(line.as_bytes())?;
stream.flush()
}
pub fn read_message(reader: &mut BufReader<TcpStream>) -> io::Result<Option<RelayMessage>> {
let mut line = Vec::with_capacity(4096);
let mut byte = [0u8; 1];
loop {
use std::io::Read;
match reader.read(&mut byte) {
Ok(0) => {
if line.is_empty() {
return Ok(None); }
break; }
Ok(_) => {
if byte[0] == b'\n' {
break;
}
line.push(byte[0]);
if line.len() > MAX_LINE_SIZE {
return Err(io::Error::other("message exceeds 1MB size limit"));
}
}
Err(e) => return Err(e),
}
}
if line.is_empty() {
return Ok(None);
}
let text = String::from_utf8(line).map_err(|e| io::Error::other(format!("utf8: {e}")))?;
let msg: RelayMessage =
serde_json::from_str(text.trim()).map_err(|e| io::Error::other(format!("parse: {e}")))?;
Ok(Some(msg))
}
pub fn send_challenge(stream: &mut TcpStream) -> io::Result<String> {
let nonce = crypto::random_hex(32);
let msg = RelayMessage {
id: gen_msg_id(),
msg_type: MessageType::Challenge,
from_peer: String::new(), timestamp: epoch_ms(),
payload: serde_json::json!({ "nonce": nonce }),
};
write_message(stream, &msg)?;
Ok(nonce)
}
pub fn verify_handshake(msg: &RelayMessage, nonce: &str, psk: &[u8; 32]) -> Result<String, String> {
if msg.msg_type != MessageType::Handshake {
return Err("expected handshake message".into());
}
let proof = msg
.payload
.get("proof")
.and_then(|v| v.as_str())
.ok_or("missing proof field")?;
let expected = crypto::hmac_sha256(psk, nonce.as_bytes());
let expected_hex = crypto::hex_encode(&expected);
if proof != expected_hex {
return Err("HMAC verification failed".into());
}
Ok(msg.from_peer.clone())
}
pub fn send_handshake_ack(stream: &mut TcpStream, identity: &str, status: &str) -> io::Result<()> {
let msg = RelayMessage {
id: gen_msg_id(),
msg_type: MessageType::HandshakeAck,
from_peer: identity.to_string(),
timestamp: epoch_ms(),
payload: serde_json::json!({ "status": status }),
};
write_message(stream, &msg)
}
pub fn compute_proof(nonce: &str, psk: &[u8; 32]) -> String {
let mac = crypto::hmac_sha256(psk, nonce.as_bytes());
crypto::hex_encode(&mac)
}
pub fn send_handshake(
stream: &mut TcpStream,
identity: &str,
nonce: &str,
psk: &[u8; 32],
) -> io::Result<()> {
let proof = compute_proof(nonce, psk);
let msg = RelayMessage {
id: gen_msg_id(),
msg_type: MessageType::Handshake,
from_peer: identity.to_string(),
timestamp: epoch_ms(),
payload: serde_json::json!({
"proof": proof,
"version": env!("CARGO_PKG_VERSION"),
}),
};
write_message(stream, &msg)
}
pub fn await_handshake_ack(reader: &mut BufReader<TcpStream>) -> Result<String, String> {
let msg = read_message(reader)
.map_err(|e| format!("read ack: {e}"))?
.ok_or("connection closed before ack")?;
if msg.msg_type != MessageType::HandshakeAck {
return Err(format!("expected handshake_ack, got {:?}", msg.msg_type));
}
let status = msg
.payload
.get("status")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
if status == "ok" {
Ok(msg.from_peer)
} else {
Err(format!("handshake denied: {status}"))
}
}
pub fn heartbeat_message(identity: &str) -> RelayMessage {
RelayMessage {
id: gen_msg_id(),
msg_type: MessageType::Heartbeat,
from_peer: identity.to_string(),
timestamp: epoch_ms(),
payload: serde_json::json!({}),
}
}
pub fn heartbeat_with_sessions(identity: &str, sessions: &[serde_json::Value]) -> RelayMessage {
RelayMessage {
id: gen_msg_id(),
msg_type: MessageType::Heartbeat,
from_peer: identity.to_string(),
timestamp: epoch_ms(),
payload: serde_json::json!({
"worker_id": identity,
"timestamp": epoch_ms(),
"sessions": sessions,
}),
}
}
pub fn ack_message(identity: &str, original_id: &str) -> RelayMessage {
RelayMessage {
id: gen_msg_id(),
msg_type: MessageType::Ack,
from_peer: identity.to_string(),
timestamp: epoch_ms(),
payload: serde_json::json!({ "ack_id": original_id }),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn challenge_response_flow() {
let psk = crypto::generate_psk();
let nonce = crypto::random_hex(32);
let proof = compute_proof(&nonce, &psk);
let msg = RelayMessage {
id: "test".into(),
msg_type: MessageType::Handshake,
from_peer: "client-1".into(),
timestamp: 0,
payload: serde_json::json!({ "proof": proof, "version": "0.35.0" }),
};
let result = verify_handshake(&msg, &nonce, &psk);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "client-1");
}
#[test]
fn bad_proof_rejected() {
let psk = crypto::generate_psk();
let nonce = crypto::random_hex(32);
let msg = RelayMessage {
id: "test".into(),
msg_type: MessageType::Handshake,
from_peer: "client-1".into(),
timestamp: 0,
payload: serde_json::json!({ "proof": "deadbeef", "version": "0.35.0" }),
};
let result = verify_handshake(&msg, &nonce, &psk);
assert!(result.is_err());
}
#[test]
fn wrong_message_type_rejected() {
let psk = crypto::generate_psk();
let nonce = crypto::random_hex(32);
let proof = compute_proof(&nonce, &psk);
let msg = RelayMessage {
id: "test".into(),
msg_type: MessageType::Heartbeat, from_peer: "client-1".into(),
timestamp: 0,
payload: serde_json::json!({ "proof": proof }),
};
let result = verify_handshake(&msg, &nonce, &psk);
assert!(result.is_err());
}
#[test]
fn heartbeat_message_valid() {
let msg = heartbeat_message("test-peer");
assert_eq!(msg.msg_type, MessageType::Heartbeat);
assert_eq!(msg.from_peer, "test-peer");
assert!(msg.timestamp > 0);
}
#[test]
fn ack_message_valid() {
let msg = ack_message("test-peer", "msg_123");
assert_eq!(msg.msg_type, MessageType::Ack);
assert_eq!(
msg.payload.get("ack_id").and_then(|v| v.as_str()),
Some("msg_123")
);
}
#[test]
fn heartbeat_with_sessions_includes_payload() {
let sessions = vec![
serde_json::json!({"pid": 1234, "project": "backend", "status": "Processing"}),
serde_json::json!({"pid": 5678, "project": "frontend", "status": "Idle"}),
];
let msg = heartbeat_with_sessions("worker-01", &sessions);
assert_eq!(msg.msg_type, MessageType::Heartbeat);
assert_eq!(msg.from_peer, "worker-01");
assert_eq!(
msg.payload.get("worker_id").and_then(|v| v.as_str()),
Some("worker-01")
);
let payload_sessions = msg.payload.get("sessions").and_then(|v| v.as_array());
assert!(payload_sessions.is_some());
assert_eq!(payload_sessions.unwrap().len(), 2);
}
#[test]
fn heartbeat_with_empty_sessions() {
let msg = heartbeat_with_sessions("worker-02", &[]);
let sessions = msg.payload.get("sessions").and_then(|v| v.as_array());
assert!(sessions.is_some());
assert_eq!(sessions.unwrap().len(), 0);
}
}