use super::{DataLane, MpscLane, NetworkError, NetworkResult};
use actr_framework::Bytes;
use actr_protocol::{ActrError, PayloadType, RpcEnvelope};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
pub struct HostTransport {
reliable_tx: mpsc::Sender<RpcEnvelope>,
reliable_rx: Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
signal_channel: Arc<Mutex<Option<ChannelPair>>>,
latency_first_channels: Arc<RwLock<HashMap<String, ChannelPair>>>,
media_track_channels: Arc<RwLock<HashMap<String, ChannelPair>>>,
lane_cache: Arc<RwLock<HashMap<LaneKey, Arc<dyn DataLane>>>>,
pending_requests:
Arc<RwLock<HashMap<String, oneshot::Sender<actr_protocol::ActorResult<Bytes>>>>>,
}
#[derive(Clone)]
struct ChannelPair {
tx: mpsc::Sender<RpcEnvelope>,
rx: Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
}
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
struct LaneKey {
payload_type: PayloadType,
identifier: Option<String>,
}
impl Default for HostTransport {
fn default() -> Self {
Self::new()
}
}
impl HostTransport {
pub fn new() -> Self {
let (reliable_tx, reliable_rx) = mpsc::channel(1024);
tracing::debug!("Created HostTransport");
tracing::debug!("Created Reliable channel");
Self {
reliable_tx,
reliable_rx: Arc::new(Mutex::new(reliable_rx)),
signal_channel: Arc::new(Mutex::new(None)),
latency_first_channels: Arc::new(RwLock::new(HashMap::new())),
media_track_channels: Arc::new(RwLock::new(HashMap::new())),
lane_cache: Arc::new(RwLock::new(HashMap::new())),
pending_requests: Arc::new(RwLock::new(HashMap::new())),
}
}
async fn ensure_signal_channel(&self) -> ChannelPair {
let mut opt = self.signal_channel.lock().await;
if opt.is_none() {
let (tx, rx) = mpsc::channel(1024);
*opt = Some(ChannelPair {
tx,
rx: Arc::new(Mutex::new(rx)),
});
tracing::debug!("Created Signal channel");
}
opt.as_ref()
.expect("Signal channel must exist after ensure_signal_channel")
.clone()
}
#[cfg(feature = "test-utils")]
pub async fn create_latency_first_channel(
&self,
channel_id: String,
) -> Arc<Mutex<mpsc::Receiver<RpcEnvelope>>> {
let mut channels = self.latency_first_channels.write().await;
if !channels.contains_key(&channel_id) {
let (tx, rx) = mpsc::channel(1024);
let pair = ChannelPair {
tx,
rx: Arc::new(Mutex::new(rx)),
};
let rx_clone = pair.rx.clone();
channels.insert(channel_id.clone(), pair);
tracing::debug!("Created LatencyFirst channel '{}'", channel_id);
rx_clone
} else {
channels
.get(&channel_id)
.expect("LatencyFirst channel must exist after contains_key check")
.rx
.clone()
}
}
#[cfg(feature = "test-utils")]
pub async fn create_media_track_channel(
&self,
track_id: String,
) -> Arc<Mutex<mpsc::Receiver<RpcEnvelope>>> {
let mut channels = self.media_track_channels.write().await;
if !channels.contains_key(&track_id) {
let (tx, rx) = mpsc::channel(1024);
let pair = ChannelPair {
tx,
rx: Arc::new(Mutex::new(rx)),
};
let rx_clone = pair.rx.clone();
channels.insert(track_id.clone(), pair);
tracing::debug!("Created MediaTrack channel '{}'", track_id);
rx_clone
} else {
channels
.get(&track_id)
.expect("MediaTrack channel must exist after contains_key check")
.rx
.clone()
}
}
pub async fn get_lane(
&self,
payload_type: PayloadType,
identifier: Option<String>,
) -> NetworkResult<Arc<dyn DataLane>> {
let key = LaneKey {
payload_type,
identifier: identifier.clone(),
};
{
let cache = self.lane_cache.read().await;
if let Some(lane) = cache.get(&key) {
tracing::debug!("Reusing cached Inproc DataLane: {:?}", key);
return Ok(lane.clone());
}
}
let pair = match payload_type {
PayloadType::RpcReliable => ChannelPair {
tx: self.reliable_tx.clone(),
rx: self.reliable_rx.clone(),
},
PayloadType::RpcSignal => self.ensure_signal_channel().await,
PayloadType::StreamReliable | PayloadType::StreamLatencyFirst => {
let channel_id = identifier
.as_ref()
.ok_or_else(|| {
NetworkError::InvalidArgument("DataStream requires channel_id".into())
})?
.clone();
let channels = self.latency_first_channels.read().await;
channels
.get(&channel_id)
.ok_or_else(|| NetworkError::ChannelNotFound(channel_id))?
.clone()
}
PayloadType::MediaRtp => {
let track_id = identifier
.as_ref()
.ok_or_else(|| {
NetworkError::InvalidArgument("MediaRtp requires track_id".into())
})?
.clone();
let channels = self.media_track_channels.read().await;
channels
.get(&track_id)
.ok_or_else(|| NetworkError::ChannelNotFound(track_id))?
.clone()
}
};
let lane: Arc<dyn DataLane> =
Arc::new(MpscLane::new_shared(payload_type, pair.tx, pair.rx));
self.lane_cache.write().await.insert(key, lane.clone());
tracing::debug!(
"Created Inproc DataLane: type={:?}, identifier={:?}",
payload_type,
identifier
);
Ok(lane)
}
#[cfg_attr(
feature = "opentelemetry",
tracing::instrument(skip_all, name = "HostTransport.send_request")
)]
pub async fn send_request(
&self,
payload_type: PayloadType,
identifier: Option<String>,
envelope: RpcEnvelope,
) -> NetworkResult<Bytes> {
let (response_tx, response_rx) = oneshot::channel();
let request_id = envelope.request_id.clone();
let timeout_ms = envelope.timeout_ms;
self.pending_requests
.write()
.await
.insert(request_id, response_tx);
let lane = self.get_lane(payload_type, identifier).await?;
lane.send_envelope(envelope).await?;
let timeout_duration = Duration::from_millis(timeout_ms as u64);
let result = tokio::time::timeout(timeout_duration, response_rx)
.await
.map_err(|_| NetworkError::TimeoutError(format!("Request timeout: {}ms", timeout_ms)))?
.map_err(|_| NetworkError::ConnectionError("Response channel closed".into()))?;
result.map_err(|e| NetworkError::Other(anyhow::anyhow!("{e}")))
}
#[cfg_attr(
feature = "opentelemetry",
tracing::instrument(skip_all, name = "HostTransport.send_message")
)]
pub async fn send_message(
&self,
payload_type: PayloadType,
identifier: Option<String>,
envelope: RpcEnvelope,
) -> NetworkResult<()> {
let lane = self.get_lane(payload_type, identifier).await?;
lane.send_envelope(envelope).await
}
#[cfg(feature = "test-utils")]
pub async fn recv(&self) -> Option<RpcEnvelope> {
loop {
tokio::select! {
biased;
msg = Self::recv_from_channel_opt(&self.signal_channel) => {
if let Some(envelope) = msg {
if !self.try_complete_response(&envelope).await {
return Some(envelope); }
}
}
msg = Self::recv_from_channel(&self.reliable_rx) => {
if let Some(envelope) = msg {
if !self.try_complete_response(&envelope).await {
return Some(envelope);
}
}
}
}
}
}
pub async fn complete_response(
&self,
request_id: &str,
response_bytes: Bytes,
) -> NetworkResult<()> {
let mut pending = self.pending_requests.write().await;
if let Some(tx) = pending.remove(request_id) {
let _ = tx.send(Ok(response_bytes));
tracing::debug!("Completed pending request: {}", request_id);
Ok(())
} else {
Err(NetworkError::InvalidArgument(format!(
"No pending request found for id: {request_id}"
)))
}
}
pub async fn complete_error(&self, request_id: &str, error: ActrError) -> NetworkResult<()> {
let mut pending = self.pending_requests.write().await;
if let Some(tx) = pending.remove(request_id) {
let _ = tx.send(Err(error));
tracing::debug!("Completed pending request with error: {}", request_id);
Ok(())
} else {
Err(NetworkError::InvalidArgument(format!(
"No pending request found for id: {request_id}"
)))
}
}
#[cfg(feature = "test-utils")]
async fn try_complete_response(&self, envelope: &RpcEnvelope) -> bool {
let mut pending = self.pending_requests.write().await;
if let Some(tx) = pending.remove(&envelope.request_id) {
match (&envelope.payload, &envelope.error) {
(Some(payload), None) => {
let _ = tx.send(Ok(payload.clone()));
tracing::debug!("Completed pending request: {}", envelope.request_id);
}
(None, Some(error)) => {
let protocol_err = ActrError::Unavailable(format!(
"RPC error {}: {}",
error.code, error.message
));
let _ = tx.send(Err(protocol_err));
tracing::debug!(
"Completed pending request with error: {}",
envelope.request_id
);
}
_ => {
tracing::error!(
"Invalid RpcEnvelope: both payload and error present or both absent"
);
let _ = tx.send(Err(ActrError::DecodeFailure(
"Invalid RpcEnvelope: payload and error fields inconsistent".to_string(),
)));
}
}
true
} else {
false
}
}
#[cfg(feature = "test-utils")]
async fn recv_from_channel(
rx: &Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
) -> Option<RpcEnvelope> {
rx.lock().await.recv().await
}
#[cfg(feature = "test-utils")]
async fn recv_from_channel_opt(opt: &Arc<Mutex<Option<ChannelPair>>>) -> Option<RpcEnvelope> {
let rx = {
let guard = opt.lock().await;
guard.as_ref().map(|pair| pair.rx.clone())
};
if let Some(rx) = rx {
rx.lock().await.recv().await
} else {
std::future::pending().await }
}
}