actr_runtime/wire/websocket/
connection.rs

1//! WebSocket C/S Connection implementation
2
3use crate::transport::DataLane;
4use crate::transport::{NetworkError, NetworkResult};
5use actr_protocol::PayloadType;
6use futures_util::stream::{SplitSink, SplitStream};
7use futures_util::{SinkExt, StreamExt};
8use std::sync::Arc;
9use tokio::net::TcpStream;
10use tokio::sync::{Mutex, RwLock, mpsc};
11use tokio_tungstenite::tungstenite::Message as WsMessage;
12use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
13
14/// WebSocket transmitting messagesprotocol
15///
16/// Forin single WebSocket Connect for multiple route reuse different Type'smessage。
17///
18/// ## Message format
19///
20/// ```text
21/// [payload_type: 1 byte][data_len: 4 bytes][data: N bytes]
22/// ```
23#[derive(Debug, Clone)]
24struct TransportMessage {
25    payload_type: PayloadType,
26    data: Vec<u8>,
27}
28
29impl TransportMessage {
30    /// frombytes stream decode
31    fn decode(data: &[u8]) -> NetworkResult<Self> {
32        if data.len() < 5 {
33            return Err(NetworkError::DeserializationError(
34                "WebSocket message too short".to_string(),
35            ));
36        }
37
38        // Parse payload_type (must match proto enum values)
39        let payload_type_raw = data[0];
40        let payload_type = match payload_type_raw {
41            0 => PayloadType::RpcReliable,
42            1 => PayloadType::RpcSignal,
43            2 => PayloadType::StreamReliable,
44            3 => PayloadType::StreamLatencyFirst,
45            4 => PayloadType::MediaRtp,
46            _ => {
47                return Err(NetworkError::DeserializationError(format!(
48                    "Invalid payload_type: {payload_type_raw}"
49                )));
50            }
51        };
52
53        // Parse length
54        let len = u32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
55
56        // Parse data
57        if data.len() < 5 + len {
58            return Err(NetworkError::DeserializationError(
59                "WebSocket message data incomplete".to_string(),
60            ));
61        }
62
63        let msg_data = data[5..5 + len].to_vec();
64
65        Ok(Self {
66            payload_type,
67            data: msg_data,
68        })
69    }
70}
71
72/// WebSocket Sink Type distinct name
73type WsSink = Arc<Mutex<Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, WsMessage>>>>;
74
75/// WebSocketConnection - WebSocket C/S Connect
76#[derive(Clone, Debug)]
77pub struct WebSocketConnection {
78    /// URL
79    url: String,
80    /// Write end (Sink) - using Option to avoid initialization issues
81    sink: WsSink,
82
83    /// message route by table :PayloadType → Sender(using array index reference ,5 fixed elements,using Bytes zero-copy)
84    router: Arc<RwLock<[Option<mpsc::Sender<bytes::Bytes>>; 5]>>,
85
86    /// Lane Cache:PayloadType → Lane(using array index reference ,5 fixed elements)
87    lane_cache: Arc<RwLock<[Option<DataLane>; 5]>>,
88
89    /// connection status
90    connected: Arc<RwLock<bool>>,
91}
92
93impl WebSocketConnection {
94    /// Connectto WebSocket service device
95    ///
96    /// # Arguments
97    /// - `url`: WebSocket URL (ws:// or wss://)
98    ///
99    /// # Example
100    ///
101    /// ```rust,ignore
102    /// let conn = WebSocketConnection::new("ws://localhost:8080");
103    /// conn.connect().await?;
104    /// ```
105    pub fn new(url: String) -> Self {
106        Self {
107            url: url.clone(),
108            sink: Arc::new(Mutex::new(None)), // initial begin as None
109            router: Arc::new(RwLock::new([None, None, None, None, None])),
110            lane_cache: Arc::new(RwLock::new([None, None, None, None, None])),
111            connected: Arc::new(RwLock::new(false)),
112        }
113    }
114
115    /// establish Connect
116    pub async fn connect(&self) -> NetworkResult<()> {
117        // 1. establish WebSocket Connect
118        let (ws_stream, _) = connect_async(&self.url).await?;
119        let (sink, stream) = ws_stream.split();
120
121        // 2. update new sink
122        *self.sink.lock().await = Some(sink);
123        *self.connected.write().await = true;
124
125        // 3. Startmessage dispatch device (in background task, not retain handle)
126        let router = self.router.clone();
127        let connected = self.connected.clone();
128        Self::spawn_dispatcher(stream, router, connected);
129
130        tracing::info!("✅ WebSocketConnection already Connect: {}", self.url);
131
132        Ok(())
133    }
134
135    /// Checkwhether already Connect
136    #[inline]
137    pub fn is_connected(&self) -> bool {
138        *self.connected.blocking_read()
139    }
140
141    /// Startmessage dispatch device (in background task)
142    fn spawn_dispatcher(
143        mut stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
144        router: Arc<RwLock<[Option<mpsc::Sender<bytes::Bytes>>; 5]>>,
145        connected: Arc<RwLock<bool>>,
146    ) -> tokio::task::JoinHandle<()> {
147        tokio::spawn(async move {
148            tracing::debug!("📡 WebSocket dispatcher Start");
149
150            while let Some(msg_result) = stream.next().await {
151                match msg_result {
152                    Ok(WsMessage::Binary(data)) => {
153                        // decodemessage
154                        match TransportMessage::decode(&data) {
155                            Ok(transport_msg) => {
156                                // Route to corresponding 's Lane(using array index reference )
157                                let idx = transport_msg.payload_type as usize;
158                                let router_guard = router.read().await;
159                                if let Some(tx) = &router_guard[idx] {
160                                    // convert exchange as Bytes( zero-copy)
161                                    let data = bytes::Bytes::from(transport_msg.data);
162                                    if let Err(e) = tx.send(data).await {
163                                        tracing::warn!(
164                                            "❌ WebSocket message route by failure (type={:?}): {}",
165                                            transport_msg.payload_type,
166                                            e
167                                        );
168                                    }
169                                } else {
170                                    tracing::warn!(
171                                        "⚠️ WebSocket received not RegisterType'smessage: {:?}",
172                                        transport_msg.payload_type
173                                    );
174                                }
175                            }
176                            Err(e) => {
177                                tracing::error!("❌ WebSocket message decodefailure: {}", e);
178                            }
179                        }
180                    }
181                    Ok(WsMessage::Close(_)) => {
182                        tracing::info!("🔌 WebSocket Connect be pair end Close");
183                        *connected.write().await = false;
184                        break;
185                    }
186                    Ok(WsMessage::Ping(_)) | Ok(WsMessage::Pong(_)) => {
187                        // ignore center skipmessage
188                    }
189                    Ok(_) => {
190                        tracing::debug!("⚠️ Received non-binary WebSocket message, ignoring");
191                    }
192                    Err(e) => {
193                        tracing::error!("❌ WebSocket Error: {}", e);
194                        *connected.write().await = false;
195                        break;
196                    }
197                }
198            }
199
200            tracing::debug!("📡 WebSocket dispatcher rollback exit ");
201        })
202    }
203
204    /// Register PayloadType route by
205    async fn register_route(
206        &self,
207        payload_type: PayloadType,
208        tx: mpsc::Sender<bytes::Bytes>,
209    ) -> NetworkResult<()> {
210        let mut router = self.router.write().await;
211        let idx = payload_type as usize;
212        router[idx] = Some(tx);
213        tracing::debug!("✅ Register WebSocket route by : {:?}", payload_type);
214        Ok(())
215    }
216}
217
218impl WebSocketConnection {
219    /// GetorCreate DataLane( carry Cache)
220    pub async fn get_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
221        let idx = payload_type as usize;
222
223        // 1. CheckCache
224        {
225            let cache = self.lane_cache.read().await;
226            if let Some(lane) = &cache[idx] {
227                tracing::debug!("📦 ReuseCache DataLane: {:?}", payload_type);
228                return Ok(lane.clone());
229            }
230        }
231
232        // 2. Createnew DataLane
233        let lane = self.create_lane_internal(payload_type).await?;
234
235        // 3. Cache
236        {
237            let mut cache = self.lane_cache.write().await;
238            cache[idx] = Some(lane.clone());
239        }
240
241        tracing::info!(
242            "✨ WebSocketConnection Createnew DataLane: {:?}",
243            payload_type
244        );
245
246        Ok(lane)
247    }
248
249    /// inner part Method:Create DataLane( not carry Cache)
250    async fn create_lane_internal(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
251        // Check connection status
252        if !*self.connected.read().await {
253            return Err(NetworkError::ConnectionError(
254                "WebSocket connection closed".to_string(),
255            ));
256        }
257
258        // CreateReceive channel
259        let (tx, rx) = mpsc::channel(100);
260
261        // Register route by
262        self.register_route(payload_type, tx).await?;
263
264        // Getshared's Sink
265        let sink = self.sink.clone();
266
267        // Create DataLane(usingnew's websocket transform body )
268        Ok(DataLane::websocket(sink, payload_type, rx))
269    }
270
271    /// backwardaftercompatible hold Method:create_lane adjust usage get_lane
272    pub async fn create_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
273        self.get_lane(payload_type).await
274    }
275
276    /// CloseConnect
277    pub async fn close(&self) -> NetworkResult<()> {
278        *self.connected.write().await = false;
279
280        // Close WebSocket(Send Close message)
281        let mut sink_opt = self.sink.lock().await;
282        if let Some(sink) = sink_opt.as_mut() {
283            let _ = sink.close().await;
284        }
285        *sink_opt = None;
286
287        // clear blank route by table
288        let mut router = self.router.write().await;
289        *router = [None, None, None, None, None];
290
291        // clear blank Lane Cache
292        let mut cache = self.lane_cache.write().await;
293        *cache = [None, None, None, None, None];
294
295        tracing::info!("🔌 WebSocketConnection already Close");
296        Ok(())
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_transport_message_decode() {
306        // Manually construct encoded message:
307        // [payload_type: 1 byte][data_len: 4 bytes][data: N bytes]
308        let mut encoded = Vec::new();
309        encoded.push(PayloadType::RpcReliable as u8); // payload_type = 0
310        encoded.extend_from_slice(&11u32.to_be_bytes()); // length = 11
311        encoded.extend_from_slice(b"hello world"); // data
312
313        let decoded = TransportMessage::decode(&encoded)
314            .expect("Should decode valid TransportMessage in test");
315
316        assert_eq!(decoded.payload_type as u8, PayloadType::RpcReliable as u8);
317        assert_eq!(decoded.data, b"hello world");
318    }
319
320    #[test]
321    fn test_transport_message_decode_invalid() {
322        // message too short
323        let data = vec![1, 0, 0];
324        assert!(TransportMessage::decode(&data).is_err());
325
326        // no effect 's payload_type
327        let data = vec![99, 0, 0, 0, 5, 1, 2, 3, 4, 5];
328        assert!(TransportMessage::decode(&data).is_err());
329    }
330}