pub mod codec;
use std::sync::{Arc, Mutex as StdMutex};
use async_trait::async_trait;
use serde::Serialize;
use serde_json::Value;
use crate::auth::transports::{Transport, TransportCallback};
use crate::auth::types::AuthMessage;
use crate::primitives::PublicKey;
use crate::Result;
use codec::{EngineIoPacket, SocketIoPacket};
pub trait SocketIoSink: Send + Sync {
fn send_socketio(&self, pkt: &SocketIoPacket) -> std::result::Result<(), String>;
fn send_engineio(&self, pkt: &EngineIoPacket) -> std::result::Result<(), String> {
match pkt {
EngineIoPacket::Message(payload) => {
match SocketIoPacket::decode(payload) {
Ok(sio) => self.send_socketio(&sio),
Err(e) => Err(format!("send_engineio default: {e}")),
}
}
other => Err(format!(
"send_engineio default impl cannot send {other:?}; override SocketIoSink::send_engineio"
)),
}
}
}
#[async_trait]
pub trait SocketIoFrameSource: Send {
async fn recv_engineio(&mut self) -> std::result::Result<EngineIoPacket, String>;
}
#[derive(Clone)]
pub struct SocketIoTransport<S: SocketIoSink> {
sink: S,
callback: Arc<StdMutex<Option<Box<TransportCallback>>>>,
}
impl<S: SocketIoSink + std::fmt::Debug> std::fmt::Debug for SocketIoTransport<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SocketIoTransport")
.field("sink", &self.sink)
.finish_non_exhaustive()
}
}
impl<S: SocketIoSink + Clone> SocketIoTransport<S> {
pub fn new(sink: S) -> Self {
Self {
sink,
callback: Arc::new(StdMutex::new(None)),
}
}
pub fn callback_handle(&self) -> Arc<StdMutex<Option<Box<TransportCallback>>>> {
self.callback.clone()
}
pub fn sink(&self) -> S {
self.sink.clone()
}
}
#[async_trait]
impl<S: SocketIoSink + Clone + 'static> Transport for SocketIoTransport<S> {
async fn send(&self, message: &AuthMessage) -> Result<()> {
let json = serde_json::to_value(message).map_err(|e| {
crate::Error::AuthError(format!("SocketIoTransport::send: serialize: {e}"))
})?;
let pkt = SocketIoPacket::Event {
nsp: "/".to_string(),
ack_id: None,
data: vec![Value::String("authMessage".to_string()), json],
};
self.sink
.send_socketio(&pkt)
.map_err(|e| crate::Error::AuthError(format!("SocketIoTransport::send: ws: {e}")))
}
fn set_callback(&self, callback: Box<TransportCallback>) {
if let Ok(mut cb) = self.callback.lock() {
*cb = Some(callback);
}
}
fn clear_callback(&self) {
if let Ok(mut cb) = self.callback.lock() {
*cb = None;
}
}
}
pub async fn run_dispatch<F, S>(
mut frames: F,
sink: S,
callback: Arc<StdMutex<Option<Box<TransportCallback>>>>,
) where
F: SocketIoFrameSource,
S: SocketIoSink,
{
loop {
let frame = match frames.recv_engineio().await {
Ok(f) => f,
Err(_) => break, };
match frame {
EngineIoPacket::Ping(payload) => {
let _ = sink.send_engineio(&EngineIoPacket::Pong(payload));
}
EngineIoPacket::Message(payload) => {
let sio = match SocketIoPacket::decode(&payload) {
Ok(p) => p,
Err(_) => continue, };
if let SocketIoPacket::Event { data, .. } = sio {
if data.len() >= 2 && data[0].as_str() == Some("authMessage") {
let auth_msg: AuthMessage = match serde_json::from_value(data[1].clone()) {
Ok(m) => m,
Err(_) => continue,
};
let fut_opt = {
match callback.lock() {
Ok(guard) => guard.as_ref().map(|cb| cb(auth_msg)),
Err(_) => None, }
};
if let Some(fut) = fut_opt {
let _ = fut.await;
}
}
}
}
_ => { }
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct AppEvent {
pub sender: PublicKey,
pub event_name: String,
pub data: Value,
}
pub async fn install_app_event_listener<W, T>(
peer: &crate::auth::Peer<W, T>,
) -> (futures::channel::mpsc::UnboundedReceiver<AppEvent>, u32)
where
W: crate::wallet::WalletInterface + 'static,
T: Transport + 'static,
{
let (tx, rx) = futures::channel::mpsc::unbounded::<AppEvent>();
let id = peer
.listen_for_general_messages(move |sender, payload| {
let tx = tx.clone();
Box::pin(async move {
let (event_name, data) = parse_app_event_payload(&payload);
let _ = tx.unbounded_send(AppEvent {
sender,
event_name,
data,
});
Ok(())
})
})
.await;
(rx, id)
}
pub fn parse_app_event_payload(payload: &[u8]) -> (String, Value) {
match serde_json::from_slice::<Value>(payload) {
Ok(json) => {
let event_name = json
.get("eventName")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let data = json.get("data").cloned().unwrap_or(Value::Null);
(event_name, data)
}
Err(_) => (String::new(), Value::Null),
}
}
pub fn build_envelope_payload(event_name: &str, data: &Value) -> Vec<u8> {
#[derive(Serialize)]
struct Envelope<'a> {
#[serde(rename = "eventName")]
event_name: &'a str,
data: &'a Value,
}
let envelope = Envelope { event_name, data };
serde_json::to_vec(&envelope).unwrap_or_default()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::types::MessageType;
use crate::primitives::PrivateKey;
use serde_json::json;
#[derive(Clone, Default)]
struct CapturingSink {
last: Arc<StdMutex<Option<String>>>,
}
impl SocketIoSink for CapturingSink {
fn send_socketio(&self, pkt: &SocketIoPacket) -> std::result::Result<(), String> {
let frame = EngineIoPacket::Message(pkt.encode()).encode();
*self.last.lock().unwrap() = Some(frame);
Ok(())
}
}
#[tokio::test]
async fn send_emits_authmessage_event_on_default_namespace() {
let sink = CapturingSink::default();
let transport = SocketIoTransport::new(sink.clone());
let key = PrivateKey::from_hex(
"0000000000000000000000000000000000000000000000000000000000000001",
)
.unwrap()
.public_key();
let msg = AuthMessage::new(MessageType::InitialRequest, key.clone());
transport.send(&msg).await.unwrap();
let frame = sink.last.lock().unwrap().clone().unwrap();
assert!(
frame.starts_with("42[\"authMessage\","),
"unexpected frame prefix: {frame}"
);
let eio = EngineIoPacket::decode(&frame).unwrap();
let payload = match eio {
EngineIoPacket::Message(p) => p,
other => panic!("expected Message, got {other:?}"),
};
let sio = SocketIoPacket::decode(&payload).unwrap();
match sio {
SocketIoPacket::Event { nsp, ack_id, data } => {
assert_eq!(nsp, "/");
assert_eq!(ack_id, None);
assert_eq!(data[0], json!("authMessage"));
let decoded: AuthMessage = serde_json::from_value(data[1].clone()).unwrap();
assert_eq!(decoded.message_type, MessageType::InitialRequest);
assert_eq!(decoded.identity_key.to_hex(), key.to_hex());
}
other => panic!("expected Event, got {other:?}"),
}
}
#[tokio::test]
async fn send_authmessage_event_array_head_is_authmessage_literal() {
let sink = CapturingSink::default();
let transport = SocketIoTransport::new(sink.clone());
let key = PrivateKey::random().public_key();
let msg = AuthMessage::new(MessageType::General, key);
transport.send(&msg).await.unwrap();
let frame = sink.last.lock().unwrap().clone().unwrap();
let payload = match EngineIoPacket::decode(&frame).unwrap() {
EngineIoPacket::Message(p) => p,
other => panic!("expected Message, got {other:?}"),
};
match SocketIoPacket::decode(&payload).unwrap() {
SocketIoPacket::Event { data, .. } => {
assert_eq!(data[0].as_str(), Some("authMessage"));
}
other => panic!("expected Event, got {other:?}"),
}
}
#[test]
fn parse_app_event_decodes_joinroom_envelope() {
let payload = br#"{"eventName":"joinRoom","data":"02abc...xyz-payment_inbox"}"#;
let (event_name, data) = parse_app_event_payload(payload);
assert_eq!(event_name, "joinRoom");
assert_eq!(data, json!("02abc...xyz-payment_inbox"));
}
#[test]
fn parse_app_event_decodes_sendmessage_envelope() {
let payload = br#"{"eventName":"sendMessage","data":{"roomId":"02abc-test","message":{"messageId":"h34","body":"hello"}}}"#;
let (event_name, data) = parse_app_event_payload(payload);
assert_eq!(event_name, "sendMessage");
assert_eq!(
data,
json!({"roomId":"02abc-test","message":{"messageId":"h34","body":"hello"}})
);
}
#[test]
fn parse_app_event_decodes_sendmessageack_with_room_suffix() {
let payload = br#"{"eventName":"sendMessageAck-02abc-h34-test","data":{"status":"success","messageId":"h34"}}"#;
let (event_name, data) = parse_app_event_payload(payload);
assert_eq!(event_name, "sendMessageAck-02abc-h34-test");
assert_eq!(data["status"], json!("success"));
assert_eq!(data["messageId"], json!("h34"));
}
#[test]
fn parse_app_event_handles_empty_data() {
let payload = br#"{"eventName":"authenticated","data":{}}"#;
let (event_name, data) = parse_app_event_payload(payload);
assert_eq!(event_name, "authenticated");
assert_eq!(data, json!({}));
}
#[test]
fn parse_app_event_returns_empty_on_malformed_json() {
let payload = b"this is not json";
let (event_name, data) = parse_app_event_payload(payload);
assert_eq!(event_name, "");
assert_eq!(data, Value::Null);
}
#[test]
fn parse_app_event_returns_empty_on_missing_fields() {
let payload = br#"{"foo":"bar"}"#;
let (event_name, data) = parse_app_event_payload(payload);
assert_eq!(event_name, "");
assert_eq!(data, Value::Null);
}
#[test]
fn parse_app_event_handles_event_name_only() {
let payload = br#"{"eventName":"someEvent"}"#;
let (event_name, data) = parse_app_event_payload(payload);
assert_eq!(event_name, "someEvent");
assert_eq!(data, Value::Null);
}
#[test]
fn parse_app_event_byte_exact_against_ts_emit_vector() {
let canonical_ts_bytes: &[u8] = b"{\"eventName\":\"sendMessage\",\"data\":{\"roomId\":\"abc-test\",\"message\":{\"messageId\":\"v1\",\"body\":\"hi\"}}}";
let (event_name, data) = parse_app_event_payload(canonical_ts_bytes);
assert_eq!(event_name, "sendMessage");
assert_eq!(data["roomId"], json!("abc-test"));
assert_eq!(data["message"]["messageId"], json!("v1"));
assert_eq!(data["message"]["body"], json!("hi"));
}
#[test]
fn build_envelope_payload_joinroom_byte_exact() {
let bytes = build_envelope_payload("joinRoom", &json!("02abc-test_inbox"));
assert_eq!(
bytes.as_slice(),
b"{\"eventName\":\"joinRoom\",\"data\":\"02abc-test_inbox\"}".as_slice(),
);
}
#[test]
fn build_envelope_payload_sendmessage_byte_exact() {
let data = json!({"roomId": "abc-test", "message": {"messageId": "v1", "body": "hi"}});
let bytes = build_envelope_payload("sendMessage", &data);
let mut expected = b"{\"eventName\":\"sendMessage\",\"data\":".to_vec();
expected.extend_from_slice(&serde_json::to_vec(&data).unwrap());
expected.push(b'}');
assert_eq!(bytes, expected);
assert!(bytes.starts_with(b"{\"eventName\":\"sendMessage\",\"data\":"));
}
#[test]
fn build_envelope_payload_leaveroom_byte_exact() {
let bytes = build_envelope_payload("leaveRoom", &json!("02abc-test_inbox"));
assert_eq!(
bytes.as_slice(),
b"{\"eventName\":\"leaveRoom\",\"data\":\"02abc-test_inbox\"}".as_slice(),
);
}
#[test]
fn build_envelope_payload_empty_data_object() {
let bytes = build_envelope_payload("authenticated", &json!({}));
assert_eq!(
bytes.as_slice(),
b"{\"eventName\":\"authenticated\",\"data\":{}}".as_slice(),
);
}
#[test]
fn build_envelope_payload_round_trips_through_parser() {
let cases: Vec<(&str, Value)> = vec![
("joinRoom", json!("02abc-room")),
(
"sendMessage",
json!({"roomId": "02abc-room", "message": {"messageId": "m1", "body": "hi"}}),
),
("leaveRoom", json!("02abc-room")),
("authenticated", json!({})),
("sendMessageAck-02abc-room", json!({"status": "success"})),
];
for (name, data) in cases {
let bytes = build_envelope_payload(name, &data);
let (decoded_name, decoded_data) = parse_app_event_payload(&bytes);
assert_eq!(decoded_name, name, "event_name round-trip for {name}");
assert_eq!(decoded_data, data, "data round-trip for {name}");
}
}
#[tokio::test]
async fn dispatch_routes_authmessage_event_to_callback() {
use std::sync::atomic::{AtomicUsize, Ordering};
struct OneShotFrames {
frame: Option<EngineIoPacket>,
}
#[async_trait]
impl SocketIoFrameSource for OneShotFrames {
async fn recv_engineio(&mut self) -> std::result::Result<EngineIoPacket, String> {
self.frame.take().ok_or_else(|| "closed".to_string())
}
}
let key = PrivateKey::random().public_key();
let msg = AuthMessage::new(MessageType::General, key);
let json = serde_json::to_value(&msg).unwrap();
let sio = SocketIoPacket::Event {
nsp: "/".into(),
ack_id: None,
data: vec![json!("authMessage"), json],
};
let frame = EngineIoPacket::Message(sio.encode());
let count = Arc::new(AtomicUsize::new(0));
let count_cb = count.clone();
let callback: Arc<StdMutex<Option<Box<TransportCallback>>>> =
Arc::new(StdMutex::new(Some(Box::new(move |_m: AuthMessage| {
let count_cb = count_cb.clone();
Box::pin(async move {
count_cb.fetch_add(1, Ordering::SeqCst);
Ok(())
})
as std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
}))));
let sink = CapturingSink::default();
run_dispatch(OneShotFrames { frame: Some(frame) }, sink, callback).await;
assert_eq!(count.load(Ordering::SeqCst), 1, "callback should fire once");
}
#[tokio::test]
async fn dispatch_replies_pong_to_ping() {
struct PingThenClose {
frame: Option<EngineIoPacket>,
}
#[async_trait]
impl SocketIoFrameSource for PingThenClose {
async fn recv_engineio(&mut self) -> std::result::Result<EngineIoPacket, String> {
self.frame.take().ok_or_else(|| "closed".to_string())
}
}
#[derive(Clone, Default)]
struct PongSink {
sent: Arc<StdMutex<Vec<String>>>,
}
impl SocketIoSink for PongSink {
fn send_socketio(&self, pkt: &SocketIoPacket) -> std::result::Result<(), String> {
self.sent
.lock()
.unwrap()
.push(EngineIoPacket::Message(pkt.encode()).encode());
Ok(())
}
fn send_engineio(&self, pkt: &EngineIoPacket) -> std::result::Result<(), String> {
self.sent.lock().unwrap().push(pkt.encode());
Ok(())
}
}
let sink = PongSink::default();
let callback: Arc<StdMutex<Option<Box<TransportCallback>>>> = Arc::new(StdMutex::new(None));
run_dispatch(
PingThenClose {
frame: Some(EngineIoPacket::Ping(String::new())),
},
sink.clone(),
callback,
)
.await;
let sent = sink.sent.lock().unwrap();
assert_eq!(sent.len(), 1, "exactly one pong should be sent");
assert_eq!(sent[0], "3", "pong frame for a bare ping is `3`");
}
}