1use super::{DataLane, NetworkError, NetworkResult};
56use actr_framework::Bytes;
57use actr_protocol::{PayloadType, RpcEnvelope};
58use std::collections::HashMap;
59use std::sync::Arc;
60use std::time::Duration;
61use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
62
63pub struct InprocTransportManager {
71 reliable_tx: mpsc::Sender<RpcEnvelope>,
74 reliable_rx: Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
75
76 signal_channel: Arc<Mutex<Option<ChannelPair>>>,
79
80 latency_first_channels: Arc<RwLock<HashMap<String, ChannelPair>>>,
82
83 media_track_channels: Arc<RwLock<HashMap<String, ChannelPair>>>,
85
86 lane_cache: Arc<RwLock<HashMap<LaneKey, DataLane>>>,
89
90 pending_requests:
93 Arc<RwLock<HashMap<String, oneshot::Sender<actr_protocol::ActorResult<Bytes>>>>>,
94}
95
96#[derive(Clone)]
98struct ChannelPair {
99 tx: mpsc::Sender<RpcEnvelope>,
100 rx: Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
101}
102
103#[derive(Hash, Eq, PartialEq, Clone, Debug)]
105struct LaneKey {
106 payload_type: PayloadType,
107 identifier: Option<String>,
109}
110
111impl Default for InprocTransportManager {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117impl InprocTransportManager {
118 pub fn new() -> Self {
123 let (reliable_tx, reliable_rx) = mpsc::channel(1024);
124
125 tracing::debug!("Created InprocTransportManager");
126 tracing::debug!("✨ Created Reliable channel");
127
128 Self {
129 reliable_tx,
130 reliable_rx: Arc::new(Mutex::new(reliable_rx)),
131 signal_channel: Arc::new(Mutex::new(None)),
132 latency_first_channels: Arc::new(RwLock::new(HashMap::new())),
133 media_track_channels: Arc::new(RwLock::new(HashMap::new())),
134 lane_cache: Arc::new(RwLock::new(HashMap::new())),
135 pending_requests: Arc::new(RwLock::new(HashMap::new())),
136 }
137 }
138
139 async fn ensure_signal_channel(&self) -> ChannelPair {
143 let mut opt = self.signal_channel.lock().await;
144 if opt.is_none() {
145 let (tx, rx) = mpsc::channel(1024);
146 *opt = Some(ChannelPair {
147 tx,
148 rx: Arc::new(Mutex::new(rx)),
149 });
150 tracing::debug!("✨ Created Signal channel");
151 }
152 opt.as_ref()
154 .expect("Signal channel must exist after ensure_signal_channel")
155 .clone()
156 }
157
158 pub async fn create_latency_first_channel(
160 &self,
161 channel_id: String,
162 ) -> Arc<Mutex<mpsc::Receiver<RpcEnvelope>>> {
163 let mut channels = self.latency_first_channels.write().await;
164
165 if !channels.contains_key(&channel_id) {
166 let (tx, rx) = mpsc::channel(1024);
167 let pair = ChannelPair {
168 tx,
169 rx: Arc::new(Mutex::new(rx)),
170 };
171 let rx_clone = pair.rx.clone();
172 channels.insert(channel_id.clone(), pair);
173
174 tracing::debug!("✨ Created LatencyFirst channel '{}'", channel_id);
175 rx_clone
176 } else {
177 channels
179 .get(&channel_id)
180 .expect("LatencyFirst channel must exist after contains_key check")
181 .rx
182 .clone()
183 }
184 }
185
186 pub async fn create_media_track_channel(
188 &self,
189 track_id: String,
190 ) -> Arc<Mutex<mpsc::Receiver<RpcEnvelope>>> {
191 let mut channels = self.media_track_channels.write().await;
192
193 if !channels.contains_key(&track_id) {
194 let (tx, rx) = mpsc::channel(1024);
195 let pair = ChannelPair {
196 tx,
197 rx: Arc::new(Mutex::new(rx)),
198 };
199 let rx_clone = pair.rx.clone();
200 channels.insert(track_id.clone(), pair);
201
202 tracing::debug!("✨ Created MediaTrack channel '{}'", track_id);
203 rx_clone
204 } else {
205 channels
207 .get(&track_id)
208 .expect("MediaTrack channel must exist after contains_key check")
209 .rx
210 .clone()
211 }
212 }
213
214 pub async fn get_lane(
225 &self,
226 payload_type: PayloadType,
227 identifier: Option<String>,
228 ) -> NetworkResult<DataLane> {
229 let key = LaneKey {
230 payload_type,
231 identifier: identifier.clone(),
232 };
233
234 {
236 let cache = self.lane_cache.read().await;
237 if let Some(lane) = cache.get(&key) {
238 tracing::debug!("📦 Reusing cached Inproc DataLane: {:?}", key);
239 return Ok(lane.clone());
240 }
241 }
242
243 let pair = match payload_type {
245 PayloadType::RpcReliable => ChannelPair {
246 tx: self.reliable_tx.clone(),
247 rx: self.reliable_rx.clone(),
248 },
249
250 PayloadType::RpcSignal => self.ensure_signal_channel().await,
251
252 PayloadType::StreamReliable | PayloadType::StreamLatencyFirst => {
253 let channel_id = identifier
254 .as_ref()
255 .ok_or_else(|| {
256 NetworkError::InvalidArgument("DataStream requires channel_id".into())
257 })?
258 .clone();
259
260 let channels = self.latency_first_channels.read().await;
261 channels
262 .get(&channel_id)
263 .ok_or_else(|| NetworkError::ChannelNotFound(channel_id))?
264 .clone()
265 }
266
267 PayloadType::MediaRtp => {
268 let track_id = identifier
269 .as_ref()
270 .ok_or_else(|| {
271 NetworkError::InvalidArgument("MediaRtp requires track_id".into())
272 })?
273 .clone();
274
275 let channels = self.media_track_channels.read().await;
276 channels
277 .get(&track_id)
278 .ok_or_else(|| NetworkError::ChannelNotFound(track_id))?
279 .clone()
280 }
281 };
282
283 let lane = DataLane::mpsc_shared(payload_type, pair.tx, pair.rx);
285
286 self.lane_cache.write().await.insert(key, lane.clone());
288
289 tracing::debug!(
290 "✨ Created Inproc DataLane: type={:?}, identifier={:?}",
291 payload_type,
292 identifier
293 );
294
295 Ok(lane)
296 }
297
298 pub async fn send_request(
302 &self,
303 payload_type: PayloadType,
304 identifier: Option<String>,
305 envelope: RpcEnvelope,
306 ) -> NetworkResult<Bytes> {
307 let (response_tx, response_rx) = oneshot::channel();
308
309 self.pending_requests
311 .write()
312 .await
313 .insert(envelope.request_id.clone(), response_tx);
314
315 let lane = self.get_lane(payload_type, identifier).await?;
317 lane.send_envelope(envelope).await?;
318
319 let result = tokio::time::timeout(Duration::from_secs(30), response_rx)
321 .await
322 .map_err(|_| NetworkError::TimeoutError("Request timeout".into()))?
323 .map_err(|_| NetworkError::ConnectionError("Response channel closed".into()))?;
324
325 result.map_err(|e| NetworkError::ProtocolError(e.to_string()))
327 }
328
329 pub async fn send_message(
331 &self,
332 payload_type: PayloadType,
333 identifier: Option<String>,
334 envelope: RpcEnvelope,
335 ) -> NetworkResult<()> {
336 let lane = self.get_lane(payload_type, identifier).await?;
337 lane.send_envelope(envelope).await
338 }
339
340 pub async fn recv(&self) -> Option<RpcEnvelope> {
346 loop {
347 tokio::select! {
348 biased;
349
350 msg = Self::recv_from_channel_opt(&self.signal_channel) => {
352 if let Some(envelope) = msg {
353 if !self.try_complete_response(&envelope).await {
354 return Some(envelope); }
356 }
358 }
359
360 msg = Self::recv_from_channel(&self.reliable_rx) => {
362 if let Some(envelope) = msg {
363 if !self.try_complete_response(&envelope).await {
364 return Some(envelope);
365 }
366 }
367 }
368
369 }
372 }
373 }
374
375 pub async fn complete_response(
385 &self,
386 request_id: &str,
387 response_bytes: Bytes,
388 ) -> NetworkResult<()> {
389 let mut pending = self.pending_requests.write().await;
390 if let Some(tx) = pending.remove(request_id) {
391 let _ = tx.send(Ok(response_bytes));
392 tracing::debug!("✅ Completed pending request: {}", request_id);
393 Ok(())
394 } else {
395 Err(NetworkError::InvalidArgument(format!(
396 "No pending request found for id: {request_id}"
397 )))
398 }
399 }
400
401 pub async fn complete_error(
407 &self,
408 request_id: &str,
409 error: actr_protocol::ProtocolError,
410 ) -> NetworkResult<()> {
411 let mut pending = self.pending_requests.write().await;
412 if let Some(tx) = pending.remove(request_id) {
413 let _ = tx.send(Err(error));
414 tracing::debug!("✅ Completed pending request with error: {}", request_id);
415 Ok(())
416 } else {
417 Err(NetworkError::InvalidArgument(format!(
418 "No pending request found for id: {request_id}"
419 )))
420 }
421 }
422
423 async fn try_complete_response(&self, envelope: &RpcEnvelope) -> bool {
425 let mut pending = self.pending_requests.write().await;
426 if let Some(tx) = pending.remove(&envelope.request_id) {
427 match (&envelope.payload, &envelope.error) {
429 (Some(payload), None) => {
430 let _ = tx.send(Ok(payload.clone()));
431 tracing::debug!("✅ Completed pending request: {}", envelope.request_id);
432 }
433 (None, Some(error)) => {
434 let protocol_err = actr_protocol::ProtocolError::TransportError(format!(
435 "RPC error {}: {}",
436 error.code, error.message
437 ));
438 let _ = tx.send(Err(protocol_err));
439 tracing::debug!(
440 "✅ Completed pending request with error: {}",
441 envelope.request_id
442 );
443 }
444 _ => {
445 tracing::error!(
446 "❌ Invalid RpcEnvelope: both payload and error present or both absent"
447 );
448 let _ = tx.send(Err(actr_protocol::ProtocolError::DecodeError(
449 "Invalid RpcEnvelope: payload and error fields inconsistent".to_string(),
450 )));
451 }
452 }
453 true
454 } else {
455 false
456 }
457 }
458
459 async fn recv_from_channel(
462 rx: &Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
463 ) -> Option<RpcEnvelope> {
464 rx.lock().await.recv().await
465 }
466
467 async fn recv_from_channel_opt(opt: &Arc<Mutex<Option<ChannelPair>>>) -> Option<RpcEnvelope> {
468 let rx = {
469 let guard = opt.lock().await;
470 guard.as_ref().map(|pair| pair.rx.clone())
471 };
472
473 if let Some(rx) = rx {
474 rx.lock().await.recv().await
475 } else {
476 std::future::pending().await }
478 }
479}