actr_runtime/wire/websocket/
connection.rs1use crate::transport::DataLane;
4use crate::transport::{NetworkError, NetworkResult};
5use actr_protocol::PayloadType;
6use futures_util::stream::{SplitSink, SplitStream};
7use futures_util::{SinkExt, StreamExt};
8use std::sync::Arc;
9use tokio::net::TcpStream;
10use tokio::sync::{Mutex, RwLock, mpsc};
11use tokio_tungstenite::tungstenite::Message as WsMessage;
12use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
13
14#[derive(Debug, Clone)]
24struct TransportMessage {
25 payload_type: PayloadType,
26 data: Vec<u8>,
27}
28
29impl TransportMessage {
30 fn decode(data: &[u8]) -> NetworkResult<Self> {
32 if data.len() < 5 {
33 return Err(NetworkError::DeserializationError(
34 "WebSocket message too short".to_string(),
35 ));
36 }
37
38 let payload_type_raw = data[0];
40 let payload_type = match payload_type_raw {
41 0 => PayloadType::RpcReliable,
42 1 => PayloadType::RpcSignal,
43 2 => PayloadType::StreamReliable,
44 3 => PayloadType::StreamLatencyFirst,
45 4 => PayloadType::MediaRtp,
46 _ => {
47 return Err(NetworkError::DeserializationError(format!(
48 "Invalid payload_type: {payload_type_raw}"
49 )));
50 }
51 };
52
53 let len = u32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
55
56 if data.len() < 5 + len {
58 return Err(NetworkError::DeserializationError(
59 "WebSocket message data incomplete".to_string(),
60 ));
61 }
62
63 let msg_data = data[5..5 + len].to_vec();
64
65 Ok(Self {
66 payload_type,
67 data: msg_data,
68 })
69 }
70}
71
72type WsSink = Arc<Mutex<Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, WsMessage>>>>;
74
75#[derive(Clone, Debug)]
77pub struct WebSocketConnection {
78 url: String,
80 sink: WsSink,
82
83 router: Arc<RwLock<[Option<mpsc::Sender<bytes::Bytes>>; 5]>>,
85
86 lane_cache: Arc<RwLock<[Option<DataLane>; 5]>>,
88
89 connected: Arc<RwLock<bool>>,
91}
92
93impl WebSocketConnection {
94 pub fn new(url: String) -> Self {
106 Self {
107 url: url.clone(),
108 sink: Arc::new(Mutex::new(None)), router: Arc::new(RwLock::new([None, None, None, None, None])),
110 lane_cache: Arc::new(RwLock::new([None, None, None, None, None])),
111 connected: Arc::new(RwLock::new(false)),
112 }
113 }
114
115 pub async fn connect(&self) -> NetworkResult<()> {
117 let (ws_stream, _) = connect_async(&self.url).await?;
119 let (sink, stream) = ws_stream.split();
120
121 *self.sink.lock().await = Some(sink);
123 *self.connected.write().await = true;
124
125 let router = self.router.clone();
127 let connected = self.connected.clone();
128 Self::spawn_dispatcher(stream, router, connected);
129
130 tracing::info!("✅ WebSocketConnection already Connect: {}", self.url);
131
132 Ok(())
133 }
134
135 #[inline]
137 pub fn is_connected(&self) -> bool {
138 *self.connected.blocking_read()
139 }
140
141 fn spawn_dispatcher(
143 mut stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
144 router: Arc<RwLock<[Option<mpsc::Sender<bytes::Bytes>>; 5]>>,
145 connected: Arc<RwLock<bool>>,
146 ) -> tokio::task::JoinHandle<()> {
147 tokio::spawn(async move {
148 tracing::debug!("📡 WebSocket dispatcher Start");
149
150 while let Some(msg_result) = stream.next().await {
151 match msg_result {
152 Ok(WsMessage::Binary(data)) => {
153 match TransportMessage::decode(&data) {
155 Ok(transport_msg) => {
156 let idx = transport_msg.payload_type as usize;
158 let router_guard = router.read().await;
159 if let Some(tx) = &router_guard[idx] {
160 let data = bytes::Bytes::from(transport_msg.data);
162 if let Err(e) = tx.send(data).await {
163 tracing::warn!(
164 "❌ WebSocket message route by failure (type={:?}): {}",
165 transport_msg.payload_type,
166 e
167 );
168 }
169 } else {
170 tracing::warn!(
171 "⚠️ WebSocket received not RegisterType'smessage: {:?}",
172 transport_msg.payload_type
173 );
174 }
175 }
176 Err(e) => {
177 tracing::error!("❌ WebSocket message decodefailure: {}", e);
178 }
179 }
180 }
181 Ok(WsMessage::Close(_)) => {
182 tracing::info!("🔌 WebSocket Connect be pair end Close");
183 *connected.write().await = false;
184 break;
185 }
186 Ok(WsMessage::Ping(_)) | Ok(WsMessage::Pong(_)) => {
187 }
189 Ok(_) => {
190 tracing::debug!("⚠️ Received non-binary WebSocket message, ignoring");
191 }
192 Err(e) => {
193 tracing::error!("❌ WebSocket Error: {}", e);
194 *connected.write().await = false;
195 break;
196 }
197 }
198 }
199
200 tracing::debug!("📡 WebSocket dispatcher rollback exit ");
201 })
202 }
203
204 async fn register_route(
206 &self,
207 payload_type: PayloadType,
208 tx: mpsc::Sender<bytes::Bytes>,
209 ) -> NetworkResult<()> {
210 let mut router = self.router.write().await;
211 let idx = payload_type as usize;
212 router[idx] = Some(tx);
213 tracing::debug!("✅ Register WebSocket route by : {:?}", payload_type);
214 Ok(())
215 }
216}
217
218impl WebSocketConnection {
219 pub async fn get_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
221 let idx = payload_type as usize;
222
223 {
225 let cache = self.lane_cache.read().await;
226 if let Some(lane) = &cache[idx] {
227 tracing::debug!("📦 ReuseCache DataLane: {:?}", payload_type);
228 return Ok(lane.clone());
229 }
230 }
231
232 let lane = self.create_lane_internal(payload_type).await?;
234
235 {
237 let mut cache = self.lane_cache.write().await;
238 cache[idx] = Some(lane.clone());
239 }
240
241 tracing::info!(
242 "✨ WebSocketConnection Createnew DataLane: {:?}",
243 payload_type
244 );
245
246 Ok(lane)
247 }
248
249 async fn create_lane_internal(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
251 if !*self.connected.read().await {
253 return Err(NetworkError::ConnectionError(
254 "WebSocket connection closed".to_string(),
255 ));
256 }
257
258 let (tx, rx) = mpsc::channel(100);
260
261 self.register_route(payload_type, tx).await?;
263
264 let sink = self.sink.clone();
266
267 Ok(DataLane::websocket(sink, payload_type, rx))
269 }
270
271 pub async fn create_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
273 self.get_lane(payload_type).await
274 }
275
276 pub async fn close(&self) -> NetworkResult<()> {
278 *self.connected.write().await = false;
279
280 let mut sink_opt = self.sink.lock().await;
282 if let Some(sink) = sink_opt.as_mut() {
283 let _ = sink.close().await;
284 }
285 *sink_opt = None;
286
287 let mut router = self.router.write().await;
289 *router = [None, None, None, None, None];
290
291 let mut cache = self.lane_cache.write().await;
293 *cache = [None, None, None, None, None];
294
295 tracing::info!("🔌 WebSocketConnection already Close");
296 Ok(())
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_transport_message_decode() {
306 let mut encoded = Vec::new();
309 encoded.push(PayloadType::RpcReliable as u8); encoded.extend_from_slice(&11u32.to_be_bytes()); encoded.extend_from_slice(b"hello world"); let decoded = TransportMessage::decode(&encoded)
314 .expect("Should decode valid TransportMessage in test");
315
316 assert_eq!(decoded.payload_type as u8, PayloadType::RpcReliable as u8);
317 assert_eq!(decoded.data, b"hello world");
318 }
319
320 #[test]
321 fn test_transport_message_decode_invalid() {
322 let data = vec![1, 0, 0];
324 assert!(TransportMessage::decode(&data).is_err());
325
326 let data = vec![99, 0, 0, 0, 5, 1, 2, 3, 4, 5];
328 assert!(TransportMessage::decode(&data).is_err());
329 }
330}