use crate::transport::{
ConnType, DataLane, NetworkError, NetworkResult, WebSocketDataLane, WireHandle, WsSink,
};
use actr_protocol::PayloadType;
use async_trait::async_trait;
use futures_util::stream::SplitStream;
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::{Mutex, RwLock, mpsc};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::tungstenite::handshake::client::generate_key;
use tokio_tungstenite::tungstenite::http::Request as WsRequest;
use tokio_tungstenite::tungstenite::http::Uri as WsUri;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
type LaneCache<const N: usize> = Arc<RwLock<[Option<Arc<dyn DataLane>>; N]>>;
#[derive(Debug, Clone)]
struct TransportMessage {
payload_type: PayloadType,
data: Vec<u8>,
}
impl TransportMessage {
fn decode(data: &[u8]) -> NetworkResult<Self> {
if data.len() < 5 {
return Err(NetworkError::DeserializationError(
"WebSocket message too short".to_string(),
));
}
let payload_type_raw = data[0];
let payload_type = match payload_type_raw {
0 => PayloadType::RpcReliable,
1 => PayloadType::RpcSignal,
2 => PayloadType::StreamReliable,
3 => PayloadType::StreamLatencyFirst,
4 => PayloadType::MediaRtp,
_ => {
return Err(NetworkError::DeserializationError(format!(
"Invalid payload_type: {payload_type_raw}"
)));
}
};
let len = u32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
if data.len() < 5 + len {
return Err(NetworkError::DeserializationError(
"WebSocket message data incomplete".to_string(),
));
}
let msg_data = data[5..5 + len].to_vec();
Ok(Self {
payload_type,
data: msg_data,
})
}
}
#[derive(Clone, Debug)]
pub(crate) struct WebSocketConnection {
url: String,
local_id_hex: Option<String>,
credential_b64: Option<String>,
sink: WsSink,
router: Arc<RwLock<[Option<mpsc::Sender<bytes::Bytes>>; 5]>>,
lane_cache: LaneCache<5>,
connected: Arc<RwLock<bool>>,
}
impl WebSocketConnection {
pub fn new(url: String) -> Self {
Self {
url: url.clone(),
local_id_hex: None,
credential_b64: None,
sink: Arc::new(Mutex::new(None)), router: Arc::new(RwLock::new([None, None, None, None, None])),
lane_cache: Arc::new(RwLock::new([None, None, None, None, None])),
connected: Arc::new(RwLock::new(false)),
}
}
pub fn with_local_id(mut self, id_hex: String) -> Self {
self.local_id_hex = Some(id_hex);
self
}
pub fn with_credential_b64(mut self, credential_b64: String) -> Self {
self.credential_b64 = Some(credential_b64);
self
}
pub fn from_server_stream(ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
let (sink, stream) = ws_stream.split();
let router: Arc<RwLock<[Option<mpsc::Sender<bytes::Bytes>>; 5]>> =
Arc::new(RwLock::new([None, None, None, None, None]));
let connected = Arc::new(RwLock::new(true));
Self::spawn_dispatcher(stream, router.clone(), connected.clone());
tracing::info!("✅ WebSocketConnection created from server stream (already connected)");
Self {
url: String::from("<inbound>"),
local_id_hex: None,
credential_b64: None,
sink: Arc::new(Mutex::new(Some(sink))),
router,
lane_cache: Arc::new(RwLock::new([None, None, None, None, None])),
connected,
}
}
pub async fn connect(&self) -> NetworkResult<()> {
let (ws_stream, _) = if let Some(ref hex_id) = self.local_id_hex {
let uri: WsUri = self
.url
.parse()
.map_err(|e| NetworkError::ConnectionError(format!("Invalid WS URI: {e}")))?;
let host = uri
.host()
.ok_or_else(|| NetworkError::ConnectionError("WS URL missing host".to_string()))?;
let host_header = match uri.port_u16() {
Some(port) => format!("{host}:{port}"),
None => host.to_string(),
};
let mut builder = WsRequest::builder()
.uri(self.url.as_str())
.header("Host", host_header)
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", generate_key())
.header("X-Actr-Source-ID", hex_id);
if let Some(ref cred_b64) = self.credential_b64 {
builder = builder.header("X-Actr-Credential", cred_b64.as_str());
}
let request = builder.body(()).map_err(|e| {
NetworkError::ConnectionError(format!("WS request build failed: {e}"))
})?;
connect_async(request).await?
} else {
connect_async(&self.url).await?
};
let (sink, stream) = ws_stream.split();
*self.sink.lock().await = Some(sink);
*self.connected.write().await = true;
let router = self.router.clone();
let connected = self.connected.clone();
Self::spawn_dispatcher(stream, router, connected);
tracing::info!("✅ WebSocketConnection already Connect: {}", self.url);
Ok(())
}
fn spawn_dispatcher(
mut stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
router: Arc<RwLock<[Option<mpsc::Sender<bytes::Bytes>>; 5]>>,
connected: Arc<RwLock<bool>>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
tracing::debug!("📡 WebSocket dispatcher Start");
while let Some(msg_result) = stream.next().await {
match msg_result {
Ok(WsMessage::Binary(data)) => {
match TransportMessage::decode(&data) {
Ok(transport_msg) => {
let idx = transport_msg.payload_type as usize;
let router_guard = router.read().await;
if let Some(tx) = &router_guard[idx] {
let data = bytes::Bytes::from(transport_msg.data);
if let Err(e) = tx.send(data).await {
tracing::warn!(
"❌ WebSocket message route by failure (type={:?}): {}",
transport_msg.payload_type,
e
);
}
} else {
tracing::warn!(
"⚠️ WebSocket received not RegisterType'smessage: {:?}",
transport_msg.payload_type
);
}
}
Err(e) => {
tracing::error!("❌ WebSocket message decodefailure: {}", e);
}
}
}
Ok(WsMessage::Close(_)) => {
tracing::info!("🔌 WebSocket Connect be pair end Close");
*connected.write().await = false;
break;
}
Ok(WsMessage::Ping(_)) | Ok(WsMessage::Pong(_)) => {
}
Ok(_) => {
tracing::debug!("⚠️ Received non-binary WebSocket message, ignoring");
}
Err(e) => {
tracing::error!("❌ WebSocket Error: {}", e);
*connected.write().await = false;
break;
}
}
}
tracing::debug!("📡 WebSocket dispatcher rollback exit ");
})
}
async fn register_route(
&self,
payload_type: PayloadType,
tx: mpsc::Sender<bytes::Bytes>,
) -> NetworkResult<()> {
let mut router = self.router.write().await;
let idx = payload_type as usize;
router[idx] = Some(tx);
tracing::debug!("✅ Register WebSocket route by : {:?}", payload_type);
Ok(())
}
}
impl WebSocketConnection {
pub async fn get_lane(&self, payload_type: PayloadType) -> NetworkResult<Arc<dyn DataLane>> {
self.get_lane_internal(payload_type).await
}
async fn get_lane_internal(
&self,
payload_type: PayloadType,
) -> NetworkResult<Arc<dyn DataLane>> {
let idx = payload_type as usize;
{
let cache = self.lane_cache.read().await;
if let Some(lane) = &cache[idx] {
tracing::debug!("Reuse cached DataLane: {:?}", payload_type);
return Ok(Arc::clone(lane));
}
}
let lane = self.create_lane_internal(payload_type).await?;
{
let mut cache = self.lane_cache.write().await;
cache[idx] = Some(Arc::clone(&lane));
}
tracing::info!(
"WebSocketConnection created new DataLane: {:?}",
payload_type
);
Ok(lane)
}
async fn create_lane_internal(
&self,
payload_type: PayloadType,
) -> NetworkResult<Arc<dyn DataLane>> {
if !*self.connected.read().await {
return Err(NetworkError::ConnectionError(
"WebSocket connection closed".to_string(),
));
}
let (tx, rx) = mpsc::channel(100);
self.register_route(payload_type, tx).await?;
let sink = self.sink.clone();
Ok(Arc::new(WebSocketDataLane::new(sink, payload_type, rx)))
}
pub async fn close(&self) -> NetworkResult<()> {
*self.connected.write().await = false;
let mut sink_opt = self.sink.lock().await;
if let Some(sink) = sink_opt.as_mut() {
let _ = sink.close().await;
}
*sink_opt = None;
let mut router = self.router.write().await;
*router = [None, None, None, None, None];
let mut cache = self.lane_cache.write().await;
*cache = [None, None, None, None, None];
tracing::info!("WebSocketConnection closed");
Ok(())
}
}
#[async_trait]
impl WireHandle for WebSocketConnection {
fn connection_type(&self) -> ConnType {
ConnType::WebSocket
}
fn priority(&self) -> u8 {
0
}
async fn connect(&self) -> NetworkResult<()> {
self.connect().await
}
fn is_connected(&self) -> bool {
*self.connected.blocking_read()
}
async fn close(&self) -> NetworkResult<()> {
Self::close(self).await
}
async fn get_lane(&self, payload_type: PayloadType) -> NetworkResult<Arc<dyn DataLane>> {
self.get_lane_internal(payload_type).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transport_message_decode() {
let mut encoded = Vec::new();
encoded.push(PayloadType::RpcReliable as u8); encoded.extend_from_slice(&11u32.to_be_bytes()); encoded.extend_from_slice(b"hello world");
let decoded = TransportMessage::decode(&encoded)
.expect("Should decode valid TransportMessage in test");
assert_eq!(decoded.payload_type as u8, PayloadType::RpcReliable as u8);
assert_eq!(decoded.data, b"hello world");
}
#[test]
fn test_transport_message_decode_invalid() {
let data = vec![1, 0, 0];
assert!(TransportMessage::decode(&data).is_err());
let data = vec![99, 0, 0, 0, 5, 1, 2, 3, 4, 5];
assert!(TransportMessage::decode(&data).is_err());
}
}