use std::sync::{
Arc,
atomic::{AtomicU64, Ordering},
};
use std::time::{Duration, Instant};
use crate::transport::PeerTransport;
use crate::wire::webrtc::{SignalingClient, WebRtcCoordinator};
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use super::connection_supervisor::{ConnectionFact, ConnectionSupervisor};
const NETWORK_EVENT_SETTLE_WINDOW: Duration = Duration::from_millis(400);
const NETWORK_EVENT_RESULT_TIMEOUT: Duration = Duration::from_secs(5);
const SIGNALING_PROBE_TIMEOUT: Duration = Duration::from_secs(1);
pub(super) const LONG_BACKGROUND_RECONNECT_THRESHOLD_MS: u64 = 30_000;
static NEXT_NETWORK_EVENT_REQUEST_ID: AtomicU64 = AtomicU64::new(1);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct NetworkSnapshot {
pub sequence: u64,
pub availability: NetworkAvailability,
pub transport: NetworkTransportFlags,
pub is_expensive: bool,
pub is_constrained: bool,
}
impl NetworkSnapshot {
pub fn is_offline(&self) -> bool {
matches!(self.availability, NetworkAvailability::Unavailable)
}
pub fn should_restore(&self) -> bool {
matches!(self.availability, NetworkAvailability::Available)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum NetworkAvailability {
Unknown,
Available,
Unavailable,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct NetworkTransportFlags {
pub wifi: bool,
pub cellular: bool,
pub ethernet: bool,
pub vpn: bool,
pub other: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AppLifecycleState {
Background,
Foreground { background_duration_ms: u64 },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CleanupReason {
AppTerminating,
UserLogout,
StaleConnectionSuspected,
ManualReset,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReconnectReason {
NetworkPathChanged,
LongBackground,
ProbeFailed,
ManualReconnect,
StaleConnectionSuspected,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum NetworkEvent {
NetworkPathChanged { snapshot: NetworkSnapshot },
AppLifecycleChanged { state: AppLifecycleState },
CleanupConnections { reason: CleanupReason },
ForceReconnect { reason: ReconnectReason },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NetworkRecoveryAction {
Noop,
Offline,
Probe,
Restore,
CleanupOnly,
ForceReconnect,
}
#[derive(Debug, Clone)]
pub struct NetworkEventResult {
pub event: NetworkEvent,
pub success: bool,
pub error: Option<String>,
pub duration_ms: u64,
}
impl NetworkEventResult {
pub fn success(event: NetworkEvent, duration_ms: u64) -> Self {
Self {
event,
success: true,
error: None,
duration_ms,
}
}
pub fn failure(event: NetworkEvent, error: String, duration_ms: u64) -> Self {
Self {
event,
success: false,
error: Some(error),
duration_ms,
}
}
}
#[async_trait::async_trait]
pub trait NetworkEventProcessor: Send + Sync {
async fn process_network_available(&self) -> Result<(), String>;
async fn process_network_lost(&self) -> Result<(), String>;
async fn process_network_type_changed(
&self,
is_wifi: bool,
is_cellular: bool,
) -> Result<(), String>;
async fn cleanup_connections(&self) -> Result<(), String>;
async fn probe_connectivity(&self) -> Result<(), String> {
Ok(())
}
async fn force_reconnect(&self) -> Result<(), String> {
self.cleanup_connections().await?;
self.process_network_available().await
}
async fn process_network_recovery_action(
&self,
action: NetworkRecoveryAction,
) -> Result<(), String> {
match action {
NetworkRecoveryAction::Noop => Ok(()),
NetworkRecoveryAction::Offline => self.process_network_lost().await,
NetworkRecoveryAction::Probe => self.probe_connectivity().await,
NetworkRecoveryAction::Restore => self.process_network_available().await,
NetworkRecoveryAction::CleanupOnly => self.cleanup_connections().await,
NetworkRecoveryAction::ForceReconnect => self.force_reconnect().await,
}
}
}
#[derive(Debug, Clone)]
pub struct DebounceConfig {
pub window: Duration,
}
impl Default for DebounceConfig {
fn default() -> Self {
Self {
window: Duration::from_secs(2),
}
}
}
#[derive(Debug)]
struct DebounceState {
last_available: tokio::sync::Mutex<Option<Instant>>,
last_lost: tokio::sync::Mutex<Option<Instant>>,
last_type_changed: tokio::sync::Mutex<Option<Instant>>,
}
impl DebounceState {
fn new() -> Self {
Self {
last_available: tokio::sync::Mutex::new(None),
last_lost: tokio::sync::Mutex::new(None),
last_type_changed: tokio::sync::Mutex::new(None),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DebounceEvent {
Available,
Lost,
TypeChanged,
}
#[derive(Debug)]
struct SignalingRecoveryState {
connect_lock: tokio::sync::Mutex<()>,
last_successful_connect: tokio::sync::Mutex<Option<Instant>>,
}
impl SignalingRecoveryState {
fn new() -> Self {
Self {
connect_lock: tokio::sync::Mutex::new(()),
last_successful_connect: tokio::sync::Mutex::new(None),
}
}
}
pub struct DefaultNetworkEventProcessor {
signaling_client: Arc<dyn SignalingClient>,
webrtc_coordinator: Option<Arc<WebRtcCoordinator>>,
peer_transport: Option<Arc<PeerTransport>>,
debounce_config: DebounceConfig,
debounce_state: Arc<DebounceState>,
recovery_state: Arc<SignalingRecoveryState>,
}
impl DefaultNetworkEventProcessor {
pub fn new(
signaling_client: Arc<dyn SignalingClient>,
webrtc_coordinator: Option<Arc<WebRtcCoordinator>>,
) -> Self {
Self::new_with_debounce_and_peer_transport(
signaling_client,
webrtc_coordinator,
DebounceConfig::default(),
None,
)
}
pub fn new_with_debounce(
signaling_client: Arc<dyn SignalingClient>,
webrtc_coordinator: Option<Arc<WebRtcCoordinator>>,
debounce_config: DebounceConfig,
) -> Self {
Self::new_with_debounce_and_peer_transport(
signaling_client,
webrtc_coordinator,
debounce_config,
None,
)
}
pub(crate) fn new_with_peer_transport(
signaling_client: Arc<dyn SignalingClient>,
webrtc_coordinator: Option<Arc<WebRtcCoordinator>>,
peer_transport: Option<Arc<PeerTransport>>,
) -> Self {
Self::new_with_debounce_and_peer_transport(
signaling_client,
webrtc_coordinator,
DebounceConfig::default(),
peer_transport,
)
}
pub(crate) fn new_with_debounce_and_peer_transport(
signaling_client: Arc<dyn SignalingClient>,
webrtc_coordinator: Option<Arc<WebRtcCoordinator>>,
debounce_config: DebounceConfig,
peer_transport: Option<Arc<PeerTransport>>,
) -> Self {
Self {
signaling_client,
webrtc_coordinator,
peer_transport,
debounce_config,
debounce_state: Arc::new(DebounceState::new()),
recovery_state: Arc::new(SignalingRecoveryState::new()),
}
}
async fn should_process_event(&self, event: DebounceEvent) -> bool {
let now = Instant::now();
match event {
DebounceEvent::Available => {
let mut last = self.debounce_state.last_available.lock().await;
if let Some(last_time) = *last {
if now.duration_since(last_time) < self.debounce_config.window {
tracing::debug!(
"⏸️ Debouncing Network Available event (last event was {:?} ago)",
now.duration_since(last_time)
);
return false;
}
}
*last = Some(now);
true
}
DebounceEvent::Lost => {
let mut last = self.debounce_state.last_lost.lock().await;
if let Some(last_time) = *last {
if now.duration_since(last_time) < self.debounce_config.window {
tracing::debug!(
"⏸️ Debouncing Network Lost event (last event was {:?} ago)",
now.duration_since(last_time)
);
return false;
}
}
*last = Some(now);
true
}
DebounceEvent::TypeChanged => {
let mut last = self.debounce_state.last_type_changed.lock().await;
if let Some(last_time) = *last {
if now.duration_since(last_time) < self.debounce_config.window {
tracing::debug!(
"⏸️ Debouncing Network TypeChanged event (last event was {:?} ago)",
now.duration_since(last_time)
);
return false;
}
}
*last = Some(now);
true
}
}
}
async fn ensure_signaling_healthy_once(&self, reason: &str) -> Result<(), String> {
let _guard = self.recovery_state.connect_lock.lock().await;
if !self.signaling_client.is_connected() {
tracing::info!(reason = reason, "🔄 Connecting signaling");
self.signaling_client.connect_once().await.map_err(|e| {
let err_msg = format!("WebSocket connect failed: {}", e);
tracing::error!("❌ {}", err_msg);
err_msg
})?;
*self.recovery_state.last_successful_connect.lock().await = Some(Instant::now());
tracing::info!(reason = reason, "✅ Signaling connected");
return Ok(());
}
tracing::debug!(
reason = reason,
timeout_ms = SIGNALING_PROBE_TIMEOUT.as_millis() as u64,
"🔎 Probing existing signaling WebSocket"
);
match self
.signaling_client
.probe_alive(SIGNALING_PROBE_TIMEOUT)
.await
{
Ok(()) => {
tracing::debug!(
reason = reason,
"✅ Signaling probe succeeded; keeping existing WebSocket"
);
Ok(())
}
Err(e) => {
tracing::warn!(
reason = reason,
"⚠️ Signaling probe failed; rebuilding WebSocket: {}",
e
);
if let Err(disconnect_err) = self.signaling_client.disconnect().await {
tracing::warn!(
reason = reason,
"⚠️ Failed to disconnect unhealthy signaling before rebuild: {}",
disconnect_err
);
}
tracing::info!(reason = reason, "🔄 Rebuilding signaling: connecting");
self.signaling_client
.connect_once()
.await
.map_err(|connect_err| {
let err_msg = format!("WebSocket rebuild failed: {}", connect_err);
tracing::error!("❌ {}", err_msg);
err_msg
})?;
*self.recovery_state.last_successful_connect.lock().await = Some(Instant::now());
tracing::info!(reason = reason, "✅ Signaling rebuilt");
Ok(())
}
}
}
async fn restore_signaling_and_webrtc(&self, reason: &str) -> Result<(), String> {
let recovery_targets = if let Some(coordinator) = self.webrtc_coordinator.clone() {
coordinator.begin_network_recovery(reason).await
} else {
Vec::new()
};
self.ensure_signaling_healthy_once(reason).await?;
let coordinator = self.webrtc_coordinator.clone();
if let Some(coordinator) = coordinator {
if recovery_targets.is_empty() {
tracing::info!("♻️ Resuming ICE restart for peers already in network recovery");
} else {
tracing::info!("♻️ Triggering ICE restart for recovering connections...");
}
coordinator.restart_network_recovery_connections().await;
}
Ok(())
}
async fn probe_or_restore(&self, reason: &str) -> Result<(), String> {
match self.probe_connectivity().await {
Ok(()) => Ok(()),
Err(e) => {
tracing::warn!(
reason = reason,
"Connectivity probe failed; restoring connections: {}",
e
);
self.restore_signaling_and_webrtc(reason).await
}
}
}
async fn process_offline(&self) -> Result<(), String> {
tracing::info!("📱 Processing: Network offline");
if let Some(ref coordinator) = self.webrtc_coordinator {
coordinator.begin_network_recovery("NetworkLost").await;
tracing::info!("🧹 Clearing pending ICE restart attempts...");
coordinator.clear_pending_restarts().await;
}
tracing::info!("🔌 Disconnecting WebSocket...");
let _ = self.signaling_client.disconnect().await;
Ok(())
}
}
#[async_trait::async_trait]
impl NetworkEventProcessor for DefaultNetworkEventProcessor {
async fn process_network_available(&self) -> Result<(), String> {
let should_process = self.should_process_event(DebounceEvent::Available).await;
if !should_process && self.signaling_client.is_connected() {
return Ok(());
}
tracing::info!("📱 Processing: Network available");
self.restore_signaling_and_webrtc("NetworkAvailable").await
}
async fn process_network_lost(&self) -> Result<(), String> {
if !self.should_process_event(DebounceEvent::Lost).await {
return Ok(());
}
self.process_offline().await
}
async fn process_network_type_changed(
&self,
is_wifi: bool,
is_cellular: bool,
) -> Result<(), String> {
let should_process = self.should_process_event(DebounceEvent::TypeChanged).await;
if !should_process && self.signaling_client.is_connected() {
return Ok(());
}
tracing::info!(
"📱 Processing: Network type changed (WiFi={}, Cellular={})",
is_wifi,
is_cellular
);
self.restore_signaling_and_webrtc("NetworkTypeChanged")
.await
}
async fn cleanup_connections(&self) -> Result<(), String> {
let _cleanup_guard = self
.webrtc_coordinator
.as_ref()
.map(|coordinator| coordinator.cleanup_guard());
tracing::info!("🧹 Manually cleaning up all connections...");
if let Some(ref coordinator) = self.webrtc_coordinator {
tracing::info!("♻️ Clearing pending ICE restart attempts...");
coordinator.clear_pending_restarts().await;
}
if let Some(ref peer_transport) = self.peer_transport {
tracing::info!("🔻 Closing all PeerTransport connections...");
if let Err(e) = peer_transport.close_all().await {
let err_msg = format!("Failed to close peer transports: {}", e);
tracing::warn!("⚠️ {}", err_msg);
} else {
tracing::info!("✅ All PeerTransport connections closed");
}
}
if let Some(ref coordinator) = self.webrtc_coordinator {
tracing::info!("🔻 Closing all WebRTC peer connections...");
if let Err(e) = coordinator.close_all_peers().await {
let err_msg = format!("Failed to close all peers: {}", e);
tracing::warn!("⚠️ {}", err_msg);
} else {
tracing::info!("✅ All WebRTC peer connections closed");
}
}
tracing::info!("🔌 Disconnecting WebSocket...");
match self.signaling_client.disconnect().await {
Ok(_) => {
tracing::info!("✅ WebSocket disconnected successfully");
}
Err(e) => {
let err_msg = format!("Failed to disconnect WebSocket: {}", e);
tracing::warn!("⚠️ {}", err_msg);
}
}
tracing::info!("✅ Connection cleanup completed");
Ok(())
}
async fn probe_connectivity(&self) -> Result<(), String> {
self.signaling_client
.probe_alive(SIGNALING_PROBE_TIMEOUT)
.await
.map_err(|e| format!("Signaling probe failed: {}", e))
}
async fn force_reconnect(&self) -> Result<(), String> {
self.cleanup_connections().await?;
self.restore_signaling_and_webrtc("ForceReconnect").await
}
async fn process_network_recovery_action(
&self,
action: NetworkRecoveryAction,
) -> Result<(), String> {
match action {
NetworkRecoveryAction::Noop => Ok(()),
NetworkRecoveryAction::Offline => self.process_offline().await,
NetworkRecoveryAction::Probe => self.probe_or_restore("Probe").await,
NetworkRecoveryAction::Restore => {
self.restore_signaling_and_webrtc("NetworkEventBatch").await
}
NetworkRecoveryAction::CleanupOnly => self.cleanup_connections().await,
NetworkRecoveryAction::ForceReconnect => self.force_reconnect().await,
}
}
}
pub fn select_network_recovery_action(events: &[NetworkEvent]) -> NetworkRecoveryAction {
ConnectionSupervisor::select_action(events)
}
pub async fn process_network_event_batch(
events: Vec<NetworkEvent>,
processor: Arc<dyn NetworkEventProcessor>,
) -> Vec<NetworkEventResult> {
if events.is_empty() {
return Vec::new();
}
let action = select_network_recovery_action(&events);
let start = Instant::now();
tracing::info!(
event_count = events.len(),
action = ?action,
"network_event.action.start"
);
let result = processor.process_network_recovery_action(action).await;
let duration_ms = start.elapsed().as_millis() as u64;
match &result {
Ok(()) => tracing::info!(
event_count = events.len(),
action = ?action,
duration_ms,
"network_event.action.completed"
),
Err(e) => tracing::warn!(
event_count = events.len(),
action = ?action,
duration_ms,
error = %e,
"network_event.action.completed"
),
}
events
.into_iter()
.map(|event| match &result {
Ok(()) => NetworkEventResult::success(event, duration_ms),
Err(e) => NetworkEventResult::failure(event, e.clone(), duration_ms),
})
.collect()
}
pub struct NetworkEventRequest {
pub event: NetworkEvent,
pub result_tx: oneshot::Sender<NetworkEventResult>,
}
pub async fn run_network_event_reconciler(
mut event_rx: mpsc::Receiver<NetworkEventRequest>,
processor: Arc<dyn NetworkEventProcessor>,
shutdown_token: CancellationToken,
) {
tracing::info!("🔄 Network event reconciler started");
loop {
tokio::select! {
Some(first_request) = event_rx.recv() => {
tracing::debug!(
event = ?first_request.event,
"network_event.reconciler.received"
);
let mut requests = vec![first_request];
let settle = tokio::time::sleep(NETWORK_EVENT_SETTLE_WINDOW);
tokio::pin!(settle);
loop {
tokio::select! {
Some(next_request) = event_rx.recv() => {
tracing::debug!(
event = ?next_request.event,
"network_event.reconciler.coalesced"
);
requests.push(next_request);
}
_ = &mut settle => {
break;
}
_ = shutdown_token.cancelled() => {
tracing::info!("🛑 Network event reconciler shutting down");
return;
}
else => {
break;
}
}
}
while let Ok(next_request) = event_rx.try_recv() {
tracing::debug!(
event = ?next_request.event,
"network_event.reconciler.coalesced"
);
requests.push(next_request);
}
let events = requests
.iter()
.map(|request| request.event.clone())
.collect::<Vec<_>>();
let action = select_network_recovery_action(&events);
let facts = events
.iter()
.map(ConnectionFact::from_network_event)
.collect::<Vec<_>>();
tracing::info!(
event_count = events.len(),
action = ?action,
events = ?events,
facts = ?facts,
settle_window_ms = NETWORK_EVENT_SETTLE_WINDOW.as_millis() as u64,
"network_event.reconciler.batch_reconciled"
);
let results = process_network_event_batch(events, processor.clone()).await;
for (request, result) in requests.into_iter().zip(results) {
if request.result_tx.send(result).is_err() {
tracing::debug!("Network event caller dropped before receiving result");
}
}
}
_ = shutdown_token.cancelled() => {
tracing::info!("🛑 Network event reconciler shutting down");
break;
}
else => break,
}
}
}
pub struct NetworkEventHandle {
event_tx: mpsc::Sender<NetworkEventRequest>,
result_timeout: Duration,
}
impl NetworkEventHandle {
pub fn new(event_tx: mpsc::Sender<NetworkEventRequest>) -> Self {
Self::new_with_result_timeout(event_tx, NETWORK_EVENT_RESULT_TIMEOUT)
}
pub fn new_with_result_timeout(
event_tx: mpsc::Sender<NetworkEventRequest>,
result_timeout: Duration,
) -> Self {
Self {
event_tx,
result_timeout,
}
}
pub async fn handle_network_path_changed(
&self,
snapshot: NetworkSnapshot,
) -> Result<NetworkEventResult, String> {
self.send_event_and_await_result(NetworkEvent::NetworkPathChanged { snapshot })
.await
}
pub async fn handle_app_lifecycle_changed(
&self,
state: AppLifecycleState,
) -> Result<NetworkEventResult, String> {
self.send_event_and_await_result(NetworkEvent::AppLifecycleChanged { state })
.await
}
pub async fn cleanup_connections(
&self,
reason: CleanupReason,
) -> Result<NetworkEventResult, String> {
self.send_event_and_await_result(NetworkEvent::CleanupConnections { reason })
.await
}
pub async fn force_reconnect(
&self,
reason: ReconnectReason,
) -> Result<NetworkEventResult, String> {
self.send_event_and_await_result(NetworkEvent::ForceReconnect { reason })
.await
}
async fn send_event_and_await_result(
&self,
event: NetworkEvent,
) -> Result<NetworkEventResult, String> {
let event_request_id = NEXT_NETWORK_EVENT_REQUEST_ID.fetch_add(1, Ordering::Relaxed);
let start = Instant::now();
let (result_tx, result_rx) = oneshot::channel();
let request = NetworkEventRequest {
event: event.clone(),
result_tx,
};
tracing::info!(
event_request_id,
event = ?event,
result_timeout_ms = self.result_timeout.as_millis() as u64,
"network_event.handle.enqueue"
);
if let Err(e) = self.event_tx.send(request).await {
let err = format!("Failed to send network event: {}", e);
tracing::warn!(
event_request_id,
event = ?event,
error = %err,
"network_event.handle.enqueue_failed"
);
return Err(err);
}
let result = match tokio::time::timeout(self.result_timeout, result_rx).await {
Ok(Ok(result)) => Ok(result),
Ok(Err(_)) => Err("Failed to receive network event result".to_string()),
Err(_) => Err(format!(
"Timed out waiting for network event result after {}ms",
self.result_timeout.as_millis()
)),
};
let wait_ms = start.elapsed().as_millis() as u64;
match &result {
Ok(result) if result.success => tracing::info!(
event_request_id,
event = ?event,
result_event = ?result.event,
duration_ms = result.duration_ms,
wait_ms,
"network_event.handle.result_received"
),
Ok(result) => tracing::warn!(
event_request_id,
event = ?event,
result_event = ?result.event,
duration_ms = result.duration_ms,
wait_ms,
error = ?result.error,
"network_event.handle.result_received"
),
Err(e) => tracing::warn!(
event_request_id,
event = ?event,
wait_ms,
error = %e,
"network_event.handle.result_failed"
),
}
result
}
}
impl Clone for NetworkEventHandle {
fn clone(&self) -> Self {
Self {
event_tx: self.event_tx.clone(),
result_timeout: self.result_timeout,
}
}
}