actr_runtime/transport/
inproc_manager.rs

1//! InprocTransportManager - Intra-process transport manager
2//!
3//! Manages mpsc channel communication between Workload and Shell
4//!
5//! # Usage Examples
6//!
7//! ## Workload Side (Subscribe to data streams)
8//!
9//! ```rust,ignore
10//! use actr_runtime::InprocTransportManager;
11//! use std::sync::Arc;
12//!
13//! struct MyWorkload {
14//!     inproc_mgr: Arc<InprocTransportManager>,
15//! }
16//!
17//! impl MyWorkload {
18//!     pub async fn subscribe_metrics_stream(&self) -> NetworkResult<()> {
19//!         // Create LatencyFirst channel
20//!         let rx = self.inproc_mgr
21//!             .create_latency_first_channel("metrics-stream".to_string())
22//!             .await;
23//!
24//!         // Start receive loop
25//!         tokio::spawn(async move {
26//!             loop {
27//!                 let mut receiver = rx.lock().await;
28//!                 if let Some(envelope) = receiver.recv().await {
29//!                     // Process streaming data
30//!                     println!("Received: {:?}", envelope);
31//!                 }
32//!             }
33//!         });
34//!
35//!         Ok(())
36//!     }
37//! }
38//! ```
39//!
40//! ## Shell Side (Send data)
41//!
42//! ```rust,ignore
43//! // Get InprocTransportManager from ActrNode
44//! if let Some(inproc_mgr) = node.inproc_mgr() {
45//!     // Send to LatencyFirst channel
46//!     let envelope = RpcEnvelope { /* ... */ };
47//!     inproc_mgr.send_message(
48//!         PayloadType::StreamLatencyFirst,
49//!         Some("metrics-stream".to_string()),
50//!         envelope
51//!     ).await?;
52//! }
53//! ```
54
55use super::{DataLane, NetworkError, NetworkResult};
56use actr_framework::Bytes;
57use actr_protocol::{PayloadType, RpcEnvelope};
58use std::collections::HashMap;
59use std::sync::Arc;
60use std::time::Duration;
61use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
62
63/// Inproc Transport Manager - manages intra-process transport (mpsc channels)
64///
65/// # Design Philosophy
66/// - **Workload ↔ Shell communication bridge** (not for arbitrary Actor-to-Actor communication)
67/// - **Reliable is mandatory, others are created on-demand**
68/// - **Dynamic multi-channel management**: HashMap<String, Channel>
69/// - **Bi-directional sharing**: Shell and Workload share the same Manager
70pub struct InprocTransportManager {
71    // ========== Mandatory base channel ==========
72    /// Reliable channel (must exist)
73    reliable_tx: mpsc::Sender<RpcEnvelope>,
74    reliable_rx: Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
75
76    // ========== Optional specialized channels ==========
77    /// Signal channel (optional, lazy creation)
78    signal_channel: Arc<Mutex<Option<ChannelPair>>>,
79
80    /// LatencyFirst channels (multi-instance, indexed by channel_id)
81    latency_first_channels: Arc<RwLock<HashMap<String, ChannelPair>>>,
82
83    /// MediaTrack channels (multi-instance, indexed by track_id)
84    media_track_channels: Arc<RwLock<HashMap<String, ChannelPair>>>,
85
86    // ========== Management data ==========
87    /// Lane cache (avoid repeated creation)
88    lane_cache: Arc<RwLock<HashMap<LaneKey, DataLane>>>,
89
90    /// Pending requests (request/response matching)
91    /// Sender can receive either success (Bytes) or error (ProtocolError)
92    pending_requests:
93        Arc<RwLock<HashMap<String, oneshot::Sender<actr_protocol::ActorResult<Bytes>>>>>,
94}
95
96/// Channel pair (tx + rx)
97#[derive(Clone)]
98struct ChannelPair {
99    tx: mpsc::Sender<RpcEnvelope>,
100    rx: Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
101}
102
103/// Lane cache key
104#[derive(Hash, Eq, PartialEq, Clone, Debug)]
105struct LaneKey {
106    payload_type: PayloadType,
107    /// channel_id (LatencyFirst) or track_id (MediaTrack)
108    identifier: Option<String>,
109}
110
111impl Default for InprocTransportManager {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117impl InprocTransportManager {
118    /// Create new instance (only creates Reliable channel, others are lazy-initialized)
119    ///
120    /// InprocTransportManager manages intra-process communication channels between Workload and Shell.
121    /// It does not need ActorId as all communication is within a single process.
122    pub fn new() -> Self {
123        let (reliable_tx, reliable_rx) = mpsc::channel(1024);
124
125        tracing::debug!("Created InprocTransportManager");
126        tracing::debug!("✨ Created Reliable channel");
127
128        Self {
129            reliable_tx,
130            reliable_rx: Arc::new(Mutex::new(reliable_rx)),
131            signal_channel: Arc::new(Mutex::new(None)),
132            latency_first_channels: Arc::new(RwLock::new(HashMap::new())),
133            media_track_channels: Arc::new(RwLock::new(HashMap::new())),
134            lane_cache: Arc::new(RwLock::new(HashMap::new())),
135            pending_requests: Arc::new(RwLock::new(HashMap::new())),
136        }
137    }
138
139    // ========== Dynamic creation APIs ==========
140
141    /// Ensure Signal channel exists
142    async fn ensure_signal_channel(&self) -> ChannelPair {
143        let mut opt = self.signal_channel.lock().await;
144        if opt.is_none() {
145            let (tx, rx) = mpsc::channel(1024);
146            *opt = Some(ChannelPair {
147                tx,
148                rx: Arc::new(Mutex::new(rx)),
149            });
150            tracing::debug!("✨ Created Signal channel");
151        }
152        // Safe: we just created it if it was None
153        opt.as_ref()
154            .expect("Signal channel must exist after ensure_signal_channel")
155            .clone()
156    }
157
158    /// Create LatencyFirst channel
159    pub async fn create_latency_first_channel(
160        &self,
161        channel_id: String,
162    ) -> Arc<Mutex<mpsc::Receiver<RpcEnvelope>>> {
163        let mut channels = self.latency_first_channels.write().await;
164
165        if !channels.contains_key(&channel_id) {
166            let (tx, rx) = mpsc::channel(1024);
167            let pair = ChannelPair {
168                tx,
169                rx: Arc::new(Mutex::new(rx)),
170            };
171            let rx_clone = pair.rx.clone();
172            channels.insert(channel_id.clone(), pair);
173
174            tracing::debug!("✨ Created LatencyFirst channel '{}'", channel_id);
175            rx_clone
176        } else {
177            // Safe: we just checked contains_key
178            channels
179                .get(&channel_id)
180                .expect("LatencyFirst channel must exist after contains_key check")
181                .rx
182                .clone()
183        }
184    }
185
186    /// Create MediaTrack channel
187    pub async fn create_media_track_channel(
188        &self,
189        track_id: String,
190    ) -> Arc<Mutex<mpsc::Receiver<RpcEnvelope>>> {
191        let mut channels = self.media_track_channels.write().await;
192
193        if !channels.contains_key(&track_id) {
194            let (tx, rx) = mpsc::channel(1024);
195            let pair = ChannelPair {
196                tx,
197                rx: Arc::new(Mutex::new(rx)),
198            };
199            let rx_clone = pair.rx.clone();
200            channels.insert(track_id.clone(), pair);
201
202            tracing::debug!("✨ Created MediaTrack channel '{}'", track_id);
203            rx_clone
204        } else {
205            // Safe: we just checked contains_key
206            channels
207                .get(&track_id)
208                .expect("MediaTrack channel must exist after contains_key check")
209                .rx
210                .clone()
211        }
212    }
213
214    // ========== Lane retrieval APIs ==========
215
216    /// Get Lane (with optional channel_id/track_id)
217    ///
218    /// # Arguments
219    /// - `payload_type`: PayloadType
220    /// - `identifier`:
221    ///   - `None` for Reliable/Signal
222    ///   - `Some(channel_id)` for LatencyFirst
223    ///   - `Some(track_id)` for MediaTrack
224    pub async fn get_lane(
225        &self,
226        payload_type: PayloadType,
227        identifier: Option<String>,
228    ) -> NetworkResult<DataLane> {
229        let key = LaneKey {
230            payload_type,
231            identifier: identifier.clone(),
232        };
233
234        // 1. Check cache
235        {
236            let cache = self.lane_cache.read().await;
237            if let Some(lane) = cache.get(&key) {
238                tracing::debug!("📦 Reusing cached Inproc DataLane: {:?}", key);
239                return Ok(lane.clone());
240            }
241        }
242
243        // 2. Get corresponding ChannelPair
244        let pair = match payload_type {
245            PayloadType::RpcReliable => ChannelPair {
246                tx: self.reliable_tx.clone(),
247                rx: self.reliable_rx.clone(),
248            },
249
250            PayloadType::RpcSignal => self.ensure_signal_channel().await,
251
252            PayloadType::StreamReliable | PayloadType::StreamLatencyFirst => {
253                let channel_id = identifier
254                    .as_ref()
255                    .ok_or_else(|| {
256                        NetworkError::InvalidArgument("DataStream requires channel_id".into())
257                    })?
258                    .clone();
259
260                let channels = self.latency_first_channels.read().await;
261                channels
262                    .get(&channel_id)
263                    .ok_or_else(|| NetworkError::ChannelNotFound(channel_id))?
264                    .clone()
265            }
266
267            PayloadType::MediaRtp => {
268                let track_id = identifier
269                    .as_ref()
270                    .ok_or_else(|| {
271                        NetworkError::InvalidArgument("MediaRtp requires track_id".into())
272                    })?
273                    .clone();
274
275                let channels = self.media_track_channels.read().await;
276                channels
277                    .get(&track_id)
278                    .ok_or_else(|| NetworkError::ChannelNotFound(track_id))?
279                    .clone()
280            }
281        };
282
283        // 3. Create DataLane
284        let lane = DataLane::mpsc_shared(payload_type, pair.tx, pair.rx);
285
286        // 4. Cache it
287        self.lane_cache.write().await.insert(key, lane.clone());
288
289        tracing::debug!(
290            "✨ Created Inproc DataLane: type={:?}, identifier={:?}",
291            payload_type,
292            identifier
293        );
294
295        Ok(lane)
296    }
297
298    // ========== High-level APIs ==========
299
300    /// Send request (with response waiting)
301    pub async fn send_request(
302        &self,
303        payload_type: PayloadType,
304        identifier: Option<String>,
305        envelope: RpcEnvelope,
306    ) -> NetworkResult<Bytes> {
307        let (response_tx, response_rx) = oneshot::channel();
308
309        // Register pending request
310        self.pending_requests
311            .write()
312            .await
313            .insert(envelope.request_id.clone(), response_tx);
314
315        // Send
316        let lane = self.get_lane(payload_type, identifier).await?;
317        lane.send_envelope(envelope).await?;
318
319        // Wait for response
320        let result = tokio::time::timeout(Duration::from_secs(30), response_rx)
321            .await
322            .map_err(|_| NetworkError::TimeoutError("Request timeout".into()))?
323            .map_err(|_| NetworkError::ConnectionError("Response channel closed".into()))?;
324
325        // result is ActorResult<Bytes>, convert to NetworkError if error
326        result.map_err(|e| NetworkError::ProtocolError(e.to_string()))
327    }
328
329    /// Send one-way message
330    pub async fn send_message(
331        &self,
332        payload_type: PayloadType,
333        identifier: Option<String>,
334        envelope: RpcEnvelope,
335    ) -> NetworkResult<()> {
336        let lane = self.get_lane(payload_type, identifier).await?;
337        lane.send_envelope(envelope).await
338    }
339
340    /// Receive one message (select first available from all channels)
341    ///
342    /// # Returns
343    /// - `Some(envelope)`: received message (response matching already handled)
344    /// - `None`: all channels closed
345    pub async fn recv(&self) -> Option<RpcEnvelope> {
346        loop {
347            tokio::select! {
348                biased;
349
350                // Signal (highest priority)
351                msg = Self::recv_from_channel_opt(&self.signal_channel) => {
352                    if let Some(envelope) = msg {
353                        if !self.try_complete_response(&envelope).await {
354                            return Some(envelope);  // It's a request
355                        }
356                        // It's a response, already handled, continue loop
357                    }
358                }
359
360                // Reliable
361                msg = Self::recv_from_channel(&self.reliable_rx) => {
362                    if let Some(envelope) = msg {
363                        if !self.try_complete_response(&envelope).await {
364                            return Some(envelope);
365                        }
366                    }
367                }
368
369                // TODO: LatencyFirst and MediaTrack reception
370                // Need to implement receiving from all channels in HashMap
371            }
372        }
373    }
374
375    /// Complete a pending request with response payload
376    ///
377    /// # Arguments
378    /// - `request_id`: The request ID to complete
379    /// - `response_bytes`: Response payload
380    ///
381    /// # Returns
382    /// - `Ok(())`: Successfully sent response to waiting sender
383    /// - `Err(NetworkError)`: No pending request found with this ID
384    pub async fn complete_response(
385        &self,
386        request_id: &str,
387        response_bytes: Bytes,
388    ) -> NetworkResult<()> {
389        let mut pending = self.pending_requests.write().await;
390        if let Some(tx) = pending.remove(request_id) {
391            let _ = tx.send(Ok(response_bytes));
392            tracing::debug!("✅ Completed pending request: {}", request_id);
393            Ok(())
394        } else {
395            Err(NetworkError::InvalidArgument(format!(
396                "No pending request found for id: {request_id}"
397            )))
398        }
399    }
400
401    /// Complete a pending request with an error
402    ///
403    /// # Returns
404    /// - `Ok(())`: Successfully sent error to waiting sender
405    /// - `Err(NetworkError)`: No pending request found with this ID
406    pub async fn complete_error(
407        &self,
408        request_id: &str,
409        error: actr_protocol::ProtocolError,
410    ) -> NetworkResult<()> {
411        let mut pending = self.pending_requests.write().await;
412        if let Some(tx) = pending.remove(request_id) {
413            let _ = tx.send(Err(error));
414            tracing::debug!("✅ Completed pending request with error: {}", request_id);
415            Ok(())
416        } else {
417            Err(NetworkError::InvalidArgument(format!(
418                "No pending request found for id: {request_id}"
419            )))
420        }
421    }
422
423    /// Handle response matching (returns true if it was a response)
424    async fn try_complete_response(&self, envelope: &RpcEnvelope) -> bool {
425        let mut pending = self.pending_requests.write().await;
426        if let Some(tx) = pending.remove(&envelope.request_id) {
427            // Check if response or error
428            match (&envelope.payload, &envelope.error) {
429                (Some(payload), None) => {
430                    let _ = tx.send(Ok(payload.clone()));
431                    tracing::debug!("✅ Completed pending request: {}", envelope.request_id);
432                }
433                (None, Some(error)) => {
434                    let protocol_err = actr_protocol::ProtocolError::TransportError(format!(
435                        "RPC error {}: {}",
436                        error.code, error.message
437                    ));
438                    let _ = tx.send(Err(protocol_err));
439                    tracing::debug!(
440                        "✅ Completed pending request with error: {}",
441                        envelope.request_id
442                    );
443                }
444                _ => {
445                    tracing::error!(
446                        "❌ Invalid RpcEnvelope: both payload and error present or both absent"
447                    );
448                    let _ = tx.send(Err(actr_protocol::ProtocolError::DecodeError(
449                        "Invalid RpcEnvelope: payload and error fields inconsistent".to_string(),
450                    )));
451                }
452            }
453            true
454        } else {
455            false
456        }
457    }
458
459    // ========== Helper methods ==========
460
461    async fn recv_from_channel(
462        rx: &Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
463    ) -> Option<RpcEnvelope> {
464        rx.lock().await.recv().await
465    }
466
467    async fn recv_from_channel_opt(opt: &Arc<Mutex<Option<ChannelPair>>>) -> Option<RpcEnvelope> {
468        let rx = {
469            let guard = opt.lock().await;
470            guard.as_ref().map(|pair| pair.rx.clone())
471        };
472
473        if let Some(rx) = rx {
474            rx.lock().await.recv().await
475        } else {
476            std::future::pending().await // If doesn't exist, wait forever
477        }
478    }
479}