use std::io::{self, Read, Write};
pub const C_EVENT: u8 = 0x01;
pub const C_DETACH: u8 = 0x02;
pub const C_RESIZE: u8 = 0x03;
pub const C_KILL: u8 = 0x04;
pub const C_PING: u8 = 0x05;
pub const C_ATTACH: u8 = 0x06;
pub const S_OUTPUT: u8 = 0x81;
pub const S_DETACHED: u8 = 0x82;
pub const S_EXIT: u8 = 0x83;
pub const S_PONG: u8 = 0x84;
pub const S_VERSION: u8 = 0x10;
pub const C_HELLO: u8 = 0x11;
pub const S_INCOMPAT: u8 = 0x12;
pub const PROTO_MAJOR: u16 = 1;
pub const PROTO_MINOR: u16 = 0;
const MAX_PAYLOAD: usize = 16 * 1024 * 1024;
pub fn write_msg(w: &mut impl Write, tag: u8, payload: &[u8]) -> io::Result<()> {
if payload.len() > MAX_PAYLOAD {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("payload too large: {} bytes", payload.len()),
));
}
let len = (payload.len() as u32).to_be_bytes();
w.write_all(&[tag])?;
w.write_all(&len)?;
if !payload.is_empty() {
w.write_all(payload)?;
}
w.flush()
}
pub fn read_msg(r: &mut impl Read) -> io::Result<(u8, Vec<u8>)> {
let mut tag = [0u8; 1];
r.read_exact(&mut tag)?;
let mut len_buf = [0u8; 4];
r.read_exact(&mut len_buf)?;
let len = u32::from_be_bytes(len_buf) as usize;
if len > MAX_PAYLOAD {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("message too large: {} bytes", len),
));
}
let mut payload = vec![0u8; len];
if len > 0 {
r.read_exact(&mut payload)?;
}
Ok((tag[0], payload))
}
pub fn encode_resize(cols: u16, rows: u16) -> [u8; 4] {
let c = cols.to_be_bytes();
let r = rows.to_be_bytes();
[c[0], c[1], r[0], r[1]]
}
pub fn decode_resize(payload: &[u8]) -> Option<(u16, u16)> {
if payload.len() < 4 {
return None;
}
let cols = u16::from_be_bytes([payload[0], payload[1]]);
let rows = u16::from_be_bytes([payload[2], payload[3]]);
Some((cols, rows))
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AttachMode {
#[default]
Steal,
Shared,
Readonly,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct AttachRequest {
pub cols: u16,
pub rows: u16,
pub mode: AttachMode,
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ServerHello {
pub proto_major: u16,
pub proto_minor: u16,
pub build: String,
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ClientHello {
pub proto_major: u16,
pub proto_minor: u16,
pub client_build: String,
pub supported_features: Vec<String>,
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct IncompatNotice {
pub server_proto: String,
pub client_proto: String,
pub message: String,
}
pub const CLIENT_FEATURES: &[&str] = &["scrollback-v3", "kitty-kbd-stack", "osc-52-confirm"];
pub fn build_string(rev: Option<&str>) -> String {
match rev {
Some(sha) if !sha.is_empty() => {
format!("ezpn {} (rev {})", env!("CARGO_PKG_VERSION"), sha)
}
_ => format!("ezpn {} (rev unknown)", env!("CARGO_PKG_VERSION")),
}
}
#[allow(dead_code)] pub fn server_hello() -> Vec<u8> {
let hello = ServerHello {
proto_major: PROTO_MAJOR,
proto_minor: PROTO_MINOR,
build: build_string(None),
};
let json = serde_json::to_vec(&hello).expect("ServerHello serialization is infallible");
let mut buf = Vec::with_capacity(5 + json.len());
write_msg(&mut buf, S_VERSION, &json).expect("writing to Vec<u8> never fails");
buf
}
#[allow(dead_code)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum FirstByteKind {
Tag,
LegacyJson,
Unknown,
}
#[allow(dead_code)] pub fn classify_first_byte(b: u8) -> FirstByteKind {
match b {
0x00..=0x20 => FirstByteKind::Tag,
b'{' | b'[' => FirstByteKind::LegacyJson,
_ => FirstByteKind::Unknown,
}
}
#[allow(dead_code)] pub fn incompat_for_major_mismatch(client: &ClientHello, session_name: &str) -> Vec<u8> {
let server_proto = format!("{}.{}", PROTO_MAJOR, PROTO_MINOR);
let client_proto = format!("{}.{}", client.proto_major, client.proto_minor);
let message = format!(
"client v{} cannot attach to server v{} \u{2014} restart the daemon with 'ezpn kill {}' to upgrade.",
client_proto, server_proto, session_name
);
let notice = IncompatNotice {
server_proto,
client_proto,
message,
};
encode_incompat(¬ice)
}
#[allow(dead_code)] pub fn incompat_for_legacy_client(session_name: &str) -> Vec<u8> {
let server_proto = format!("{}.{}", PROTO_MAJOR, PROTO_MINOR);
let message = format!(
"legacy client detected (no version handshake) \u{2014} restart the daemon with 'ezpn kill {}' to upgrade.",
session_name
);
let notice = IncompatNotice {
server_proto,
client_proto: "unknown".to_string(),
message,
};
encode_incompat(¬ice)
}
#[allow(dead_code)] fn encode_incompat(notice: &IncompatNotice) -> Vec<u8> {
let json = serde_json::to_vec(notice).expect("IncompatNotice serialization is infallible");
let mut buf = Vec::with_capacity(5 + json.len());
write_msg(&mut buf, S_INCOMPAT, &json).expect("writing to Vec<u8> never fails");
buf
}
#[derive(Debug)]
pub enum HandshakeOutcome {
Ok(ServerHello),
Incompat(IncompatNotice),
}
pub fn client_handshake<R: Read, W: Write>(
reader: &mut R,
writer: &mut W,
) -> io::Result<HandshakeOutcome> {
let (tag, payload) = read_msg(reader)?;
match tag {
S_VERSION => {
let server: ServerHello = serde_json::from_slice(&payload).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("malformed S_VERSION payload: {}", e),
)
})?;
if server.proto_major != PROTO_MAJOR {
return Ok(HandshakeOutcome::Incompat(IncompatNotice {
server_proto: format!("{}.{}", server.proto_major, server.proto_minor),
client_proto: format!("{}.{}", PROTO_MAJOR, PROTO_MINOR),
message: format!(
"client v{}.{} cannot attach to server v{}.{} \u{2014} reinstall the matching ezpn binary or restart the daemon.",
PROTO_MAJOR, PROTO_MINOR, server.proto_major, server.proto_minor
),
}));
}
let hello = ClientHello {
proto_major: PROTO_MAJOR,
proto_minor: PROTO_MINOR,
client_build: build_string(None),
supported_features: CLIENT_FEATURES.iter().map(|s| s.to_string()).collect(),
};
let json = serde_json::to_vec(&hello).map_err(io::Error::other)?;
write_msg(writer, C_HELLO, &json)?;
Ok(HandshakeOutcome::Ok(server))
}
S_INCOMPAT => {
let notice: IncompatNotice = serde_json::from_slice(&payload).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("malformed S_INCOMPAT payload: {}", e),
)
})?;
Ok(HandshakeOutcome::Incompat(notice))
}
other => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"expected S_VERSION (0x10) or S_INCOMPAT (0x12) as first frame, got 0x{:02x}",
other
),
)),
}
}
#[allow(dead_code)] pub fn parse_client_hello(payload: &[u8]) -> io::Result<ClientHello> {
serde_json::from_slice(payload).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("malformed C_HELLO payload: {}", e),
)
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn attach_request_round_trip() {
let req = AttachRequest {
cols: 120,
rows: 40,
mode: AttachMode::Shared,
};
let json = serde_json::to_vec(&req).unwrap();
let decoded: AttachRequest = serde_json::from_slice(&json).unwrap();
assert_eq!(decoded.cols, 120);
assert_eq!(decoded.rows, 40);
assert_eq!(decoded.mode, AttachMode::Shared);
}
#[test]
fn attach_mode_default_is_steal() {
assert_eq!(AttachMode::default(), AttachMode::Steal);
}
#[test]
fn resize_encode_decode() {
let encoded = encode_resize(200, 50);
let (cols, rows) = decode_resize(&encoded).unwrap();
assert_eq!(cols, 200);
assert_eq!(rows, 50);
}
#[test]
fn framed_message_round_trip() {
let mut buf: Vec<u8> = Vec::new();
write_msg(&mut buf, C_EVENT, b"hello").unwrap();
let (tag, payload) = read_msg(&mut buf.as_slice()).unwrap();
assert_eq!(tag, C_EVENT);
assert_eq!(payload, b"hello");
}
#[test]
fn s_version_frame_round_trip() {
let buf = server_hello();
assert_eq!(buf[0], S_VERSION);
let (tag, payload) = read_msg(&mut buf.as_slice()).unwrap();
assert_eq!(tag, S_VERSION);
let parsed: ServerHello = serde_json::from_slice(&payload).unwrap();
assert_eq!(parsed.proto_major, PROTO_MAJOR);
assert_eq!(parsed.proto_minor, PROTO_MINOR);
assert!(parsed.build.starts_with("ezpn "));
}
#[test]
fn server_hello_tag_is_in_reserved_range() {
assert!((0x10..=0x1F).contains(&S_VERSION));
assert!((0x10..=0x1F).contains(&C_HELLO));
assert!((0x10..=0x1F).contains(&S_INCOMPAT));
}
#[test]
fn c_hello_payload_round_trip() {
let hello = ClientHello {
proto_major: 1,
proto_minor: 0,
client_build: "ezpn 0.12.0 (rev test)".into(),
supported_features: vec![
"scrollback-v3".into(),
"kitty-kbd-stack".into(),
"osc-52-confirm".into(),
],
};
let json = serde_json::to_vec(&hello).unwrap();
let parsed = parse_client_hello(&json).unwrap();
assert_eq!(parsed, hello);
}
#[test]
fn c_hello_tolerates_unknown_additive_fields() {
let json = br#"{
"proto_major": 1,
"proto_minor": 7,
"client_build": "ezpn 99.0.0 (rev future)",
"supported_features": ["scrollback-v3"],
"future_field": {"nested": true}
}"#;
let parsed = parse_client_hello(json).unwrap();
assert_eq!(parsed.proto_minor, 7);
assert_eq!(parsed.supported_features, vec!["scrollback-v3".to_string()]);
}
#[test]
fn major_mismatch_emits_incompat() {
let client = ClientHello {
proto_major: PROTO_MAJOR + 1,
proto_minor: 0,
client_build: "ezpn 1.0.0".into(),
supported_features: vec![],
};
let frame = incompat_for_major_mismatch(&client, "myproj");
assert_eq!(frame[0], S_INCOMPAT);
let (tag, payload) = read_msg(&mut frame.as_slice()).unwrap();
assert_eq!(tag, S_INCOMPAT);
let notice: IncompatNotice = serde_json::from_slice(&payload).unwrap();
assert_eq!(
notice.server_proto,
format!("{}.{}", PROTO_MAJOR, PROTO_MINOR)
);
assert_eq!(notice.client_proto, format!("{}.0", PROTO_MAJOR + 1));
assert!(notice.message.contains("ezpn kill myproj"));
assert!(notice.message.contains("cannot attach"));
}
#[test]
fn minor_mismatch_is_tolerated_by_client_handshake() {
let server = ServerHello {
proto_major: PROTO_MAJOR,
proto_minor: PROTO_MINOR + 1,
build: format!("ezpn {} (rev future)", env!("CARGO_PKG_VERSION")),
};
let mut server_to_client: Vec<u8> = Vec::new();
let json = serde_json::to_vec(&server).unwrap();
write_msg(&mut server_to_client, S_VERSION, &json).unwrap();
let mut client_to_server: Vec<u8> = Vec::new();
let outcome =
client_handshake(&mut server_to_client.as_slice(), &mut client_to_server).unwrap();
match outcome {
HandshakeOutcome::Ok(parsed) => {
assert_eq!(parsed.proto_minor, PROTO_MINOR + 1);
}
HandshakeOutcome::Incompat(n) => {
panic!("minor mismatch must NOT be reported as incompat: {:?}", n);
}
}
let (tag, payload) = read_msg(&mut client_to_server.as_slice()).unwrap();
assert_eq!(tag, C_HELLO);
let hello = parse_client_hello(&payload).unwrap();
assert_eq!(hello.proto_major, PROTO_MAJOR);
assert_eq!(hello.proto_minor, PROTO_MINOR);
assert!(hello
.supported_features
.contains(&"scrollback-v3".to_string()));
}
#[test]
fn major_mismatch_short_circuits_client_handshake() {
let server = ServerHello {
proto_major: PROTO_MAJOR + 1,
proto_minor: 0,
build: "ezpn future".into(),
};
let mut server_to_client: Vec<u8> = Vec::new();
let json = serde_json::to_vec(&server).unwrap();
write_msg(&mut server_to_client, S_VERSION, &json).unwrap();
let mut client_to_server: Vec<u8> = Vec::new();
let outcome =
client_handshake(&mut server_to_client.as_slice(), &mut client_to_server).unwrap();
assert!(
client_to_server.is_empty(),
"client must not send C_HELLO on major mismatch"
);
match outcome {
HandshakeOutcome::Incompat(notice) => {
assert_eq!(notice.server_proto, format!("{}.0", PROTO_MAJOR + 1));
assert_eq!(
notice.client_proto,
format!("{}.{}", PROTO_MAJOR, PROTO_MINOR)
);
}
HandshakeOutcome::Ok(_) => panic!("expected Incompat for major mismatch"),
}
}
#[test]
fn server_pushed_incompat_is_surfaced() {
let frame = incompat_for_legacy_client("demo");
let mut client_to_server: Vec<u8> = Vec::new();
let outcome = client_handshake(&mut frame.as_slice(), &mut client_to_server).unwrap();
match outcome {
HandshakeOutcome::Incompat(notice) => {
assert!(notice.message.contains("legacy client detected"));
assert!(notice.message.contains("ezpn kill demo"));
}
HandshakeOutcome::Ok(_) => panic!("expected Incompat from server-pushed S_INCOMPAT"),
}
}
#[test]
fn legacy_first_byte_classification() {
assert_eq!(classify_first_byte(C_EVENT), FirstByteKind::Tag);
assert_eq!(classify_first_byte(C_RESIZE), FirstByteKind::Tag);
assert_eq!(classify_first_byte(C_ATTACH), FirstByteKind::Tag);
assert_eq!(classify_first_byte(S_VERSION), FirstByteKind::Tag);
assert_eq!(classify_first_byte(C_HELLO), FirstByteKind::Tag);
assert_eq!(classify_first_byte(b'{'), FirstByteKind::LegacyJson);
assert_eq!(classify_first_byte(b'['), FirstByteKind::LegacyJson);
assert_eq!(classify_first_byte(b'A'), FirstByteKind::Unknown);
assert_eq!(classify_first_byte(0xFF), FirstByteKind::Unknown);
}
#[test]
fn build_string_with_and_without_rev() {
let with_rev = build_string(Some("abc1234"));
assert!(with_rev.contains("rev abc1234"));
assert!(with_rev.contains(env!("CARGO_PKG_VERSION")));
let no_rev = build_string(None);
assert!(no_rev.contains("rev unknown"));
let empty_rev = build_string(Some(""));
assert!(empty_rev.contains("rev unknown"));
}
#[test]
fn unknown_first_frame_is_an_error() {
let mut bad: Vec<u8> = Vec::new();
write_msg(&mut bad, S_OUTPUT, b"raw bytes").unwrap();
let mut sink: Vec<u8> = Vec::new();
let err = client_handshake(&mut bad.as_slice(), &mut sink).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
}