actr_runtime/wire/webrtc/
signaling.rs

1//! signaling clientImplementation
2//!
3//! Based on protobuf Definition'ssignalingprotocol, using SignalingEnvelope conclude construct
4
5#[cfg(feature = "opentelemetry")]
6use super::trace;
7use crate::lifecycle::CredentialState;
8use crate::transport::error::{NetworkError, NetworkResult};
9#[cfg(feature = "opentelemetry")]
10use crate::wire::webrtc::trace::extract_trace_context;
11#[cfg(feature = "opentelemetry")]
12use actr_protocol::ActrIdExt;
13use actr_protocol::prost::Message as ProstMessage;
14use actr_protocol::{
15    AIdCredential, ActrId, ActrToSignaling, CredentialUpdateRequest, PeerToSignaling, Ping, Pong,
16    RegisterRequest, RegisterResponse, RouteCandidatesRequest, RouteCandidatesResponse,
17    ServiceAvailabilityState, SignalingEnvelope, UnregisterRequest, UnregisterResponse,
18    actr_to_signaling, peer_to_signaling, signaling_envelope, signaling_to_actr,
19};
20use async_trait::async_trait;
21use base64::Engine as _;
22use futures_util::{SinkExt, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::sync::{
26    Arc,
27    atomic::{AtomicBool, AtomicU64, Ordering},
28};
29use tokio::net::TcpStream;
30use tokio::sync::{mpsc, oneshot, watch};
31use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
32use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config};
33#[cfg(feature = "opentelemetry")]
34use tracing_opentelemetry::OpenTelemetrySpanExt;
35use url::Url;
36
37// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
38// Constants
39// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
40
41/// Default timeout in seconds for waiting for signaling response
42const RESPONSE_TIMEOUT_SECS: u64 = 15;
43// WebSocket-level keepalive to detect silent half-open connections
44const PING_INTERVAL_SECS: u64 = 5;
45const PONG_TIMEOUT_SECS: u64 = 10;
46
47// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
48// configurationType
49// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
50
51/// signalingconfiguration
52#[derive(Debug, Clone)]
53pub struct SignalingConfig {
54    /// signaling server URL
55    pub server_url: Url,
56
57    /// Connecttimeout temporal duration (seconds)
58    pub connection_timeout: u64,
59
60    /// center skipinterval(seconds)
61    pub heartbeat_interval: u64,
62
63    /// reconnection configuration
64    pub reconnect_config: ReconnectConfig,
65
66    /// acknowledge verify configuration
67    pub auth_config: Option<AuthConfig>,
68}
69
70/// reconnection configuration
71#[derive(Debug, Clone)]
72pub struct ReconnectConfig {
73    /// whether start usage automatic reconnection
74    pub enabled: bool,
75
76    /// maximum reconnection attempts
77    pub max_attempts: u32,
78
79    /// initial reconnection delay(seconds)
80    pub initial_delay: u64,
81
82    /// maximum reconnection delay(seconds)
83    pub max_delay: u64,
84
85    /// Backoff multiplier factor
86    pub backoff_multiplier: f64,
87}
88
89impl Default for ReconnectConfig {
90    fn default() -> Self {
91        Self {
92            enabled: true,
93            max_attempts: 10,
94            initial_delay: 1,
95            max_delay: 60,
96            backoff_multiplier: 2.0,
97        }
98    }
99}
100
101/// acknowledge verify configuration
102#[derive(Debug, Clone)]
103pub struct AuthConfig {
104    /// acknowledge verify Type
105    pub auth_type: AuthType,
106
107    /// acknowledge verify credential data
108    pub credentials: HashMap<String, String>,
109}
110
111/// acknowledge verify Type
112#[derive(Debug, Clone)]
113pub enum AuthType {
114    /// no acknowledge verify
115    None,
116    /// Bearer Token
117    BearerToken,
118    /// API Key
119    ApiKey,
120    /// JWT
121    Jwt,
122}
123
124// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
125// Client interface and implementation
126// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
127
128/// signaling client connect port
129///
130/// # interior mutability
131/// allMethodusing `&self` and non `&mut self`, with for conveniencein Arc in shared.
132/// Implementation class needs interior mutability ( like Mutex)to manage WebSocket connection status.
133#[async_trait]
134pub trait SignalingClient: Send + Sync {
135    /// Connecttosignaling server
136    async fn connect(&self) -> NetworkResult<()>;
137
138    /// DisconnectConnect
139    async fn disconnect(&self) -> NetworkResult<()>;
140
141    /// SendRegisterrequest(Register front stream process, using PeerToSignaling)
142    async fn send_register_request(
143        &self,
144        request: RegisterRequest,
145    ) -> NetworkResult<RegisterResponse>;
146
147    /// Send UnregisterRequest to signaling server (Actr → Signaling flow)
148    ///
149    /// This is used when an Actor is shutting down gracefully and wants to
150    /// proactively notify the signaling server that it is no longer available.
151    async fn send_unregister_request(
152        &self,
153        actor_id: ActrId,
154        credential: AIdCredential,
155        reason: Option<String>,
156    ) -> NetworkResult<UnregisterResponse>;
157
158    /// Send center skip(Registerafter stream process, using ActrToSignaling)
159    /// Returns Pong response if received, error if timeout or no response
160    async fn send_heartbeat(
161        &self,
162        actor_id: ActrId,
163        credential: AIdCredential,
164        availability: ServiceAvailabilityState,
165        power_reserve: f32,
166        mailbox_backlog: f32,
167    ) -> NetworkResult<Pong>;
168
169    /// Send RouteCandidatesRequest (requires authenticated Actor session)
170    async fn send_route_candidates_request(
171        &self,
172        actor_id: ActrId,
173        credential: AIdCredential,
174        request: RouteCandidatesRequest,
175    ) -> NetworkResult<RouteCandidatesResponse>;
176
177    /// Send CredentialUpdateRequest to refresh the Actor's credential
178    ///
179    /// This is used to refresh the credential before it expires. The server responds
180    /// with a RegisterResponse containing the new credential and expiration time.
181    async fn send_credential_update_request(
182        &self,
183        actor_id: ActrId,
184        credential: AIdCredential,
185    ) -> NetworkResult<RegisterResponse>;
186
187    /// Sendsignalingsignal seal ( pass usage Method)
188    async fn send_envelope(&self, envelope: SignalingEnvelope) -> NetworkResult<()>;
189
190    /// Receivesignalingsignal seal
191    async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>>;
192
193    /// Check connection status
194    fn is_connected(&self) -> bool;
195
196    /// GetConnect statistics info
197    fn get_stats(&self) -> SignalingStats;
198    /// Subscribe to connection state changes (Connected/Disconnected).
199    fn subscribe_state(&self) -> watch::Receiver<ConnectionState>;
200
201    /// Set actor ID and credential state for reconnect URL parameters.
202    async fn set_actor_id(&self, actor_id: ActrId);
203    async fn set_credential_state(&self, credential_state: CredentialState);
204}
205
206/// High-level signaling connection state.
207#[derive(Debug, Clone, Copy, PartialEq, Eq)]
208pub enum ConnectionState {
209    Disconnected,
210    Connected,
211}
212
213/// WebSocket signaling clientImplementation
214pub struct WebSocketSignalingClient {
215    config: SignalingConfig,
216    actor_id: tokio::sync::Mutex<Option<ActrId>>,
217    credential_state: tokio::sync::Mutex<Option<CredentialState>>,
218    /// WebSocket write end (using Mutex Implementation interior mutability )
219    ws_sink: Arc<
220        tokio::sync::Mutex<
221            Option<
222                futures_util::stream::SplitSink<
223                    WebSocketStream<MaybeTlsStream<TcpStream>>,
224                    tokio_tungstenite::tungstenite::Message,
225                >,
226            >,
227        >,
228    >,
229    /// WebSocket read end (using Mutex Implementation interior mutability )
230    ws_stream: tokio::sync::Mutex<
231        Option<futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
232    >,
233    /// connection status
234    connected: Arc<AtomicBool>,
235    /// Connection in progress flag (prevents concurrent connect attempts)
236    connecting: Arc<AtomicBool>,
237    /// statistics info
238    stats: Arc<AtomicSignalingStats>,
239    /// Envelope count number device
240    envelope_counter: tokio::sync::Mutex<u64>,
241    /// Pending reply waiters (reply_for -> oneshot)
242    pending_replies: Arc<tokio::sync::Mutex<HashMap<String, oneshot::Sender<SignalingEnvelope>>>>,
243    /// Inbound envelope channel for unmatched messages (ActrRelay / push)
244    inbound_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<SignalingEnvelope>>>,
245    inbound_tx: tokio::sync::Mutex<mpsc::UnboundedSender<SignalingEnvelope>>,
246    /// Background receive task handle to allow graceful shutdown
247    receiver_task: Arc<tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>>,
248    /// Background ping task to detect half-open connections
249    ping_task: tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
250    /// Connection state broadcast channel
251    state_tx: watch::Sender<ConnectionState>,
252    /// Last time we saw inbound traffic (pong/any message), unix epoch seconds
253    last_pong: Arc<AtomicU64>,
254    /// Flag to track if auto-reconnector has been started (used with config.reconnect_config.enabled)
255    reconnector_started: Arc<AtomicBool>,
256}
257
258impl WebSocketSignalingClient {
259    /// Create new WebSocket signaling client
260    pub fn new(config: SignalingConfig) -> Self {
261        let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
262        let (state_tx, _state_rx) = watch::channel(ConnectionState::Disconnected);
263        Self {
264            config,
265            actor_id: tokio::sync::Mutex::new(None),
266            credential_state: tokio::sync::Mutex::new(None),
267            ws_sink: Arc::new(tokio::sync::Mutex::new(None)),
268            ws_stream: tokio::sync::Mutex::new(None),
269            connected: Arc::new(AtomicBool::new(false)),
270            connecting: Arc::new(AtomicBool::new(false)),
271            stats: Arc::new(AtomicSignalingStats::default()),
272            envelope_counter: tokio::sync::Mutex::new(0),
273            pending_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
274            inbound_rx: Arc::new(tokio::sync::Mutex::new(inbound_rx)),
275            inbound_tx: tokio::sync::Mutex::new(inbound_tx),
276            receiver_task: Arc::new(tokio::sync::Mutex::new(None)),
277            ping_task: tokio::sync::Mutex::new(None),
278            state_tx,
279            last_pong: Arc::new(AtomicU64::new(0)),
280            reconnector_started: Arc::new(AtomicBool::new(false)),
281        }
282    }
283
284    /// Start the auto-reconnector if enabled in config and not already started.
285    ///
286    /// This should be called after wrapping in Arc, typically right after creation or
287    /// on first connect(). It's safe to call multiple times - reconnector starts only once.
288    pub fn start_auto_reconnector(self: &Arc<Self>) {
289        // Check if auto-reconnect is enabled and not already started
290        if self.config.reconnect_config.enabled
291            && self
292                .reconnector_started
293                .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
294                .is_ok()
295        {
296            tracing::info!("🔄 Starting auto-reconnector for signaling client");
297
298            let self_clone = self.clone();
299            let mut state_rx = self.subscribe_state();
300
301            tokio::spawn(async move {
302                loop {
303                    match state_rx.changed().await {
304                        Err(_) => {
305                            // State channel closed, client is being dropped
306                            tracing::info!("� Signaling client dropped, stopping reconnect helper");
307                            break;
308                        }
309                        Ok(_) => {
310                            if *state_rx.borrow() == ConnectionState::Disconnected {
311                                // Cleanup old WebSocket resources before reconnecting
312                                tracing::debug!(
313                                    "🧹 Cleaning up old WebSocket resources before reconnect"
314                                );
315                                if let Err(e) = self_clone.disconnect().await {
316                                    tracing::warn!("⚠️ Disconnect cleanup failed (non-fatal): {e}");
317                                }
318
319                                tracing::warn!(
320                                    "📡 Signaling state is DISCONNECTED, attempting reconnect"
321                                );
322                                if let Err(e) = self_clone.connect().await {
323                                    tracing::error!("❌ Signaling reconnect failed: {e}");
324                                } else {
325                                    tracing::info!("✅ Signaling reconnect succeeded");
326                                }
327                            }
328                        }
329                    }
330                }
331            });
332        }
333    }
334
335    /// simple for convenience construct create Function
336    pub async fn connect_to(url: &str) -> NetworkResult<Arc<Self>> {
337        let config = SignalingConfig {
338            server_url: url.parse()?,
339            connection_timeout: 5,
340            heartbeat_interval: 30,
341            reconnect_config: ReconnectConfig::default(),
342            auth_config: None,
343        };
344
345        let client = Arc::new(Self::new(config));
346        client.start_auto_reconnector();
347        client.connect().await?;
348        Ok(client)
349    }
350
351    /// alive integrate down a envelope ID
352    async fn next_envelope_id(&self) -> String {
353        let mut counter = self.envelope_counter.lock().await;
354        *counter += 1;
355        format!("env-{}", *counter)
356    }
357
358    /// Create SignalingEnvelope
359    async fn create_envelope(&self, flow: signaling_envelope::Flow) -> SignalingEnvelope {
360        SignalingEnvelope {
361            envelope_version: 1,
362            envelope_id: self.next_envelope_id().await,
363            reply_for: None,
364            timestamp: prost_types::Timestamp {
365                seconds: chrono::Utc::now().timestamp(),
366                nanos: 0,
367            },
368            traceparent: None,
369            tracestate: None,
370            flow: Some(flow),
371        }
372    }
373
374    /// Reset inbound channel for a fresh session (useful after disconnects).
375    async fn reset_inbound_channel(&self) {
376        let (tx, rx) = mpsc::unbounded_channel();
377        *self.inbound_tx.lock().await = tx;
378        *self.inbound_rx.lock().await = rx;
379    }
380
381    /// Build signaling URL, attaching actor identity/token if available for reconnects.
382    async fn build_url_with_identity(&self) -> Url {
383        let mut url = self.config.server_url.clone();
384        let actor_id_opt = self.actor_id.lock().await.clone();
385        let credential_state_opt = self.credential_state.lock().await.clone();
386        if let (Some(actor_id), Some(credential_state)) = (actor_id_opt, credential_state_opt) {
387            let credential = credential_state.credential().await;
388            let actor_str = actr_protocol::ActrIdExt::to_string_repr(&actor_id);
389            let token_b64 =
390                base64::engine::general_purpose::STANDARD.encode(&credential.encrypted_token);
391            {
392                let mut pairs = url.query_pairs_mut();
393                pairs.append_pair("actor_id", &actor_str);
394                pairs.append_pair("token", &token_b64);
395                pairs.append_pair("token_key_id", &credential.token_key_id.to_string());
396            }
397        }
398        url
399    }
400
401    /// Establish a single signaling WebSocket connection attempt, honoring connection_timeout.
402    ///
403    /// This does not perform any retry logic; callers that want retries should wrap this.
404    async fn establish_connection_once(&self) -> NetworkResult<()> {
405        let url = self.build_url_with_identity().await;
406        let timeout_secs = self.config.connection_timeout;
407        tracing::debug!("Establishing connection to URL: {}", url.as_str());
408        // 断网后,写入到缓冲区的数据,网络恢复后会继续发送
409        let config = WebSocketConfig::default().write_buffer_size(0);
410        // Connect with an optional timeout. A timeout of 0 means "no timeout".
411        let connect_result = if timeout_secs == 0 {
412            connect_async_with_config(url.as_str(), Some(config), false).await
413        } else {
414            let timeout_duration = std::time::Duration::from_secs(timeout_secs);
415            tokio::time::timeout(
416                timeout_duration,
417                connect_async_with_config(url.as_str(), Some(config), false),
418            )
419            .await
420            .map_err(|_| {
421                NetworkError::ConnectionError(format!(
422                    "Signaling connect timeout after {}s",
423                    timeout_secs
424                ))
425            })?
426        }?;
427
428        let (ws_stream, _) = connect_result;
429
430        // Split read/write halves and initialize client state
431        let (sink, stream) = ws_stream.split();
432
433        *self.ws_sink.lock().await = Some(sink);
434        *self.ws_stream.lock().await = Some(stream);
435        self.connected.store(true, Ordering::Release);
436        self.last_pong.store(current_unix_secs(), Ordering::Release);
437        // Notify listeners that we are now connected
438        let _ = self.state_tx.send(ConnectionState::Connected);
439
440        self.stats.connections.fetch_add(1, Ordering::Relaxed);
441
442        Ok(())
443    }
444
445    /// Connect to signaling server with retry and exponential backoff based on reconnect_config.
446    async fn connect_with_retries(&self) -> NetworkResult<()> {
447        let cfg = &self.config.reconnect_config;
448
449        // If reconnect is disabled, just attempt once.
450        if !cfg.enabled {
451            return self.establish_connection_once().await;
452        }
453
454        let mut attempt: u32 = 0;
455        let mut delay_secs = cfg.initial_delay.max(1);
456
457        loop {
458            attempt += 1;
459
460            match self.establish_connection_once().await {
461                Ok(()) => {
462                    return Ok(());
463                }
464                Err(e) => {
465                    tracing::warn!("Signaling connect attempt {} failed: {e:?}", attempt);
466
467                    if attempt >= cfg.max_attempts {
468                        tracing::error!(
469                            "Signaling connect failed after {} attempts, giving up",
470                            attempt
471                        );
472                        return Err(e);
473                    }
474
475                    let sleep_secs = delay_secs.min(cfg.max_delay.max(1));
476                    tracing::info!("Retry signaling connect after {}s", sleep_secs);
477                    tokio::time::sleep(std::time::Duration::from_secs(sleep_secs)).await;
478
479                    // Exponential backoff for next attempt
480                    delay_secs = ((delay_secs as f64) * cfg.backoff_multiplier)
481                        .round()
482                        .max(1.0) as u64;
483                }
484            }
485        }
486    }
487
488    /// Send envelope and wait for response with timeout and error handling.
489    #[cfg_attr(
490        feature = "opentelemetry",
491        tracing::instrument(skip_all, fields(envelope_id = %envelope.envelope_id))
492    )]
493    async fn send_envelope_and_wait_response(
494        &self,
495        envelope: SignalingEnvelope,
496    ) -> NetworkResult<SignalingEnvelope> {
497        let reply_for = envelope.envelope_id.clone();
498
499        // Register waiter before sending
500        let (tx, rx) = oneshot::channel();
501        self.pending_replies
502            .lock()
503            .await
504            .insert(reply_for.clone(), tx);
505
506        if let Err(e) = self.send_envelope(envelope).await {
507            // Cleanup waiter on immediate send failure to avoid leaks.
508            self.pending_replies.lock().await.remove(&reply_for);
509            return Err(e);
510        }
511
512        let result =
513            tokio::time::timeout(std::time::Duration::from_secs(RESPONSE_TIMEOUT_SECS), rx).await;
514        // Clean up waiter on timeout
515        if result.is_err() {
516            self.pending_replies.lock().await.remove(&reply_for);
517        }
518
519        let response_envelope = result
520            .map_err(|_| {
521                NetworkError::ConnectionError(
522                    "Timed out waiting for signaling response".to_string(),
523                )
524            })?
525            .map_err(|_| {
526                NetworkError::ConnectionError(
527                    "Receiver dropped while waiting for signaling response".to_string(),
528                )
529            })?;
530
531        Ok(response_envelope)
532    }
533
534    /// Spawn background receiver to demux envelopes by reply_for.
535    async fn start_receiver(&self) {
536        let mut stream_guard = self.ws_stream.lock().await;
537        if stream_guard.is_none() {
538            return;
539        }
540
541        let mut stream = stream_guard.take().expect("stream exists");
542        let pending = self.pending_replies.clone();
543        let inbound_tx = { self.inbound_tx.lock().await.clone() };
544        let stats = self.stats.clone();
545        let connected = self.connected.clone();
546        let state_tx = self.state_tx.clone();
547        let last_pong = self.last_pong.clone();
548        let handle = tokio::spawn(async move {
549            while let Some(msg) = stream.next().await {
550                match msg {
551                    Ok(tokio_tungstenite::tungstenite::Message::Binary(data)) => {
552                        // Any inbound traffic counts as liveness
553                        last_pong.store(current_unix_secs(), Ordering::Release);
554                        match SignalingEnvelope::decode(&data[..]) {
555                            Ok(envelope) => {
556                                #[cfg(feature = "opentelemetry")]
557                                let span = {
558                                    let span = tracing::info_span!("signaling.receive_envelope", envelope_id = %envelope.envelope_id);
559                                    span.set_parent(extract_trace_context(&envelope));
560                                    span
561                                };
562
563                                stats.messages_received.fetch_add(1, Ordering::Relaxed);
564                                tracing::debug!("Received message: {:?}", envelope);
565                                if let Some(reply_for) = envelope.reply_for.clone() {
566                                    if let Some(sender) = pending.lock().await.remove(&reply_for) {
567                                        #[cfg(feature = "opentelemetry")]
568                                        let _ = span.enter();
569                                        if let Err(e) = sender.send(envelope) {
570                                            stats.errors.fetch_add(1, Ordering::Relaxed);
571                                            tracing::warn!(
572                                                "Failed to send reply envelope to waiter: {e:?}",
573                                            );
574                                        }
575                                        continue;
576                                    }
577                                }
578                                tracing::debug!(
579                                    "Unmatched or push message -> forward to inbound channel"
580                                );
581                                // Unmatched or push message -> forward to inbound channel
582                                if let Err(e) = inbound_tx.send(envelope) {
583                                    stats.errors.fetch_add(1, Ordering::Relaxed);
584                                    tracing::warn!(
585                                        "Failed to send envelope to inbound channel: {e:?}"
586                                    );
587                                }
588                            }
589                            Err(e) => {
590                                stats.errors.fetch_add(1, Ordering::Relaxed);
591                                tracing::warn!("Failed to decode SignalingEnvelope: {e}");
592                            }
593                        }
594                    }
595                    Ok(tokio_tungstenite::tungstenite::Message::Pong(_)) => {
596                        tracing::debug!("Received pong");
597                        last_pong.store(current_unix_secs(), Ordering::Release);
598                    }
599                    Ok(tokio_tungstenite::tungstenite::Message::Ping(_)) => {
600                        tracing::debug!("Received ping");
601                        last_pong.store(current_unix_secs(), Ordering::Release);
602                    }
603                    Ok(_) => {
604                        tracing::warn!("Received non-binary frame, ignoring");
605                    }
606                    Err(e) => {
607                        stats.errors.fetch_add(1, Ordering::Relaxed);
608                        tracing::error!("Signaling receive error: {e}");
609                        break;
610                    }
611                }
612            }
613
614            // Reaching here means the underlying WebSocket stream has terminated.
615            connected.store(false, Ordering::Release);
616            stats.disconnections.fetch_add(1, Ordering::Relaxed);
617            let _ = state_tx.send(ConnectionState::Disconnected);
618        });
619
620        *self.receiver_task.lock().await = Some(handle);
621    }
622
623    /// Spawn background ping task to detect half-open connections where writes do not fail but peer is gone.
624    /// fixme: merge to heartbeat task
625    async fn start_ping_task(&self) {
626        let mut existing = self.ping_task.lock().await;
627        if let Some(handle) = existing.as_ref() {
628            if handle.is_finished() {
629                existing.take();
630            } else {
631                return;
632            }
633        }
634
635        let sink = self.ws_sink.clone();
636        let connected = self.connected.clone();
637        let state_tx = self.state_tx.clone();
638        let last_pong = self.last_pong.clone();
639        let receiver_task_clone = Arc::clone(&self.receiver_task);
640
641        let handle = tokio::spawn(async move {
642            loop {
643                tokio::time::sleep(std::time::Duration::from_secs(PING_INTERVAL_SECS)).await;
644
645                if !connected.load(Ordering::Acquire) {
646                    break;
647                }
648
649                // Send ping; mark disconnect on failure.
650                let mut sink_guard = sink.lock().await;
651                if let Some(sink) = sink_guard.as_mut() {
652                    if let Err(e) = sink
653                        .send(tokio_tungstenite::tungstenite::Message::Ping(
654                            Vec::new().into(),
655                        ))
656                        .await
657                    {
658                        tracing::warn!("Signaling ping send failed: {e}");
659                        connected.store(false, Ordering::Release);
660                        let _ = state_tx.send(ConnectionState::Disconnected);
661                        break;
662                    }
663                } else {
664                    tracing::warn!("Signaling not connected");
665                    connected.store(false, Ordering::Release);
666                    let _ = state_tx.send(ConnectionState::Disconnected);
667                    break;
668                }
669                drop(sink_guard);
670
671                // Check for stale pong
672                let now = current_unix_secs();
673                let last = last_pong.load(Ordering::Acquire);
674                if now.saturating_sub(last) > PONG_TIMEOUT_SECS {
675                    tracing::warn!(
676                        "Signaling pong timeout (last seen {}s ago), marking disconnected",
677                        now.saturating_sub(last)
678                    );
679                    connected.store(false, Ordering::Release);
680                    let _ = state_tx.send(ConnectionState::Disconnected);
681                    if let Some(handle) = receiver_task_clone.lock().await.take() {
682                        handle.abort();
683                    }
684                    break;
685                }
686            }
687        });
688
689        *existing = Some(handle);
690    }
691
692    /// Wait for ongoing connection attempt to complete (used when another task is connecting).
693    ///
694    /// This uses the watch channel to efficiently wait for state changes instead of polling.
695    async fn wait_for_connection_result(&self) -> NetworkResult<()> {
696        let mut state_rx = self.subscribe_state();
697        let timeout = tokio::time::sleep(std::time::Duration::from_secs(30));
698        tokio::pin!(timeout);
699
700        loop {
701            tokio::select! {
702                _ = &mut timeout => {
703                    // Timeout: check final state
704                    if self.connected.load(Ordering::Acquire) {
705                        tracing::debug!("Connection succeeded just before timeout");
706                        return Ok(());
707                    }
708
709                    // Check if we can retry (connecting flag cleared)
710                    if !self.connecting.load(Ordering::Acquire) {
711                        tracing::warn!("Other connection attempt failed/timed out, will retry");
712                        // Recursively call connect() to retry
713                        return self.connect().await;
714                    }
715
716                    return Err(NetworkError::ConnectionError(
717                        "Timeout waiting for concurrent connection attempt".to_string(),
718                    ));
719                }
720
721                result = state_rx.changed() => {
722                    if result.is_err() {
723                        return Err(NetworkError::ConnectionError(
724                            "State channel closed while waiting for connection".to_string(),
725                        ));
726                    }
727
728                    let state = *state_rx.borrow();
729                    match state {
730                        ConnectionState::Connected => {
731                            tracing::debug!("Connection established by another task");
732                            return Ok(());
733                        }
734                        ConnectionState::Disconnected => {
735                            // Check if the connecting task gave up
736                            if !self.connecting.load(Ordering::Acquire) {
737                                tracing::warn!("Other connection attempt failed, will retry");
738                                // Recursively call connect() to retry with fresh attempt
739                                return self.connect().await;
740                            }
741                            // Otherwise, keep waiting (might be transient state)
742                        }
743                    }
744                }
745            }
746        }
747    }
748}
749
750#[async_trait]
751impl SignalingClient for WebSocketSignalingClient {
752    async fn connect(&self) -> NetworkResult<()> {
753        // 🔐 Fast path: Check if already connected
754        if self.connected.load(Ordering::Acquire) {
755            tracing::debug!("Already connected, skipping connect()");
756            return Ok(());
757        }
758
759        // 🔐 Try to acquire "connecting" lock using compare-and-swap
760        if self
761            .connecting
762            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
763            .is_err()
764        {
765            // Another task is connecting, wait for state change using watch channel
766            tracing::debug!("Another connection attempt in progress, waiting for state change...");
767
768            return self.wait_for_connection_result().await;
769        }
770
771        // 🔐 We now hold the "connecting" lock, proceed with connection
772        tracing::debug!("Acquired connection lock, establishing connection...");
773
774        // Perform actual connection
775        let result = self.connect_with_retries().await;
776
777        // Clear "connecting" flag regardless of result
778        self.connecting.store(false, Ordering::Release);
779
780        // Handle connection result
781        match result {
782            Ok(()) => {
783                self.start_receiver().await;
784                self.start_ping_task().await;
785                Ok(())
786            }
787            Err(e) => {
788                // Explicitly notify waiting tasks that connection failed
789                // This allows them to retry immediately instead of waiting for timeout
790                let _ = self.state_tx.send(ConnectionState::Disconnected);
791                tracing::error!("Connection failed: {e}");
792                Err(e)
793            }
794        }
795    }
796
797    async fn disconnect(&self) -> NetworkResult<()> {
798        // fetch exit sink and stream
799        let mut sink_guard = self.ws_sink.lock().await;
800        let mut stream_guard = self.ws_stream.lock().await;
801
802        // Close sink
803        if let Some(mut sink) = sink_guard.take() {
804            let _ = sink.close().await;
805        }
806
807        // clear blank stream
808        stream_guard.take();
809
810        // Stop receiver task if running
811        if let Some(handle) = self.receiver_task.lock().await.take() {
812            handle.abort();
813        }
814        // Stop ping task if running
815        if let Some(handle) = self.ping_task.lock().await.take() {
816            handle.abort();
817        }
818
819        self.reset_inbound_channel().await;
820
821        self.connected.store(false, Ordering::Release);
822        self.stats.disconnections.fetch_add(1, Ordering::Relaxed);
823
824        Ok(())
825    }
826
827    #[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
828    async fn send_register_request(
829        &self,
830        request: RegisterRequest,
831    ) -> NetworkResult<RegisterResponse> {
832        // Create PeerToSignaling stream process (Register front )
833        let flow = signaling_envelope::Flow::PeerToServer(PeerToSignaling {
834            payload: Some(peer_to_signaling::Payload::RegisterRequest(request)),
835        });
836
837        let envelope = self.create_envelope(flow).await;
838        let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
839
840        if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
841        {
842            if let Some(signaling_to_actr::Payload::RegisterResponse(response)) =
843                server_to_actr.payload
844            {
845                return Ok(response);
846            }
847        }
848
849        Err(NetworkError::ConnectionError(
850            "Invalid registration response".to_string(),
851        ))
852    }
853
854    #[cfg_attr(
855        feature = "opentelemetry",
856        tracing::instrument(skip_all, fields(actor_id = %actor_id.to_string_repr()))
857    )]
858    async fn send_unregister_request(
859        &self,
860        actor_id: ActrId,
861        credential: AIdCredential,
862        reason: Option<String>,
863    ) -> NetworkResult<UnregisterResponse> {
864        // Build UnregisterRequest payload
865        let request = UnregisterRequest {
866            actr_id: actor_id.clone(),
867            reason,
868        };
869
870        // Wrap into ActrToSignaling flow
871        let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
872            source: actor_id,
873            credential,
874            payload: Some(actr_to_signaling::Payload::UnregisterRequest(request)),
875        });
876
877        // Send envelope (fire-and-forget)
878        let envelope = self.create_envelope(flow).await;
879        self.send_envelope(envelope).await?;
880
881        // Do not wait for UnregisterResponse here because the signaling stream
882        // is also consumed by WebRtcCoordinator. Waiting could race with that loop
883        // and lead to spurious timeouts. Treat Unregister as best-effort.
884        // not wait for the response , because the signaling stream have multi customers use it, fixme: should wait for the response
885        Ok(UnregisterResponse {
886            result: Some(actr_protocol::unregister_response::Result::Success(
887                actr_protocol::unregister_response::UnregisterOk {},
888            )),
889        })
890    }
891
892    #[cfg_attr(
893        feature = "opentelemetry",
894        tracing::instrument(level = "debug", skip_all, fields(actor_id = %actor_id.to_string_repr()))
895    )]
896    async fn send_heartbeat(
897        &self,
898        actor_id: ActrId,
899        credential: AIdCredential,
900        availability: ServiceAvailabilityState,
901        power_reserve: f32,
902        mailbox_backlog: f32,
903    ) -> NetworkResult<Pong> {
904        let ping = Ping {
905            availability: availability as i32,
906            power_reserve,
907            mailbox_backlog,
908            sticky_client_ids: vec![], // TODO: Implement sticky session tracking
909        };
910
911        let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
912            source: actor_id,
913            credential,
914            payload: Some(actr_to_signaling::Payload::Ping(ping)),
915        });
916
917        let envelope = self.create_envelope(flow).await;
918        let reply_for = envelope.envelope_id.clone();
919
920        // Register waiter before sending
921        let (tx, rx) = oneshot::channel();
922        self.pending_replies
923            .lock()
924            .await
925            .insert(reply_for.clone(), tx);
926
927        if let Err(e) = self.send_envelope(envelope).await {
928            // Cleanup waiter on immediate send failure to avoid leaks.
929            self.pending_replies.lock().await.remove(&reply_for);
930            return Err(e);
931        }
932
933        // Wait for response
934        let response_envelope = rx.await.map_err(|_| {
935            NetworkError::ConnectionError(
936                "Receiver dropped while waiting for heartbeat response".to_string(),
937            )
938        })?;
939
940        // Extract Pong from response
941        if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
942        {
943            if let Some(signaling_to_actr::Payload::Pong(pong)) = server_to_actr.payload {
944                return Ok(pong);
945            }
946        }
947
948        Err(NetworkError::ConnectionError(
949            "Received response but not a Pong message".to_string(),
950        ))
951    }
952
953    #[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
954    async fn send_route_candidates_request(
955        &self,
956        actor_id: ActrId,
957        credential: AIdCredential,
958        request: RouteCandidatesRequest,
959    ) -> NetworkResult<RouteCandidatesResponse> {
960        let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
961            source: actor_id,
962            credential,
963            payload: Some(actr_to_signaling::Payload::RouteCandidatesRequest(request)),
964        });
965
966        let envelope = self.create_envelope(flow).await;
967        let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
968
969        if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
970        {
971            match server_to_actr.payload {
972                Some(signaling_to_actr::Payload::RouteCandidatesResponse(response)) => {
973                    return Ok(response);
974                }
975                Some(signaling_to_actr::Payload::Error(err)) => {
976                    return Err(NetworkError::ServiceDiscoveryError(format!(
977                        "{} ({})",
978                        err.message, err.code
979                    )));
980                }
981                _ => {}
982            }
983        }
984
985        Err(NetworkError::ConnectionError(
986            "Invalid route candidates response".to_string(),
987        ))
988    }
989
990    #[cfg_attr(
991        feature = "opentelemetry",
992        tracing::instrument(level = "debug", skip_all, fields(actor_id = %actor_id.to_string_repr()))
993    )]
994    async fn send_credential_update_request(
995        &self,
996        actor_id: ActrId,
997        credential: AIdCredential,
998    ) -> NetworkResult<RegisterResponse> {
999        let request = CredentialUpdateRequest {
1000            actr_id: actor_id.clone(),
1001        };
1002
1003        let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
1004            source: actor_id,
1005            credential,
1006            payload: Some(actr_to_signaling::Payload::CredentialUpdateRequest(request)),
1007        });
1008
1009        let envelope = self.create_envelope(flow).await;
1010        let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
1011
1012        if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
1013        {
1014            match server_to_actr.payload {
1015                Some(signaling_to_actr::Payload::RegisterResponse(response)) => {
1016                    return Ok(response);
1017                }
1018                Some(signaling_to_actr::Payload::Error(err)) => {
1019                    return Err(NetworkError::ConnectionError(format!(
1020                        "Credential update failed: {} ({})",
1021                        err.message, err.code
1022                    )));
1023                }
1024                _ => {}
1025            }
1026        }
1027
1028        Err(NetworkError::ConnectionError(
1029            "Invalid credential update response".to_string(),
1030        ))
1031    }
1032
1033    #[cfg_attr(
1034        feature = "opentelemetry",
1035        tracing::instrument(level = "debug", skip_all, fields(envelope_id = %envelope.envelope_id))
1036    )]
1037    async fn send_envelope(&self, envelope: SignalingEnvelope) -> NetworkResult<()> {
1038        #[cfg(feature = "opentelemetry")]
1039        let envelope = {
1040            let mut envelope = envelope;
1041            trace::inject_span_context(&tracing::Span::current(), &mut envelope);
1042            envelope
1043        };
1044
1045        // Check connection state first to avoid sending on stale/closed connections
1046        // This prevents "Broken pipe" errors when ws_sink exists but connection is dead
1047        if !self.is_connected() {
1048            return Err(NetworkError::ConnectionError(
1049                "Cannot send: WebSocket not connected".to_string(),
1050            ));
1051        }
1052
1053        let mut sink_guard = self.ws_sink.lock().await;
1054
1055        if let Some(sink) = sink_guard.as_mut() {
1056            // using protobuf binary serialization
1057            let mut buf = Vec::new();
1058            envelope.encode(&mut buf)?;
1059            let msg = tokio_tungstenite::tungstenite::Message::Binary(buf.into());
1060            sink.send(msg).await?;
1061
1062            self.stats.messages_sent.fetch_add(1, Ordering::Relaxed);
1063            tracing::debug!("Stats: {:?}", self.stats.snapshot());
1064            Ok(())
1065        } else {
1066            Err(NetworkError::ConnectionError("Not connected".to_string()))
1067        }
1068    }
1069
1070    async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
1071        let mut rx = self.inbound_rx.lock().await;
1072        match rx.recv().await {
1073            Some(envelope) => Ok(Some(envelope)),
1074            None => {
1075                tracing::error!("Inbound channel closed");
1076                Err(NetworkError::ConnectionError(
1077                    "Inbound channel closed".to_string(),
1078                ))
1079            }
1080        }
1081    }
1082
1083    fn is_connected(&self) -> bool {
1084        self.connected.load(Ordering::Acquire)
1085    }
1086
1087    fn get_stats(&self) -> SignalingStats {
1088        self.stats.snapshot()
1089    }
1090
1091    fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
1092        self.state_tx.subscribe()
1093    }
1094
1095    async fn set_actor_id(&self, actor_id: ActrId) {
1096        *self.actor_id.lock().await = Some(actor_id);
1097    }
1098
1099    async fn set_credential_state(&self, credential_state: CredentialState) {
1100        *self.credential_state.lock().await = Some(credential_state);
1101    }
1102}
1103
1104/// signaling statistics info
1105#[derive(Debug)]
1106pub(crate) struct AtomicSignalingStats {
1107    /// Connect attempts
1108    pub connections: AtomicU64,
1109
1110    /// DisconnectConnect attempts
1111    pub disconnections: AtomicU64,
1112
1113    /// Send'smessage number
1114    pub messages_sent: AtomicU64,
1115
1116    /// Receive'smessage number
1117    pub messages_received: AtomicU64,
1118
1119    /// Send's center skip number
1120    /// TODO: Wire heartbeat counters when heartbeat send/receive paths are instrumented; currently never incremented.
1121    pub heartbeats_sent: AtomicU64,
1122
1123    /// Receive's center skip number
1124    /// TODO: Wire heartbeat counters when heartbeat send/receive paths are instrumented; currently never incremented.
1125    pub heartbeats_received: AtomicU64,
1126
1127    /// Error attempts
1128    pub errors: AtomicU64,
1129}
1130
1131impl Default for AtomicSignalingStats {
1132    fn default() -> Self {
1133        Self {
1134            connections: AtomicU64::new(0),
1135            disconnections: AtomicU64::new(0),
1136            messages_sent: AtomicU64::new(0),
1137            messages_received: AtomicU64::new(0),
1138            heartbeats_sent: AtomicU64::new(0),
1139            heartbeats_received: AtomicU64::new(0),
1140            errors: AtomicU64::new(0),
1141        }
1142    }
1143}
1144
1145/// Snapshot of statistics for serialization and reading
1146#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
1147pub struct SignalingStats {
1148    /// Connect attempts
1149    pub connections: u64,
1150
1151    /// DisconnectConnect attempts
1152    pub disconnections: u64,
1153
1154    /// Send'smessage number
1155    pub messages_sent: u64,
1156
1157    /// Receive'smessage number
1158    pub messages_received: u64,
1159
1160    /// Send's center skip number
1161    pub heartbeats_sent: u64,
1162
1163    /// Receive's center skip number
1164    pub heartbeats_received: u64,
1165
1166    /// Error attempts
1167    pub errors: u64,
1168}
1169
1170impl AtomicSignalingStats {
1171    /// Create a snapshot of current statistics
1172    pub fn snapshot(&self) -> SignalingStats {
1173        SignalingStats {
1174            connections: self.connections.load(Ordering::Relaxed),
1175            disconnections: self.disconnections.load(Ordering::Relaxed),
1176            messages_sent: self.messages_sent.load(Ordering::Relaxed),
1177            messages_received: self.messages_received.load(Ordering::Relaxed),
1178            heartbeats_sent: self.heartbeats_sent.load(Ordering::Relaxed),
1179            heartbeats_received: self.heartbeats_received.load(Ordering::Relaxed),
1180            errors: self.errors.load(Ordering::Relaxed),
1181        }
1182    }
1183}
1184
1185fn current_unix_secs() -> u64 {
1186    use std::time::{SystemTime, UNIX_EPOCH};
1187    SystemTime::now()
1188        .duration_since(UNIX_EPOCH)
1189        .unwrap_or_default()
1190        .as_secs()
1191}
1192
1193#[cfg(test)]
1194mod tests {
1195    use super::*;
1196    use std::sync::atomic::{AtomicUsize, Ordering as UsizeOrdering};
1197    use tokio_util::sync::CancellationToken;
1198
1199    /// Simple fake SignalingClient implementation for testing the reconnect helper.
1200    struct FakeSignalingClient {
1201        state_tx: watch::Sender<ConnectionState>,
1202        connect_calls: Arc<AtomicUsize>,
1203        actor_id: tokio::sync::Mutex<Option<ActrId>>,
1204        credential_state: tokio::sync::Mutex<Option<CredentialState>>,
1205    }
1206
1207    #[async_trait]
1208    impl SignalingClient for FakeSignalingClient {
1209        async fn connect(&self) -> NetworkResult<()> {
1210            self.connect_calls.fetch_add(1, UsizeOrdering::SeqCst);
1211            Ok(())
1212        }
1213
1214        async fn disconnect(&self) -> NetworkResult<()> {
1215            Ok(())
1216        }
1217
1218        async fn send_register_request(
1219            &self,
1220            _request: RegisterRequest,
1221        ) -> NetworkResult<RegisterResponse> {
1222            unimplemented!("not needed in tests");
1223        }
1224
1225        async fn send_unregister_request(
1226            &self,
1227            _actor_id: ActrId,
1228            _credential: AIdCredential,
1229            _reason: Option<String>,
1230        ) -> NetworkResult<UnregisterResponse> {
1231            unimplemented!("not needed in tests");
1232        }
1233
1234        async fn send_heartbeat(
1235            &self,
1236            _actor_id: ActrId,
1237            _credential: AIdCredential,
1238            _availability: ServiceAvailabilityState,
1239            _power_reserve: f32,
1240            _mailbox_backlog: f32,
1241        ) -> NetworkResult<Pong> {
1242            unimplemented!("not needed in tests");
1243        }
1244
1245        async fn send_route_candidates_request(
1246            &self,
1247            _actor_id: ActrId,
1248            _credential: AIdCredential,
1249            _request: RouteCandidatesRequest,
1250        ) -> NetworkResult<RouteCandidatesResponse> {
1251            unimplemented!("not needed in tests");
1252        }
1253
1254        async fn send_credential_update_request(
1255            &self,
1256            _actor_id: ActrId,
1257            _credential: AIdCredential,
1258        ) -> NetworkResult<RegisterResponse> {
1259            unimplemented!("not needed in tests");
1260        }
1261
1262        async fn send_envelope(&self, _envelope: SignalingEnvelope) -> NetworkResult<()> {
1263            unimplemented!("not needed in tests");
1264        }
1265
1266        async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
1267            unimplemented!("not needed in tests");
1268        }
1269
1270        fn is_connected(&self) -> bool {
1271            // Derived from last published state; keep implementation simple for tests.
1272            *self.state_tx.borrow() == ConnectionState::Connected
1273        }
1274
1275        fn get_stats(&self) -> SignalingStats {
1276            SignalingStats::default()
1277        }
1278
1279        fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
1280            self.state_tx.subscribe()
1281        }
1282
1283        async fn set_actor_id(&self, actor_id: ActrId) {
1284            *self.actor_id.lock().await = Some(actor_id);
1285        }
1286
1287        async fn set_credential_state(&self, credential_state: CredentialState) {
1288            *self.credential_state.lock().await = Some(credential_state);
1289        }
1290    }
1291
1292    fn make_fake_client() -> (Arc<FakeSignalingClient>, watch::Sender<ConnectionState>) {
1293        let (state_tx, _rx) = watch::channel(ConnectionState::Disconnected);
1294        let client = Arc::new(FakeSignalingClient {
1295            state_tx: state_tx.clone(),
1296            connect_calls: Arc::new(AtomicUsize::new(0)),
1297            actor_id: tokio::sync::Mutex::new(None),
1298            credential_state: tokio::sync::Mutex::new(None),
1299        });
1300        (client, state_tx)
1301    }
1302
1303    #[test]
1304    fn test_websocket_signaling_client_initial_state_disconnected() {
1305        // Build a minimal config; URL doesn't need to be reachable for this test.
1306        let config = SignalingConfig {
1307            server_url: Url::parse("ws://example.com/signaling/ws").unwrap(),
1308            connection_timeout: 30,
1309            heartbeat_interval: 30,
1310            reconnect_config: ReconnectConfig::default(),
1311            auth_config: None,
1312        };
1313
1314        let client = WebSocketSignalingClient::new(config);
1315        let state_rx = client.subscribe_state();
1316        assert_eq!(*state_rx.borrow(), ConnectionState::Disconnected);
1317    }
1318}