1use super::error::{NetworkError, NetworkResult};
17use actr_protocol::PayloadType;
18use futures_util::SinkExt;
19use futures_util::stream::SplitSink;
20use std::sync::Arc;
21use tokio::net::TcpStream;
22use tokio::sync::{Mutex, mpsc};
23use tokio_tungstenite::tungstenite::Message as WsMessage;
24use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
25use webrtc::data_channel::RTCDataChannel;
26
27type WsSink = Arc<Mutex<Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, WsMessage>>>>;
29
30#[derive(Clone)]
35pub enum DataLane {
36 WebRtcDataChannel {
40 data_channel: Arc<RTCDataChannel>,
42
43 rx: Arc<Mutex<mpsc::Receiver<bytes::Bytes>>>,
45 },
46
47 Mpsc {
53 payload_type: PayloadType,
55
56 tx: mpsc::Sender<actr_protocol::RpcEnvelope>,
58
59 rx: Arc<Mutex<mpsc::Receiver<actr_protocol::RpcEnvelope>>>,
61 },
62
63 WebSocket {
67 sink: WsSink,
70
71 payload_type: PayloadType,
73
74 rx: Arc<Mutex<mpsc::Receiver<bytes::Bytes>>>,
76 },
77}
78
79impl DataLane {
80 pub async fn send(&self, data: bytes::Bytes) -> NetworkResult<()> {
92 match self {
93 DataLane::WebRtcDataChannel { data_channel, .. } => {
94 use webrtc::data_channel::data_channel_state::RTCDataChannelState;
95
96 let start = tokio::time::Instant::now();
98 loop {
99 let state = data_channel.ready_state();
100 if state == RTCDataChannelState::Open {
101 break;
102 }
103 if state == RTCDataChannelState::Closed || state == RTCDataChannelState::Closing
104 {
105 return Err(NetworkError::DataChannelError(format!(
106 "DataChannel closed: {state:?}"
107 )));
108 }
109 if start.elapsed() > std::time::Duration::from_secs(5) {
110 return Err(NetworkError::DataChannelError(format!(
111 "DataChannel open timeout: {state:?}"
112 )));
113 }
114 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
115 }
116
117 data_channel
119 .send(&data)
120 .await
121 .map_err(|e| NetworkError::DataChannelError(format!("Send failed: {e}")))?;
122
123 tracing::trace!("📤 WebRTC DataChannel sent {} bytes", data.len());
124 Ok(())
125 }
126
127 DataLane::Mpsc { .. } => {
128 Err(NetworkError::InvalidOperation(
130 "Mpsc DataLane requires send_envelope(), not send(bytes)".to_string(),
131 ))
132 }
133
134 DataLane::WebSocket {
135 sink, payload_type, ..
136 } => {
137 let mut buf = Vec::with_capacity(5 + data.len());
139
140 buf.push(*payload_type as u8);
142
143 let len = data.len() as u32;
145 buf.extend_from_slice(&len.to_be_bytes());
146
147 buf.extend_from_slice(&data);
149
150 let mut sink_opt = sink.lock().await;
152 if let Some(s) = sink_opt.as_mut() {
153 s.send(WsMessage::Binary(buf.into())).await.map_err(|e| {
154 NetworkError::SendError(format!("WebSocket send failed: {e}"))
155 })?;
156
157 tracing::trace!(
158 "📤 WebSocket sent {} bytes (type={:?})",
159 data.len(),
160 payload_type
161 );
162 Ok(())
163 } else {
164 Err(NetworkError::ConnectionError(
165 "WebSocket not connected".to_string(),
166 ))
167 }
168 }
169 }
170 }
171
172 pub async fn send_envelope(&self, envelope: actr_protocol::RpcEnvelope) -> NetworkResult<()> {
189 match self {
190 DataLane::Mpsc { tx, .. } => {
191 tx.send(envelope)
192 .await
193 .map_err(|_| NetworkError::ChannelClosed("Mpsc channel closed".to_string()))?;
194
195 tracing::trace!("📤 Mpsc sent RpcEnvelope");
196 Ok(())
197 }
198 _ => Err(NetworkError::InvalidOperation(
199 "send_envelope() only supports Mpsc DataLane".to_string(),
200 )),
201 }
202 }
203
204 pub async fn recv(&self) -> NetworkResult<bytes::Bytes> {
219 match self {
220 DataLane::WebRtcDataChannel { rx, .. } | DataLane::WebSocket { rx, .. } => {
221 let mut receiver = rx.lock().await;
222 receiver.recv().await.ok_or_else(|| {
223 NetworkError::ChannelClosed("DataLane receiver closed".to_string())
224 })
225 }
226 DataLane::Mpsc { .. } => {
227 Err(NetworkError::InvalidOperation(
229 "Mpsc DataLane requires recv_envelope(), not recv()".to_string(),
230 ))
231 }
232 }
233 }
234
235 pub async fn recv_envelope(&self) -> NetworkResult<actr_protocol::RpcEnvelope> {
244 match self {
245 DataLane::Mpsc { rx, .. } => {
246 let mut receiver = rx.lock().await;
247 receiver
248 .recv()
249 .await
250 .ok_or_else(|| NetworkError::ChannelClosed("Mpsc channel closed".to_string()))
251 }
252 _ => Err(NetworkError::InvalidOperation(
253 "recv_envelope() only supports Mpsc DataLane".to_string(),
254 )),
255 }
256 }
257
258 pub async fn try_recv(&self) -> NetworkResult<Option<bytes::Bytes>> {
265 match self {
266 DataLane::WebRtcDataChannel { rx, .. } | DataLane::WebSocket { rx, .. } => {
267 let mut receiver = rx.lock().await;
268 match receiver.try_recv() {
269 Ok(data) => Ok(Some(data)),
270 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
271 Err(mpsc::error::TryRecvError::Disconnected) => Err(
272 NetworkError::ChannelClosed("Lane receiver closed".to_string()),
273 ),
274 }
275 }
276 DataLane::Mpsc { .. } => {
277 Err(NetworkError::InvalidOperation(
279 "Mpsc Lane requires try_recv_envelope(), not try_recv()".to_string(),
280 ))
281 }
282 }
283 }
284
285 #[inline]
287 pub fn lane_type(&self) -> &'static str {
288 match self {
289 DataLane::WebRtcDataChannel { .. } => "WebRtcDataChannel",
290 DataLane::Mpsc { .. } => "Mpsc",
291 DataLane::WebSocket { .. } => "WebSocket",
292 }
293 }
294}
295
296impl std::fmt::Debug for DataLane {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 match self {
299 DataLane::WebRtcDataChannel { .. } => write!(f, "DataLane::WebRtcDataChannel(..)"),
300 DataLane::Mpsc { .. } => write!(f, "DataLane::Mpsc(..)"),
301 DataLane::WebSocket { payload_type, .. } => {
302 write!(f, "DataLane::WebSocket(type={payload_type:?})")
303 }
304 }
305 }
306}
307
308impl DataLane {
310 #[inline]
317 pub fn mpsc(
318 payload_type: PayloadType,
319 tx: mpsc::Sender<actr_protocol::RpcEnvelope>,
320 rx: mpsc::Receiver<actr_protocol::RpcEnvelope>,
321 ) -> Self {
322 DataLane::Mpsc {
323 payload_type,
324 tx,
325 rx: Arc::new(Mutex::new(rx)),
326 }
327 }
328
329 #[inline]
336 pub fn mpsc_shared(
337 payload_type: PayloadType,
338 tx: mpsc::Sender<actr_protocol::RpcEnvelope>,
339 rx: Arc<Mutex<mpsc::Receiver<actr_protocol::RpcEnvelope>>>,
340 ) -> Self {
341 DataLane::Mpsc {
342 payload_type,
343 tx,
344 rx,
345 }
346 }
347
348 #[inline]
354 pub fn webrtc_data_channel(
355 data_channel: Arc<RTCDataChannel>,
356 rx: mpsc::Receiver<bytes::Bytes>,
357 ) -> Self {
358 DataLane::WebRtcDataChannel {
359 data_channel,
360 rx: Arc::new(Mutex::new(rx)),
361 }
362 }
363
364 #[inline]
371 pub fn websocket(
372 sink: WsSink,
373 payload_type: PayloadType,
374 rx: mpsc::Receiver<bytes::Bytes>,
375 ) -> Self {
376 DataLane::WebSocket {
377 sink,
378 payload_type,
379 rx: Arc::new(Mutex::new(rx)),
380 }
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use bytes::Bytes;
388
389 #[tokio::test]
390 async fn test_mpsc_lane() {
391 use actr_protocol::RpcEnvelope;
392
393 let (tx, rx) = mpsc::channel(10);
394 let lane = DataLane::mpsc(PayloadType::RpcReliable, tx.clone(), rx);
395
396 let envelope = RpcEnvelope {
398 request_id: "test-1".to_string(),
399 route_key: "test.route".to_string(),
400 payload: Some(Bytes::from_static(b"hello")),
401 trace_id: "trace-1".to_string(),
402 metadata: vec![],
403 timeout_ms: 30000,
404 error: None,
405 };
406 lane.send_envelope(envelope.clone()).await.unwrap();
407
408 let received = lane.recv_envelope().await.unwrap();
410 assert_eq!(received.request_id, "test-1");
411 assert_eq!(received.payload, Some(Bytes::from_static(b"hello")));
412 }
413
414 #[tokio::test]
415 async fn test_mpsc_lane_clone() {
416 use actr_protocol::RpcEnvelope;
417
418 let (tx, rx) = mpsc::channel(10);
419 let lane = DataLane::mpsc(PayloadType::RpcReliable, tx.clone(), rx);
420
421 let lane2 = lane.clone();
423
424 let envelope = RpcEnvelope {
426 request_id: "test-2".to_string(),
427 route_key: "test.route".to_string(),
428 payload: Some(Bytes::from_static(b"test")),
429 trace_id: "trace-2".to_string(),
430 metadata: vec![],
431 timeout_ms: 30000,
432 error: None,
433 };
434 lane.send_envelope(envelope.clone()).await.unwrap();
435
436 let received = lane2.recv_envelope().await.unwrap();
438 assert_eq!(received.request_id, "test-2");
439 assert_eq!(received.payload, Some(Bytes::from_static(b"test")));
440 }
441
442 #[tokio::test]
443 async fn test_mpsc_lane_with_shared_rx() {
444 use actr_protocol::RpcEnvelope;
445
446 let (tx, rx) = mpsc::channel(10);
447 let rx_shared = Arc::new(Mutex::new(rx));
448
449 let lane = DataLane::mpsc_shared(PayloadType::RpcReliable, tx.clone(), rx_shared.clone());
451
452 let envelope = RpcEnvelope {
453 request_id: "test-3".to_string(),
454 route_key: "test.route".to_string(),
455 payload: Some(Bytes::from_static(b"shared")),
456 trace_id: "trace-3".to_string(),
457 metadata: vec![],
458 timeout_ms: 30000,
459 error: None,
460 };
461 lane.send_envelope(envelope.clone()).await.unwrap();
462
463 let received = lane.recv_envelope().await.unwrap();
464 assert_eq!(received.request_id, "test-3");
465 assert_eq!(received.payload, Some(Bytes::from_static(b"shared")));
466 }
467
468 #[test]
469 fn test_lane_type_name() {
470 let (tx, rx) = mpsc::channel(10);
471 let lane = DataLane::mpsc(PayloadType::RpcReliable, tx, rx);
472 assert_eq!(lane.lane_type(), "Mpsc");
473 }
474}