use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum WsTextMessage {
Command(Command),
Response(CommandResponse),
Control(ControlMessage),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Command {
Socks { tunnel_id: u32, port: u16 },
ReverseTunnel {
tunnel_id: u32,
remote_port: u16,
local_target: String,
},
Ping { seq: u64 },
StopTunnel { tunnel_id: u32 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum CommandResponse {
Ok {
tunnel_id: Option<u32>,
message: Option<String>,
},
SocksReady {
tunnel_id: u32,
},
ReverseTunnelReady {
tunnel_id: u32,
},
Error {
tunnel_id: Option<u32>,
message: String,
},
Pong {
seq: u64,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ControlMessage {
ChannelOpen {
channel_id: u32,
tunnel_id: u32,
target: Option<String>,
},
ChannelReady { channel_id: u32 },
ChannelClose { channel_id: u32 },
}
pub fn frame_tunnel_data(channel_id: u32, payload: &[u8]) -> Vec<u8> {
let mut frame = Vec::with_capacity(4 + payload.len());
frame.extend_from_slice(&channel_id.to_be_bytes());
frame.extend_from_slice(payload);
frame
}
pub fn parse_tunnel_data(data: &[u8]) -> Option<(u32, &[u8])> {
if data.len() < 4 {
return None;
}
let channel_id = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
Some((channel_id, &data[4..]))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_command_serde_roundtrip() {
let msg = WsTextMessage::Command(Command::Socks { tunnel_id: 1, port: 1080 });
let json = serde_json::to_string(&msg).unwrap();
let parsed: WsTextMessage = serde_json::from_str(&json).unwrap();
match parsed {
WsTextMessage::Command(Command::Socks { tunnel_id, port }) => {
assert_eq!(tunnel_id, 1);
assert_eq!(port, 1080);
}
_ => panic!("unexpected variant"),
}
}
#[test]
fn test_response_serde_roundtrip() {
let msg = WsTextMessage::Response(CommandResponse::Ok {
tunnel_id: Some(1),
message: None,
});
let json = serde_json::to_string(&msg).unwrap();
let parsed: WsTextMessage = serde_json::from_str(&json).unwrap();
match parsed {
WsTextMessage::Response(CommandResponse::Ok { tunnel_id, .. }) => {
assert_eq!(tunnel_id, Some(1));
}
_ => panic!("unexpected variant"),
}
}
#[test]
fn test_control_serde_roundtrip() {
let msg = WsTextMessage::Control(ControlMessage::ChannelOpen {
channel_id: 3,
tunnel_id: 1,
target: Some("example.com:443".into()),
});
let json = serde_json::to_string(&msg).unwrap();
let parsed: WsTextMessage = serde_json::from_str(&json).unwrap();
match parsed {
WsTextMessage::Control(ControlMessage::ChannelOpen {
channel_id,
tunnel_id,
target,
}) => {
assert_eq!(channel_id, 3);
assert_eq!(tunnel_id, 1);
assert_eq!(target.as_deref(), Some("example.com:443"));
}
_ => panic!("unexpected variant"),
}
}
#[test]
fn test_frame_parse_roundtrip() {
let data = b"hello world";
let framed = frame_tunnel_data(42, data);
let (channel_id, payload) = parse_tunnel_data(&framed).unwrap();
assert_eq!(channel_id, 42);
assert_eq!(payload, data);
}
#[test]
fn test_parse_tunnel_data_too_short() {
assert!(parse_tunnel_data(&[0, 1, 2]).is_none());
assert!(parse_tunnel_data(&[]).is_none());
}
}