use std::sync::Arc;
use async_trait::async_trait;
use rust_socketio::asynchronous::Client;
use rust_socketio::Payload;
use tokio::sync::{mpsc, Mutex};
use bsv::auth::error::AuthError;
use bsv::auth::transports::Transport;
use bsv::auth::types::AuthMessage;
#[allow(dead_code)]
pub(crate) fn encode_ws_event(event_name: &str, data: serde_json::Value) -> Vec<u8> {
serde_json::to_vec(&serde_json::json!({
"eventName": event_name,
"data": data
}))
.unwrap_or_default()
}
#[allow(dead_code)]
pub(crate) fn decode_ws_event(payload: &[u8]) -> Option<(String, serde_json::Value)> {
let v: serde_json::Value = serde_json::from_slice(payload).ok()?;
let event_name = v.get("eventName")?.as_str()?.to_string();
let data = v.get("data").cloned().unwrap_or(serde_json::Value::Null);
Some((event_name, data))
}
#[allow(dead_code)]
pub(crate) fn parse_auth_message_from_payload(payload: &Payload) -> Option<AuthMessage> {
match payload {
Payload::Text(values) => {
let first = values.first()?;
serde_json::from_value::<AuthMessage>(first.clone()).ok()
}
_ => None,
}
}
pub struct SocketIOTransport {
client: Client,
incoming_rx: Arc<Mutex<Option<mpsc::Receiver<AuthMessage>>>>,
}
impl SocketIOTransport {
pub fn new(client: Client, incoming_rx: mpsc::Receiver<AuthMessage>) -> Self {
Self {
client,
incoming_rx: Arc::new(Mutex::new(Some(incoming_rx))),
}
}
}
#[async_trait]
impl Transport for SocketIOTransport {
async fn send(&self, message: AuthMessage) -> Result<(), AuthError> {
let json = serde_json::to_value(&message)
.map_err(|e| AuthError::SerializationError(e.to_string()))?;
self.client
.emit("authMessage", json)
.await
.map_err(|e| AuthError::TransportError(e.to_string()))
}
fn subscribe(&self) -> mpsc::Receiver<AuthMessage> {
self.incoming_rx
.try_lock()
.expect("subscribe() mutex should not be contended")
.take()
.expect("subscribe() can only be called once per SocketIOTransport")
}
}
#[cfg(test)]
mod tests {
use super::*;
use bsv::auth::types::MessageType;
#[test]
fn encode_decode_round_trip() {
let data = serde_json::json!({"identityKey": "03abc"});
let encoded = encode_ws_event("authenticated", data.clone());
let (name, decoded_data) = decode_ws_event(&encoded).expect("round-trip should succeed");
assert_eq!(name, "authenticated");
assert_eq!(decoded_data, data);
}
#[test]
fn encode_produces_correct_json_structure() {
let encoded = encode_ws_event("myEvent", serde_json::json!(42));
let v: serde_json::Value = serde_json::from_slice(&encoded).unwrap();
assert_eq!(v["eventName"], "myEvent");
assert_eq!(v["data"], 42);
}
#[test]
fn decode_invalid_utf8_returns_none() {
let bad = vec![0xFF, 0xFE, 0x00];
assert!(decode_ws_event(&bad).is_none());
}
#[test]
fn decode_valid_json_missing_event_name_returns_none() {
let json = br#"{"data": {"foo": "bar"}}"#;
assert!(decode_ws_event(json).is_none());
}
#[test]
fn decode_empty_bytes_returns_none() {
assert!(decode_ws_event(&[]).is_none());
}
#[test]
fn encode_decode_preserves_nested_data() {
let data = serde_json::json!({"nested": {"a": 1, "b": [1, 2, 3]}});
let encoded = encode_ws_event("complexEvent", data.clone());
let (name, decoded) = decode_ws_event(&encoded).unwrap();
assert_eq!(name, "complexEvent");
assert_eq!(decoded, data);
}
fn make_valid_auth_message_value() -> serde_json::Value {
serde_json::json!({
"version": "0.1",
"messageType": "initialRequest",
"identityKey": "03abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890ab"
})
}
#[test]
fn parse_auth_message_from_text_payload_valid() {
let payload = Payload::Text(vec![make_valid_auth_message_value()]);
let result = parse_auth_message_from_payload(&payload);
assert!(result.is_some(), "should parse valid AuthMessage from Text payload");
let msg = result.unwrap();
assert_eq!(msg.message_type, MessageType::InitialRequest);
assert_eq!(
msg.identity_key,
"03abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890ab"
);
}
#[test]
fn parse_auth_message_from_text_payload_non_auth_json_returns_none() {
let payload = Payload::Text(vec![serde_json::json!({"foo": "bar"})]);
let result = parse_auth_message_from_payload(&payload);
assert!(result.is_none(), "non-auth JSON should return None");
}
#[test]
fn parse_auth_message_from_empty_text_returns_none() {
let payload = Payload::Text(vec![]);
let result = parse_auth_message_from_payload(&payload);
assert!(result.is_none(), "empty Text payload should return None");
}
#[test]
fn parse_auth_message_from_binary_returns_none() {
let payload = Payload::from(b"some bytes".to_vec());
let result = parse_auth_message_from_payload(&payload);
assert!(result.is_none(), "Binary payload should return None");
}
#[test]
#[should_panic(expected = "subscribe() can only be called once per SocketIOTransport")]
fn subscribe_panics_on_second_call() {
let (_tx, rx) = mpsc::channel::<AuthMessage>(1);
let incoming_rx = Arc::new(Mutex::new(Some(rx)));
{
let _receiver = incoming_rx.blocking_lock().take().expect("subscribe() can only be called once per SocketIOTransport");
}
let _receiver = incoming_rx.blocking_lock().take().expect("subscribe() can only be called once per SocketIOTransport");
}
#[test]
fn subscribe_first_call_returns_receiver() {
let (_tx, rx) = mpsc::channel::<AuthMessage>(1);
let incoming_rx = Arc::new(Mutex::new(Some(rx)));
let result = incoming_rx.blocking_lock().take();
assert!(result.is_some(), "first subscribe should return a receiver");
let result2 = incoming_rx.blocking_lock().take();
assert!(result2.is_none(), "second take should return None");
}
}