use crate::{COUNT_LEN, CsError, EventType, PUB_KEY_LEN, TIMESTAMP_LEN, UID_LEN, crypto::Crypto};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use x25519_dalek::PublicKey;
pub fn build_associated_data(uid: &[u8; UID_LEN], event_type: u8) -> [u8; UID_LEN + 1] {
let mut associated_data = [0; UID_LEN + 1];
associated_data[..UID_LEN].copy_from_slice(uid);
associated_data[UID_LEN] = event_type;
associated_data
}
pub struct Encoder;
impl Encoder {
pub fn encrypted<C: Crypto>(
session_crypto: &C,
count: u64,
uid: &[u8; UID_LEN],
buf: &mut Vec<u8>,
) -> Result<(), CsError> {
buf.reserve(COUNT_LEN + C::ADDITION_LEN + UID_LEN + 1);
buf.extend_from_slice(&count.to_le_bytes());
let associated_data = build_associated_data(uid, EventType::Encrypted);
session_crypto.encrypt(&associated_data, buf)?;
buf.extend_from_slice(uid);
buf.push(EventType::Encrypted);
Ok(())
}
pub fn hello() -> [u8; 64] {
let mut buf = [0; 64];
*buf.last_mut().unwrap() = EventType::Hello;
buf
}
pub fn ack_hello<C: Crypto>(server_salt: &C::Salt, buf: &mut Vec<u8>) {
buf.clear();
buf.reserve(C::SALT_LEN + 1);
buf.extend_from_slice(server_salt.as_ref());
buf.push(EventType::AckHello);
}
pub fn connect<C: Crypto>(
server_crypto: &C,
client_pub: &[u8; PUB_KEY_LEN],
ttl: Duration,
uid: &[u8; UID_LEN],
buf: &mut Vec<u8>,
) -> Result<(), CsError> {
buf.clear();
buf.reserve(PUB_KEY_LEN + size_of::<u64>() + TIMESTAMP_LEN + UID_LEN + C::ADDITION_LEN + 1);
buf.extend_from_slice(client_pub);
buf.extend_from_slice(&ttl.as_secs().to_le_bytes());
let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
buf.extend_from_slice(×tamp.to_le_bytes());
buf.extend_from_slice(uid);
server_crypto.encrypt(&[EventType::Connect], buf)?;
buf.push(EventType::Connect);
Ok(())
}
pub fn ack_connect<C: Crypto>(
server_crypto: &C,
server_pub: &[u8; PUB_KEY_LEN],
uid: &[u8; UID_LEN],
buf: &mut Vec<u8>,
) -> Result<(), CsError> {
buf.clear();
buf.reserve(PUB_KEY_LEN + C::ADDITION_LEN + 1);
buf.extend_from_slice(server_pub);
let associated_data = build_associated_data(uid, EventType::AckConnect);
server_crypto.encrypt(&associated_data, buf)?;
buf.push(EventType::AckConnect);
Ok(())
}
pub fn heartbeat<C: Crypto>(
session_crypto: &C,
count: u64,
uid: &[u8; UID_LEN],
buf: &mut Vec<u8>,
) -> Result<(), CsError> {
buf.clear();
buf.reserve(COUNT_LEN + C::ADDITION_LEN + UID_LEN + 1);
buf.extend_from_slice(&count.to_le_bytes());
let associated_data = build_associated_data(uid, EventType::Heartbeat);
session_crypto.encrypt(&associated_data, buf)?;
buf.extend_from_slice(uid);
buf.push(EventType::Heartbeat);
Ok(())
}
pub fn ack_heartbeat<C: Crypto>(
session_crypto: &C,
count: u64,
uid: &[u8; UID_LEN],
buf: &mut Vec<u8>,
) -> Result<(), CsError> {
buf.clear();
buf.reserve(COUNT_LEN + C::ADDITION_LEN + UID_LEN + 1);
buf.extend_from_slice(&count.to_le_bytes());
let associated_data = build_associated_data(uid, EventType::AckHeartbeat);
session_crypto.encrypt(&associated_data, buf)?;
buf.extend_from_slice(uid);
buf.push(EventType::AckHeartbeat);
Ok(())
}
}
pub struct Decoder;
impl Decoder {
pub fn peek_uid(buf: &[u8]) -> Result<[u8; UID_LEN], CsError> {
if buf.len() < UID_LEN + 1 {
return Err(CsError::InvalidFormat);
}
let uid_start = buf.len() - UID_LEN - 1;
let uid: [u8; UID_LEN] = buf[uid_start..uid_start + UID_LEN].try_into().unwrap();
Ok(uid)
}
pub fn encrypted<C: Crypto>(
session_crypto: &C,
buf: &mut Vec<u8>,
) -> Result<(u64, [u8; UID_LEN]), CsError> {
if !matches!(
buf.last(),
Some(&EventType::Encrypted | &EventType::Heartbeat | &EventType::AckHeartbeat)
) {
return Err(CsError::InvalidType(buf.last().cloned()));
}
if buf.len() < COUNT_LEN + C::ADDITION_LEN + UID_LEN + 1 {
return Err(CsError::InvalidFormat);
}
let event_type = buf.pop().unwrap();
let uid_start = buf.len() - UID_LEN;
let uid: [u8; UID_LEN] = buf[uid_start..].try_into().unwrap();
buf.truncate(uid_start);
let associated_data = build_associated_data(&uid, event_type);
session_crypto.decrypt(&associated_data, buf)?;
let count_start = buf.len() - COUNT_LEN;
let count = u64::from_le_bytes(buf[count_start..].try_into().unwrap());
buf.truncate(count_start);
Ok((count, uid))
}
pub fn hello(buf: &[u8]) -> Result<(), CsError> {
if buf.last() != Some(&EventType::Hello) {
return Err(CsError::InvalidType(buf.last().cloned()));
}
if buf[..buf.len() - 1] == [0; 63] {
Ok(())
} else {
Err(CsError::InvalidFormat)
}
}
pub fn ack_hello<C: Crypto>(buf: &[u8]) -> Result<C::Salt, CsError> {
if buf.last() != Some(&EventType::AckHello) {
return Err(CsError::InvalidType(buf.last().cloned()));
}
if buf.len() != C::SALT_LEN + 1 {
return Err(CsError::InvalidFormat);
}
let mut salt = C::Salt::default();
salt.as_mut().copy_from_slice(&buf[..C::SALT_LEN]);
Ok(salt)
}
pub fn connect<C: Crypto>(
server_crypto: &C,
buf: &mut Vec<u8>,
) -> Result<(PublicKey, Duration, u64, [u8; UID_LEN]), CsError> {
if buf.last() != Some(&EventType::Connect) {
return Err(CsError::InvalidType(buf.last().cloned()));
}
if buf.len()
!= PUB_KEY_LEN + size_of::<u64>() + TIMESTAMP_LEN + UID_LEN + C::ADDITION_LEN + 1
{
return Err(CsError::InvalidFormat);
}
buf.pop();
server_crypto.decrypt(&[EventType::Connect], buf)?;
let uid_start = buf.len() - UID_LEN;
let uid: [u8; UID_LEN] = buf[uid_start..].try_into().unwrap();
buf.truncate(uid_start);
let ts_start = buf.len() - TIMESTAMP_LEN;
let timestamp = u64::from_le_bytes(buf[ts_start..].try_into().unwrap());
buf.truncate(ts_start);
let ttl = u64::from_le_bytes(buf[PUB_KEY_LEN..].try_into().unwrap());
let ttl = Duration::from_secs(ttl);
let client_pub: [u8; PUB_KEY_LEN] = buf[..PUB_KEY_LEN].try_into().unwrap();
let client_pub = PublicKey::from(client_pub);
Ok((client_pub, ttl, timestamp, uid))
}
pub fn ack_connect<C: Crypto>(
server_crypto: &C,
uid: &[u8; UID_LEN],
buf: &mut Vec<u8>,
) -> Result<PublicKey, CsError> {
if buf.last() != Some(&EventType::AckConnect) {
return Err(CsError::InvalidType(buf.last().cloned()));
}
if buf.len() != PUB_KEY_LEN + C::ADDITION_LEN + 1 {
return Err(CsError::InvalidFormat);
}
buf.pop();
let associated_data = build_associated_data(uid, EventType::AckConnect);
server_crypto.decrypt(&associated_data, buf)?;
let server_pub: [u8; PUB_KEY_LEN] = buf[..].try_into().unwrap();
let server_pub = PublicKey::from(server_pub);
Ok(server_pub)
}
}