1#[cfg(feature = "opentelemetry")]
6use super::trace;
7use crate::lifecycle::CredentialState;
8use crate::transport::{NetworkError, NetworkResult};
9#[cfg(feature = "opentelemetry")]
10use crate::wire::webrtc::trace::extract_trace_context;
11use actr_protocol::prost::Message as ProstMessage;
12use actr_protocol::{
13 AIdCredential, ActrId, ActrToSignaling, CredentialUpdateRequest, GetSigningKeyRequest,
14 PeerToSignaling, Ping, Pong, RegisterRequest, RegisterResponse, RouteCandidatesRequest,
15 RouteCandidatesResponse, ServiceAvailabilityState, SignalingEnvelope, UnregisterRequest,
16 UnregisterResponse, actr_to_signaling, peer_to_signaling, signaling_envelope,
17 signaling_to_actr,
18};
19use async_trait::async_trait;
20use base64::Engine as _;
21use futures_util::{SinkExt, StreamExt};
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::future::Future;
25use std::pin::Pin;
26use std::sync::{
27 Arc, OnceLock,
28 atomic::{AtomicBool, AtomicU64, Ordering},
29};
30use std::time::Duration;
31use tokio::net::TcpStream;
32use tokio::sync::{broadcast, mpsc, oneshot};
33use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
34use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config};
35#[cfg(feature = "opentelemetry")]
36use tracing_opentelemetry::OpenTelemetrySpanExt;
37use url::Url;
38
39type WsSink = Arc<
41 tokio::sync::Mutex<
42 Option<
43 futures_util::stream::SplitSink<
44 WebSocketStream<MaybeTlsStream<TcpStream>>,
45 tokio_tungstenite::tungstenite::Message,
46 >,
47 >,
48 >,
49>;
50
51const RESPONSE_TIMEOUT_SECS: u64 = 15;
57const PING_INTERVAL_SECS: u64 = 5;
59const PONG_TIMEOUT_SECS: u64 = 10;
60const SIGNALING_SEND_TIMEOUT_SECS: u64 = 5;
61const CONCURRENT_CONNECT_WAIT_TIMEOUT_SECS: u64 = 5;
62const DISCONNECT_LOCK_TIMEOUT_SECS: u64 = 5;
63const DISCONNECT_CLOSE_TIMEOUT_SECS: u64 = 1;
64
65#[derive(Debug, Clone)]
71pub struct SignalingConfig {
72 pub server_url: Url,
74
75 pub connection_timeout: u64,
77
78 pub heartbeat_interval: u64,
80
81 pub reconnect_config: ReconnectConfig,
83
84 pub auth_config: Option<AuthConfig>,
86
87 pub webrtc_role: Option<String>,
89}
90
91#[derive(Debug, Clone)]
93pub struct ReconnectConfig {
94 pub enabled: bool,
96
97 pub max_attempts: u32,
99
100 pub initial_delay: u64,
102
103 pub max_delay: u64,
105
106 pub backoff_multiplier: f64,
108}
109
110impl Default for ReconnectConfig {
111 fn default() -> Self {
112 Self {
113 enabled: true,
114 max_attempts: 10,
115 initial_delay: 1,
116 max_delay: 60,
117 backoff_multiplier: 2.0,
118 }
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct AuthConfig {
125 pub auth_type: AuthType,
127
128 pub credentials: HashMap<String, String>,
130}
131
132#[derive(Debug, Clone)]
134pub enum AuthType {
135 None,
137 BearerToken,
139 ApiKey,
141 Jwt,
143}
144
145#[async_trait]
155pub trait SignalingClient: Send + Sync {
156 async fn connect(&self) -> NetworkResult<()>;
158
159 async fn connect_once(&self) -> NetworkResult<()> {
164 self.connect().await
165 }
166
167 async fn disconnect(&self) -> NetworkResult<()>;
169
170 async fn probe_alive(&self, _timeout: Duration) -> NetworkResult<()> {
176 if self.is_connected() {
177 Ok(())
178 } else {
179 Err(NetworkError::ConnectionError(
180 "Signaling client is not connected".to_string(),
181 ))
182 }
183 }
184
185 async fn send_register_request(
188 &self,
189 request: RegisterRequest,
190 ) -> NetworkResult<RegisterResponse>;
191
192 async fn send_unregister_request(
197 &self,
198 actor_id: ActrId,
199 credential: AIdCredential,
200 reason: Option<String>,
201 ) -> NetworkResult<UnregisterResponse>;
202
203 async fn send_heartbeat(
206 &self,
207 actor_id: ActrId,
208 credential: AIdCredential,
209 availability: ServiceAvailabilityState,
210 power_reserve: f32,
211 mailbox_backlog: f32,
212 ) -> NetworkResult<Pong>;
213
214 async fn send_route_candidates_request(
216 &self,
217 actor_id: ActrId,
218 credential: AIdCredential,
219 request: RouteCandidatesRequest,
220 ) -> NetworkResult<RouteCandidatesResponse>;
221
222 async fn get_signing_key(
227 &self,
228 actor_id: ActrId,
229 credential: AIdCredential,
230 key_id: u32,
231 ) -> NetworkResult<(u32, Vec<u8>)>;
232
233 async fn send_credential_update_request(
238 &self,
239 actor_id: ActrId,
240 credential: AIdCredential,
241 ) -> NetworkResult<RegisterResponse>;
242
243 async fn send_envelope(&self, envelope: SignalingEnvelope) -> NetworkResult<()>;
245
246 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>>;
248
249 fn is_connected(&self) -> bool;
251
252 fn get_stats(&self) -> SignalingStats;
254 fn subscribe_events(&self) -> broadcast::Receiver<SignalingEvent>;
256
257 async fn set_actor_id(&self, actor_id: ActrId);
259 async fn set_credential_state(&self, credential_state: CredentialState);
260
261 async fn clear_identity(&self);
268
269 fn set_hook_callback(&self, _cb: HookCallback) {}
273}
274
275#[derive(Debug, Clone, Copy, PartialEq, Eq)]
277pub enum ConnectionState {
278 Disconnected,
279 Connected,
280}
281
282#[derive(Debug, Clone)]
288pub enum SignalingEvent {
289 ConnectStart { attempt: u32 },
291 Connected,
293 Disconnected { reason: DisconnectReason },
295}
296
297#[derive(Debug, Clone)]
299pub enum DisconnectReason {
300 StreamEnded,
302 PongTimeout,
304 PingSendFailed,
306 CredentialExpired,
308 Manual,
310 ConnectionFailed(String),
312}
313
314#[derive(Clone, Debug)]
323pub enum HookEvent {
324 SignalingConnectStart {
326 attempt: u32,
327 },
328 SignalingConnected,
329 SignalingDisconnected,
330 WebRtcConnectStart {
332 peer_id: ActrId,
333 },
334 WebRtcConnected {
335 peer_id: ActrId,
336 relayed: bool,
337 },
338 WebRtcDisconnected {
339 peer_id: ActrId,
340 },
341 DataStreamDeliveryUncertain {
342 stream_id: String,
343 session_id: u64,
344 reason: String,
345 },
346 WebSocketConnectStart {
348 peer_id: ActrId,
349 },
350 WebSocketConnected {
351 peer_id: ActrId,
352 },
353 WebSocketDisconnected {
354 peer_id: ActrId,
355 },
356 CredentialRenewed {
358 new_expiry: std::time::SystemTime,
359 },
360 CredentialExpiring {
361 new_expiry: std::time::SystemTime,
362 },
363 MailboxBackpressure {
365 queue_len: usize,
366 threshold: usize,
367 },
368}
369
370pub type HookCallback =
375 Arc<dyn Fn(HookEvent) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
376
377#[derive(Debug, Clone, Copy)]
378enum ConnectIntent {
379 Explicit,
380 AutoReconnect { generation: u64 },
381}
382
383pub struct WebSocketSignalingClient {
385 config: SignalingConfig,
386 actor_id: tokio::sync::Mutex<Option<ActrId>>,
387 credential_state: tokio::sync::Mutex<Option<CredentialState>>,
388 ws_sink: WsSink,
390 ws_stream: tokio::sync::Mutex<
392 Option<futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
393 >,
394 connected: Arc<AtomicBool>,
396 connecting: Arc<AtomicBool>,
398 stats: Arc<AtomicSignalingStats>,
400 envelope_counter: tokio::sync::Mutex<u64>,
402 pending_replies: Arc<tokio::sync::Mutex<HashMap<String, oneshot::Sender<SignalingEnvelope>>>>,
404 pending_pongs: Arc<tokio::sync::Mutex<HashMap<Vec<u8>, oneshot::Sender<()>>>>,
406 probe_counter: AtomicU64,
408 inbound_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<SignalingEnvelope>>>,
410 inbound_tx: tokio::sync::Mutex<mpsc::UnboundedSender<SignalingEnvelope>>,
411 receiver_task: Arc<tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>>,
413 ping_task: tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
415 event_tx: broadcast::Sender<SignalingEvent>,
417 last_pong: Arc<AtomicU64>,
419 reconnector_started: Arc<AtomicBool>,
421 reconnect_notify: Arc<tokio::sync::Notify>,
423 auto_reconnect_suppressed: AtomicBool,
425 reconnect_generation: AtomicU64,
427 hook_callback: OnceLock<HookCallback>,
429}
430
431impl WebSocketSignalingClient {
432 pub fn new(config: SignalingConfig) -> Self {
434 let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
435 let (event_tx, _event_rx) = broadcast::channel(64);
436 Self {
437 config,
438 actor_id: tokio::sync::Mutex::new(None),
439 credential_state: tokio::sync::Mutex::new(None),
440 ws_sink: Arc::new(tokio::sync::Mutex::new(None)),
441 ws_stream: tokio::sync::Mutex::new(None),
442 connected: Arc::new(AtomicBool::new(false)),
443 connecting: Arc::new(AtomicBool::new(false)),
444 stats: Arc::new(AtomicSignalingStats::default()),
445 envelope_counter: tokio::sync::Mutex::new(0),
446 pending_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
447 pending_pongs: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
448 probe_counter: AtomicU64::new(0),
449 inbound_rx: Arc::new(tokio::sync::Mutex::new(inbound_rx)),
450 inbound_tx: tokio::sync::Mutex::new(inbound_tx),
451 receiver_task: Arc::new(tokio::sync::Mutex::new(None)),
452 ping_task: tokio::sync::Mutex::new(None),
453 event_tx,
454 last_pong: Arc::new(AtomicU64::new(0)),
455 reconnector_started: Arc::new(AtomicBool::new(false)),
456 reconnect_notify: Arc::new(tokio::sync::Notify::new()),
457 auto_reconnect_suppressed: AtomicBool::new(false),
458 reconnect_generation: AtomicU64::new(0),
459 hook_callback: OnceLock::new(),
460 }
461 }
462
463 async fn invoke_hook(&self, event: HookEvent) {
470 if let Some(cb) = self.hook_callback.get() {
471 cb(event).await;
472 }
473 }
474
475 async fn publish_disconnected_transition(
476 was_connected: bool,
477 stats: &Arc<AtomicSignalingStats>,
478 event_tx: &broadcast::Sender<SignalingEvent>,
479 hook_callback: Option<HookCallback>,
480 reason: DisconnectReason,
481 reconnect_notify: Option<&Arc<tokio::sync::Notify>>,
482 ) -> bool {
483 if !was_connected {
484 return false;
485 }
486
487 stats.disconnections.fetch_add(1, Ordering::Relaxed);
488
489 if let Some(cb) = hook_callback {
490 cb(HookEvent::SignalingDisconnected).await;
491 }
492
493 let _ = event_tx.send(SignalingEvent::Disconnected { reason });
494
495 if let Some(notify) = reconnect_notify {
496 notify.notify_one();
497 }
498
499 true
500 }
501
502 pub fn start_reconnect_manager(self: &Arc<Self>) {
503 if !self.config.reconnect_config.enabled {
504 return;
505 }
506 if self
507 .reconnector_started
508 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
509 .is_err()
510 {
511 return; }
513
514 tracing::info!("🔄 Starting reconnect manager for signaling client");
515
516 let client = self.clone();
517 let notify = self.reconnect_notify.clone();
518
519 tokio::spawn(async move {
520 loop {
521 notify.notified().await;
523
524 if !client.config.reconnect_config.enabled {
525 break;
526 }
527
528 client.run_reconnect_cycle().await;
530 }
531 });
532 }
533
534 async fn run_reconnect_cycle(self: &Arc<Self>) {
536 use actr_framework::ExponentialBackoff;
537
538 let cfg = &self.config.reconnect_config;
539 let generation = self.reconnect_generation.load(Ordering::Acquire);
540
541 if self.auto_reconnect_cancelled(generation) {
542 tracing::debug!("Skipping signaling auto-reconnect cycle after explicit disconnect");
543 return;
544 }
545
546 if self.connected.load(Ordering::Acquire) {
547 tracing::debug!("🔎 Probing connected signaling before reconnect cycle");
548 match self
549 .probe_alive(Duration::from_secs(PONG_TIMEOUT_SECS))
550 .await
551 {
552 Ok(()) => {
553 tracing::debug!("Signaling probe succeeded, skipping reconnect cycle");
554 return;
555 }
556 Err(e) => {
557 tracing::warn!("Signaling probe failed before reconnect: {e}");
558 if let Err(disconnect_err) = self.disconnect_internal(false).await {
559 tracing::warn!(
560 "⚠️ Disconnect cleanup failed after failed probe (non-fatal): {disconnect_err}"
561 );
562 }
563 }
564 }
565 }
566
567 let backoff = ExponentialBackoff::builder()
568 .initial_delay(std::time::Duration::from_secs(cfg.initial_delay.max(1)))
569 .max_delay(std::time::Duration::from_secs(cfg.max_delay.max(1)))
570 .max_retries(cfg.max_attempts)
571 .with_jitter()
572 .build();
573
574 let mut attempt: u32 = 0;
575
576 for delay in backoff {
577 if self.auto_reconnect_cancelled(generation) {
578 tracing::debug!(
579 "Stopping signaling auto-reconnect cycle after explicit disconnect"
580 );
581 return;
582 }
583
584 if self.connected.load(Ordering::Acquire) {
585 tracing::debug!("Already connected, aborting reconnect cycle");
586 return;
587 }
588
589 attempt += 1;
590 let _ = self.event_tx.send(SignalingEvent::ConnectStart { attempt });
591
592 match self.connect_once_for_auto_reconnect(generation).await {
593 Ok(()) => {
594 tracing::info!("✅ Signaling reconnect succeeded on attempt {attempt}");
595 return;
596 }
597 Err(e) => {
598 if self.auto_reconnect_cancelled(generation) {
599 tracing::debug!(
600 "Stopping signaling auto-reconnect cycle after explicit disconnect"
601 );
602 return;
603 }
604
605 tracing::warn!(
606 "❌ Reconnect attempt {attempt} failed: {e}, retrying in {delay:?}"
607 );
608 tokio::select! {
609 _ = tokio::time::sleep(delay) => {}
610 _ = self.reconnect_notify.notified() => {
611 tracing::debug!("Explicit reconnect request interrupted reconnect backoff");
612 }
613 }
614 if self.auto_reconnect_cancelled(generation) {
615 tracing::debug!(
616 "Stopping signaling auto-reconnect cycle after explicit disconnect"
617 );
618 return;
619 }
620 }
621 }
622 }
623
624 tracing::error!("Reconnect failed after {attempt} attempts, entering cooldown");
626 let cooldown = std::time::Duration::from_secs(cfg.max_delay.max(1) * 2);
627 tokio::select! {
628 _ = tokio::time::sleep(cooldown) => {}
629 _ = self.reconnect_notify.notified() => {
630 tracing::debug!("Explicit reconnect request interrupted reconnect cooldown");
631 }
632 }
633 if self.auto_reconnect_cancelled(generation) {
634 tracing::debug!(
635 "Signaling auto-reconnect cooldown ended after explicit disconnect suppression"
636 );
637 }
638 }
640
641 #[cfg(feature = "test-utils")]
649 pub async fn connect_to(url: &str) -> NetworkResult<Arc<Self>> {
650 let config = SignalingConfig {
651 server_url: url.parse()?,
652 connection_timeout: 5,
653 heartbeat_interval: 30,
654 reconnect_config: ReconnectConfig::default(),
655 auth_config: None,
656 webrtc_role: None,
657 };
658
659 let client = Arc::new(Self::new(config));
660 client.start_reconnect_manager();
661 client.connect().await?;
662 Ok(client)
663 }
664
665 #[cfg(feature = "test-utils")]
671 pub async fn connect_to_with_identity(
672 url: &str,
673 actor_id: ActrId,
674 credential_state: CredentialState,
675 ) -> NetworkResult<Arc<Self>> {
676 let config = SignalingConfig {
677 server_url: url.parse()?,
678 connection_timeout: 5,
679 heartbeat_interval: 30,
680 reconnect_config: ReconnectConfig::default(),
681 auth_config: None,
682 webrtc_role: None,
683 };
684
685 let client = Arc::new(Self::new(config));
686 client.set_actor_id(actor_id).await;
687 client.set_credential_state(credential_state).await;
688 client.start_reconnect_manager();
689 client.connect().await?;
690 Ok(client)
691 }
692
693 async fn next_envelope_id(&self) -> String {
695 let mut counter = self.envelope_counter.lock().await;
696 *counter += 1;
697 format!("env-{}", *counter)
698 }
699
700 async fn create_envelope(&self, flow: signaling_envelope::Flow) -> SignalingEnvelope {
702 SignalingEnvelope {
703 envelope_version: 1,
704 envelope_id: self.next_envelope_id().await,
705 reply_for: None,
706 timestamp: prost_types::Timestamp {
707 seconds: chrono::Utc::now().timestamp(),
708 nanos: 0,
709 },
710 traceparent: None,
711 tracestate: None,
712 flow: Some(flow),
713 }
714 }
715
716 async fn reset_inbound_channel(&self) {
718 self.drop_pending_replies("inbound channel reset").await;
719 self.drop_pending_pongs("inbound channel reset").await;
720
721 let (tx, rx) = mpsc::unbounded_channel();
722 *self.inbound_tx.lock().await = tx;
723 *self.inbound_rx.lock().await = rx;
724 }
725
726 async fn drop_pending_replies(&self, reason: &'static str) {
727 let dropped = {
728 let mut pending = self.pending_replies.lock().await;
729 let dropped = pending.len();
730 pending.clear();
731 dropped
732 };
733
734 if dropped > 0 {
735 tracing::debug!(reason, dropped, "Dropping pending signaling reply waiters");
736 }
737 }
738
739 async fn drop_pending_pongs(&self, reason: &'static str) {
740 let dropped = {
741 let mut pending = self.pending_pongs.lock().await;
742 let dropped = pending.len();
743 pending.clear();
744 dropped
745 };
746
747 if dropped > 0 {
748 tracing::debug!(reason, dropped, "Dropping pending signaling pong waiters");
749 }
750 }
751
752 async fn build_url_with_identity(&self) -> Url {
757 let mut url = self.config.server_url.clone();
758 let actor_id_opt = self.actor_id.lock().await.clone();
759 if let Some(actor_id) = actor_id_opt {
760 let actor_str = actr_protocol::ActrId::to_string_repr(&actor_id);
761 url.query_pairs_mut().append_pair("actor_id", &actor_str);
762 }
763
764 let cred_state_opt = self.credential_state.lock().await.clone();
766 if let Some(cred_state) = cred_state_opt {
767 let cred = cred_state.credential().await;
768 let claims_b64 = base64::engine::general_purpose::STANDARD.encode(&cred.claims);
769 let sig_b64 = base64::engine::general_purpose::STANDARD.encode(&cred.signature);
770 url.query_pairs_mut()
771 .append_pair("key_id", &cred.key_id.to_string())
772 .append_pair("claims", &claims_b64)
773 .append_pair("signature", &sig_b64);
774 }
775
776 if let Some(role) = &self.config.webrtc_role {
778 url.query_pairs_mut().append_pair("webrtc_role", role);
779 }
780
781 url
782 }
783
784 fn redact_signaling_url_for_log(url: &Url) -> String {
785 let mut redacted = url.clone();
786 let pairs: Vec<(String, String)> = redacted
787 .query_pairs()
788 .map(|(key, value)| {
789 let redacted_value = match key.to_ascii_lowercase().as_str() {
790 "claims" | "signature" | "token" | "authorization" | "bearer"
791 | "access_token" | "api_key" => "REDACTED".to_string(),
792 _ => value.into_owned(),
793 };
794 (key.into_owned(), redacted_value)
795 })
796 .collect();
797
798 redacted.set_query(None);
799 if !pairs.is_empty() {
800 let mut query = redacted.query_pairs_mut();
801 for (key, value) in pairs {
802 query.append_pair(&key, &value);
803 }
804 }
805
806 redacted.to_string()
807 }
808
809 fn auto_reconnect_cancelled(&self, generation: u64) -> bool {
810 self.auto_reconnect_suppressed.load(Ordering::Acquire)
811 || self.reconnect_generation.load(Ordering::Acquire) != generation
812 }
813
814 async fn establish_connection_once(&self) -> NetworkResult<()> {
818 self.establish_connection_once_with_intent(ConnectIntent::Explicit)
819 .await
820 }
821
822 async fn establish_connection_once_for_auto_reconnect(
823 &self,
824 generation: u64,
825 ) -> NetworkResult<()> {
826 self.establish_connection_once_with_intent(ConnectIntent::AutoReconnect { generation })
827 .await
828 }
829
830 async fn establish_connection_once_with_intent(
831 &self,
832 intent: ConnectIntent,
833 ) -> NetworkResult<()> {
834 if self.connected.load(Ordering::Acquire) {
836 tracing::debug!("Connection already established, skipping establish_connection_once()");
837 return Ok(());
838 }
839
840 let url = self.build_url_with_identity().await;
841 let timeout_secs = self.config.connection_timeout;
842 tracing::debug!(
843 "Establishing connection to URL: {}",
844 Self::redact_signaling_url_for_log(&url)
845 );
846 let config = WebSocketConfig::default().write_buffer_size(0);
848 let connect_result = if timeout_secs == 0 {
850 connect_async_with_config(url.as_str(), Some(config), false).await
851 } else {
852 let timeout_duration = std::time::Duration::from_secs(timeout_secs);
853 tokio::time::timeout(
854 timeout_duration,
855 connect_async_with_config(url.as_str(), Some(config), false),
856 )
857 .await
858 .map_err(|_| {
859 NetworkError::ConnectionError(format!(
860 "Signaling connect timeout after {}s",
861 timeout_secs
862 ))
863 })?
864 }?;
865
866 let (ws_stream, _) = connect_result;
867
868 let (sink, stream) = ws_stream.split();
870
871 if let ConnectIntent::AutoReconnect { generation } = intent
872 && self.auto_reconnect_cancelled(generation)
873 {
874 tracing::debug!(
875 generation,
876 "Discarding completed signaling auto-reconnect after explicit disconnect"
877 );
878 let mut sink = sink;
879 if let Err(e) = sink.close().await {
880 tracing::warn!(
881 "Signaling auto-reconnect socket close failed after cancellation: {}",
882 e
883 );
884 }
885 return Err(NetworkError::ConnectionError(
886 "Signaling auto-reconnect was cancelled by explicit disconnect".to_string(),
887 ));
888 }
889
890 *self.ws_sink.lock().await = Some(sink);
891 *self.ws_stream.lock().await = Some(stream);
892 self.connected.store(true, Ordering::Release);
893 self.auto_reconnect_suppressed
894 .store(false, Ordering::Release);
895 self.last_pong.store(current_unix_secs(), Ordering::Release);
896 self.invoke_hook(HookEvent::SignalingConnected).await;
898 let _ = self.event_tx.send(SignalingEvent::Connected);
899
900 self.stats.connections.fetch_add(1, Ordering::Relaxed);
901
902 Ok(())
903 }
904
905 async fn connect_with_retries(&self) -> NetworkResult<()> {
907 use actr_framework::ExponentialBackoff;
908
909 let cfg = &self.config.reconnect_config;
910
911 if !cfg.enabled {
913 return self.establish_connection_once().await;
914 }
915
916 let backoff = ExponentialBackoff::builder()
917 .initial_delay(std::time::Duration::from_secs(cfg.initial_delay.max(1)))
918 .max_delay(std::time::Duration::from_secs(cfg.max_delay.max(1)))
919 .max_retries(cfg.max_attempts)
920 .with_jitter()
921 .build();
922
923 let mut last_err = None;
924
925 for (attempt, delay) in std::iter::once(std::time::Duration::ZERO)
927 .chain(backoff)
928 .enumerate()
929 {
930 let attempt = attempt as u32 + 1;
931 self.invoke_hook(HookEvent::SignalingConnectStart { attempt })
932 .await;
933 if delay > std::time::Duration::ZERO {
934 tracing::info!("Retry signaling connect after {delay:?} (attempt {attempt})");
935 tokio::select! {
936 _ = tokio::time::sleep(delay) => {}
937 _ = self.reconnect_notify.notified() => {
938 tracing::debug!("Explicit reconnect request interrupted signaling connect backoff");
939 }
940 }
941 }
942
943 match self.establish_connection_once().await {
944 Ok(()) => return Ok(()),
945 Err(e) => {
946 tracing::warn!("Signaling connect attempt {attempt} failed: {e:?}");
947 last_err = Some(e);
948 }
949 }
950 }
951
952 let total = cfg.max_attempts + 1; tracing::error!("Signaling connect failed after {total} attempts, giving up");
954 Err(last_err.unwrap_or_else(|| {
955 NetworkError::ConnectionError("All connection attempts failed".to_string())
956 }))
957 }
958
959 #[cfg_attr(
961 feature = "opentelemetry",
962 tracing::instrument(skip_all, fields(envelope_id = %envelope.envelope_id))
963 )]
964 async fn send_envelope_and_wait_response(
965 &self,
966 envelope: SignalingEnvelope,
967 ) -> NetworkResult<SignalingEnvelope> {
968 let reply_for = envelope.envelope_id.clone();
969
970 let (tx, rx) = oneshot::channel();
972 self.pending_replies
973 .lock()
974 .await
975 .insert(reply_for.clone(), tx);
976
977 if let Err(e) = self.send_envelope(envelope).await {
978 self.pending_replies.lock().await.remove(&reply_for);
980 return Err(e);
981 }
982
983 let result =
984 tokio::time::timeout(std::time::Duration::from_secs(RESPONSE_TIMEOUT_SECS), rx).await;
985 if result.is_err() {
987 self.pending_replies.lock().await.remove(&reply_for);
988 }
989
990 let response_envelope = result
991 .map_err(|_| {
992 NetworkError::ConnectionError(
993 "Timed out waiting for signaling response".to_string(),
994 )
995 })?
996 .map_err(|_| {
997 NetworkError::ConnectionError(
998 "Receiver dropped while waiting for signaling response".to_string(),
999 )
1000 })?;
1001
1002 Ok(response_envelope)
1003 }
1004
1005 async fn start_receiver(&self) {
1007 let mut stream_guard = self.ws_stream.lock().await;
1008 if stream_guard.is_none() {
1009 return;
1010 }
1011
1012 let mut stream = stream_guard.take().expect("stream exists");
1013 let pending = self.pending_replies.clone();
1014 let inbound_tx = { self.inbound_tx.lock().await.clone() };
1015 let stats = self.stats.clone();
1016 let connected = self.connected.clone();
1017 let event_tx = self.event_tx.clone();
1018 let last_pong = self.last_pong.clone();
1019 let pending_pongs = self.pending_pongs.clone();
1020 let reconnect_notify = self.reconnect_notify.clone();
1021 let reconnect_enabled = self.config.reconnect_config.enabled;
1022 let hook_callback = self.hook_callback.get().cloned();
1023 let handle = tokio::spawn(async move {
1024 while let Some(msg) = stream.next().await {
1025 match msg {
1026 Ok(tokio_tungstenite::tungstenite::Message::Binary(data)) => {
1027 last_pong.store(current_unix_secs(), Ordering::Release);
1029 match SignalingEnvelope::decode(&data[..]) {
1030 Ok(envelope) => {
1031 #[cfg(feature = "opentelemetry")]
1032 let span = {
1033 let span = tracing::info_span!("signaling.receive_envelope", envelope_id = %envelope.envelope_id);
1034 span.set_parent(extract_trace_context(&envelope));
1035 span
1036 };
1037
1038 stats.messages_received.fetch_add(1, Ordering::Relaxed);
1039 tracing::debug!("Received message: {:?}", envelope);
1040 if let Some(reply_for) = envelope.reply_for.clone() {
1041 if let Some(sender) = pending.lock().await.remove(&reply_for) {
1042 #[cfg(feature = "opentelemetry")]
1043 let _ = span.enter();
1044 if let Err(e) = sender.send(envelope) {
1045 stats.errors.fetch_add(1, Ordering::Relaxed);
1046 tracing::warn!(
1047 "Failed to send reply envelope to waiter: {e:?}",
1048 );
1049 }
1050 continue;
1051 }
1052 }
1053 tracing::debug!(
1054 "Unmatched or push message -> forward to inbound channel"
1055 );
1056 if let Err(e) = inbound_tx.send(envelope) {
1058 stats.errors.fetch_add(1, Ordering::Relaxed);
1059 tracing::warn!(
1060 "Failed to send envelope to inbound channel: {e:?}"
1061 );
1062 }
1063 }
1064 Err(e) => {
1065 stats.errors.fetch_add(1, Ordering::Relaxed);
1066 tracing::warn!("Failed to decode SignalingEnvelope: {e}");
1067 }
1068 }
1069 }
1070 Ok(tokio_tungstenite::tungstenite::Message::Pong(payload)) => {
1071 tracing::debug!("Received pong");
1072 last_pong.store(current_unix_secs(), Ordering::Release);
1073 if let Some(sender) = pending_pongs.lock().await.remove(&payload.to_vec()) {
1074 let _ = sender.send(());
1075 }
1076 }
1077 Ok(tokio_tungstenite::tungstenite::Message::Ping(_)) => {
1078 tracing::debug!("Received ping");
1079 last_pong.store(current_unix_secs(), Ordering::Release);
1080 }
1081 Ok(other) => {
1082 tracing::warn!("Received non-binary frame, ignoring: {other:?}");
1083 }
1084 Err(e) => {
1085 stats.errors.fetch_add(1, Ordering::Relaxed);
1086 tracing::error!("Signaling receive error: {e}");
1087 break;
1088 }
1089 }
1090 }
1091
1092 tracing::warn!("Stream terminated");
1093 let was_connected = connected.swap(false, Ordering::AcqRel);
1097 Self::publish_disconnected_transition(
1098 was_connected,
1099 &stats,
1100 &event_tx,
1101 hook_callback,
1102 DisconnectReason::StreamEnded,
1103 reconnect_enabled.then_some(&reconnect_notify),
1104 )
1105 .await;
1106 pending_pongs.lock().await.clear();
1107 });
1108
1109 *self.receiver_task.lock().await = Some(handle);
1110 }
1111
1112 async fn start_ping_task(&self) {
1115 let mut existing = self.ping_task.lock().await;
1116 if let Some(handle) = existing.as_ref() {
1117 if handle.is_finished() {
1118 existing.take();
1119 } else {
1120 return;
1121 }
1122 }
1123
1124 let sink = self.ws_sink.clone();
1125 let connected = self.connected.clone();
1126 let stats = self.stats.clone();
1127 let event_tx = self.event_tx.clone();
1128 let last_pong = self.last_pong.clone();
1129 let receiver_task_clone = Arc::clone(&self.receiver_task);
1130 let reconnect_notify = self.reconnect_notify.clone();
1131 let reconnect_enabled = self.config.reconnect_config.enabled;
1132 let hook_callback = self.hook_callback.get().cloned();
1133
1134 let handle = tokio::spawn(async move {
1135 loop {
1136 tokio::time::sleep(std::time::Duration::from_secs(PING_INTERVAL_SECS)).await;
1137
1138 if !connected.load(Ordering::Acquire) {
1139 break;
1140 }
1141
1142 let mut disconnect_reason = None;
1144 {
1145 let mut sink_guard = sink.lock().await;
1146 if let Some(sink) = sink_guard.as_mut() {
1147 match tokio::time::timeout(
1148 std::time::Duration::from_secs(SIGNALING_SEND_TIMEOUT_SECS),
1149 sink.send(tokio_tungstenite::tungstenite::Message::Ping(
1150 Vec::new().into(),
1151 )),
1152 )
1153 .await
1154 {
1155 Ok(Ok(())) => {}
1156 Ok(Err(e)) => {
1157 tracing::warn!("Signaling ping send failed: {e}");
1158 disconnect_reason = Some(DisconnectReason::PingSendFailed);
1159 }
1160 Err(_) => {
1161 tracing::warn!("Signaling ping send timed out");
1162 disconnect_reason = Some(DisconnectReason::PingSendFailed);
1163 }
1164 }
1165 } else {
1166 tracing::warn!("Signaling not connected");
1167 disconnect_reason = Some(DisconnectReason::PingSendFailed);
1168 }
1169 }
1170
1171 if let Some(reason) = disconnect_reason {
1172 let was_connected = connected.swap(false, Ordering::AcqRel);
1173 Self::publish_disconnected_transition(
1174 was_connected,
1175 &stats,
1176 &event_tx,
1177 hook_callback.clone(),
1178 reason,
1179 reconnect_enabled.then_some(&reconnect_notify),
1180 )
1181 .await;
1182 break;
1183 }
1184
1185 let now = current_unix_secs();
1187 let last = last_pong.load(Ordering::Acquire);
1188 if now.saturating_sub(last) > PONG_TIMEOUT_SECS {
1189 tracing::warn!(
1190 "Signaling pong timeout (last seen {}s ago), marking disconnected",
1191 now.saturating_sub(last)
1192 );
1193 if let Some(handle) = receiver_task_clone.lock().await.take() {
1194 handle.abort();
1195 }
1196 let was_connected = connected.swap(false, Ordering::AcqRel);
1197 Self::publish_disconnected_transition(
1198 was_connected,
1199 &stats,
1200 &event_tx,
1201 hook_callback.clone(),
1202 DisconnectReason::PongTimeout,
1203 reconnect_enabled.then_some(&reconnect_notify),
1204 )
1205 .await;
1206 break;
1207 }
1208 }
1209 });
1210
1211 *existing = Some(handle);
1212 }
1213
1214 async fn disconnect_internal(&self, suppress_auto_reconnect: bool) -> NetworkResult<()> {
1215 if suppress_auto_reconnect {
1216 self.reconnect_generation.fetch_add(1, Ordering::AcqRel);
1217 self.auto_reconnect_suppressed
1218 .store(true, Ordering::Release);
1219 self.reconnect_notify.notify_waiters();
1220 }
1221
1222 self.drop_pending_replies("signaling disconnect").await;
1223 self.drop_pending_pongs("signaling disconnect").await;
1224 let was_connected = self.connected.swap(false, Ordering::AcqRel);
1225
1226 let ping_handle = match tokio::time::timeout(
1231 std::time::Duration::from_secs(DISCONNECT_LOCK_TIMEOUT_SECS),
1232 self.ping_task.lock(),
1233 )
1234 .await
1235 {
1236 Ok(mut task_guard) => task_guard.take(),
1237 Err(_) => {
1238 tracing::warn!("Timed out waiting for signaling ping task lock during disconnect");
1239 None
1240 }
1241 };
1242 if let Some(handle) = ping_handle {
1243 handle.abort();
1244 }
1245
1246 let receiver_handle = match tokio::time::timeout(
1247 std::time::Duration::from_secs(DISCONNECT_LOCK_TIMEOUT_SECS),
1248 self.receiver_task.lock(),
1249 )
1250 .await
1251 {
1252 Ok(mut task_guard) => task_guard.take(),
1253 Err(_) => {
1254 tracing::warn!(
1255 "Timed out waiting for signaling receiver task lock during disconnect"
1256 );
1257 None
1258 }
1259 };
1260 if let Some(handle) = receiver_handle {
1261 handle.abort();
1262 }
1263
1264 let sink = match tokio::time::timeout(
1268 std::time::Duration::from_secs(DISCONNECT_LOCK_TIMEOUT_SECS),
1269 self.ws_sink.lock(),
1270 )
1271 .await
1272 {
1273 Ok(mut sink_guard) => sink_guard.take(),
1274 Err(_) => {
1275 tracing::warn!(
1276 "Timed out waiting for signaling WebSocket sink lock during disconnect"
1277 );
1278 None
1279 }
1280 };
1281
1282 if let Some(mut sink) = sink {
1283 match tokio::time::timeout(
1284 std::time::Duration::from_secs(DISCONNECT_CLOSE_TIMEOUT_SECS),
1285 sink.close(),
1286 )
1287 .await
1288 {
1289 Ok(Ok(())) => {}
1290 Ok(Err(e)) => {
1291 tracing::warn!("Signaling WebSocket close failed during disconnect: {}", e);
1292 }
1293 Err(_) => {
1294 tracing::warn!(
1295 "Signaling WebSocket close timed out during disconnect; continuing cleanup"
1296 );
1297 }
1298 }
1299 }
1300
1301 match tokio::time::timeout(
1302 std::time::Duration::from_secs(DISCONNECT_LOCK_TIMEOUT_SECS),
1303 self.ws_stream.lock(),
1304 )
1305 .await
1306 {
1307 Ok(mut stream_guard) => {
1308 stream_guard.take();
1309 }
1310 Err(_) => {
1311 tracing::warn!(
1312 "Timed out waiting for signaling WebSocket stream lock during disconnect"
1313 );
1314 }
1315 }
1316
1317 self.reset_inbound_channel().await;
1318
1319 Self::publish_disconnected_transition(
1321 was_connected,
1322 &self.stats,
1323 &self.event_tx,
1324 self.hook_callback.get().cloned(),
1325 DisconnectReason::Manual,
1326 None,
1327 )
1328 .await;
1329
1330 Ok(())
1331 }
1332
1333 async fn connect_once_for_auto_reconnect(&self, generation: u64) -> NetworkResult<()> {
1334 if self.auto_reconnect_cancelled(generation) {
1335 return Err(NetworkError::ConnectionError(
1336 "Signaling auto-reconnect was cancelled".to_string(),
1337 ));
1338 }
1339
1340 if self.connected.load(Ordering::Acquire) {
1341 tracing::debug!("Already connected, skipping auto-reconnect connect_once()");
1342 return Ok(());
1343 }
1344
1345 match self
1346 .connecting
1347 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
1348 {
1349 Ok(_) => {}
1350 Err(_) => {
1351 if self.connected.load(Ordering::Acquire) {
1352 tracing::debug!("Already connected, skipping auto-reconnect connect_once()");
1353 return Ok(());
1354 }
1355
1356 tracing::debug!(
1357 "Another connection attempt in progress, waiting for state change..."
1358 );
1359 let result = self.wait_for_connection_result().await;
1360 if self.auto_reconnect_cancelled(generation) {
1361 return Err(NetworkError::ConnectionError(
1362 "Signaling auto-reconnect was cancelled".to_string(),
1363 ));
1364 }
1365 return result;
1366 }
1367 }
1368
1369 if self.auto_reconnect_cancelled(generation) {
1370 self.connecting.store(false, Ordering::Release);
1371 return Err(NetworkError::ConnectionError(
1372 "Signaling auto-reconnect was cancelled".to_string(),
1373 ));
1374 }
1375
1376 if self.connected.load(Ordering::Acquire) {
1377 tracing::debug!("Connection completed by another task while acquiring lock");
1378 self.connecting.store(false, Ordering::Release);
1379 return Ok(());
1380 }
1381
1382 tracing::debug!(
1383 generation,
1384 "Acquired connection lock, establishing one auto-reconnect signaling attempt..."
1385 );
1386
1387 let result = self
1388 .establish_connection_once_for_auto_reconnect(generation)
1389 .await;
1390 self.connecting.store(false, Ordering::Release);
1391
1392 match result {
1393 Ok(()) => {
1394 if self.auto_reconnect_cancelled(generation) {
1395 self.disconnect_internal(false).await?;
1396 return Err(NetworkError::ConnectionError(
1397 "Signaling auto-reconnect was cancelled".to_string(),
1398 ));
1399 }
1400 self.start_receiver().await;
1401 self.start_ping_task().await;
1402 Ok(())
1403 }
1404 Err(e) => {
1405 if !self.auto_reconnect_cancelled(generation) {
1406 let _ = self.event_tx.send(SignalingEvent::Disconnected {
1407 reason: DisconnectReason::ConnectionFailed(e.to_string()),
1408 });
1409 tracing::error!("Connection attempt failed: {e}");
1410 }
1411 Err(e)
1412 }
1413 }
1414 }
1415
1416 async fn wait_for_connection_result(&self) -> NetworkResult<()> {
1420 let mut event_rx = self.event_tx.subscribe();
1421 let deadline = tokio::time::Instant::now()
1422 + std::time::Duration::from_secs(CONCURRENT_CONNECT_WAIT_TIMEOUT_SECS);
1423
1424 loop {
1425 tokio::select! {
1426 _ = tokio::time::sleep_until(deadline) => {
1427 if self.connected.load(Ordering::Acquire) {
1429 tracing::debug!("Connection succeeded just before timeout");
1430 return Ok(());
1431 }
1432 return Err(NetworkError::ConnectionError(
1433 "Timeout waiting for concurrent connection attempt".to_string(),
1434 ));
1435 }
1436 result = event_rx.recv() => {
1437 match result {
1438 Ok(SignalingEvent::Connected) => {
1439 tracing::debug!("Connection established by another task");
1440 return Ok(());
1441 }
1442 Ok(SignalingEvent::Disconnected { reason }) => {
1443 return Err(NetworkError::ConnectionError(format!(
1444 "Concurrent signaling connection failed: {reason:?}"
1445 )));
1446 }
1447 Ok(_) => continue, Err(broadcast::error::RecvError::Lagged(n)) => {
1449 tracing::warn!("Event receiver lagged by {n} events");
1450 if self.connected.load(Ordering::Acquire) {
1452 return Ok(());
1453 }
1454 continue;
1455 }
1456 Err(broadcast::error::RecvError::Closed) => {
1457 return Err(NetworkError::ConnectionError(
1458 "Event channel closed while waiting for connection".to_string(),
1459 ));
1460 }
1461 }
1462 }
1463 }
1464 }
1465 }
1466}
1467
1468#[async_trait]
1469impl SignalingClient for WebSocketSignalingClient {
1470 async fn connect(&self) -> NetworkResult<()> {
1471 match self
1476 .connecting
1477 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
1478 {
1479 Ok(_) => {
1480 }
1483 Err(_) => {
1484 if self.connected.load(Ordering::Acquire) {
1487 tracing::debug!("Already connected, skipping connect()");
1488 return Ok(());
1489 }
1490
1491 tracing::debug!(
1493 "Another connection attempt in progress, waiting for state change..."
1494 );
1495 return self.wait_for_connection_result().await;
1496 }
1497 }
1498
1499 if self.connected.load(Ordering::Acquire) {
1504 tracing::debug!("Connection completed by another task while acquiring lock");
1505 self.connecting.store(false, Ordering::Release);
1506 return Ok(());
1507 }
1508
1509 tracing::debug!("Acquired connection lock, establishing connection...");
1510
1511 let result = self.connect_with_retries().await;
1513
1514 self.connecting.store(false, Ordering::Release);
1516
1517 match result {
1519 Ok(()) => {
1520 self.start_receiver().await;
1521 self.start_ping_task().await;
1522 Ok(())
1523 }
1524 Err(e) => {
1525 let _ = self.event_tx.send(SignalingEvent::Disconnected {
1527 reason: DisconnectReason::ConnectionFailed(e.to_string()),
1528 });
1529 tracing::error!("Connection failed: {e}");
1530 Err(e)
1531 }
1532 }
1533 }
1534
1535 async fn connect_once(&self) -> NetworkResult<()> {
1536 if self.connected.load(Ordering::Acquire) {
1537 tracing::debug!("Already connected, skipping connect_once()");
1538 return Ok(());
1539 }
1540
1541 match self
1542 .connecting
1543 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
1544 {
1545 Ok(_) => {}
1546 Err(_) => {
1547 if self.connected.load(Ordering::Acquire) {
1548 tracing::debug!("Already connected, skipping connect_once()");
1549 return Ok(());
1550 }
1551
1552 tracing::debug!(
1553 "Another connection attempt in progress, waiting for state change..."
1554 );
1555 return self.wait_for_connection_result().await;
1556 }
1557 }
1558
1559 if self.connected.load(Ordering::Acquire) {
1560 tracing::debug!("Connection completed by another task while acquiring lock");
1561 self.connecting.store(false, Ordering::Release);
1562 return Ok(());
1563 }
1564
1565 tracing::debug!(
1566 "Acquired connection lock, establishing one signaling connection attempt..."
1567 );
1568
1569 let result = self.establish_connection_once().await;
1570 self.connecting.store(false, Ordering::Release);
1571
1572 match result {
1573 Ok(()) => {
1574 self.start_receiver().await;
1575 self.start_ping_task().await;
1576 Ok(())
1577 }
1578 Err(e) => {
1579 let _ = self.event_tx.send(SignalingEvent::Disconnected {
1580 reason: DisconnectReason::ConnectionFailed(e.to_string()),
1581 });
1582 tracing::error!("Connection attempt failed: {e}");
1583 Err(e)
1584 }
1585 }
1586 }
1587
1588 async fn disconnect(&self) -> NetworkResult<()> {
1589 self.disconnect_internal(true).await
1590 }
1591
1592 async fn probe_alive(&self, timeout: Duration) -> NetworkResult<()> {
1593 if !self.connected.load(Ordering::Acquire) {
1594 return Err(NetworkError::ConnectionError(
1595 "Signaling client is not connected".to_string(),
1596 ));
1597 }
1598
1599 let probe_id = self.probe_counter.fetch_add(1, Ordering::Relaxed) + 1;
1600 let payload =
1601 format!("actr-signaling-probe-{probe_id}-{}", current_unix_secs()).into_bytes();
1602 let (tx, rx) = oneshot::channel();
1603 self.pending_pongs.lock().await.insert(payload.clone(), tx);
1604
1605 let send_result = {
1606 let mut sink_guard = self.ws_sink.lock().await;
1607 match sink_guard.as_mut() {
1608 Some(sink) => sink
1609 .send(tokio_tungstenite::tungstenite::Message::Ping(
1610 payload.clone().into(),
1611 ))
1612 .await
1613 .map_err(|e| {
1614 NetworkError::ConnectionError(format!("Signaling probe ping failed: {e}"))
1615 }),
1616 None => Err(NetworkError::ConnectionError(
1617 "Signaling probe failed: WebSocket sink is not available".to_string(),
1618 )),
1619 }
1620 };
1621
1622 if let Err(e) = send_result {
1623 self.pending_pongs.lock().await.remove(&payload);
1624 let was_connected = self.connected.swap(false, Ordering::AcqRel);
1625 Self::publish_disconnected_transition(
1626 was_connected,
1627 &self.stats,
1628 &self.event_tx,
1629 self.hook_callback.get().cloned(),
1630 DisconnectReason::PingSendFailed,
1631 None,
1632 )
1633 .await;
1634 return Err(e);
1635 }
1636
1637 match tokio::time::timeout(timeout, rx).await {
1638 Ok(Ok(())) => {
1639 self.last_pong.store(current_unix_secs(), Ordering::Release);
1640 Ok(())
1641 }
1642 Ok(Err(_)) => {
1643 self.pending_pongs.lock().await.remove(&payload);
1644 Err(NetworkError::ConnectionError(
1645 "Signaling probe pong waiter dropped".to_string(),
1646 ))
1647 }
1648 Err(_) => {
1649 self.pending_pongs.lock().await.remove(&payload);
1650 Err(NetworkError::TimeoutError(format!(
1651 "Timed out waiting for signaling probe pong after {}ms",
1652 timeout.as_millis()
1653 )))
1654 }
1655 }
1656 }
1657
1658 #[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
1659 async fn send_register_request(
1660 &self,
1661 request: RegisterRequest,
1662 ) -> NetworkResult<RegisterResponse> {
1663 let flow = signaling_envelope::Flow::PeerToServer(PeerToSignaling {
1665 payload: Some(peer_to_signaling::Payload::RegisterRequest(request)),
1666 });
1667
1668 let envelope = self.create_envelope(flow).await;
1669 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
1670
1671 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
1672 {
1673 if let Some(signaling_to_actr::Payload::RegisterResponse(response)) =
1674 server_to_actr.payload
1675 {
1676 return Ok(response);
1677 }
1678 }
1679
1680 Err(NetworkError::ConnectionError(
1681 "Invalid registration response".to_string(),
1682 ))
1683 }
1684
1685 #[cfg_attr(
1686 feature = "opentelemetry",
1687 tracing::instrument(skip_all, fields(actor_id = %actor_id))
1688 )]
1689 async fn send_unregister_request(
1690 &self,
1691 actor_id: ActrId,
1692 credential: AIdCredential,
1693 reason: Option<String>,
1694 ) -> NetworkResult<UnregisterResponse> {
1695 let request = UnregisterRequest {
1697 actr_id: actor_id.clone(),
1698 reason,
1699 };
1700
1701 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
1703 source: actor_id,
1704 credential,
1705 payload: Some(actr_to_signaling::Payload::UnregisterRequest(request)),
1706 });
1707
1708 let envelope = self.create_envelope(flow).await;
1710 self.send_envelope(envelope).await?;
1711
1712 Ok(UnregisterResponse {
1717 result: Some(actr_protocol::unregister_response::Result::Success(
1718 actr_protocol::unregister_response::UnregisterOk {},
1719 )),
1720 })
1721 }
1722
1723 #[cfg_attr(
1724 feature = "opentelemetry",
1725 tracing::instrument(level = "debug", skip_all, fields(actor_id = %actor_id))
1726 )]
1727 async fn send_heartbeat(
1728 &self,
1729 actor_id: ActrId,
1730 credential: AIdCredential,
1731 availability: ServiceAvailabilityState,
1732 power_reserve: f32,
1733 mailbox_backlog: f32,
1734 ) -> NetworkResult<Pong> {
1735 let ping = Ping {
1736 availability: availability as i32,
1737 power_reserve,
1738 mailbox_backlog,
1739 sticky_client_ids: vec![], };
1741
1742 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
1743 source: actor_id,
1744 credential,
1745 payload: Some(actr_to_signaling::Payload::Ping(ping)),
1746 });
1747
1748 let envelope = self.create_envelope(flow).await;
1749 let reply_for = envelope.envelope_id.clone();
1750
1751 let (tx, rx) = oneshot::channel();
1753 self.pending_replies
1754 .lock()
1755 .await
1756 .insert(reply_for.clone(), tx);
1757
1758 if let Err(e) = self.send_envelope(envelope).await {
1759 self.pending_replies.lock().await.remove(&reply_for);
1761 return Err(e);
1762 }
1763
1764 let response_envelope = rx.await.map_err(|_| {
1766 NetworkError::ConnectionError(
1767 "Receiver dropped while waiting for heartbeat response".to_string(),
1768 )
1769 })?;
1770
1771 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
1773 {
1774 match server_to_actr.payload {
1775 Some(signaling_to_actr::Payload::Pong(pong)) => {
1776 return Ok(pong);
1777 }
1778 Some(signaling_to_actr::Payload::Error(err)) => {
1779 if err.code == 401 {
1781 return Err(NetworkError::CredentialExpired(err.message));
1782 }
1783 return Err(NetworkError::AuthenticationError(format!(
1784 "{} ({})",
1785 err.message, err.code
1786 )));
1787 }
1788 _ => {}
1789 }
1790 }
1791
1792 Err(NetworkError::ConnectionError(
1793 "Received response but not a Pong message".to_string(),
1794 ))
1795 }
1796
1797 #[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
1798 async fn send_route_candidates_request(
1799 &self,
1800 actor_id: ActrId,
1801 credential: AIdCredential,
1802 request: RouteCandidatesRequest,
1803 ) -> NetworkResult<RouteCandidatesResponse> {
1804 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
1805 source: actor_id,
1806 credential,
1807 payload: Some(actr_to_signaling::Payload::RouteCandidatesRequest(request)),
1808 });
1809
1810 let envelope = self.create_envelope(flow).await;
1811 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
1812
1813 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
1814 {
1815 match server_to_actr.payload {
1816 Some(signaling_to_actr::Payload::RouteCandidatesResponse(response)) => {
1817 return Ok(response);
1818 }
1819 Some(signaling_to_actr::Payload::Error(err)) => {
1820 return Err(NetworkError::ServiceDiscoveryError(format!(
1821 "{} ({})",
1822 err.message, err.code
1823 )));
1824 }
1825 _ => {}
1826 }
1827 }
1828
1829 Err(NetworkError::ConnectionError(
1830 "Invalid route candidates response".to_string(),
1831 ))
1832 }
1833
1834 async fn get_signing_key(
1835 &self,
1836 actor_id: ActrId,
1837 credential: AIdCredential,
1838 key_id: u32,
1839 ) -> NetworkResult<(u32, Vec<u8>)> {
1840 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
1841 source: actor_id,
1842 credential,
1843 payload: Some(actr_to_signaling::Payload::GetSigningKeyRequest(
1844 GetSigningKeyRequest { key_id },
1845 )),
1846 });
1847
1848 let envelope = self.create_envelope(flow).await;
1849 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
1850
1851 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
1852 {
1853 match server_to_actr.payload {
1854 Some(signaling_to_actr::Payload::GetSigningKeyResponse(resp)) => {
1855 return Ok((resp.key_id, resp.pubkey.to_vec()));
1856 }
1857 Some(signaling_to_actr::Payload::Error(err)) => {
1858 return Err(NetworkError::ConnectionError(format!(
1859 "get_signing_key failed: {} ({})",
1860 err.message, err.code
1861 )));
1862 }
1863 _ => {}
1864 }
1865 }
1866
1867 Err(NetworkError::ConnectionError(
1868 "get_signing_key: invalid response".to_string(),
1869 ))
1870 }
1871
1872 #[cfg_attr(
1873 feature = "opentelemetry",
1874 tracing::instrument(level = "debug", skip_all, fields(actor_id = %actor_id))
1875 )]
1876 async fn send_credential_update_request(
1877 &self,
1878 actor_id: ActrId,
1879 credential: AIdCredential,
1880 ) -> NetworkResult<RegisterResponse> {
1881 let request = CredentialUpdateRequest {
1882 actr_id: actor_id.clone(),
1883 };
1884
1885 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
1886 source: actor_id,
1887 credential,
1888 payload: Some(actr_to_signaling::Payload::CredentialUpdateRequest(request)),
1889 });
1890
1891 let envelope = self.create_envelope(flow).await;
1892 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
1893
1894 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
1895 {
1896 match server_to_actr.payload {
1897 Some(signaling_to_actr::Payload::RegisterResponse(response)) => {
1898 return Ok(response);
1899 }
1900 Some(signaling_to_actr::Payload::Error(err)) => {
1901 return Err(NetworkError::ConnectionError(format!(
1902 "Credential update failed: {} ({})",
1903 err.message, err.code
1904 )));
1905 }
1906 _ => {}
1907 }
1908 }
1909
1910 Err(NetworkError::ConnectionError(
1911 "Invalid credential update response".to_string(),
1912 ))
1913 }
1914
1915 #[cfg_attr(
1916 feature = "opentelemetry",
1917 tracing::instrument(level = "debug", skip_all, fields(envelope_id = %envelope.envelope_id))
1918 )]
1919 async fn send_envelope(&self, envelope: SignalingEnvelope) -> NetworkResult<()> {
1920 #[cfg(feature = "opentelemetry")]
1921 let envelope = {
1922 let mut envelope = envelope;
1923 trace::inject_span_context(&tracing::Span::current(), &mut envelope);
1924 envelope
1925 };
1926
1927 if !self.is_connected() {
1930 return Err(NetworkError::ConnectionError(
1931 "Cannot send: WebSocket not connected".to_string(),
1932 ));
1933 }
1934
1935 let mut sink_guard = self.ws_sink.lock().await;
1936
1937 if let Some(sink) = sink_guard.as_mut() {
1938 let mut buf = Vec::new();
1940 envelope.encode(&mut buf)?;
1941 let msg = tokio_tungstenite::tungstenite::Message::Binary(buf.into());
1942 match tokio::time::timeout(
1943 std::time::Duration::from_secs(SIGNALING_SEND_TIMEOUT_SECS),
1944 sink.send(msg),
1945 )
1946 .await
1947 {
1948 Ok(Ok(())) => {}
1949 Ok(Err(e)) => return Err(e.into()),
1950 Err(_) => {
1951 self.connected.store(false, Ordering::Release);
1952 return Err(NetworkError::ConnectionError(
1953 "Signaling WebSocket send timed out".to_string(),
1954 ));
1955 }
1956 }
1957
1958 self.stats.messages_sent.fetch_add(1, Ordering::Relaxed);
1959 tracing::debug!("Stats: {:?}", self.stats.snapshot());
1960 Ok(())
1961 } else {
1962 Err(NetworkError::ConnectionError("Not connected".to_string()))
1963 }
1964 }
1965
1966 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
1967 let mut rx = self.inbound_rx.lock().await;
1968 match rx.recv().await {
1969 Some(envelope) => Ok(Some(envelope)),
1970 None => {
1971 tracing::error!("Inbound channel closed");
1972 Err(NetworkError::ConnectionError(
1973 "Inbound channel closed".to_string(),
1974 ))
1975 }
1976 }
1977 }
1978
1979 fn is_connected(&self) -> bool {
1980 self.connected.load(Ordering::Acquire)
1981 }
1982
1983 fn get_stats(&self) -> SignalingStats {
1984 self.stats.snapshot()
1985 }
1986
1987 fn subscribe_events(&self) -> broadcast::Receiver<SignalingEvent> {
1988 self.event_tx.subscribe()
1989 }
1990
1991 async fn set_actor_id(&self, actor_id: ActrId) {
1992 *self.actor_id.lock().await = Some(actor_id);
1993 }
1994
1995 async fn set_credential_state(&self, credential_state: CredentialState) {
1996 *self.credential_state.lock().await = Some(credential_state);
1997 }
1998
1999 async fn clear_identity(&self) {
2000 *self.actor_id.lock().await = None;
2001 *self.credential_state.lock().await = None;
2002 }
2003
2004 fn set_hook_callback(&self, cb: HookCallback) {
2005 let _ = self.hook_callback.set(cb);
2006 }
2007}
2008
2009#[derive(Debug)]
2011pub(crate) struct AtomicSignalingStats {
2012 pub connections: AtomicU64,
2014
2015 pub disconnections: AtomicU64,
2017
2018 pub messages_sent: AtomicU64,
2020
2021 pub messages_received: AtomicU64,
2023
2024 pub heartbeats_sent: AtomicU64,
2027
2028 pub heartbeats_received: AtomicU64,
2031
2032 pub errors: AtomicU64,
2034}
2035
2036impl Default for AtomicSignalingStats {
2037 fn default() -> Self {
2038 Self {
2039 connections: AtomicU64::new(0),
2040 disconnections: AtomicU64::new(0),
2041 messages_sent: AtomicU64::new(0),
2042 messages_received: AtomicU64::new(0),
2043 heartbeats_sent: AtomicU64::new(0),
2044 heartbeats_received: AtomicU64::new(0),
2045 errors: AtomicU64::new(0),
2046 }
2047 }
2048}
2049
2050#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
2052pub struct SignalingStats {
2053 pub connections: u64,
2055
2056 pub disconnections: u64,
2058
2059 pub messages_sent: u64,
2061
2062 pub messages_received: u64,
2064
2065 pub heartbeats_sent: u64,
2067
2068 pub heartbeats_received: u64,
2070
2071 pub errors: u64,
2073}
2074
2075impl AtomicSignalingStats {
2076 pub fn snapshot(&self) -> SignalingStats {
2078 SignalingStats {
2079 connections: self.connections.load(Ordering::Relaxed),
2080 disconnections: self.disconnections.load(Ordering::Relaxed),
2081 messages_sent: self.messages_sent.load(Ordering::Relaxed),
2082 messages_received: self.messages_received.load(Ordering::Relaxed),
2083 heartbeats_sent: self.heartbeats_sent.load(Ordering::Relaxed),
2084 heartbeats_received: self.heartbeats_received.load(Ordering::Relaxed),
2085 errors: self.errors.load(Ordering::Relaxed),
2086 }
2087 }
2088}
2089
2090fn current_unix_secs() -> u64 {
2091 use std::time::{SystemTime, UNIX_EPOCH};
2092 SystemTime::now()
2093 .duration_since(UNIX_EPOCH)
2094 .unwrap_or_default()
2095 .as_secs()
2096}
2097
2098#[cfg(test)]
2099mod tests {
2100 use super::*;
2101 use std::future::Future;
2102 use std::pin::Pin;
2103 use std::sync::atomic::{AtomicUsize, Ordering as UsizeOrdering};
2104
2105 struct FakeSignalingClient {
2107 event_tx: broadcast::Sender<SignalingEvent>,
2108 connected: AtomicBool,
2109 connect_calls: Arc<AtomicUsize>,
2110 actor_id: tokio::sync::Mutex<Option<ActrId>>,
2111 credential_state: tokio::sync::Mutex<Option<CredentialState>>,
2112 }
2113
2114 #[async_trait]
2115 impl SignalingClient for FakeSignalingClient {
2116 async fn connect(&self) -> NetworkResult<()> {
2117 self.connect_calls.fetch_add(1, UsizeOrdering::SeqCst);
2118 Ok(())
2119 }
2120
2121 async fn disconnect(&self) -> NetworkResult<()> {
2122 Ok(())
2123 }
2124
2125 async fn send_register_request(
2126 &self,
2127 _request: RegisterRequest,
2128 ) -> NetworkResult<RegisterResponse> {
2129 unimplemented!("not needed in tests");
2130 }
2131
2132 async fn send_unregister_request(
2133 &self,
2134 _actor_id: ActrId,
2135 _credential: AIdCredential,
2136 _reason: Option<String>,
2137 ) -> NetworkResult<UnregisterResponse> {
2138 unimplemented!("not needed in tests");
2139 }
2140
2141 async fn send_heartbeat(
2142 &self,
2143 _actor_id: ActrId,
2144 _credential: AIdCredential,
2145 _availability: ServiceAvailabilityState,
2146 _power_reserve: f32,
2147 _mailbox_backlog: f32,
2148 ) -> NetworkResult<Pong> {
2149 unimplemented!("not needed in tests");
2150 }
2151
2152 async fn send_route_candidates_request(
2153 &self,
2154 _actor_id: ActrId,
2155 _credential: AIdCredential,
2156 _request: RouteCandidatesRequest,
2157 ) -> NetworkResult<RouteCandidatesResponse> {
2158 unimplemented!("not needed in tests");
2159 }
2160
2161 async fn get_signing_key(
2162 &self,
2163 _actor_id: ActrId,
2164 _credential: AIdCredential,
2165 _key_id: u32,
2166 ) -> NetworkResult<(u32, Vec<u8>)> {
2167 unimplemented!("not needed in tests");
2168 }
2169
2170 async fn send_credential_update_request(
2171 &self,
2172 _actor_id: ActrId,
2173 _credential: AIdCredential,
2174 ) -> NetworkResult<RegisterResponse> {
2175 unimplemented!("not needed in tests");
2176 }
2177
2178 async fn send_envelope(&self, _envelope: SignalingEnvelope) -> NetworkResult<()> {
2179 unimplemented!("not needed in tests");
2180 }
2181
2182 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
2183 unimplemented!("not needed in tests");
2184 }
2185
2186 fn is_connected(&self) -> bool {
2187 self.connected.load(Ordering::SeqCst)
2188 }
2189
2190 fn get_stats(&self) -> SignalingStats {
2191 SignalingStats::default()
2192 }
2193
2194 fn subscribe_events(&self) -> broadcast::Receiver<SignalingEvent> {
2195 self.event_tx.subscribe()
2196 }
2197
2198 async fn set_actor_id(&self, actor_id: ActrId) {
2199 *self.actor_id.lock().await = Some(actor_id);
2200 }
2201
2202 async fn set_credential_state(&self, credential_state: CredentialState) {
2203 *self.credential_state.lock().await = Some(credential_state);
2204 }
2205
2206 async fn clear_identity(&self) {
2207 *self.actor_id.lock().await = None;
2208 *self.credential_state.lock().await = None;
2209 }
2210 }
2211
2212 fn make_fake_client() -> Arc<FakeSignalingClient> {
2213 let (event_tx, _erx) = broadcast::channel(64);
2214 Arc::new(FakeSignalingClient {
2215 event_tx,
2216 connected: AtomicBool::new(false),
2217 connect_calls: Arc::new(AtomicUsize::new(0)),
2218 actor_id: tokio::sync::Mutex::new(None),
2219 credential_state: tokio::sync::Mutex::new(None),
2220 })
2221 }
2222
2223 fn make_config() -> SignalingConfig {
2225 SignalingConfig {
2226 server_url: Url::parse("ws://127.0.0.1:1/signaling/ws").unwrap(),
2227 connection_timeout: 2,
2228 heartbeat_interval: 30,
2229 reconnect_config: ReconnectConfig::default(),
2230 auth_config: None,
2231 webrtc_role: None,
2232 }
2233 }
2234
2235 fn make_ws_client(config: SignalingConfig) -> Arc<WebSocketSignalingClient> {
2237 Arc::new(WebSocketSignalingClient::new(config))
2238 }
2239
2240 #[tokio::test]
2241 async fn test_publish_disconnected_transition_fires_hook_once() {
2242 let stats = Arc::new(AtomicSignalingStats::default());
2243 let (event_tx, mut event_rx) = broadcast::channel(4);
2244 let hook_count = Arc::new(AtomicUsize::new(0));
2245 let hook_count_for_cb = hook_count.clone();
2246 let hook_callback: HookCallback = Arc::new(move |event| {
2247 let hook_count = hook_count_for_cb.clone();
2248 Box::pin(async move {
2249 if matches!(event, HookEvent::SignalingDisconnected) {
2250 hook_count.fetch_add(1, UsizeOrdering::SeqCst);
2251 }
2252 }) as Pin<Box<dyn Future<Output = ()> + Send>>
2253 });
2254
2255 let first = WebSocketSignalingClient::publish_disconnected_transition(
2256 true,
2257 &stats,
2258 &event_tx,
2259 Some(hook_callback.clone()),
2260 DisconnectReason::StreamEnded,
2261 None,
2262 )
2263 .await;
2264 assert!(
2265 first,
2266 "first connected->disconnected transition should publish"
2267 );
2268 assert_eq!(hook_count.load(UsizeOrdering::SeqCst), 1);
2269 assert_eq!(stats.snapshot().disconnections, 1);
2270 assert!(matches!(
2271 event_rx.recv().await,
2272 Ok(SignalingEvent::Disconnected {
2273 reason: DisconnectReason::StreamEnded
2274 })
2275 ));
2276
2277 let second = WebSocketSignalingClient::publish_disconnected_transition(
2278 false,
2279 &stats,
2280 &event_tx,
2281 Some(hook_callback),
2282 DisconnectReason::PongTimeout,
2283 None,
2284 )
2285 .await;
2286 assert!(
2287 !second,
2288 "stale duplicate disconnected transition should be ignored"
2289 );
2290 assert_eq!(hook_count.load(UsizeOrdering::SeqCst), 1);
2291 assert_eq!(stats.snapshot().disconnections, 1);
2292 assert!(event_rx.try_recv().is_err());
2293 }
2294
2295 #[test]
2300 fn test_reconnect_config_defaults() {
2301 let cfg = ReconnectConfig::default();
2302 assert!(cfg.enabled);
2303 assert_eq!(cfg.max_attempts, 10);
2304 assert_eq!(cfg.initial_delay, 1);
2305 assert_eq!(cfg.max_delay, 60);
2306 assert!((cfg.backoff_multiplier - 2.0).abs() < f64::EPSILON);
2307 }
2308
2309 #[test]
2314 fn test_websocket_signaling_client_initial_state_disconnected() {
2315 let client = WebSocketSignalingClient::new(make_config());
2316 assert!(
2317 !client.is_connected(),
2318 "newly created client should be Disconnected"
2319 );
2320 assert!(
2321 !client.connecting.load(Ordering::Acquire),
2322 "newly created client should not be in connecting state"
2323 );
2324 assert!(
2325 !client.reconnector_started.load(Ordering::Acquire),
2326 "reconnect manager should not be started automatically"
2327 );
2328 }
2329
2330 #[test]
2331 fn test_initial_stats_are_zero() {
2332 let client = WebSocketSignalingClient::new(make_config());
2333 let stats = client.get_stats();
2334 assert_eq!(stats.connections, 0);
2335 assert_eq!(stats.disconnections, 0);
2336 assert_eq!(stats.messages_sent, 0);
2337 assert_eq!(stats.messages_received, 0);
2338 assert_eq!(stats.errors, 0);
2339 }
2340
2341 #[test]
2342 fn test_signaling_url_log_redacts_credential_query_params() {
2343 let url = Url::parse(
2344 "wss://example.com/signaling?actor_id=abc&key_id=7&claims=claims-value&signature=signature-value&token=token-value",
2345 )
2346 .unwrap();
2347
2348 let redacted = WebSocketSignalingClient::redact_signaling_url_for_log(&url);
2349
2350 assert!(redacted.contains("actor_id=abc"));
2351 assert!(redacted.contains("key_id=7"));
2352 assert!(redacted.contains("claims=REDACTED"));
2353 assert!(redacted.contains("signature=REDACTED"));
2354 assert!(redacted.contains("token=REDACTED"));
2355 assert!(!redacted.contains("claims-value"));
2356 assert!(!redacted.contains("signature-value"));
2357 assert!(!redacted.contains("token-value"));
2358 }
2359
2360 #[tokio::test]
2365 async fn test_reconnect_manager_idempotent() {
2366 let client = make_ws_client(make_config());
2367
2368 client.start_reconnect_manager();
2370 assert!(
2371 client.reconnector_started.load(Ordering::Acquire),
2372 "reconnector_started should be true after first call"
2373 );
2374
2375 client.start_reconnect_manager();
2377 assert!(client.reconnector_started.load(Ordering::Acquire));
2379 }
2380
2381 #[tokio::test]
2382 async fn test_reconnect_manager_disabled_when_config_disabled() {
2383 let mut config = make_config();
2384 config.reconnect_config.enabled = false;
2385 let client = make_ws_client(config);
2386
2387 client.start_reconnect_manager();
2388 assert!(
2389 !client.reconnector_started.load(Ordering::Acquire),
2390 "reconnect manager should not start when reconnect config is disabled"
2391 );
2392 }
2393
2394 #[tokio::test]
2399 async fn test_connect_fast_path_when_already_connected() {
2400 let client = make_ws_client(make_config());
2401 client.connected.store(true, Ordering::Release);
2403
2404 let result = client.connect().await;
2406 assert!(
2407 result.is_ok(),
2408 "connect() should return Ok when already connected"
2409 );
2410 assert!(!client.connecting.load(Ordering::Acquire));
2412 }
2413
2414 #[tokio::test]
2415 async fn test_connect_sets_connecting_flag() {
2416 let mut config = make_config();
2417 config.reconnect_config.enabled = false; config.connection_timeout = 1;
2419 let client = make_ws_client(config);
2420
2421 let result = client.connect().await;
2423 assert!(
2424 result.is_err(),
2425 "connecting to unreachable address should fail"
2426 );
2427 assert!(
2428 !client.connecting.load(Ordering::Acquire),
2429 "connecting flag should be cleared after connection failure"
2430 );
2431 }
2432
2433 #[tokio::test]
2438 async fn test_event_subscribe_receives_events() {
2439 let client = make_ws_client(make_config());
2440 let mut rx = client.subscribe_events();
2441
2442 let _ = client.event_tx.send(SignalingEvent::Connected);
2444
2445 match tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv()).await {
2446 Ok(Ok(SignalingEvent::Connected)) => {} other => panic!("expected Connected event, but got {:?}", other),
2448 }
2449 }
2450
2451 #[tokio::test]
2452 async fn test_disconnect_event_on_connect_failure() {
2453 let mut config = make_config();
2454 config.reconnect_config.enabled = false;
2455 config.connection_timeout = 1;
2456 let client = make_ws_client(config);
2457 let mut rx = client.subscribe_events();
2458
2459 let _ = client.connect().await;
2461
2462 match tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()).await {
2464 Ok(Ok(SignalingEvent::Disconnected {
2465 reason: DisconnectReason::ConnectionFailed(_),
2466 })) => {} other => panic!(
2468 "expected Disconnected(ConnectionFailed) event, but got {:?}",
2469 other
2470 ),
2471 }
2472 }
2473
2474 #[tokio::test]
2479 async fn test_disconnect_clears_connected_flag() {
2480 let client = make_ws_client(make_config());
2481 client.connected.store(true, Ordering::Release);
2483 assert!(client.is_connected());
2484
2485 let result = client.disconnect().await;
2486 assert!(result.is_ok());
2487 assert!(
2488 !client.is_connected(),
2489 "should be Disconnected after disconnect()"
2490 );
2491 }
2492
2493 #[tokio::test]
2494 async fn test_disconnect_increments_disconnection_stat() {
2495 let client = make_ws_client(make_config());
2496 client.connected.store(true, Ordering::Release);
2497
2498 let stats_before = client.get_stats().disconnections;
2499 let _ = client.disconnect().await;
2500 let stats_after = client.get_stats().disconnections;
2501 assert_eq!(
2502 stats_after,
2503 stats_before + 1,
2504 "disconnect() should increment disconnection count"
2505 );
2506 }
2507
2508 #[tokio::test]
2509 async fn test_disconnect_idempotent() {
2510 let client = make_ws_client(make_config());
2511
2512 let r1 = client.disconnect().await;
2514 let r2 = client.disconnect().await;
2515 assert!(r1.is_ok());
2516 assert!(r2.is_ok());
2517 assert!(!client.is_connected());
2518 }
2519
2520 #[tokio::test]
2525 async fn test_reconnect_notify_wakes_waiter() {
2526 let notify = Arc::new(tokio::sync::Notify::new());
2527 let notify_clone = notify.clone();
2528 let woken = Arc::new(AtomicBool::new(false));
2529 let woken_clone = woken.clone();
2530
2531 let handle = tokio::spawn(async move {
2532 notify_clone.notified().await;
2533 woken_clone.store(true, Ordering::Release);
2534 });
2535
2536 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2538 assert!(
2539 !woken.load(Ordering::Acquire),
2540 "should not be woken before notification"
2541 );
2542
2543 notify.notify_one();
2545 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2546 assert!(
2547 woken.load(Ordering::Acquire),
2548 "should be woken after notification"
2549 );
2550
2551 handle.abort();
2552 }
2553
2554 #[tokio::test]
2555 async fn test_explicit_disconnect_suppresses_reconnect_cycle_in_backoff() {
2556 let mut config = make_config();
2557 config.connection_timeout = 1;
2558 config.reconnect_config = ReconnectConfig {
2559 enabled: true,
2560 max_attempts: 5,
2561 initial_delay: 1,
2562 max_delay: 1,
2563 backoff_multiplier: 1.0,
2564 };
2565 let client = make_ws_client(config);
2566 let mut rx = client.subscribe_events();
2567
2568 let reconnect_client = client.clone();
2569 let reconnect_task = tokio::spawn(async move {
2570 reconnect_client.run_reconnect_cycle().await;
2571 });
2572
2573 match tokio::time::timeout(Duration::from_secs(1), rx.recv()).await {
2574 Ok(Ok(SignalingEvent::ConnectStart { attempt: 1 })) => {}
2575 other => panic!("expected first reconnect attempt, got {other:?}"),
2576 }
2577
2578 client
2579 .disconnect()
2580 .await
2581 .expect("explicit disconnect should be idempotent");
2582
2583 tokio::time::timeout(Duration::from_secs(2), reconnect_task)
2584 .await
2585 .expect("suppressed reconnect cycle should exit promptly")
2586 .expect("reconnect task should not panic");
2587
2588 while let Ok(Ok(event)) = tokio::time::timeout(Duration::from_millis(100), rx.recv()).await
2589 {
2590 if let SignalingEvent::ConnectStart { attempt } = event {
2591 panic!("suppressed reconnect cycle sent unexpected attempt {attempt}");
2592 }
2593 }
2594
2595 assert!(
2596 client.auto_reconnect_suppressed.load(Ordering::Acquire),
2597 "explicit disconnect should suppress stale auto-reconnect cycles"
2598 );
2599 }
2600
2601 #[tokio::test]
2602 async fn test_explicit_disconnect_suppresses_in_flight_auto_reconnect_connected_event() {
2603 let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
2604 .await
2605 .expect("test listener should bind");
2606 let server_url = format!(
2607 "ws://{}/signaling/ws",
2608 listener
2609 .local_addr()
2610 .expect("test listener should have local addr")
2611 );
2612 let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>();
2613
2614 let server_task = tokio::spawn(async move {
2615 let (stream, _) = listener
2616 .accept()
2617 .await
2618 .expect("test server should accept tcp connection");
2619 let _ = release_rx.await;
2620 let ws_stream = tokio_tungstenite::accept_async(stream)
2621 .await
2622 .expect("test server should complete websocket handshake");
2623 tokio::time::sleep(Duration::from_millis(100)).await;
2624 drop(ws_stream);
2625 });
2626
2627 let mut config = make_config();
2628 config.server_url = Url::parse(&server_url).expect("test websocket URL should parse");
2629 config.connection_timeout = 5;
2630 config.reconnect_config = ReconnectConfig {
2631 enabled: true,
2632 max_attempts: 3,
2633 initial_delay: 1,
2634 max_delay: 1,
2635 backoff_multiplier: 1.0,
2636 };
2637 let client = make_ws_client(config);
2638 let mut rx = client.subscribe_events();
2639
2640 let reconnect_client = client.clone();
2641 let reconnect_task = tokio::spawn(async move {
2642 reconnect_client.run_reconnect_cycle().await;
2643 });
2644
2645 match tokio::time::timeout(Duration::from_secs(1), rx.recv()).await {
2646 Ok(Ok(SignalingEvent::ConnectStart { attempt: 1 })) => {}
2647 other => panic!("expected first reconnect attempt, got {other:?}"),
2648 }
2649
2650 client
2651 .disconnect()
2652 .await
2653 .expect("explicit disconnect should cancel the in-flight auto-reconnect");
2654 release_tx
2655 .send(())
2656 .expect("test server handshake should still be waiting");
2657
2658 tokio::time::timeout(Duration::from_secs(2), reconnect_task)
2659 .await
2660 .expect("cancelled in-flight reconnect should exit promptly")
2661 .expect("reconnect task should not panic");
2662
2663 while let Ok(Ok(event)) = tokio::time::timeout(Duration::from_millis(150), rx.recv()).await
2664 {
2665 assert!(
2666 !matches!(event, SignalingEvent::Connected),
2667 "cancelled auto-reconnect must not publish Connected"
2668 );
2669 }
2670
2671 assert!(
2672 !client.is_connected(),
2673 "cancelled auto-reconnect must not leave signaling connected"
2674 );
2675
2676 tokio::time::timeout(Duration::from_secs(1), server_task)
2677 .await
2678 .expect("test server task should finish")
2679 .expect("test server task should not panic");
2680 }
2681
2682 #[tokio::test]
2687 async fn test_build_url_without_identity() {
2688 let config = make_config();
2689 let expected_base = config.server_url.to_string();
2690 let client = WebSocketSignalingClient::new(config);
2691
2692 let url = client.build_url_with_identity().await;
2693 assert_eq!(
2694 url.to_string(),
2695 expected_base,
2696 "URL should not contain identity parameters when actor_id is not set"
2697 );
2698 }
2699
2700 #[tokio::test]
2701 async fn test_build_url_with_webrtc_role() {
2702 let mut config = make_config();
2703 config.webrtc_role = Some("answer".to_string());
2704 let client = WebSocketSignalingClient::new(config);
2705
2706 let url = client.build_url_with_identity().await;
2707 assert!(
2708 url.query().unwrap_or("").contains("webrtc_role=answer"),
2709 "URL should contain webrtc_role parameter, actual URL: {}",
2710 url
2711 );
2712 }
2713
2714 #[tokio::test]
2719 async fn test_reset_inbound_channel_creates_fresh_channel() {
2720 let client = WebSocketSignalingClient::new(make_config());
2721
2722 {
2724 let tx = client.inbound_tx.lock().await;
2725 let _ = tx.send(SignalingEnvelope::default());
2726 }
2727
2728 client.reset_inbound_channel().await;
2730
2731 let mut rx = client.inbound_rx.lock().await;
2733 let result = rx.try_recv();
2734 assert!(
2735 result.is_err(),
2736 "old messages should not be visible in the new channel after reset"
2737 );
2738 }
2739
2740 #[tokio::test]
2745 async fn test_envelope_id_monotonically_increasing() {
2746 let client = WebSocketSignalingClient::new(make_config());
2747
2748 let id1 = client.next_envelope_id().await;
2749 let id2 = client.next_envelope_id().await;
2750 let id3 = client.next_envelope_id().await;
2751
2752 assert_eq!(id1, "env-1");
2753 assert_eq!(id2, "env-2");
2754 assert_eq!(id3, "env-3");
2755 }
2756
2757 #[tokio::test]
2762 async fn test_send_envelope_fails_when_not_connected() {
2763 let client = WebSocketSignalingClient::new(make_config());
2764 let envelope = SignalingEnvelope::default();
2765
2766 let result = client.send_envelope(envelope).await;
2767 assert!(
2768 result.is_err(),
2769 "send_envelope should return error when not connected"
2770 );
2771 match result {
2772 Err(NetworkError::ConnectionError(msg)) => {
2773 assert!(
2774 msg.contains("not connected") || msg.contains("Not connected"),
2775 "error message should contain 'not connected', actual: {}",
2776 msg
2777 );
2778 }
2779 other => panic!("expected ConnectionError, got {:?}", other),
2780 }
2781 }
2782
2783 #[tokio::test]
2788 async fn test_fake_client_tracks_connect_calls() {
2789 let client = make_fake_client();
2790 assert_eq!(client.connect_calls.load(UsizeOrdering::SeqCst), 0);
2791
2792 client.connect().await.unwrap();
2793 client.connect().await.unwrap();
2794 client.connect().await.unwrap();
2795
2796 assert_eq!(
2797 client.connect_calls.load(UsizeOrdering::SeqCst),
2798 3,
2799 "FakeSignalingClient should accurately track connect call count"
2800 );
2801 }
2802}