1use super::error::{NetworkError, NetworkResult};
17use actr_protocol::PayloadType;
18use futures_util::SinkExt;
19use futures_util::stream::SplitSink;
20use std::sync::Arc;
21use tokio::net::TcpStream;
22use tokio::sync::{Mutex, mpsc};
23use tokio_tungstenite::tungstenite::Message as WsMessage;
24use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
25use webrtc::data_channel::RTCDataChannel;
26
27type WsSink = Arc<Mutex<Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, WsMessage>>>>;
29
30#[derive(Clone)]
35pub enum DataLane {
36 WebRtcDataChannel {
40 data_channel: Arc<RTCDataChannel>,
42
43 rx: Arc<Mutex<mpsc::Receiver<bytes::Bytes>>>,
45 },
46
47 Mpsc {
53 payload_type: PayloadType,
55
56 tx: mpsc::Sender<actr_protocol::RpcEnvelope>,
58
59 rx: Arc<Mutex<mpsc::Receiver<actr_protocol::RpcEnvelope>>>,
61 },
62
63 WebSocket {
67 sink: WsSink,
70
71 payload_type: PayloadType,
73
74 rx: Arc<Mutex<mpsc::Receiver<bytes::Bytes>>>,
76 },
77}
78
79impl DataLane {
80 pub async fn send(&self, data: bytes::Bytes) -> NetworkResult<()> {
92 match self {
93 DataLane::WebRtcDataChannel { data_channel, .. } => {
94 use webrtc::data_channel::data_channel_state::RTCDataChannelState;
95
96 let start = tokio::time::Instant::now();
98 loop {
99 let state = data_channel.ready_state();
100 if state == RTCDataChannelState::Open {
101 break;
102 }
103 if state == RTCDataChannelState::Closed || state == RTCDataChannelState::Closing
104 {
105 return Err(NetworkError::DataChannelError(format!(
106 "DataChannel closed: {state:?}"
107 )));
108 }
109 if start.elapsed() > std::time::Duration::from_secs(5) {
110 return Err(NetworkError::DataChannelError(format!(
111 "DataChannel open timeout: {state:?}"
112 )));
113 }
114 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
115 }
116 tracing::debug!("🔄 WebRTC DataChannel send");
117 data_channel
119 .send(&data)
120 .await
121 .map_err(|e| NetworkError::DataChannelError(format!("Send failed: {e}")))?;
122
123 tracing::trace!("📤 WebRTC DataChannel sent {} bytes", data.len());
124 Ok(())
125 }
126
127 DataLane::Mpsc { .. } => {
128 Err(NetworkError::InvalidOperation(
130 "Mpsc DataLane requires send_envelope(), not send(bytes)".to_string(),
131 ))
132 }
133
134 DataLane::WebSocket {
135 sink, payload_type, ..
136 } => {
137 let mut buf = Vec::with_capacity(5 + data.len());
139
140 buf.push(*payload_type as u8);
142
143 let len = data.len() as u32;
145 buf.extend_from_slice(&len.to_be_bytes());
146
147 buf.extend_from_slice(&data);
149
150 let mut sink_opt = sink.lock().await;
152 if let Some(s) = sink_opt.as_mut() {
153 s.send(WsMessage::Binary(buf.into())).await.map_err(|e| {
154 NetworkError::SendError(format!("WebSocket send failed: {e}"))
155 })?;
156
157 tracing::trace!(
158 "📤 WebSocket sent {} bytes (type={:?})",
159 data.len(),
160 payload_type
161 );
162 Ok(())
163 } else {
164 Err(NetworkError::ConnectionError(
165 "WebSocket not connected".to_string(),
166 ))
167 }
168 }
169 }
170 }
171
172 #[cfg_attr(
189 feature = "opentelemetry",
190 tracing::instrument(skip_all, name = "DataLane.send_envelope")
191 )]
192 pub async fn send_envelope(&self, envelope: actr_protocol::RpcEnvelope) -> NetworkResult<()> {
193 match self {
194 DataLane::Mpsc { tx, .. } => {
195 tx.send(envelope)
196 .await
197 .map_err(|_| NetworkError::ChannelClosed("Mpsc channel closed".to_string()))?;
198
199 tracing::trace!("📤 Mpsc sent RpcEnvelope");
200 Ok(())
201 }
202 _ => Err(NetworkError::InvalidOperation(
203 "send_envelope() only supports Mpsc DataLane".to_string(),
204 )),
205 }
206 }
207
208 pub async fn recv(&self) -> NetworkResult<bytes::Bytes> {
223 match self {
224 DataLane::WebRtcDataChannel { rx, .. } | DataLane::WebSocket { rx, .. } => {
225 let mut receiver = rx.lock().await;
226 tracing::debug!("🔄 WebRTC DataLane recv: {:?}", receiver);
227 receiver.recv().await.ok_or_else(|| {
228 NetworkError::ChannelClosed("DataLane receiver closed".to_string())
229 })
230 }
231 DataLane::Mpsc { .. } => {
232 Err(NetworkError::InvalidOperation(
234 "Mpsc DataLane requires recv_envelope(), not recv()".to_string(),
235 ))
236 }
237 }
238 }
239
240 pub async fn recv_envelope(&self) -> NetworkResult<actr_protocol::RpcEnvelope> {
249 match self {
250 DataLane::Mpsc { rx, .. } => {
251 let mut receiver = rx.lock().await;
252 receiver
253 .recv()
254 .await
255 .ok_or_else(|| NetworkError::ChannelClosed("Mpsc channel closed".to_string()))
256 }
257 _ => Err(NetworkError::InvalidOperation(
258 "recv_envelope() only supports Mpsc DataLane".to_string(),
259 )),
260 }
261 }
262
263 pub async fn try_recv(&self) -> NetworkResult<Option<bytes::Bytes>> {
270 match self {
271 DataLane::WebRtcDataChannel { rx, .. } | DataLane::WebSocket { rx, .. } => {
272 let mut receiver = rx.lock().await;
273 match receiver.try_recv() {
274 Ok(data) => Ok(Some(data)),
275 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
276 Err(mpsc::error::TryRecvError::Disconnected) => Err(
277 NetworkError::ChannelClosed("Lane receiver closed".to_string()),
278 ),
279 }
280 }
281 DataLane::Mpsc { .. } => {
282 Err(NetworkError::InvalidOperation(
284 "Mpsc Lane requires try_recv_envelope(), not try_recv()".to_string(),
285 ))
286 }
287 }
288 }
289
290 #[inline]
292 pub fn lane_type(&self) -> &'static str {
293 match self {
294 DataLane::WebRtcDataChannel { .. } => "WebRtcDataChannel",
295 DataLane::Mpsc { .. } => "Mpsc",
296 DataLane::WebSocket { .. } => "WebSocket",
297 }
298 }
299}
300
301impl std::fmt::Debug for DataLane {
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 match self {
304 DataLane::WebRtcDataChannel { .. } => write!(f, "DataLane::WebRtcDataChannel(..)"),
305 DataLane::Mpsc { .. } => write!(f, "DataLane::Mpsc(..)"),
306 DataLane::WebSocket { payload_type, .. } => {
307 write!(f, "DataLane::WebSocket(type={payload_type:?})")
308 }
309 }
310 }
311}
312
313impl DataLane {
315 #[inline]
322 pub fn mpsc(
323 payload_type: PayloadType,
324 tx: mpsc::Sender<actr_protocol::RpcEnvelope>,
325 rx: mpsc::Receiver<actr_protocol::RpcEnvelope>,
326 ) -> Self {
327 DataLane::Mpsc {
328 payload_type,
329 tx,
330 rx: Arc::new(Mutex::new(rx)),
331 }
332 }
333
334 #[inline]
341 pub fn mpsc_shared(
342 payload_type: PayloadType,
343 tx: mpsc::Sender<actr_protocol::RpcEnvelope>,
344 rx: Arc<Mutex<mpsc::Receiver<actr_protocol::RpcEnvelope>>>,
345 ) -> Self {
346 DataLane::Mpsc {
347 payload_type,
348 tx,
349 rx,
350 }
351 }
352
353 #[inline]
359 pub fn webrtc_data_channel(
360 data_channel: Arc<RTCDataChannel>,
361 rx: mpsc::Receiver<bytes::Bytes>,
362 ) -> Self {
363 DataLane::WebRtcDataChannel {
364 data_channel,
365 rx: Arc::new(Mutex::new(rx)),
366 }
367 }
368
369 #[inline]
376 pub fn websocket(
377 sink: WsSink,
378 payload_type: PayloadType,
379 rx: mpsc::Receiver<bytes::Bytes>,
380 ) -> Self {
381 DataLane::WebSocket {
382 sink,
383 payload_type,
384 rx: Arc::new(Mutex::new(rx)),
385 }
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use bytes::Bytes;
393
394 #[tokio::test]
395 async fn test_mpsc_lane() {
396 use actr_protocol::RpcEnvelope;
397
398 let (tx, rx) = mpsc::channel(10);
399 let lane = DataLane::mpsc(PayloadType::RpcReliable, tx.clone(), rx);
400
401 let envelope = RpcEnvelope {
403 request_id: "test-1".to_string(),
404 route_key: "test.route".to_string(),
405 payload: Some(Bytes::from_static(b"hello")),
406 traceparent: None,
407 tracestate: None,
408 metadata: vec![],
409 timeout_ms: 30000,
410 error: None,
411 };
412 lane.send_envelope(envelope.clone()).await.unwrap();
413
414 let received = lane.recv_envelope().await.unwrap();
416 assert_eq!(received.request_id, "test-1");
417 assert_eq!(received.payload, Some(Bytes::from_static(b"hello")));
418 }
419
420 #[tokio::test]
421 async fn test_mpsc_lane_clone() {
422 use actr_protocol::RpcEnvelope;
423
424 let (tx, rx) = mpsc::channel(10);
425 let lane = DataLane::mpsc(PayloadType::RpcReliable, tx.clone(), rx);
426
427 let lane2 = lane.clone();
429
430 let envelope = RpcEnvelope {
432 request_id: "test-2".to_string(),
433 route_key: "test.route".to_string(),
434 payload: Some(Bytes::from_static(b"test")),
435 traceparent: None,
436 tracestate: None,
437 metadata: vec![],
438 timeout_ms: 30000,
439 error: None,
440 };
441 lane.send_envelope(envelope.clone()).await.unwrap();
442
443 let received = lane2.recv_envelope().await.unwrap();
445 assert_eq!(received.request_id, "test-2");
446 assert_eq!(received.payload, Some(Bytes::from_static(b"test")));
447 }
448
449 #[tokio::test]
450 async fn test_mpsc_lane_with_shared_rx() {
451 use actr_protocol::RpcEnvelope;
452
453 let (tx, rx) = mpsc::channel(10);
454 let rx_shared = Arc::new(Mutex::new(rx));
455
456 let lane = DataLane::mpsc_shared(PayloadType::RpcReliable, tx.clone(), rx_shared.clone());
458
459 let envelope = RpcEnvelope {
460 request_id: "test-3".to_string(),
461 route_key: "test.route".to_string(),
462 payload: Some(Bytes::from_static(b"shared")),
463 traceparent: None,
464 tracestate: None,
465 metadata: vec![],
466 timeout_ms: 30000,
467 error: None,
468 };
469 lane.send_envelope(envelope.clone()).await.unwrap();
470
471 let received = lane.recv_envelope().await.unwrap();
472 assert_eq!(received.request_id, "test-3");
473 assert_eq!(received.payload, Some(Bytes::from_static(b"shared")));
474 }
475
476 #[test]
477 fn test_lane_type_name() {
478 let (tx, rx) = mpsc::channel(10);
479 let lane = DataLane::mpsc(PayloadType::RpcReliable, tx, rx);
480 assert_eq!(lane.lane_type(), "Mpsc");
481 }
482}