actr_runtime/transport/
manager.rs

1//! OutprocTransportManager - Cross-process transport manager
2//!
3//! Manages transport layer for multiple Dests, providing unified send/recv interface
4//!
5//! # Naming Convention
6//! - **OutprocTransportManager**: Manages cross-process communication (WebRTC, WebSocket)
7//! - **InprocTransportManager**: Manages intra-process communication (mpsc channels)
8//!
9//! These two form a symmetric design, handling different transport scenarios
10
11use super::Dest; // Re-exported from actr-framework
12use super::dest_transport::DestTransport;
13use super::error::{NetworkError, NetworkResult};
14use super::wire_handle::WireHandle;
15use actr_protocol::{ActrId, PayloadType};
16use async_trait::async_trait;
17use either::Either;
18use std::collections::HashMap;
19use std::collections::HashSet;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::{Mutex, Notify, RwLock};
23use tokio_util::sync::CancellationToken;
24
25/// Wire builder trait: asynchronously creates Wire components based on Dest
26///
27/// Implement this trait to customize Wire layer component creation logic (e.g., WebRTC, WebSocket)
28#[async_trait]
29pub trait WireBuilder: Send + Sync {
30    /// Create Wire handle list to specified Dest
31    ///
32    /// # Arguments
33    /// - `dest`: Target destination
34    ///
35    /// # Returns
36    /// - Wire handle list (may contain multiple types: WebSocket, WebRTC, etc.)
37    async fn create_connections(&self, dest: &Dest) -> NetworkResult<Vec<WireHandle>>;
38
39    /// Create Wire handle list with cancellation support
40    ///
41    /// # Arguments
42    /// - `dest`: Target destination
43    /// - `cancel_token`: Optional cancellation token to terminate the operation
44    ///
45    /// # Returns
46    /// - Wire handle list (may contain multiple types: WebSocket, WebRTC, etc.)
47    /// - Returns error if cancelled
48    ///
49    /// Default implementation ignores the cancel token and calls `create_connections`.
50    async fn create_connections_with_cancel(
51        &self,
52        dest: &Dest,
53        cancel_token: Option<CancellationToken>,
54    ) -> NetworkResult<Vec<WireHandle>> {
55        // Check if already cancelled
56        if let Some(ref token) = cancel_token {
57            if token.is_cancelled() {
58                return Err(NetworkError::ConnectionClosed(
59                    "Connection creation cancelled".to_string(),
60                ));
61            }
62        }
63
64        // Default: just call create_connections
65        self.create_connections(dest).await
66    }
67}
68
69/// Destination transport state
70///
71/// Uses Either to manage connection lifecycle:
72/// - Left: Connecting state with shared Notify (multiple waiters)
73/// - Right: Connected state with DestTransport
74type DestState = Either<Arc<Notify>, Arc<DestTransport>>;
75
76/// OutprocTransportManager - Cross-process transport manager
77///
78/// Responsibilities:
79/// - Manage transport layer for multiple Dests (each Dest maps to one DestTransport)
80/// - Create DestTransport on-demand (lazy initialization)
81/// - Provide unified send/recv interface
82/// - Support custom connection factories
83/// - Prevent duplicate connection creation using Either state machine
84///
85/// # Comparison with InprocTransportManager
86/// - **OutprocTransportManager**: Cross-process, uses WebRTC/WebSocket
87/// - **InprocTransportManager**: Intra-process, uses mpsc channels, zero serialization
88///
89/// # State Machine
90/// ```text
91/// None → Connecting(Notify) → Connected(Transport)
92///         ↓                      ↓
93///      (multiple waiters)     (ready)
94/// ```
95pub struct OutprocTransportManager {
96    /// Local Actor ID
97    local_id: ActrId,
98
99    /// Dest → DestState mapping (Either state machine)
100    transports: Arc<RwLock<HashMap<Dest, DestState>>>,
101
102    /// Wire builder (used to create Wire handles for new DestTransport)
103    conn_factory: Arc<dyn WireBuilder>,
104
105    /// Cancellation tokens for in-progress connection creation
106    /// Dest → CancellationToken (for cancelling ongoing connection attempts)
107    pending_tokens: Arc<Mutex<HashMap<Dest, CancellationToken>>>,
108
109    #[allow(unused)]
110    /// todo: Set of peers currently being closed (to reject new connection attempts) ,closed requests will be cleaned up in event listener
111    closing_peers: Arc<RwLock<HashSet<Dest>>>,
112}
113
114impl OutprocTransportManager {
115    /// Create new OutprocTransportManager
116    ///
117    /// # Arguments
118    /// - `local_id`: Local Actor ID
119    /// - `conn_factory`: Wire builder, asynchronously creates Wire handle list based on Dest
120    pub fn new(local_id: ActrId, conn_factory: Arc<dyn WireBuilder>) -> Self {
121        Self {
122            local_id,
123            transports: Arc::new(RwLock::new(HashMap::new())),
124            conn_factory,
125            pending_tokens: Arc::new(Mutex::new(HashMap::new())),
126            closing_peers: Arc::new(RwLock::new(HashSet::new())),
127        }
128    }
129
130    /// Check if a destination is currently being closed
131    pub async fn is_closing(&self, dest: &Dest) -> bool {
132        self.closing_peers.read().await.contains(dest)
133    }
134
135    /// Get or create DestTransport for specified Dest
136    ///
137    /// # Arguments
138    /// - `dest`: Target destination
139    ///
140    /// # Returns
141    /// - DestTransport for this Dest (Arc-shared)
142    ///
143    /// # State Machine
144    /// Uses Either to prevent duplicate connections:
145    /// 1. If Connected → return transport
146    /// 2. If Connecting → wait for notify, then retry
147    /// 3. If None → insert Connecting(notify), create connection outside lock
148    #[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
149    pub async fn get_or_create_transport(&self, dest: &Dest) -> NetworkResult<Arc<DestTransport>> {
150        // 0. Check if dest is being closed - fast fail
151        // if self.closing_peers.read().await.contains(dest) {
152        //     return Err(NetworkError::ConnectionClosed(format!(
153        //         "Destination {:?} is being closed.",
154        //         dest
155        //     )));
156        // }
157
158        loop {
159            // 1. Fast path: check current state
160            let state_opt = {
161                let transports = self.transports.read().await;
162                transports.get(dest).cloned()
163            };
164
165            match state_opt {
166                // Already connected - fast path
167                Some(Either::Right(transport)) => {
168                    tracing::debug!("📦 Reusing existing DestTransport: {:?}", dest);
169                    return Ok(transport);
170                }
171                // Currently connecting - wait for completion
172                Some(Either::Left(notify)) => {
173                    tracing::debug!("⏳ Waiting for ongoing connection: {:?}", dest);
174                    notify.notified().await;
175                    // Check if cancelled during wait
176                    // if self.closing_peers.read().await.contains(dest) {
177                    //     return Err(NetworkError::ConnectionClosed(format!(
178                    //         "Destination {:?} was closed while waiting",
179                    //         dest
180                    //     )));
181                    // }
182                    // Retry after notification
183                    continue;
184                }
185                // Not exists - need to create
186                None => {
187                    // Enter slow path
188                }
189            }
190
191            // 2. Slow path: try to become the creator
192            let notify = {
193                let mut transports = self.transports.write().await;
194
195                // Double-check: may have been created while waiting for write lock
196                match transports.get(dest) {
197                    Some(Either::Right(transport)) => {
198                        return Ok(Arc::clone(transport));
199                    }
200                    Some(Either::Left(notify)) => {
201                        // Another thread is creating, wait for it
202                        Arc::clone(notify)
203                    }
204                    None => {
205                        // Check closing again before creating
206                        // if self.closing_peers.read().await.contains(dest) {
207                        //     return Err(NetworkError::ConnectionClosed(format!(
208                        //         "Destination {:?} is being closed",
209                        //         dest
210                        //     )));
211                        // }
212                        // We are the creator, insert Connecting state
213                        let notify = Arc::new(Notify::new());
214                        transports.insert(dest.clone(), Either::Left(Arc::clone(&notify)));
215                        tracing::debug!("🔄 Inserted Connecting state for: {:?}", dest);
216                        Arc::clone(&notify)
217                    }
218                }
219            };
220
221            // Check if we are the creator (notify was just created)
222            let is_creator = {
223                let transports = self.transports.read().await;
224                matches!(transports.get(dest), Some(Either::Left(n)) if Arc::ptr_eq(n, &notify))
225            };
226
227            if !is_creator {
228                // Wait for the actual creator
229                tracing::debug!("⏳ Another thread is creating connection: {:?}", dest);
230                // notify 加超时 10秒
231                match tokio::time::timeout(Duration::from_secs(10), notify.notified()).await {
232                    Ok(_) => continue,
233                    Err(e) => {
234                        return Err(NetworkError::TimeoutError(format!(
235                            "Timeout waiting for notification: {:?} {}",
236                            dest, e
237                        )));
238                    }
239                }
240            }
241
242            // 3. We are the creator - create connections OUTSIDE lock
243            tracing::info!("🚀 Creating new connection for: {:?}", dest);
244
245            // Create cancellation token for this connection attempt
246            let cancel_token = CancellationToken::new();
247            {
248                let mut tokens = self.pending_tokens.lock().await;
249                tokens.insert(dest.clone(), cancel_token.clone());
250            }
251
252            let result = async {
253                let connections = self
254                    .conn_factory
255                    .create_connections_with_cancel(dest, Some(cancel_token.clone()))
256                    .await?;
257
258                if connections.is_empty() {
259                    return Err(NetworkError::ConfigurationError(format!(
260                        "Connection factory returned no connections: {dest:?}"
261                    )));
262                }
263
264                tracing::info!(
265                    "✨ Creating DestTransport: {:?} ({} connections)",
266                    dest,
267                    connections.len()
268                );
269                let transport = DestTransport::new(dest.clone(), connections).await?;
270                Ok(Arc::new(transport))
271            }
272            .await;
273
274            // 4. Clean up pending token (connection attempt finished)
275            {
276                let mut tokens = self.pending_tokens.lock().await;
277                tokens.remove(dest);
278            }
279
280            // 5. Update state and notify waiters
281            let mut transports = self.transports.write().await;
282
283            match result {
284                Ok(transport) => {
285                    tracing::info!("✅ Connection established: {:?}", dest);
286                    transports.insert(dest.clone(), Either::Right(Arc::clone(&transport)));
287                    drop(transports);
288                    self.spawn_ready_monitor(dest.clone(), Arc::clone(&transport));
289                    notify.notify_waiters();
290                    return Ok(transport);
291                }
292                Err(e) => {
293                    tracing::error!("❌ Connection failed: {:?}: {}", dest, e);
294                    transports.remove(dest);
295                    drop(transports);
296                    notify.notify_waiters();
297                    return Err(e);
298                }
299            }
300        }
301    }
302
303    /// Send message to specified Dest
304    ///
305    /// # Arguments
306    /// - `dest`: Target destination
307    /// - `payload_type`: Message type
308    /// - `data`: Message data
309    ///
310    /// # Example
311    ///
312    /// ```rust,ignore
313    /// mgr.send(&dest, PayloadType::RpcSignal, b"hello").await?;
314    /// ```
315    #[cfg_attr(
316        feature = "opentelemetry",
317        tracing::instrument(skip_all, name = "OutprocTransportManager.send")
318    )]
319    pub async fn send(
320        &self,
321        dest: &Dest,
322        payload_type: PayloadType,
323        data: &[u8],
324    ) -> NetworkResult<()> {
325        tracing::debug!(
326            "📤 [OutprocTransportManager] Sending to {:?}: type={:?}, size={}",
327            dest,
328            payload_type,
329            data.len()
330        );
331
332        // Get or create DestTransport for this Dest
333        let transport = self.get_or_create_transport(dest).await?;
334
335        // Send through DestTransport
336        transport.send(payload_type, data).await
337    }
338
339    /// Close DestTransport for specified Dest
340    ///
341    /// Called by OutprocOutGate when connection events indicate cleanup is needed.
342    /// This triggers the cleanup chain: OutprocTransportManager → DestTransport → WirePool
343    ///
344    /// # Arguments
345    /// - `dest`: Target destination
346    pub async fn close_transport(&self, dest: &Dest) -> NetworkResult<()> {
347        // 1. Mark as closing
348        self.closing_peers.write().await.insert(dest.clone());
349
350        // 2. Cancel in-progress connection creation
351        {
352            let mut tokens = self.pending_tokens.lock().await;
353            if let Some(token) = tokens.remove(dest) {
354                tracing::info!("🚫 Cancelling in-progress connection for {:?}", dest);
355                token.cancel();
356            }
357        }
358
359        // 3. Remove and close the transport
360        let mut transports = self.transports.write().await;
361
362        if let Some(state) = transports.remove(dest) {
363            drop(transports); // Release lock before calling close()
364
365            match state {
366                Either::Right(transport) => {
367                    tracing::info!("🔌 Closing DestTransport: {:?}", dest);
368                    transport.close().await?;
369                }
370                Either::Left(notify) => {
371                    tracing::debug!("⏸️ Removed Connecting state for: {:?}", dest);
372                    // Notify waiters that connection was cancelled
373                    notify.notify_waiters();
374                }
375            }
376        }
377
378        // 4. Remove from closing set after cleanup completes
379        self.closing_peers.write().await.remove(dest);
380
381        Ok(())
382    }
383
384    /// Close all DestTransports
385    pub async fn close_all(&self) -> NetworkResult<()> {
386        let mut transports = self.transports.write().await;
387
388        tracing::info!(
389            "🔌 Closing all DestTransports (count: {})",
390            transports.len()
391        );
392
393        for (dest, state) in transports.drain() {
394            match state {
395                Either::Right(transport) => {
396                    if let Err(e) = transport.close().await {
397                        tracing::warn!("❌ Failed to close DestTransport {:?}: {}", dest, e);
398                    }
399                }
400                Either::Left(_notify) => {
401                    tracing::debug!("⏸️ Skipped Connecting state for: {:?}", dest);
402                }
403            }
404        }
405
406        Ok(())
407    }
408
409    /// Get count of currently managed Dests
410    pub async fn dest_count(&self) -> usize {
411        self.transports.read().await.len()
412    }
413
414    /// Get local Actor ID
415    #[inline]
416    pub fn local_id(&self) -> &ActrId {
417        &self.local_id
418    }
419
420    /// List all connected Dests
421    pub async fn list_dests(&self) -> Vec<Dest> {
422        self.transports.read().await.keys().cloned().collect()
423    }
424
425    /// Check if connection to specified Dest exists
426    pub async fn has_dest(&self, dest: &Dest) -> bool {
427        self.transports.read().await.contains_key(dest)
428    }
429
430    /// Monitor a DestTransport ready-set and remove it when all connections are gone.
431    fn spawn_ready_monitor(&self, dest: Dest, transport: Arc<DestTransport>) {
432        let transports = Arc::clone(&self.transports);
433        tokio::spawn(async move {
434            let mut rx = transport.watch_ready();
435            let mut had_ready = !rx.borrow().is_empty();
436
437            loop {
438                if rx.changed().await.is_err() {
439                    break;
440                }
441                let ready = rx.borrow().clone();
442
443                if ready.is_empty() && had_ready {
444                    // Only remove if the same transport is still mapped.
445                    let mut map = transports.write().await;
446                    let matched = matches!(
447                        map.get(&dest),
448                        Some(Either::Right(existing)) if Arc::ptr_eq(existing, &transport)
449                    );
450                    if matched {
451                        map.remove(&dest);
452                        drop(map);
453
454                        tracing::warn!(
455                            "🧹 Removing DestTransport for {:?} after all connections closed",
456                            dest
457                        );
458                        if let Err(e) = transport.close().await {
459                            tracing::warn!("⚠️ Failed to close DestTransport {:?}: {}", dest, e);
460                        }
461                    }
462                    break;
463                }
464
465                if !ready.is_empty() {
466                    had_ready = true;
467                }
468            }
469        });
470    }
471
472    /// Spawn health checker background task with smart reconnect
473    ///
474    /// Periodically checks all DestTransport health status:
475    /// - If some connections failed → trigger smart reconnect (reuse working connections)
476    /// - If all connections failed → remove entire DestTransport
477    ///
478    /// # Arguments
479    /// - `interval`: Health check interval (recommended: 10-30 seconds)
480    ///
481    /// # Returns
482    /// - JoinHandle for the background task (can be used to cancel)
483    ///
484    /// # Example
485    /// ```rust,ignore
486    /// let mgr = Arc::new(OutprocTransportManager::new(local_id, factory));
487    /// let health_check_handle = mgr.spawn_health_checker(Duration::from_secs(10));
488    /// ```
489    pub fn spawn_health_checker(&self, interval: Duration) -> tokio::task::JoinHandle<()> {
490        let transports = Arc::clone(&self.transports);
491        let conn_factory = Arc::clone(&self.conn_factory);
492
493        tokio::spawn(async move {
494            let mut interval_timer = tokio::time::interval(interval);
495            interval_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
496
497            loop {
498                interval_timer.tick().await;
499
500                // Collect snapshot of connected Dests first (no async under lock)
501                let snapshot: Vec<(Dest, Arc<DestTransport>)> = {
502                    let transports_read = transports.read().await;
503
504                    transports_read
505                        .iter()
506                        .filter_map(|(dest, state)| {
507                            // Only check Connected transports, skip Connecting
508                            if let Either::Right(transport) = state {
509                                Some((dest.clone(), Arc::clone(transport)))
510                            } else {
511                                None
512                            }
513                        })
514                        .collect()
515                };
516
517                // Process each Dest outside of the lock
518                for (dest_clone, transport) in snapshot {
519                    let healthy = transport.has_healthy_connection().await;
520
521                    if !healthy {
522                        // All connections failed - schedule for removal
523                        tracing::warn!(
524                            "💀 All connections failed for {:?}, will remove",
525                            dest_clone
526                        );
527
528                        // Remove entire DestTransport
529                        let mut transports_write = transports.write().await;
530                        if let Some(Either::Right(transport)) = transports_write.remove(&dest_clone)
531                        {
532                            tracing::info!(
533                                "🗑️  Removing completely failed DestTransport: {:?}",
534                                dest_clone
535                            );
536                            // Drop lock before awaiting close
537                            drop(transports_write);
538
539                            if let Err(e) = transport.close().await {
540                                tracing::warn!(
541                                    "❌ Failed to close DestTransport {:?}: {}",
542                                    dest_clone,
543                                    e
544                                );
545                            }
546                        } else {
547                            // State changed between snapshot and removal; skip safely
548                            drop(transports_write);
549                        }
550                    } else {
551                        // At least one connection is working
552                        // Try to reconnect failed ones (smart reconnect)
553                        tracing::debug!("🔄 Triggering smart reconnect for: {:?}", dest_clone);
554                        if let Err(e) = transport
555                            .retry_failed_connections(&dest_clone, conn_factory.as_ref())
556                            .await
557                        {
558                            tracing::warn!("❌ Smart reconnect failed for {:?}: {}", dest_clone, e);
559                        }
560                    }
561                }
562            }
563        })
564    }
565}
566
567impl Drop for OutprocTransportManager {
568    fn drop(&mut self) {
569        tracing::debug!("🗑️  OutprocTransportManager dropped");
570        // Note: async cleanup requires external call to close_all()
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577
578    struct TestFactory;
579
580    #[async_trait]
581    impl WireBuilder for TestFactory {
582        async fn create_connections(&self, _dest: &Dest) -> NetworkResult<Vec<WireHandle>> {
583            // Test factory: returns empty list (real usage requires actual connections)
584            Ok(vec![])
585        }
586    }
587
588    fn create_test_factory() -> Arc<dyn WireBuilder> {
589        Arc::new(TestFactory)
590    }
591
592    #[tokio::test]
593    async fn test_transport_manager_creation() {
594        let local_id = ActrId::default();
595        let factory = create_test_factory();
596        let mgr = OutprocTransportManager::new(local_id.clone(), factory);
597
598        assert_eq!(mgr.dest_count().await, 0);
599        assert_eq!(mgr.local_id(), &local_id);
600    }
601
602    #[tokio::test]
603    async fn test_list_dests() {
604        let local_id = ActrId::default();
605        let factory = create_test_factory();
606        let mgr = OutprocTransportManager::new(local_id, factory);
607
608        let dests = mgr.list_dests().await;
609        assert_eq!(dests.len(), 0);
610    }
611
612    #[tokio::test]
613    async fn test_has_dest() {
614        let local_id = ActrId::default();
615        let factory = create_test_factory();
616        let mgr = OutprocTransportManager::new(local_id, factory);
617
618        let dest = Dest::shell();
619        assert!(!mgr.has_dest(&dest).await);
620    }
621}