1use super::{DataLane, NetworkError, NetworkResult};
56use actr_framework::Bytes;
57use actr_protocol::{PayloadType, RpcEnvelope};
58use std::collections::HashMap;
59use std::sync::Arc;
60use std::time::Duration;
61use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
62
63pub struct InprocTransportManager {
71 reliable_tx: mpsc::Sender<RpcEnvelope>,
74 reliable_rx: Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
75
76 signal_channel: Arc<Mutex<Option<ChannelPair>>>,
79
80 latency_first_channels: Arc<RwLock<HashMap<String, ChannelPair>>>,
82
83 media_track_channels: Arc<RwLock<HashMap<String, ChannelPair>>>,
85
86 lane_cache: Arc<RwLock<HashMap<LaneKey, DataLane>>>,
89
90 pending_requests:
93 Arc<RwLock<HashMap<String, oneshot::Sender<actr_protocol::ActorResult<Bytes>>>>>,
94}
95
96#[derive(Clone)]
98struct ChannelPair {
99 tx: mpsc::Sender<RpcEnvelope>,
100 rx: Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
101}
102
103#[derive(Hash, Eq, PartialEq, Clone, Debug)]
105struct LaneKey {
106 payload_type: PayloadType,
107 identifier: Option<String>,
109}
110
111impl Default for InprocTransportManager {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117impl InprocTransportManager {
118 pub fn new() -> Self {
123 let (reliable_tx, reliable_rx) = mpsc::channel(1024);
124
125 tracing::debug!("Created InprocTransportManager");
126 tracing::debug!("✨ Created Reliable channel");
127
128 Self {
129 reliable_tx,
130 reliable_rx: Arc::new(Mutex::new(reliable_rx)),
131 signal_channel: Arc::new(Mutex::new(None)),
132 latency_first_channels: Arc::new(RwLock::new(HashMap::new())),
133 media_track_channels: Arc::new(RwLock::new(HashMap::new())),
134 lane_cache: Arc::new(RwLock::new(HashMap::new())),
135 pending_requests: Arc::new(RwLock::new(HashMap::new())),
136 }
137 }
138
139 async fn ensure_signal_channel(&self) -> ChannelPair {
143 let mut opt = self.signal_channel.lock().await;
144 if opt.is_none() {
145 let (tx, rx) = mpsc::channel(1024);
146 *opt = Some(ChannelPair {
147 tx,
148 rx: Arc::new(Mutex::new(rx)),
149 });
150 tracing::debug!("✨ Created Signal channel");
151 }
152 opt.as_ref()
154 .expect("Signal channel must exist after ensure_signal_channel")
155 .clone()
156 }
157
158 pub async fn create_latency_first_channel(
160 &self,
161 channel_id: String,
162 ) -> Arc<Mutex<mpsc::Receiver<RpcEnvelope>>> {
163 let mut channels = self.latency_first_channels.write().await;
164
165 if !channels.contains_key(&channel_id) {
166 let (tx, rx) = mpsc::channel(1024);
167 let pair = ChannelPair {
168 tx,
169 rx: Arc::new(Mutex::new(rx)),
170 };
171 let rx_clone = pair.rx.clone();
172 channels.insert(channel_id.clone(), pair);
173
174 tracing::debug!("✨ Created LatencyFirst channel '{}'", channel_id);
175 rx_clone
176 } else {
177 channels
179 .get(&channel_id)
180 .expect("LatencyFirst channel must exist after contains_key check")
181 .rx
182 .clone()
183 }
184 }
185
186 pub async fn create_media_track_channel(
188 &self,
189 track_id: String,
190 ) -> Arc<Mutex<mpsc::Receiver<RpcEnvelope>>> {
191 let mut channels = self.media_track_channels.write().await;
192
193 if !channels.contains_key(&track_id) {
194 let (tx, rx) = mpsc::channel(1024);
195 let pair = ChannelPair {
196 tx,
197 rx: Arc::new(Mutex::new(rx)),
198 };
199 let rx_clone = pair.rx.clone();
200 channels.insert(track_id.clone(), pair);
201
202 tracing::debug!("✨ Created MediaTrack channel '{}'", track_id);
203 rx_clone
204 } else {
205 channels
207 .get(&track_id)
208 .expect("MediaTrack channel must exist after contains_key check")
209 .rx
210 .clone()
211 }
212 }
213
214 pub async fn get_lane(
225 &self,
226 payload_type: PayloadType,
227 identifier: Option<String>,
228 ) -> NetworkResult<DataLane> {
229 let key = LaneKey {
230 payload_type,
231 identifier: identifier.clone(),
232 };
233
234 {
236 let cache = self.lane_cache.read().await;
237 if let Some(lane) = cache.get(&key) {
238 tracing::debug!("📦 Reusing cached Inproc DataLane: {:?}", key);
239 return Ok(lane.clone());
240 }
241 }
242
243 let pair = match payload_type {
245 PayloadType::RpcReliable => ChannelPair {
246 tx: self.reliable_tx.clone(),
247 rx: self.reliable_rx.clone(),
248 },
249
250 PayloadType::RpcSignal => self.ensure_signal_channel().await,
251
252 PayloadType::StreamReliable | PayloadType::StreamLatencyFirst => {
253 let channel_id = identifier
254 .as_ref()
255 .ok_or_else(|| {
256 NetworkError::InvalidArgument("DataStream requires channel_id".into())
257 })?
258 .clone();
259
260 let channels = self.latency_first_channels.read().await;
261 channels
262 .get(&channel_id)
263 .ok_or_else(|| NetworkError::ChannelNotFound(channel_id))?
264 .clone()
265 }
266
267 PayloadType::MediaRtp => {
268 let track_id = identifier
269 .as_ref()
270 .ok_or_else(|| {
271 NetworkError::InvalidArgument("MediaRtp requires track_id".into())
272 })?
273 .clone();
274
275 let channels = self.media_track_channels.read().await;
276 channels
277 .get(&track_id)
278 .ok_or_else(|| NetworkError::ChannelNotFound(track_id))?
279 .clone()
280 }
281 };
282
283 let lane = DataLane::mpsc_shared(payload_type, pair.tx, pair.rx);
285
286 self.lane_cache.write().await.insert(key, lane.clone());
288
289 tracing::debug!(
290 "✨ Created Inproc DataLane: type={:?}, identifier={:?}",
291 payload_type,
292 identifier
293 );
294
295 Ok(lane)
296 }
297
298 #[cfg_attr(
302 feature = "opentelemetry",
303 tracing::instrument(skip_all, name = "InprocTransportManager.send_request")
304 )]
305 pub async fn send_request(
306 &self,
307 payload_type: PayloadType,
308 identifier: Option<String>,
309 envelope: RpcEnvelope,
310 ) -> NetworkResult<Bytes> {
311 let (response_tx, response_rx) = oneshot::channel();
312
313 self.pending_requests
315 .write()
316 .await
317 .insert(envelope.request_id.clone(), response_tx);
318
319 let lane = self.get_lane(payload_type, identifier).await?;
321 lane.send_envelope(envelope).await?;
322
323 let result = tokio::time::timeout(Duration::from_secs(30), response_rx)
325 .await
326 .map_err(|_| NetworkError::TimeoutError("Request timeout".into()))?
327 .map_err(|_| NetworkError::ConnectionError("Response channel closed".into()))?;
328
329 result.map_err(|e| NetworkError::ProtocolError(e.to_string()))
331 }
332
333 #[cfg_attr(
335 feature = "opentelemetry",
336 tracing::instrument(skip_all, name = "InprocTransportManager.send_message")
337 )]
338 pub async fn send_message(
339 &self,
340 payload_type: PayloadType,
341 identifier: Option<String>,
342 envelope: RpcEnvelope,
343 ) -> NetworkResult<()> {
344 let lane = self.get_lane(payload_type, identifier).await?;
345 lane.send_envelope(envelope).await
346 }
347
348 pub async fn recv(&self) -> Option<RpcEnvelope> {
354 loop {
355 tokio::select! {
356 biased;
357
358 msg = Self::recv_from_channel_opt(&self.signal_channel) => {
360 if let Some(envelope) = msg {
361 if !self.try_complete_response(&envelope).await {
362 return Some(envelope); }
364 }
366 }
367
368 msg = Self::recv_from_channel(&self.reliable_rx) => {
370 if let Some(envelope) = msg {
371 if !self.try_complete_response(&envelope).await {
372 return Some(envelope);
373 }
374 }
375 }
376
377 }
380 }
381 }
382
383 pub async fn complete_response(
393 &self,
394 request_id: &str,
395 response_bytes: Bytes,
396 ) -> NetworkResult<()> {
397 let mut pending = self.pending_requests.write().await;
398 if let Some(tx) = pending.remove(request_id) {
399 let _ = tx.send(Ok(response_bytes));
400 tracing::debug!("✅ Completed pending request: {}", request_id);
401 Ok(())
402 } else {
403 Err(NetworkError::InvalidArgument(format!(
404 "No pending request found for id: {request_id}"
405 )))
406 }
407 }
408
409 pub async fn complete_error(
415 &self,
416 request_id: &str,
417 error: actr_protocol::ProtocolError,
418 ) -> NetworkResult<()> {
419 let mut pending = self.pending_requests.write().await;
420 if let Some(tx) = pending.remove(request_id) {
421 let _ = tx.send(Err(error));
422 tracing::debug!("✅ Completed pending request with error: {}", request_id);
423 Ok(())
424 } else {
425 Err(NetworkError::InvalidArgument(format!(
426 "No pending request found for id: {request_id}"
427 )))
428 }
429 }
430
431 async fn try_complete_response(&self, envelope: &RpcEnvelope) -> bool {
433 let mut pending = self.pending_requests.write().await;
434 if let Some(tx) = pending.remove(&envelope.request_id) {
435 match (&envelope.payload, &envelope.error) {
437 (Some(payload), None) => {
438 let _ = tx.send(Ok(payload.clone()));
439 tracing::debug!("✅ Completed pending request: {}", envelope.request_id);
440 }
441 (None, Some(error)) => {
442 let protocol_err = actr_protocol::ProtocolError::TransportError(format!(
443 "RPC error {}: {}",
444 error.code, error.message
445 ));
446 let _ = tx.send(Err(protocol_err));
447 tracing::debug!(
448 "✅ Completed pending request with error: {}",
449 envelope.request_id
450 );
451 }
452 _ => {
453 tracing::error!(
454 "❌ Invalid RpcEnvelope: both payload and error present or both absent"
455 );
456 let _ = tx.send(Err(actr_protocol::ProtocolError::DecodeError(
457 "Invalid RpcEnvelope: payload and error fields inconsistent".to_string(),
458 )));
459 }
460 }
461 true
462 } else {
463 false
464 }
465 }
466
467 async fn recv_from_channel(
470 rx: &Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
471 ) -> Option<RpcEnvelope> {
472 rx.lock().await.recv().await
473 }
474
475 async fn recv_from_channel_opt(opt: &Arc<Mutex<Option<ChannelPair>>>) -> Option<RpcEnvelope> {
476 let rx = {
477 let guard = opt.lock().await;
478 guard.as_ref().map(|pair| pair.rx.clone())
479 };
480
481 if let Some(rx) = rx {
482 rx.lock().await.recv().await
483 } else {
484 std::future::pending().await }
486 }
487}