use alloc::{
string::{String, ToString},
vec::Vec,
};
use base64::{Engine as _, engine::general_purpose::URL_SAFE};
use serde::{Serialize, de::DeserializeOwned};
use uuid::Uuid;
use super::{crypto::MythicCrypto, error::MythicMessageError};
pub const MYTHIC_UUID_LEN: usize = 36;
pub const MYTHIC_UUID_BIN_LEN: usize = 16;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum UuidEncoding {
Hyphenated,
#[allow(dead_code)]
Binary,
}
fn build_packet(uuid: Uuid, payload: &[u8], encoding: UuidEncoding) -> Vec<u8> {
let header_len = match encoding {
UuidEncoding::Hyphenated => MYTHIC_UUID_LEN,
UuidEncoding::Binary => MYTHIC_UUID_BIN_LEN,
};
let mut packet = Vec::with_capacity(header_len + payload.len());
match encoding {
UuidEncoding::Hyphenated => {
packet.extend_from_slice(uuid.hyphenated().to_string().as_bytes());
}
UuidEncoding::Binary => packet.extend_from_slice(uuid.as_bytes()),
}
packet.extend_from_slice(payload);
packet
}
fn parse_packet<'a>(
packet: &'a [u8],
expected_uuid: Option<Uuid>,
encoding: UuidEncoding,
) -> Result<(Uuid, &'a [u8]), MythicMessageError> {
let header_len = match encoding {
UuidEncoding::Hyphenated => MYTHIC_UUID_LEN,
UuidEncoding::Binary => MYTHIC_UUID_BIN_LEN,
};
if packet.len() < header_len {
return Err(MythicMessageError::InvalidPacket);
}
let (uuid_bytes, payload) = packet.split_at(header_len);
let uuid = match encoding {
UuidEncoding::Hyphenated => {
let uuid_str =
core::str::from_utf8(uuid_bytes).map_err(|_| MythicMessageError::Utf8)?;
Uuid::parse_str(uuid_str).map_err(|_| MythicMessageError::InvalidUuid)?
}
UuidEncoding::Binary => {
let mut arr = [0u8; MYTHIC_UUID_BIN_LEN];
arr.copy_from_slice(uuid_bytes);
Uuid::from_bytes(arr)
}
};
if expected_uuid.is_some_and(|expected| expected != uuid) {
return Err(MythicMessageError::UuidMismatch);
}
Ok((uuid, payload))
}
fn base64_decode(packed: &str) -> Result<Vec<u8>, MythicMessageError> {
URL_SAFE
.decode(packed.trim().as_bytes())
.map_err(|_| MythicMessageError::Base64Decode)
}
fn base64_encode(data: &[u8]) -> String {
URL_SAFE.encode(data)
}
pub fn encode_message<T: Serialize>(
msg: &T,
uuid: Uuid,
crypto: &impl MythicCrypto,
) -> Result<String, MythicMessageError> {
let json = serde_json::to_vec(msg).map_err(|_| MythicMessageError::Serialize)?;
let ciphertext = crypto.encrypt(&json)?;
Ok(base64_encode(&build_packet(
uuid,
&ciphertext,
UuidEncoding::Hyphenated,
)))
}
pub fn encode_message_plain<T: Serialize>(
msg: &T,
uuid: Uuid,
) -> Result<String, MythicMessageError> {
let json = serde_json::to_vec(msg).map_err(|_| MythicMessageError::Serialize)?;
Ok(base64_encode(&build_packet(
uuid,
&json,
UuidEncoding::Hyphenated,
)))
}
pub fn decode_message<T: DeserializeOwned>(
packed: &str,
expected_uuid: Option<Uuid>,
crypto: &impl MythicCrypto,
) -> Result<(Uuid, T), MythicMessageError> {
let packet = base64_decode(packed)?;
let (uuid, ciphertext) = parse_packet(&packet, expected_uuid, UuidEncoding::Hyphenated)?;
let plaintext = crypto.decrypt(ciphertext)?;
let msg = serde_json::from_slice(&plaintext).map_err(|_| MythicMessageError::Deserialize)?;
Ok((uuid, msg))
}
pub fn decode_message_plain<T: DeserializeOwned>(
packed: &str,
expected_uuid: Option<Uuid>,
) -> Result<(Uuid, T), MythicMessageError> {
let packet = base64_decode(packed)?;
let (uuid, payload) = parse_packet(&packet, expected_uuid, UuidEncoding::Hyphenated)?;
let msg = serde_json::from_slice(payload).map_err(|_| MythicMessageError::Deserialize)?;
Ok((uuid, msg))
}
pub trait MythicMessage: Serialize + DeserializeOwned + Sized {
fn to_wire(
&self,
uuid: Uuid,
crypto: &impl MythicCrypto,
) -> Result<String, MythicMessageError> {
encode_message(self, uuid, crypto)
}
fn to_wire_plain(&self, uuid: Uuid) -> Result<String, MythicMessageError> {
encode_message_plain(self, uuid)
}
fn from_wire(
packed: &str,
expected_uuid: Option<Uuid>,
crypto: &impl MythicCrypto,
) -> Result<(Uuid, Self), MythicMessageError> {
decode_message(packed, expected_uuid, crypto)
}
fn from_wire_plain(
packed: &str,
expected_uuid: Option<Uuid>,
) -> Result<(Uuid, Self), MythicMessageError> {
decode_message_plain(packed, expected_uuid)
}
}
impl<T: Serialize + DeserializeOwned + Sized> MythicMessage for T {}
#[cfg(test)]
mod tests {
use super::*;
use alloc::{string::ToString, vec, vec::Vec};
use serde::{Serialize, Serializer};
use crate::protocol::staging::{ReqStagingRSA, ReqStagingTranslation};
struct ReverseCrypto;
impl MythicCrypto for ReverseCrypto {
fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, MythicMessageError> {
let mut out = plaintext.to_vec();
out.reverse();
Ok(out)
}
fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, MythicMessageError> {
let mut out = ciphertext.to_vec();
out.reverse();
Ok(out)
}
}
struct FailingCrypto;
impl MythicCrypto for FailingCrypto {
fn encrypt(&self, _plaintext: &[u8]) -> Result<Vec<u8>, MythicMessageError> {
Err(MythicMessageError::Crypto)
}
fn decrypt(&self, _ciphertext: &[u8]) -> Result<Vec<u8>, MythicMessageError> {
Err(MythicMessageError::Crypto)
}
}
struct BrokenSerialize;
impl Serialize for BrokenSerialize {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
Err(serde::ser::Error::custom("broken"))
}
}
#[test]
fn plain_roundtrip() {
let uuid = Uuid::nil();
let req = ReqStagingRSA::new("pub".to_string(), "sid".to_string());
let packed = req.to_wire_plain(uuid).unwrap();
let (decoded_uuid, decoded_req) =
ReqStagingRSA::from_wire_plain(&packed, Some(uuid)).unwrap();
assert_eq!(decoded_uuid, uuid);
assert_eq!(decoded_req, req);
}
#[test]
fn encrypted_roundtrip() {
let uuid = Uuid::nil();
let req = ReqStagingTranslation::new(
"sid".to_string(),
"enc".to_string(),
"dec".to_string(),
"aes".to_string(),
uuid,
"hello".to_string(),
);
let packed = req.to_wire(uuid, &ReverseCrypto).unwrap();
let (decoded_uuid, decoded_req) =
ReqStagingTranslation::from_wire(&packed, Some(uuid), &ReverseCrypto).unwrap();
assert_eq!(decoded_uuid, uuid);
assert_eq!(decoded_req, req);
}
#[test]
fn encoding_error_paths() {
let uuid = Uuid::nil();
assert!(matches!(
decode_message::<ReqStagingRSA>("!!!", None, &ReverseCrypto),
Err(MythicMessageError::Base64Decode)
));
assert!(matches!(
decode_message_plain::<ReqStagingRSA>(&URL_SAFE.encode(b"short"), None),
Err(MythicMessageError::InvalidPacket)
));
let mut bad = vec![0xff; MYTHIC_UUID_LEN];
bad.extend_from_slice(b"payload");
assert!(matches!(
parse_packet(&bad, None, UuidEncoding::Hyphenated),
Err(MythicMessageError::Utf8)
));
let mut invalid_uuid = b"xxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx".to_vec();
invalid_uuid.extend_from_slice(b"payload");
assert!(matches!(
parse_packet(&invalid_uuid, None, UuidEncoding::Hyphenated),
Err(MythicMessageError::InvalidUuid)
));
let ok_packet = encode_message_plain(&ReqStagingRSA::new("p".into(), "s".into()), uuid).unwrap();
let other = Uuid::from_u128(7);
assert!(matches!(
decode_message_plain::<ReqStagingRSA>(&ok_packet, Some(other)),
Err(MythicMessageError::UuidMismatch)
));
}
#[test]
fn serialization_and_crypto_errors() {
let uuid = Uuid::nil();
assert!(matches!(
encode_message(&BrokenSerialize, uuid, &ReverseCrypto),
Err(MythicMessageError::Serialize)
));
assert!(matches!(
encode_message_plain(&BrokenSerialize, uuid),
Err(MythicMessageError::Serialize)
));
let mut packet = vec![0u8; MYTHIC_UUID_LEN];
packet[..MYTHIC_UUID_LEN].copy_from_slice(
"00000000-0000-0000-0000-000000000000".as_bytes(),
);
packet.extend_from_slice(b"not-json");
let encoded = base64_encode(&packet);
assert!(matches!(
decode_message_plain::<ReqStagingRSA>(&encoded, None),
Err(MythicMessageError::Deserialize)
));
let req = ReqStagingRSA::new("p".into(), "s".into());
assert!(matches!(
req.to_wire(uuid, &FailingCrypto),
Err(MythicMessageError::Crypto)
));
let ok_packet =
encode_message_plain(&ReqStagingRSA::new("p".into(), "s".into()), uuid).unwrap();
assert!(matches!(
decode_message::<ReqStagingRSA>(&ok_packet, None, &FailingCrypto),
Err(MythicMessageError::Crypto)
));
}
#[test]
fn trait_methods_match_free_functions() {
let uuid = Uuid::from_u128(2);
let msg = ReqStagingTranslation::new(
"sid".into(), "enc".into(), "dec".into(), "aes".into(), uuid, "hello".into(),
);
let packed_fn = encode_message(&msg, uuid, &ReverseCrypto).unwrap();
let packed_trait = msg.to_wire(uuid, &ReverseCrypto).unwrap();
assert_eq!(packed_fn, packed_trait);
let (uuid_fn, msg_fn): (Uuid, ReqStagingTranslation) =
decode_message(&packed_fn, Some(uuid), &ReverseCrypto).unwrap();
let (uuid_trait, msg_trait) =
ReqStagingTranslation::from_wire(&packed_trait, Some(uuid), &ReverseCrypto).unwrap();
assert_eq!(uuid_fn, uuid_trait);
assert_eq!(msg_fn, msg_trait);
}
#[test]
fn plain_trait_methods_match() {
let uuid = Uuid::from_u128(9);
let msg = ReqStagingRSA::new("pub".into(), "sid".into());
let packed_fn = encode_message_plain(&msg, uuid).unwrap();
let packed_trait = msg.to_wire_plain(uuid).unwrap();
assert_eq!(packed_fn, packed_trait);
let (uuid_fn, msg_fn): (Uuid, ReqStagingRSA) =
decode_message_plain(&packed_fn, Some(uuid)).unwrap();
let (uuid_trait, msg_trait) =
ReqStagingRSA::from_wire_plain(&packed_trait, Some(uuid)).unwrap();
assert_eq!(uuid_fn, uuid_trait);
assert_eq!(msg_fn, msg_trait);
}
#[test]
fn binary_encoding_roundtrip() {
let uuid = Uuid::from_u128(1);
let payload = b"binary payload".to_vec();
let packet = build_packet(uuid, &payload, UuidEncoding::Binary);
let (decoded_uuid, decoded_payload) =
parse_packet(&packet, Some(uuid), UuidEncoding::Binary).unwrap();
assert_eq!(decoded_uuid, uuid);
assert_eq!(decoded_payload, payload);
}
}