use std::future::Future;
use anyhow::Context as _;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _};
use crate::base::{Constant, Res, SessionPath, Visibility};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Envelope {
pub ciphertext: Vec<u8>,
pub key_id: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum Payload {
Plain(String),
Encrypted(Envelope),
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ChannelInfo {
pub name: String,
pub visibility: Visibility,
pub member: bool,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct MachineInfo {
pub name: String,
pub pubkey: String,
pub added_at: String,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct InviteInfo {
pub token: String,
pub uses_remaining: Option<i64>,
pub expires_at: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum AdminOp {
CreateChannel {
name: String,
visibility: Visibility,
},
DeleteChannel {
name: String,
},
RenameChannel {
name: String,
new_name: String,
},
SetVisibility {
name: String,
visibility: Visibility,
},
AclAdd {
channel: String,
user: String,
},
AclRemove {
channel: String,
user: String,
},
InviteCreate {
channel: String,
uses: Option<u32>,
expires_in_secs: Option<u64>,
},
InviteRevoke {
token: String,
},
Kick {
channel: String,
target: String,
},
Ban {
channel: String,
user: String,
},
UserRemove {
username: String,
},
MachineRemove {
name: String,
},
MachineAdd {
name: String,
pubkey: Vec<u8>,
},
AclList {
channel: String,
},
Unban {
channel: String,
user: String,
},
BanList {
channel: String,
},
InviteList {
channel: String,
},
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum ProtocolMessage {
Hello {
protocol_version: u32,
session: String,
},
Challenge {
nonce: Vec<u8>,
},
Auth {
pubkey: Vec<u8>,
signature: Vec<u8>,
},
Established {
path: SessionPath,
},
Register {
username: String,
machine: String,
pubkey: Vec<u8>,
},
Join {
channel: String,
token: Option<String>,
},
Leave {
channel: String,
},
Who {
channel: Option<String>,
},
Admin(AdminOp),
ChannelMsg {
channel: String,
from: SessionPath,
payload: Payload,
},
Whisper {
from: SessionPath,
target: SessionPath,
payload: Payload,
},
Presence {
channel: Option<String>,
sessions: Vec<SessionPath>,
},
Error(ProtocolError),
ListChannels,
ChannelList {
channels: Vec<ChannelInfo>,
},
Joined {
channel: String,
},
Ack {
detail: Option<String>,
},
InviteToken {
token: String,
},
Ping,
Pong,
ServerInfo {
admin: bool,
},
ListMachines,
MachineList {
machines: Vec<MachineInfo>,
},
ListUsers,
UserList {
users: Vec<String>,
},
InviteList {
invites: Vec<InviteInfo>,
},
}
#[allow(clippy::module_name_repetitions)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, thiserror::Error)]
pub enum ProtocolError {
#[error("incompatible protocol version: ours={ours}, theirs={theirs}")]
VersionMismatch {
ours: u32,
theirs: u32,
},
#[error("malformed frame: {0}")]
MalformedFrame(String),
#[error("unauthorized: {0}")]
Unauthorized(String),
#[error("not found: {0}")]
NotFound(String),
#[error("internal error: {0}")]
Internal(String),
}
pub fn negotiate_version(theirs: u32) -> Result<u32, ProtocolError> {
if theirs == Constant::PROTOCOL_VERSION {
Ok(Constant::PROTOCOL_VERSION)
} else {
Err(ProtocolError::VersionMismatch { ours: Constant::PROTOCOL_VERSION, theirs })
}
}
pub fn encode(message: &ProtocolMessage) -> Res<Vec<u8>> {
bincode::serde::encode_to_vec(message, bincode::config::standard()).context("failed to encode protocol frame")
}
pub fn decode(bytes: &[u8]) -> Res<ProtocolMessage> {
let (message, _) = bincode::serde::decode_from_slice(bytes, bincode::config::standard()).context("failed to decode protocol frame")?;
Ok(message)
}
pub trait ProtocolWrite: AsyncWrite + Unpin {
fn send_message(&mut self, message: &ProtocolMessage) -> impl Future<Output = Res<()>> {
async move {
let body = encode(message)?;
let len = u32::try_from(body.len()).context("protocol frame exceeds u32 length")?;
self.write_all(&len.to_be_bytes()).await?;
self.write_all(&body).await?;
self.flush().await?;
Ok(())
}
}
}
impl<T: AsyncWrite + Unpin + ?Sized> ProtocolWrite for T {}
pub trait ProtocolRead: AsyncRead + Unpin {
fn recv_message(&mut self) -> impl Future<Output = Res<ProtocolMessage>> {
async move {
let mut len_buf = [0_u8; 4];
self.read_exact(&mut len_buf).await?;
let len = usize::try_from(u32::from_be_bytes(len_buf)).context("frame length overflow")?;
anyhow::ensure!(len <= Constant::MAX_FRAME_SIZE, "protocol frame of {len} bytes exceeds the {} byte cap", Constant::MAX_FRAME_SIZE);
let mut body = vec![0_u8; len];
self.read_exact(&mut body).await?;
decode(&body)
}
}
}
impl<T: AsyncRead + Unpin + ?Sized> ProtocolRead for T {}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use crate::tests::duplex;
use pretty_assertions::assert_eq;
fn assert_round_trips(message: &ProtocolMessage) {
let bytes = encode(message).unwrap();
assert_eq!(&decode(&bytes).unwrap(), message);
}
#[test]
fn hello_round_trips_with_version_field() {
assert_round_trips(&ProtocolMessage::Hello {
protocol_version: Constant::PROTOCOL_VERSION,
session: "razel".to_owned(),
});
}
#[test]
fn channel_message_round_trips_plaintext() {
assert_round_trips(&ProtocolMessage::ChannelMsg {
channel: "ops".to_owned(),
from: SessionPath::new("aaron", "workstation", "razel"),
payload: Payload::Plain("hello, agents".to_owned()),
});
}
#[test]
fn data_frame_round_trips_the_reserved_e2e_envelope() {
assert_round_trips(&ProtocolMessage::Whisper {
from: SessionPath::new("aaron", "workstation", "razel"),
target: SessionPath::new("david", "desktop", "main"),
payload: Payload::Encrypted(Envelope {
ciphertext: vec![0xDE, 0xAD, 0xBE, 0xEF],
key_id: Some("channel-key-1".to_owned()),
}),
});
}
#[test]
fn admin_op_round_trips() {
assert_round_trips(&ProtocolMessage::Admin(AdminOp::CreateChannel {
name: "ops".to_owned(),
visibility: Visibility::Private,
}));
}
#[test]
fn machine_add_admin_op_round_trips() {
assert_round_trips(&ProtocolMessage::Admin(AdminOp::MachineAdd {
name: "sno-box".to_owned(),
pubkey: vec![1, 2, 3, 4],
}));
}
#[test]
fn acl_list_admin_op_round_trips() {
assert_round_trips(&ProtocolMessage::Admin(AdminOp::AclList { channel: "ops".to_owned() }));
}
#[test]
fn invite_list_round_trips_op_and_response() {
assert_round_trips(&ProtocolMessage::Admin(AdminOp::InviteList { channel: "ops".to_owned() }));
assert_round_trips(&ProtocolMessage::InviteList {
invites: vec![InviteInfo {
token: "tok".to_owned(),
uses_remaining: Some(3),
expires_at: None,
}],
});
}
#[test]
fn ban_visibility_admin_ops_round_trip() {
assert_round_trips(&ProtocolMessage::Admin(AdminOp::Unban {
channel: "ops".to_owned(),
user: "bob".to_owned(),
}));
assert_round_trips(&ProtocolMessage::Admin(AdminOp::BanList { channel: "ops".to_owned() }));
}
#[test]
fn m2_response_frames_round_trip() {
assert_round_trips(&ProtocolMessage::ListChannels);
assert_round_trips(&ProtocolMessage::ChannelList {
channels: vec![ChannelInfo {
name: "ops".to_owned(),
visibility: Visibility::Private,
member: true,
}],
});
assert_round_trips(&ProtocolMessage::Joined { channel: "ops".to_owned() });
assert_round_trips(&ProtocolMessage::Ack { detail: Some("ops".to_owned()) });
assert_round_trips(&ProtocolMessage::InviteToken { token: "tok-abc".to_owned() });
assert_round_trips(&ProtocolMessage::Ping);
assert_round_trips(&ProtocolMessage::Pong);
}
#[test]
fn m4_frames_round_trip() {
assert_round_trips(&ProtocolMessage::ServerInfo { admin: true });
assert_round_trips(&ProtocolMessage::ListMachines);
assert_round_trips(&ProtocolMessage::MachineList {
machines: vec![MachineInfo {
name: "workstation".to_owned(),
pubkey: "PUBKEY".to_owned(),
added_at: "2026-07-02T00:00:00Z".to_owned(),
}],
});
assert_round_trips(&ProtocolMessage::ListUsers);
assert_round_trips(&ProtocolMessage::UserList {
users: vec!["aaron".to_owned(), "david".to_owned()],
});
}
#[test]
fn appending_variants_preserves_existing_wire_indices() {
let hello = ProtocolMessage::Hello {
protocol_version: Constant::PROTOCOL_VERSION,
session: "razel".to_owned(),
};
assert_eq!(encode(&hello).unwrap()[0], 0, "the first variant's discriminant must remain 0");
}
#[test]
fn error_frame_round_trips() {
assert_round_trips(&ProtocolMessage::Error(ProtocolError::VersionMismatch { ours: 1, theirs: 2 }));
}
#[tokio::test]
async fn frames_stream_over_an_async_duplex() {
let (mut a, mut b) = duplex();
let sent = ProtocolMessage::Presence {
channel: Some("ops".to_owned()),
sessions: vec![SessionPath::new("aaron", "workstation", "razel"), SessionPath::new("david", "desktop", "main")],
};
a.send_message(&sent).await.unwrap();
let got = b.recv_message().await.unwrap();
assert_eq!(got, sent);
}
#[test]
fn version_negotiation_accepts_matching_and_rejects_mismatch() {
assert_eq!(negotiate_version(Constant::PROTOCOL_VERSION).unwrap(), Constant::PROTOCOL_VERSION);
assert_eq!(
negotiate_version(999),
Err(ProtocolError::VersionMismatch {
ours: Constant::PROTOCOL_VERSION,
theirs: 999,
})
);
}
#[tokio::test]
async fn recv_rejects_a_frame_larger_than_the_cap() {
let oversized = u32::try_from(Constant::MAX_FRAME_SIZE + 1).unwrap();
let framed = oversized.to_be_bytes();
let mut reader = framed.as_slice();
assert!(reader.recv_message().await.is_err());
}
}