actr_runtime/transport/
inproc_manager.rs

1//! InprocTransportManager - Intra-process transport manager
2//!
3//! Manages mpsc channel communication between Workload and Shell
4//!
5//! # Usage Examples
6//!
7//! ## Workload Side (Subscribe to data streams)
8//!
9//! ```rust,ignore
10//! use actr_runtime::InprocTransportManager;
11//! use std::sync::Arc;
12//!
13//! struct MyWorkload {
14//!     inproc_mgr: Arc<InprocTransportManager>,
15//! }
16//!
17//! impl MyWorkload {
18//!     pub async fn subscribe_metrics_stream(&self) -> NetworkResult<()> {
19//!         // Create LatencyFirst channel
20//!         let rx = self.inproc_mgr
21//!             .create_latency_first_channel("metrics-stream".to_string())
22//!             .await;
23//!
24//!         // Start receive loop
25//!         tokio::spawn(async move {
26//!             loop {
27//!                 let mut receiver = rx.lock().await;
28//!                 if let Some(envelope) = receiver.recv().await {
29//!                     // Process streaming data
30//!                     println!("Received: {:?}", envelope);
31//!                 }
32//!             }
33//!         });
34//!
35//!         Ok(())
36//!     }
37//! }
38//! ```
39//!
40//! ## Shell Side (Send data)
41//!
42//! ```rust,ignore
43//! // Get InprocTransportManager from ActrNode
44//! if let Some(inproc_mgr) = node.inproc_mgr() {
45//!     // Send to LatencyFirst channel
46//!     let envelope = RpcEnvelope { /* ... */ };
47//!     inproc_mgr.send_message(
48//!         PayloadType::StreamLatencyFirst,
49//!         Some("metrics-stream".to_string()),
50//!         envelope
51//!     ).await?;
52//! }
53//! ```
54
55use super::{DataLane, NetworkError, NetworkResult};
56use actr_framework::Bytes;
57use actr_protocol::{PayloadType, RpcEnvelope};
58use std::collections::HashMap;
59use std::sync::Arc;
60use std::time::Duration;
61use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
62
63/// Inproc Transport Manager - manages intra-process transport (mpsc channels)
64///
65/// # Design Philosophy
66/// - **Workload ↔ Shell communication bridge** (not for arbitrary Actor-to-Actor communication)
67/// - **Reliable is mandatory, others are created on-demand**
68/// - **Dynamic multi-channel management**: HashMap<String, Channel>
69/// - **Bi-directional sharing**: Shell and Workload share the same Manager
70pub struct InprocTransportManager {
71    // ========== Mandatory base channel ==========
72    /// Reliable channel (must exist)
73    reliable_tx: mpsc::Sender<RpcEnvelope>,
74    reliable_rx: Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
75
76    // ========== Optional specialized channels ==========
77    /// Signal channel (optional, lazy creation)
78    signal_channel: Arc<Mutex<Option<ChannelPair>>>,
79
80    /// LatencyFirst channels (multi-instance, indexed by channel_id)
81    latency_first_channels: Arc<RwLock<HashMap<String, ChannelPair>>>,
82
83    /// MediaTrack channels (multi-instance, indexed by track_id)
84    media_track_channels: Arc<RwLock<HashMap<String, ChannelPair>>>,
85
86    // ========== Management data ==========
87    /// Lane cache (avoid repeated creation)
88    lane_cache: Arc<RwLock<HashMap<LaneKey, DataLane>>>,
89
90    /// Pending requests (request/response matching)
91    /// Sender can receive either success (Bytes) or error (ProtocolError)
92    pending_requests:
93        Arc<RwLock<HashMap<String, oneshot::Sender<actr_protocol::ActorResult<Bytes>>>>>,
94}
95
96/// Channel pair (tx + rx)
97#[derive(Clone)]
98struct ChannelPair {
99    tx: mpsc::Sender<RpcEnvelope>,
100    rx: Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
101}
102
103/// Lane cache key
104#[derive(Hash, Eq, PartialEq, Clone, Debug)]
105struct LaneKey {
106    payload_type: PayloadType,
107    /// channel_id (LatencyFirst) or track_id (MediaTrack)
108    identifier: Option<String>,
109}
110
111impl Default for InprocTransportManager {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117impl InprocTransportManager {
118    /// Create new instance (only creates Reliable channel, others are lazy-initialized)
119    ///
120    /// InprocTransportManager manages intra-process communication channels between Workload and Shell.
121    /// It does not need ActorId as all communication is within a single process.
122    pub fn new() -> Self {
123        let (reliable_tx, reliable_rx) = mpsc::channel(1024);
124
125        tracing::debug!("Created InprocTransportManager");
126        tracing::debug!("✨ Created Reliable channel");
127
128        Self {
129            reliable_tx,
130            reliable_rx: Arc::new(Mutex::new(reliable_rx)),
131            signal_channel: Arc::new(Mutex::new(None)),
132            latency_first_channels: Arc::new(RwLock::new(HashMap::new())),
133            media_track_channels: Arc::new(RwLock::new(HashMap::new())),
134            lane_cache: Arc::new(RwLock::new(HashMap::new())),
135            pending_requests: Arc::new(RwLock::new(HashMap::new())),
136        }
137    }
138
139    // ========== Dynamic creation APIs ==========
140
141    /// Ensure Signal channel exists
142    async fn ensure_signal_channel(&self) -> ChannelPair {
143        let mut opt = self.signal_channel.lock().await;
144        if opt.is_none() {
145            let (tx, rx) = mpsc::channel(1024);
146            *opt = Some(ChannelPair {
147                tx,
148                rx: Arc::new(Mutex::new(rx)),
149            });
150            tracing::debug!("✨ Created Signal channel");
151        }
152        // Safe: we just created it if it was None
153        opt.as_ref()
154            .expect("Signal channel must exist after ensure_signal_channel")
155            .clone()
156    }
157
158    /// Create LatencyFirst channel
159    pub async fn create_latency_first_channel(
160        &self,
161        channel_id: String,
162    ) -> Arc<Mutex<mpsc::Receiver<RpcEnvelope>>> {
163        let mut channels = self.latency_first_channels.write().await;
164
165        if !channels.contains_key(&channel_id) {
166            let (tx, rx) = mpsc::channel(1024);
167            let pair = ChannelPair {
168                tx,
169                rx: Arc::new(Mutex::new(rx)),
170            };
171            let rx_clone = pair.rx.clone();
172            channels.insert(channel_id.clone(), pair);
173
174            tracing::debug!("✨ Created LatencyFirst channel '{}'", channel_id);
175            rx_clone
176        } else {
177            // Safe: we just checked contains_key
178            channels
179                .get(&channel_id)
180                .expect("LatencyFirst channel must exist after contains_key check")
181                .rx
182                .clone()
183        }
184    }
185
186    /// Create MediaTrack channel
187    pub async fn create_media_track_channel(
188        &self,
189        track_id: String,
190    ) -> Arc<Mutex<mpsc::Receiver<RpcEnvelope>>> {
191        let mut channels = self.media_track_channels.write().await;
192
193        if !channels.contains_key(&track_id) {
194            let (tx, rx) = mpsc::channel(1024);
195            let pair = ChannelPair {
196                tx,
197                rx: Arc::new(Mutex::new(rx)),
198            };
199            let rx_clone = pair.rx.clone();
200            channels.insert(track_id.clone(), pair);
201
202            tracing::debug!("✨ Created MediaTrack channel '{}'", track_id);
203            rx_clone
204        } else {
205            // Safe: we just checked contains_key
206            channels
207                .get(&track_id)
208                .expect("MediaTrack channel must exist after contains_key check")
209                .rx
210                .clone()
211        }
212    }
213
214    // ========== Lane retrieval APIs ==========
215
216    /// Get Lane (with optional channel_id/track_id)
217    ///
218    /// # Arguments
219    /// - `payload_type`: PayloadType
220    /// - `identifier`:
221    ///   - `None` for Reliable/Signal
222    ///   - `Some(channel_id)` for LatencyFirst
223    ///   - `Some(track_id)` for MediaTrack
224    pub async fn get_lane(
225        &self,
226        payload_type: PayloadType,
227        identifier: Option<String>,
228    ) -> NetworkResult<DataLane> {
229        let key = LaneKey {
230            payload_type,
231            identifier: identifier.clone(),
232        };
233
234        // 1. Check cache
235        {
236            let cache = self.lane_cache.read().await;
237            if let Some(lane) = cache.get(&key) {
238                tracing::debug!("📦 Reusing cached Inproc DataLane: {:?}", key);
239                return Ok(lane.clone());
240            }
241        }
242
243        // 2. Get corresponding ChannelPair
244        let pair = match payload_type {
245            PayloadType::RpcReliable => ChannelPair {
246                tx: self.reliable_tx.clone(),
247                rx: self.reliable_rx.clone(),
248            },
249
250            PayloadType::RpcSignal => self.ensure_signal_channel().await,
251
252            PayloadType::StreamReliable | PayloadType::StreamLatencyFirst => {
253                let channel_id = identifier
254                    .as_ref()
255                    .ok_or_else(|| {
256                        NetworkError::InvalidArgument("DataStream requires channel_id".into())
257                    })?
258                    .clone();
259
260                let channels = self.latency_first_channels.read().await;
261                channels
262                    .get(&channel_id)
263                    .ok_or_else(|| NetworkError::ChannelNotFound(channel_id))?
264                    .clone()
265            }
266
267            PayloadType::MediaRtp => {
268                let track_id = identifier
269                    .as_ref()
270                    .ok_or_else(|| {
271                        NetworkError::InvalidArgument("MediaRtp requires track_id".into())
272                    })?
273                    .clone();
274
275                let channels = self.media_track_channels.read().await;
276                channels
277                    .get(&track_id)
278                    .ok_or_else(|| NetworkError::ChannelNotFound(track_id))?
279                    .clone()
280            }
281        };
282
283        // 3. Create DataLane
284        let lane = DataLane::mpsc_shared(payload_type, pair.tx, pair.rx);
285
286        // 4. Cache it
287        self.lane_cache.write().await.insert(key, lane.clone());
288
289        tracing::debug!(
290            "✨ Created Inproc DataLane: type={:?}, identifier={:?}",
291            payload_type,
292            identifier
293        );
294
295        Ok(lane)
296    }
297
298    // ========== High-level APIs ==========
299
300    /// Send request (with response waiting)
301    #[cfg_attr(
302        feature = "opentelemetry",
303        tracing::instrument(skip_all, name = "InprocTransportManager.send_request")
304    )]
305    pub async fn send_request(
306        &self,
307        payload_type: PayloadType,
308        identifier: Option<String>,
309        envelope: RpcEnvelope,
310    ) -> NetworkResult<Bytes> {
311        let (response_tx, response_rx) = oneshot::channel();
312
313        // Register pending request
314        self.pending_requests
315            .write()
316            .await
317            .insert(envelope.request_id.clone(), response_tx);
318
319        // Send
320        let lane = self.get_lane(payload_type, identifier).await?;
321        lane.send_envelope(envelope).await?;
322
323        // Wait for response
324        let result = tokio::time::timeout(Duration::from_secs(30), response_rx)
325            .await
326            .map_err(|_| NetworkError::TimeoutError("Request timeout".into()))?
327            .map_err(|_| NetworkError::ConnectionError("Response channel closed".into()))?;
328
329        // result is ActorResult<Bytes>, convert to NetworkError if error
330        result.map_err(|e| NetworkError::ProtocolError(e.to_string()))
331    }
332
333    /// Send one-way message
334    #[cfg_attr(
335        feature = "opentelemetry",
336        tracing::instrument(skip_all, name = "InprocTransportManager.send_message")
337    )]
338    pub async fn send_message(
339        &self,
340        payload_type: PayloadType,
341        identifier: Option<String>,
342        envelope: RpcEnvelope,
343    ) -> NetworkResult<()> {
344        let lane = self.get_lane(payload_type, identifier).await?;
345        lane.send_envelope(envelope).await
346    }
347
348    /// Receive one message (select first available from all channels)
349    ///
350    /// # Returns
351    /// - `Some(envelope)`: received message (response matching already handled)
352    /// - `None`: all channels closed
353    pub async fn recv(&self) -> Option<RpcEnvelope> {
354        loop {
355            tokio::select! {
356                biased;
357
358                // Signal (highest priority)
359                msg = Self::recv_from_channel_opt(&self.signal_channel) => {
360                    if let Some(envelope) = msg {
361                        if !self.try_complete_response(&envelope).await {
362                            return Some(envelope);  // It's a request
363                        }
364                        // It's a response, already handled, continue loop
365                    }
366                }
367
368                // Reliable
369                msg = Self::recv_from_channel(&self.reliable_rx) => {
370                    if let Some(envelope) = msg {
371                        if !self.try_complete_response(&envelope).await {
372                            return Some(envelope);
373                        }
374                    }
375                }
376
377                // TODO: LatencyFirst and MediaTrack reception
378                // Need to implement receiving from all channels in HashMap
379            }
380        }
381    }
382
383    /// Complete a pending request with response payload
384    ///
385    /// # Arguments
386    /// - `request_id`: The request ID to complete
387    /// - `response_bytes`: Response payload
388    ///
389    /// # Returns
390    /// - `Ok(())`: Successfully sent response to waiting sender
391    /// - `Err(NetworkError)`: No pending request found with this ID
392    pub async fn complete_response(
393        &self,
394        request_id: &str,
395        response_bytes: Bytes,
396    ) -> NetworkResult<()> {
397        let mut pending = self.pending_requests.write().await;
398        if let Some(tx) = pending.remove(request_id) {
399            let _ = tx.send(Ok(response_bytes));
400            tracing::debug!("✅ Completed pending request: {}", request_id);
401            Ok(())
402        } else {
403            Err(NetworkError::InvalidArgument(format!(
404                "No pending request found for id: {request_id}"
405            )))
406        }
407    }
408
409    /// Complete a pending request with an error
410    ///
411    /// # Returns
412    /// - `Ok(())`: Successfully sent error to waiting sender
413    /// - `Err(NetworkError)`: No pending request found with this ID
414    pub async fn complete_error(
415        &self,
416        request_id: &str,
417        error: actr_protocol::ProtocolError,
418    ) -> NetworkResult<()> {
419        let mut pending = self.pending_requests.write().await;
420        if let Some(tx) = pending.remove(request_id) {
421            let _ = tx.send(Err(error));
422            tracing::debug!("✅ Completed pending request with error: {}", request_id);
423            Ok(())
424        } else {
425            Err(NetworkError::InvalidArgument(format!(
426                "No pending request found for id: {request_id}"
427            )))
428        }
429    }
430
431    /// Handle response matching (returns true if it was a response)
432    async fn try_complete_response(&self, envelope: &RpcEnvelope) -> bool {
433        let mut pending = self.pending_requests.write().await;
434        if let Some(tx) = pending.remove(&envelope.request_id) {
435            // Check if response or error
436            match (&envelope.payload, &envelope.error) {
437                (Some(payload), None) => {
438                    let _ = tx.send(Ok(payload.clone()));
439                    tracing::debug!("✅ Completed pending request: {}", envelope.request_id);
440                }
441                (None, Some(error)) => {
442                    let protocol_err = actr_protocol::ProtocolError::TransportError(format!(
443                        "RPC error {}: {}",
444                        error.code, error.message
445                    ));
446                    let _ = tx.send(Err(protocol_err));
447                    tracing::debug!(
448                        "✅ Completed pending request with error: {}",
449                        envelope.request_id
450                    );
451                }
452                _ => {
453                    tracing::error!(
454                        "❌ Invalid RpcEnvelope: both payload and error present or both absent"
455                    );
456                    let _ = tx.send(Err(actr_protocol::ProtocolError::DecodeError(
457                        "Invalid RpcEnvelope: payload and error fields inconsistent".to_string(),
458                    )));
459                }
460            }
461            true
462        } else {
463            false
464        }
465    }
466
467    // ========== Helper methods ==========
468
469    async fn recv_from_channel(
470        rx: &Arc<Mutex<mpsc::Receiver<RpcEnvelope>>>,
471    ) -> Option<RpcEnvelope> {
472        rx.lock().await.recv().await
473    }
474
475    async fn recv_from_channel_opt(opt: &Arc<Mutex<Option<ChannelPair>>>) -> Option<RpcEnvelope> {
476        let rx = {
477            let guard = opt.lock().await;
478            guard.as_ref().map(|pair| pair.rx.clone())
479        };
480
481        if let Some(rx) = rx {
482            rx.lock().await.recv().await
483        } else {
484            std::future::pending().await // If doesn't exist, wait forever
485        }
486    }
487}