use anyhow::{Context, Result};
use bytes::{Buf, BufMut};
use prost::Message as ProstMessage;
use prost::encoding::{
DecodeContext, WireType,
bytes::{encode as encode_bytes, encoded_len as len_bytes, merge as merge_bytes},
int32::{encode as encode_int32, encoded_len as len_int32, merge as merge_int32},
skip_field,
string::{encode as encode_string, encoded_len as len_string, merge as merge_string},
uint64::{encode as encode_u64, encoded_len as len_u64, merge as merge_uint64},
};
use rand::TryRngCore;
use rand::rngs::OsRng;
#[cfg(feature = "ser")]
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
use tracing::{debug, instrument};
#[cfg_attr(feature = "ser", derive(Serialize, Deserialize))]
#[derive(Clone, Copy, PartialEq, Eq, Debug, Default)]
pub enum MessageType {
#[default]
Unknown = 0,
Ping = 1,
Broadcast = 2,
FileTransfer = 3,
Command = 4,
DelegateTask = 5,
RegisterKey = 6,
Create = 7,
Terminate = 8,
Run = 9,
Reply = 10,
}
impl MessageType {
pub fn from_i32(value: i32) -> Self {
match value {
1 => MessageType::Ping,
2 => MessageType::Broadcast,
3 => MessageType::FileTransfer,
4 => MessageType::Command,
5 => MessageType::DelegateTask,
6 => MessageType::RegisterKey,
7 => MessageType::Create,
8 => MessageType::Terminate,
9 => MessageType::Run,
10 => MessageType::Reply,
_ => MessageType::Unknown,
}
}
pub fn as_i32(&self) -> i32 {
*self as i32
}
}
impl From<&str> for MessageType {
fn from(value: &str) -> Self {
match value.to_ascii_lowercase().as_str() {
"ping" => MessageType::Ping,
"broadcast" => MessageType::Broadcast,
"filetransfer" | "file_transfer" => MessageType::FileTransfer,
"command" => MessageType::Command,
"delegatetask" | "delegate_task" => MessageType::DelegateTask,
"registerkey" | "register_key" => MessageType::RegisterKey,
"create" => MessageType::Create,
"terminate" => MessageType::Terminate,
"run" => MessageType::Run,
"reply" => MessageType::Reply,
_ => MessageType::Unknown,
}
}
}
#[cfg_attr(feature = "ser", derive(Serialize, Deserialize))]
#[derive(Clone, PartialEq, Debug, Default)]
pub struct Message {
pub from: String,
pub to: String,
pub msg_type: MessageType,
pub payload_json: String,
pub timestamp: u64,
pub msg_id: u64,
pub session_id: u64,
pub signature: Vec<u8>,
pub extra_data: Vec<u8>,
}
impl ProstMessage for Message {
fn encode_raw(&self, buf: &mut impl BufMut) {
encode_string(1, &self.from, buf);
encode_string(2, &self.to, buf);
encode_int32(3, &self.msg_type.as_i32(), buf);
encode_string(4, &self.payload_json, buf);
encode_u64(5, &self.timestamp, buf);
encode_u64(6, &self.msg_id, buf);
encode_u64(7, &self.session_id, buf);
encode_bytes(8, &self.signature, buf);
encode_bytes(9, &self.extra_data, buf);
}
fn merge_field(
&mut self,
tag: u32,
wire_type: WireType,
buf: &mut impl Buf,
ctx: DecodeContext,
) -> core::result::Result<(), prost::DecodeError> {
match tag {
1 => merge_string(wire_type, &mut self.from, buf, ctx),
2 => merge_string(wire_type, &mut self.to, buf, ctx),
3 => {
let mut raw = 0i32;
merge_int32(wire_type, &mut raw, buf, ctx)?;
self.msg_type = MessageType::from_i32(raw);
Ok(())
}
4 => merge_string(wire_type, &mut self.payload_json, buf, ctx),
5 => merge_uint64(wire_type, &mut self.timestamp, buf, ctx),
6 => merge_uint64(wire_type, &mut self.msg_id, buf, ctx),
7 => merge_uint64(wire_type, &mut self.session_id, buf, ctx),
8 => merge_bytes(wire_type, &mut self.signature, buf, ctx),
9 => merge_bytes(wire_type, &mut self.extra_data, buf, ctx),
_ => skip_field(wire_type, tag, buf, ctx),
}
}
fn encoded_len(&self) -> usize {
len_string(1, &self.from)
+ len_string(2, &self.to)
+ len_int32(3, &self.msg_type.as_i32())
+ len_string(4, &self.payload_json)
+ len_u64(5, &self.timestamp)
+ len_u64(6, &self.msg_id)
+ len_u64(7, &self.session_id)
+ len_bytes(8, &self.signature)
+ len_bytes(9, &self.extra_data)
}
fn clear(&mut self) {
*self = Self::default();
}
}
impl Message {
pub fn new(from: &str, to: &str, msg_type: MessageType, payload_json: &str) -> Self {
Self {
from: from.to_string(),
to: to.to_string(),
msg_type,
payload_json: payload_json.to_string(),
timestamp: curr_time(),
msg_id: gen_msg_id(),
session_id: 0,
signature: vec![],
extra_data: vec![],
}
}
#[instrument(skip_all, fields(msg_id = self.msg_id, msg_type = ?self.msg_type))]
pub fn serialize(&self) -> Result<Vec<u8>> {
let mut buf = Vec::with_capacity(self.encoded_len());
self.encode(&mut buf)
.map_err(|e| anyhow::anyhow!("Failed to encode Message: {}", e))?;
debug!(bytes = buf.len(), "✅ Message serialized");
Ok(buf)
}
#[instrument(skip_all)]
pub fn deserialize(bytes: &[u8]) -> Result<Self> {
let msg = Message::decode(bytes)
.map_err(|e| anyhow::anyhow!("Failed to decode Message: {}", e))?;
debug!(
msg_id = msg.msg_id,
msg_type = ?msg.msg_type,
"📥 Message deserialized"
);
Ok(msg)
}
#[instrument(skip_all, fields(msg_id = self.msg_id))]
pub fn sign(&mut self, signer: &crate::crypto::Signer) -> Result<()> {
let mut copy = self.clone();
copy.signature = vec![];
let data = copy
.serialize()
.context("Failed to serialize message for signing")?;
self.signature = signer.sign(&data).context("Failed to sign message")?;
debug!(sig_len = self.signature.len(), "✍️ Message signed");
Ok(())
}
#[instrument(skip_all, fields(msg_id = self.msg_id))]
pub fn verify(&self, verifier: &crate::crypto::Verifier) -> Result<()> {
let mut copy = self.clone();
copy.signature = vec![];
let data = copy
.serialize()
.context("Failed to serialize message for verification")?;
verifier
.verify(&data, &self.signature)
.context("Signature verification failed")?;
debug!("🔐 Message signature verified");
Ok(())
}
#[instrument(skip_all, fields(from = from, to = to))]
pub fn ping(from: &str, to: &str, session_id: u64) -> Self {
let timestamp = curr_time();
let msg_id = gen_msg_id();
let msg = Message {
from: from.to_string(),
to: to.to_string(),
msg_type: MessageType::Ping,
payload_json: "".to_string(),
timestamp,
msg_id,
session_id,
signature: vec![],
extra_data: vec![],
};
debug!(
msg_id = msg.msg_id,
msg_type = ?msg.msg_type,
"📡 Created PING message"
);
msg
}
#[instrument(skip_all, fields(from = from))]
pub fn broadcast(from: &str, payload_json: &str, session_id: u64) -> Self {
let timestamp = curr_time();
let msg_id = gen_msg_id();
let msg = Message {
from: from.to_string(),
to: "".to_string(),
msg_type: MessageType::Broadcast,
payload_json: payload_json.to_string(),
timestamp,
msg_id,
session_id,
signature: vec![],
extra_data: vec![],
};
debug!(
msg_id = msg.msg_id,
msg_type = ?msg.msg_type,
payload_len = payload_json.len(),
"📢 Created BROADCAST message"
);
msg
}
#[instrument(skip_all, fields(from = from))]
pub fn reply(from: &str, to: &str, payload_json: &str, session_id: u64) -> Self {
let timestamp = curr_time();
let msg_id = gen_msg_id();
Message {
from: from.to_string(),
to: to.to_string(),
msg_type: MessageType::Reply,
payload_json: payload_json.to_string(),
timestamp,
msg_id,
session_id,
signature: vec![],
extra_data: vec![],
}
}
}
fn curr_time() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn gen_msg_id() -> u64 {
OsRng
.try_next_u64()
.expect("Secure RNG failed to initialize")
}