use super::error::{NetworkError, NetworkResult};
use crate::INITIAL_CONNECTION_TIMEOUT;
use actr_protocol::PayloadType;
use async_trait::async_trait;
use futures_util::SinkExt;
use futures_util::stream::SplitSink;
use std::collections::HashMap;
#[cfg(feature = "test-utils")]
use std::future::Future;
#[cfg(feature = "test-utils")]
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[cfg(feature = "test-utils")]
use std::sync::{Mutex as StdMutex, OnceLock};
use std::time::Instant;
use tokio::net::TcpStream;
use tokio::sync::{Mutex, mpsc};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use webrtc::data_channel::RTCDataChannel;
const DC_MAX_MESSAGE_SIZE: usize = 65535;
const FRAGMENT_HEADER_SIZE: usize = 8;
const REASSEMBLY_TTL: std::time::Duration = std::time::Duration::from_secs(6 * 60 * 60);
const DC_MAX_PAYLOAD_SIZE: usize = DC_MAX_MESSAGE_SIZE - FRAGMENT_HEADER_SIZE;
#[cfg(feature = "test-utils")]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WebRtcFragmentSendEvent {
pub msg_id: u32,
pub frag_index: u16,
pub total_frags: u16,
pub fragment_payload_len: usize,
pub message_len: usize,
}
#[cfg(feature = "test-utils")]
pub type WebRtcFragmentSendHookFuture = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
#[cfg(feature = "test-utils")]
pub type WebRtcFragmentSendHook =
Arc<dyn Fn(WebRtcFragmentSendEvent) -> WebRtcFragmentSendHookFuture + Send + Sync + 'static>;
#[cfg(feature = "test-utils")]
static WEBRTC_FRAGMENT_SEND_HOOK: OnceLock<StdMutex<Option<WebRtcFragmentSendHook>>> =
OnceLock::new();
#[cfg(feature = "test-utils")]
fn webrtc_fragment_send_hook_slot() -> &'static StdMutex<Option<WebRtcFragmentSendHook>> {
WEBRTC_FRAGMENT_SEND_HOOK.get_or_init(|| StdMutex::new(None))
}
#[cfg(feature = "test-utils")]
pub struct WebRtcFragmentSendHookGuard {
previous: Option<WebRtcFragmentSendHook>,
}
#[cfg(feature = "test-utils")]
impl Drop for WebRtcFragmentSendHookGuard {
fn drop(&mut self) {
let mut hook = webrtc_fragment_send_hook_slot()
.lock()
.expect("fragment send hook mutex poisoned");
*hook = self.previous.take();
}
}
#[cfg(feature = "test-utils")]
pub fn install_webrtc_fragment_send_hook_for_test(
hook: WebRtcFragmentSendHook,
) -> WebRtcFragmentSendHookGuard {
let mut slot = webrtc_fragment_send_hook_slot()
.lock()
.expect("fragment send hook mutex poisoned");
let previous = slot.replace(hook);
WebRtcFragmentSendHookGuard { previous }
}
#[cfg(feature = "test-utils")]
async fn notify_webrtc_fragment_sent_for_test(event: WebRtcFragmentSendEvent) {
let hook = {
webrtc_fragment_send_hook_slot()
.lock()
.expect("fragment send hook mutex poisoned")
.clone()
};
if let Some(hook) = hook {
hook(event).await;
}
}
struct FragmentEntry {
total: u16,
created_at: Instant,
fragments: HashMap<u16, bytes::Bytes>,
}
pub(crate) struct ReassemblyBuffer {
pending: HashMap<u32, FragmentEntry>,
}
impl ReassemblyBuffer {
fn new() -> Self {
Self {
pending: HashMap::new(),
}
}
fn insert(
&mut self,
msg_id: u32,
frag_index: u16,
total_frags: u16,
payload: bytes::Bytes,
) -> Option<bytes::Bytes> {
self.evict_stale();
let entry = self.pending.entry(msg_id).or_insert_with(|| FragmentEntry {
total: total_frags,
created_at: Instant::now(),
fragments: HashMap::new(),
});
entry.fragments.insert(frag_index, payload);
if entry.fragments.len() == entry.total as usize {
let entry = self.pending.remove(&msg_id).unwrap();
let mut ordered: Vec<(u16, bytes::Bytes)> = entry.fragments.into_iter().collect();
ordered.sort_by_key(|(idx, _)| *idx);
let total_len: usize = ordered.iter().map(|(_, b)| b.len()).sum();
let mut out = bytes::BytesMut::with_capacity(total_len);
for (_, frag) in ordered {
out.extend_from_slice(&frag);
}
Some(out.freeze())
} else {
None
}
}
fn evict_stale(&mut self) {
let now = Instant::now();
let before = self.pending.len();
self.pending.retain(|msg_id, entry| {
let age = now.duration_since(entry.created_at);
if age > REASSEMBLY_TTL {
tracing::warn!(
"evicting stale reassembly entry: msg_id={} \
(age={:.1}s, {}/{} fragments received)",
msg_id,
age.as_secs_f64(),
entry.fragments.len(),
entry.total,
);
false
} else {
true
}
});
let evicted = before - self.pending.len();
if evicted > 0 {
tracing::info!(
"evicted {} stale reassembly entries ({} remaining)",
evicted,
self.pending.len()
);
}
}
}
#[inline]
fn encode_fragment_header(buf: &mut Vec<u8>, msg_id: u32, frag_index: u16, total_frags: u16) {
buf.extend_from_slice(&msg_id.to_be_bytes());
buf.extend_from_slice(&frag_index.to_be_bytes());
buf.extend_from_slice(&total_frags.to_be_bytes());
}
#[inline]
fn decode_fragment_header(raw: bytes::Bytes) -> NetworkResult<(u32, u16, u16, bytes::Bytes)> {
if raw.len() < FRAGMENT_HEADER_SIZE {
return Err(NetworkError::DataChannelError(format!(
"fragment too short: {} bytes (minimum {})",
raw.len(),
FRAGMENT_HEADER_SIZE
)));
}
let msg_id = u32::from_be_bytes(raw[0..4].try_into().unwrap());
let frag_index = u16::from_be_bytes(raw[4..6].try_into().unwrap());
let total_frags = u16::from_be_bytes(raw[6..8].try_into().unwrap());
let payload = raw.slice(FRAGMENT_HEADER_SIZE..);
Ok((msg_id, frag_index, total_frags, payload))
}
pub(crate) type WsSink =
Arc<Mutex<Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, WsMessage>>>>;
#[async_trait]
pub trait DataLane: Send + Sync + std::fmt::Debug {
async fn send(&self, _data: bytes::Bytes) -> NetworkResult<()> {
Err(NetworkError::InvalidOperation(
"send(bytes) not supported on this lane type".to_string(),
))
}
async fn recv(&self) -> NetworkResult<bytes::Bytes> {
Err(NetworkError::InvalidOperation(
"recv() not supported on this lane type".to_string(),
))
}
#[allow(dead_code)]
async fn try_recv(&self) -> NetworkResult<Option<bytes::Bytes>> {
Err(NetworkError::InvalidOperation(
"try_recv() not supported on this lane type".to_string(),
))
}
async fn send_envelope(&self, _envelope: actr_protocol::RpcEnvelope) -> NetworkResult<()> {
Err(NetworkError::InvalidOperation(
"send_envelope() not supported on this lane type".to_string(),
))
}
async fn recv_envelope(&self) -> NetworkResult<actr_protocol::RpcEnvelope> {
Err(NetworkError::InvalidOperation(
"recv_envelope() not supported on this lane type".to_string(),
))
}
#[allow(dead_code)]
fn lane_type(&self) -> &'static str;
fn is_healthy(&self) -> bool {
true
}
}
#[derive(Clone, Debug)]
pub(crate) struct MpscLane {
#[allow(dead_code)]
payload_type: PayloadType,
tx: mpsc::Sender<actr_protocol::RpcEnvelope>,
rx: Arc<Mutex<mpsc::Receiver<actr_protocol::RpcEnvelope>>>,
}
impl MpscLane {
#[cfg(test)]
#[inline]
pub(crate) fn new(
payload_type: PayloadType,
tx: mpsc::Sender<actr_protocol::RpcEnvelope>,
rx: mpsc::Receiver<actr_protocol::RpcEnvelope>,
) -> Self {
Self {
payload_type,
tx,
rx: Arc::new(Mutex::new(rx)),
}
}
#[inline]
pub(crate) fn new_shared(
payload_type: PayloadType,
tx: mpsc::Sender<actr_protocol::RpcEnvelope>,
rx: Arc<Mutex<mpsc::Receiver<actr_protocol::RpcEnvelope>>>,
) -> Self {
Self {
payload_type,
tx,
rx,
}
}
}
#[async_trait]
impl DataLane for MpscLane {
#[cfg_attr(
feature = "opentelemetry",
tracing::instrument(skip_all, name = "MpscLane.send_envelope")
)]
async fn send_envelope(&self, envelope: actr_protocol::RpcEnvelope) -> NetworkResult<()> {
self.tx
.send(envelope)
.await
.map_err(|_| NetworkError::ChannelClosed("Mpsc channel closed".to_string()))?;
tracing::trace!("Mpsc sent RpcEnvelope");
Ok(())
}
async fn recv_envelope(&self) -> NetworkResult<actr_protocol::RpcEnvelope> {
let mut receiver = self.rx.lock().await;
receiver
.recv()
.await
.ok_or_else(|| NetworkError::ChannelClosed("Mpsc channel closed".to_string()))
}
#[inline]
fn lane_type(&self) -> &'static str {
"Mpsc"
}
}
#[derive(Clone)]
pub(crate) struct WebRtcDataLane {
pub(crate) data_channel: Arc<RTCDataChannel>,
rx: Arc<Mutex<mpsc::Receiver<bytes::Bytes>>>,
msg_id_counter: Arc<AtomicU32>,
reassembly: Arc<Mutex<ReassemblyBuffer>>,
}
impl std::fmt::Debug for WebRtcDataLane {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "WebRtcDataLane(..)")
}
}
impl WebRtcDataLane {
#[inline]
pub(crate) fn new(data_channel: Arc<RTCDataChannel>, rx: mpsc::Receiver<bytes::Bytes>) -> Self {
Self {
data_channel,
rx: Arc::new(Mutex::new(rx)),
msg_id_counter: Arc::new(AtomicU32::new(0)),
reassembly: Arc::new(Mutex::new(ReassemblyBuffer::new())),
}
}
}
#[async_trait]
impl DataLane for WebRtcDataLane {
async fn send(&self, data: bytes::Bytes) -> NetworkResult<()> {
use webrtc::data_channel::data_channel_state::RTCDataChannelState;
let start = tokio::time::Instant::now();
loop {
let state = self.data_channel.ready_state();
if state == RTCDataChannelState::Open {
break;
}
if state == RTCDataChannelState::Closed || state == RTCDataChannelState::Closing {
return Err(NetworkError::DataChannelError(format!(
"DataChannel closed: {state:?}"
)));
}
if start.elapsed() > INITIAL_CONNECTION_TIMEOUT {
return Err(NetworkError::DataChannelError(format!(
"DataChannel open timeout: {state:?}"
)));
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
let msg_id = self.msg_id_counter.fetch_add(1, Ordering::Relaxed);
let data_len = data.len();
if data_len <= DC_MAX_PAYLOAD_SIZE {
let mut buf = Vec::with_capacity(FRAGMENT_HEADER_SIZE + data_len);
encode_fragment_header(&mut buf, msg_id, 0, 1);
buf.extend_from_slice(&data);
let frame = bytes::Bytes::from(buf);
self.data_channel
.send(&frame)
.await
.map_err(|e| NetworkError::DataChannelError(format!("Send failed: {e}")))?;
#[cfg(feature = "test-utils")]
notify_webrtc_fragment_sent_for_test(WebRtcFragmentSendEvent {
msg_id,
frag_index: 0,
total_frags: 1,
fragment_payload_len: data_len,
message_len: data_len,
})
.await;
tracing::trace!(
"sent single fragment: msg_id={} payload={} bytes",
msg_id,
data_len
);
} else {
let total_frags = data_len.div_ceil(DC_MAX_PAYLOAD_SIZE);
if total_frags > u16::MAX as usize {
return Err(NetworkError::DataChannelError(format!(
"message too large: {data_len} bytes would require {total_frags} fragments (max {})",
u16::MAX
)));
}
let total_frags = total_frags as u16;
tracing::debug!(
"fragmenting message: msg_id={} total_bytes={} fragments={}",
msg_id,
data_len,
total_frags
);
for (frag_index, chunk) in data.chunks(DC_MAX_PAYLOAD_SIZE).enumerate() {
let mut buf = Vec::with_capacity(FRAGMENT_HEADER_SIZE + chunk.len());
encode_fragment_header(&mut buf, msg_id, frag_index as u16, total_frags);
buf.extend_from_slice(chunk);
let frame = bytes::Bytes::from(buf);
self.data_channel.send(&frame).await.map_err(|e| {
NetworkError::DataChannelError(format!(
"Send fragment {frag_index} failed: {e}"
))
})?;
#[cfg(feature = "test-utils")]
notify_webrtc_fragment_sent_for_test(WebRtcFragmentSendEvent {
msg_id,
frag_index: frag_index as u16,
total_frags,
fragment_payload_len: chunk.len(),
message_len: data_len,
})
.await;
tracing::debug!(
"sent fragment {}/{}: msg_id={} chunk={} bytes",
frag_index + 1,
total_frags,
msg_id,
chunk.len()
);
}
}
Ok(())
}
async fn recv(&self) -> NetworkResult<bytes::Bytes> {
loop {
let raw = {
let mut receiver = self.rx.lock().await;
receiver.recv().await.ok_or_else(|| {
NetworkError::ChannelClosed("DataLane receiver closed".to_string())
})?
};
let (msg_id, frag_index, total_frags, payload) = decode_fragment_header(raw)?;
if total_frags == 1 {
return Ok(payload);
}
let mut buf = self.reassembly.lock().await;
if let Some(complete) = buf.insert(msg_id, frag_index, total_frags, payload) {
tracing::debug!(
"reassembled message: msg_id={} total_bytes={}",
msg_id,
complete.len()
);
return Ok(complete);
}
}
}
async fn try_recv(&self) -> NetworkResult<Option<bytes::Bytes>> {
loop {
let raw = {
let mut receiver = self.rx.lock().await;
match receiver.try_recv() {
Ok(data) => data,
Err(mpsc::error::TryRecvError::Empty) => return Ok(None),
Err(mpsc::error::TryRecvError::Disconnected) => {
return Err(NetworkError::ChannelClosed(
"Lane receiver closed".to_string(),
));
}
}
};
let (msg_id, frag_index, total_frags, payload) = decode_fragment_header(raw)?;
if total_frags == 1 {
return Ok(Some(payload));
}
let mut buf = self.reassembly.lock().await;
if let Some(complete) = buf.insert(msg_id, frag_index, total_frags, payload) {
tracing::debug!(
"reassembled message (try_recv): msg_id={} total_bytes={}",
msg_id,
complete.len()
);
return Ok(Some(complete));
}
}
}
#[inline]
fn lane_type(&self) -> &'static str {
"WebRtcDataChannel"
}
fn is_healthy(&self) -> bool {
use webrtc::data_channel::data_channel_state::RTCDataChannelState;
let state = self.data_channel.ready_state();
!matches!(
state,
RTCDataChannelState::Closed | RTCDataChannelState::Closing
)
}
}
#[derive(Clone)]
pub(crate) struct WebSocketDataLane {
pub(crate) sink: WsSink,
payload_type: PayloadType,
rx: Arc<Mutex<mpsc::Receiver<bytes::Bytes>>>,
}
impl std::fmt::Debug for WebSocketDataLane {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "WebSocketDataLane(type={:?})", self.payload_type)
}
}
impl WebSocketDataLane {
#[inline]
pub(crate) fn new(
sink: WsSink,
payload_type: PayloadType,
rx: mpsc::Receiver<bytes::Bytes>,
) -> Self {
Self {
sink,
payload_type,
rx: Arc::new(Mutex::new(rx)),
}
}
}
#[async_trait]
impl DataLane for WebSocketDataLane {
async fn send(&self, data: bytes::Bytes) -> NetworkResult<()> {
let mut buf = Vec::with_capacity(5 + data.len());
buf.push(self.payload_type as u8);
let len = data.len() as u32;
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(&data);
let mut sink_opt = self.sink.lock().await;
if let Some(s) = sink_opt.as_mut() {
s.send(WsMessage::Binary(buf.into()))
.await
.map_err(|e| NetworkError::SendError(format!("WebSocket send failed: {e}")))?;
tracing::trace!(
"WebSocket sent {} bytes (type={:?})",
data.len(),
self.payload_type
);
Ok(())
} else {
Err(NetworkError::ConnectionError(
"WebSocket not connected".to_string(),
))
}
}
async fn recv(&self) -> NetworkResult<bytes::Bytes> {
let mut receiver = self.rx.lock().await;
receiver
.recv()
.await
.ok_or_else(|| NetworkError::ChannelClosed("DataLane receiver closed".to_string()))
}
async fn try_recv(&self) -> NetworkResult<Option<bytes::Bytes>> {
let mut receiver = self.rx.lock().await;
match receiver.try_recv() {
Ok(data) => Ok(Some(data)),
Err(mpsc::error::TryRecvError::Empty) => Ok(None),
Err(mpsc::error::TryRecvError::Disconnected) => Err(NetworkError::ChannelClosed(
"Lane receiver closed".to_string(),
)),
}
}
#[inline]
fn lane_type(&self) -> &'static str {
"WebSocket"
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
#[test]
fn test_encode_decode_fragment_header_single() {
let mut buf = Vec::new();
encode_fragment_header(&mut buf, 42, 0, 1);
assert_eq!(buf.len(), FRAGMENT_HEADER_SIZE);
let payload = Bytes::from(b"hello".as_slice().to_vec());
let mut raw = buf;
raw.extend_from_slice(&payload);
let (msg_id, frag_index, total_frags, decoded_payload) =
decode_fragment_header(Bytes::from(raw)).unwrap();
assert_eq!(msg_id, 42);
assert_eq!(frag_index, 0);
assert_eq!(total_frags, 1);
assert_eq!(decoded_payload, payload);
}
#[test]
fn test_encode_decode_fragment_header_multi() {
let mut buf = Vec::new();
encode_fragment_header(&mut buf, 0xDEAD_BEEF, 3, 7);
let (msg_id, frag_index, total_frags, _) =
decode_fragment_header(Bytes::from(buf)).unwrap();
assert_eq!(msg_id, 0xDEAD_BEEF);
assert_eq!(frag_index, 3);
assert_eq!(total_frags, 7);
}
#[test]
fn test_decode_too_short_returns_error() {
let short = Bytes::from_static(b"short");
assert!(decode_fragment_header(short).is_err());
}
#[test]
fn test_reassembly_single_fragment() {
let mut buf = ReassemblyBuffer::new();
let payload = Bytes::from_static(b"single");
let result = buf.insert(1, 0, 1, payload.clone());
assert_eq!(result, Some(payload));
}
#[test]
fn test_reassembly_two_fragments_in_order() {
let mut buf = ReassemblyBuffer::new();
let part0 = Bytes::from_static(b"hello ");
let part1 = Bytes::from_static(b"world");
assert!(buf.insert(5, 0, 2, part0).is_none());
let result = buf.insert(5, 1, 2, part1).unwrap();
assert_eq!(result, Bytes::from_static(b"hello world"));
}
#[test]
fn test_reassembly_two_fragments_out_of_order() {
let mut buf = ReassemblyBuffer::new();
let part0 = Bytes::from_static(b"hello ");
let part1 = Bytes::from_static(b"world");
assert!(buf.insert(7, 1, 2, part1).is_none());
let result = buf.insert(7, 0, 2, part0).unwrap();
assert_eq!(result, Bytes::from_static(b"hello world"));
}
#[test]
fn test_reassembly_multiple_messages_interleaved() {
let mut buf = ReassemblyBuffer::new();
assert!(buf.insert(1, 0, 2, Bytes::from_static(b"A1")).is_none());
assert!(buf.insert(2, 0, 2, Bytes::from_static(b"B1")).is_none());
assert!(buf.insert(1, 1, 2, Bytes::from_static(b"A2")).is_some());
let msg2 = buf.insert(2, 1, 2, Bytes::from_static(b"B2")).unwrap();
assert_eq!(msg2, Bytes::from_static(b"B1B2"));
}
#[test]
fn test_fragment_count_small_message() {
let size = DC_MAX_PAYLOAD_SIZE;
let count = size.div_ceil(DC_MAX_PAYLOAD_SIZE);
assert_eq!(
count, 1,
"message equal to payload size should be 1 fragment"
);
}
#[test]
fn test_fragment_count_one_byte_over() {
let size = DC_MAX_PAYLOAD_SIZE + 1;
let count = size.div_ceil(DC_MAX_PAYLOAD_SIZE);
assert_eq!(count, 2, "one byte over should require 2 fragments");
}
#[test]
fn test_fragment_count_200kb() {
let size: usize = 200 * 1024; let count = size.div_ceil(DC_MAX_PAYLOAD_SIZE);
assert_eq!(count, 4);
}
fn make_frame(msg_id: u32, frag_index: u16, total_frags: u16, payload: &[u8]) -> Bytes {
let mut buf = Vec::with_capacity(FRAGMENT_HEADER_SIZE + payload.len());
encode_fragment_header(&mut buf, msg_id, frag_index, total_frags);
buf.extend_from_slice(payload);
Bytes::from(buf)
}
fn recv_one(raw: Bytes, reassembly: &mut ReassemblyBuffer) -> Option<Bytes> {
let (msg_id, frag_index, total_frags, payload) = decode_fragment_header(raw).unwrap();
if total_frags == 1 {
return Some(payload);
}
reassembly.insert(msg_id, frag_index, total_frags, payload)
}
#[test]
fn test_roundtrip_small_message() {
let data = b"small message";
let frame = make_frame(0, 0, 1, data);
let mut buf = ReassemblyBuffer::new();
let result = recv_one(frame, &mut buf).unwrap();
assert_eq!(result.as_ref(), data);
}
#[test]
fn test_roundtrip_exactly_max_payload() {
let data = vec![0xABu8; DC_MAX_PAYLOAD_SIZE];
let frame = make_frame(1, 0, 1, &data);
let mut buf = ReassemblyBuffer::new();
let result = recv_one(frame, &mut buf).unwrap();
assert_eq!(result.as_ref(), data.as_slice());
}
#[test]
fn test_roundtrip_one_byte_over_max_payload() {
let data = vec![0xCDu8; DC_MAX_PAYLOAD_SIZE + 1];
let (part0, part1) = data.split_at(DC_MAX_PAYLOAD_SIZE);
let frame0 = make_frame(2, 0, 2, part0);
let frame1 = make_frame(2, 1, 2, part1);
let mut buf = ReassemblyBuffer::new();
assert!(recv_one(frame0, &mut buf).is_none());
let result = recv_one(frame1, &mut buf).unwrap();
assert_eq!(result.as_ref(), data.as_slice());
}
#[test]
fn test_roundtrip_200kb_message() {
let data: Vec<u8> = (0u8..=255).cycle().take(200 * 1024).collect();
let total_frags = data.len().div_ceil(DC_MAX_PAYLOAD_SIZE) as u16;
let mut buf = ReassemblyBuffer::new();
let mut result = None;
for (i, chunk) in data.chunks(DC_MAX_PAYLOAD_SIZE).enumerate() {
let frame = make_frame(99, i as u16, total_frags, chunk);
result = recv_one(frame, &mut buf);
}
let result = result.unwrap();
assert_eq!(result.as_ref(), data.as_slice());
}
#[tokio::test]
async fn test_mpsc_lane() {
use actr_protocol::RpcEnvelope;
let (tx, rx) = mpsc::channel(10);
let lane = MpscLane::new(PayloadType::RpcReliable, tx.clone(), rx);
let envelope = RpcEnvelope {
request_id: "test-1".to_string(),
route_key: "test.route".to_string(),
payload: Some(Bytes::from_static(b"hello")),
traceparent: None,
tracestate: None,
metadata: vec![],
timeout_ms: 30000,
error: None,
};
lane.send_envelope(envelope.clone()).await.unwrap();
let received = lane.recv_envelope().await.unwrap();
assert_eq!(received.request_id, "test-1");
assert_eq!(received.payload, Some(Bytes::from_static(b"hello")));
}
#[tokio::test]
async fn test_mpsc_lane_clone() {
use actr_protocol::RpcEnvelope;
let (tx, rx) = mpsc::channel(10);
let lane = MpscLane::new(PayloadType::RpcReliable, tx.clone(), rx);
let lane2 = lane.clone();
let envelope = RpcEnvelope {
request_id: "test-2".to_string(),
route_key: "test.route".to_string(),
payload: Some(Bytes::from_static(b"test")),
traceparent: None,
tracestate: None,
metadata: vec![],
timeout_ms: 30000,
error: None,
};
lane.send_envelope(envelope.clone()).await.unwrap();
let received = lane2.recv_envelope().await.unwrap();
assert_eq!(received.request_id, "test-2");
assert_eq!(received.payload, Some(Bytes::from_static(b"test")));
}
#[tokio::test]
async fn test_mpsc_lane_with_shared_rx() {
use actr_protocol::RpcEnvelope;
let (tx, rx) = mpsc::channel(10);
let rx_shared = Arc::new(Mutex::new(rx));
let lane = MpscLane::new_shared(PayloadType::RpcReliable, tx.clone(), rx_shared.clone());
let envelope = RpcEnvelope {
request_id: "test-3".to_string(),
route_key: "test.route".to_string(),
payload: Some(Bytes::from_static(b"shared")),
traceparent: None,
tracestate: None,
metadata: vec![],
timeout_ms: 30000,
error: None,
};
lane.send_envelope(envelope.clone()).await.unwrap();
let received = lane.recv_envelope().await.unwrap();
assert_eq!(received.request_id, "test-3");
assert_eq!(received.payload, Some(Bytes::from_static(b"shared")));
}
#[test]
fn test_lane_type_name() {
let (tx, rx) = mpsc::channel(10);
let lane = MpscLane::new(PayloadType::RpcReliable, tx, rx);
assert_eq!(lane.lane_type(), "Mpsc");
}
}