1#[cfg(feature = "opentelemetry")]
6use super::trace;
7use crate::lifecycle::CredentialState;
8use crate::transport::{NetworkError, NetworkResult};
9#[cfg(feature = "opentelemetry")]
10use crate::wire::webrtc::trace::extract_trace_context;
11use actr_protocol::prost::Message as ProstMessage;
12use actr_protocol::{
13 AIdCredential, ActrId, ActrToSignaling, CredentialUpdateRequest, GetSigningKeyRequest,
14 PeerToSignaling, Ping, Pong, RegisterRequest, RegisterResponse, RouteCandidatesRequest,
15 RouteCandidatesResponse, ServiceAvailabilityState, SignalingEnvelope, UnregisterRequest,
16 UnregisterResponse, actr_to_signaling, peer_to_signaling, signaling_envelope,
17 signaling_to_actr,
18};
19use async_trait::async_trait;
20use base64::Engine as _;
21use futures_util::{SinkExt, StreamExt};
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::future::Future;
25use std::pin::Pin;
26use std::sync::{
27 Arc, OnceLock,
28 atomic::{AtomicBool, AtomicU64, Ordering},
29};
30use std::time::Duration;
31use tokio::net::TcpStream;
32use tokio::sync::{broadcast, mpsc, oneshot};
33use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
34use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config};
35#[cfg(feature = "opentelemetry")]
36use tracing_opentelemetry::OpenTelemetrySpanExt;
37use url::Url;
38
39type WsSink = Arc<
41 tokio::sync::Mutex<
42 Option<
43 futures_util::stream::SplitSink<
44 WebSocketStream<MaybeTlsStream<TcpStream>>,
45 tokio_tungstenite::tungstenite::Message,
46 >,
47 >,
48 >,
49>;
50
51const RESPONSE_TIMEOUT_SECS: u64 = 15;
57const PING_INTERVAL_SECS: u64 = 5;
59const PONG_TIMEOUT_SECS: u64 = 10;
60const SIGNALING_SEND_TIMEOUT_SECS: u64 = 5;
61const DISCONNECT_LOCK_TIMEOUT_SECS: u64 = 5;
62const DISCONNECT_CLOSE_TIMEOUT_SECS: u64 = 1;
63
64#[derive(Debug, Clone)]
70pub struct SignalingConfig {
71 pub server_url: Url,
73
74 pub connection_timeout: u64,
76
77 pub heartbeat_interval: u64,
79
80 pub reconnect_config: ReconnectConfig,
82
83 pub auth_config: Option<AuthConfig>,
85
86 pub webrtc_role: Option<String>,
88}
89
90#[derive(Debug, Clone)]
92pub struct ReconnectConfig {
93 pub enabled: bool,
95
96 pub max_attempts: u32,
98
99 pub initial_delay: u64,
101
102 pub max_delay: u64,
104
105 pub backoff_multiplier: f64,
107}
108
109impl Default for ReconnectConfig {
110 fn default() -> Self {
111 Self {
112 enabled: true,
113 max_attempts: 10,
114 initial_delay: 1,
115 max_delay: 60,
116 backoff_multiplier: 2.0,
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
123pub struct AuthConfig {
124 pub auth_type: AuthType,
126
127 pub credentials: HashMap<String, String>,
129}
130
131#[derive(Debug, Clone)]
133pub enum AuthType {
134 None,
136 BearerToken,
138 ApiKey,
140 Jwt,
142}
143
144#[async_trait]
154pub trait SignalingClient: Send + Sync {
155 async fn connect(&self) -> NetworkResult<()>;
157
158 async fn connect_once(&self) -> NetworkResult<()> {
163 self.connect().await
164 }
165
166 async fn disconnect(&self) -> NetworkResult<()>;
168
169 async fn probe_alive(&self, _timeout: Duration) -> NetworkResult<()> {
175 if self.is_connected() {
176 Ok(())
177 } else {
178 Err(NetworkError::ConnectionError(
179 "Signaling client is not connected".to_string(),
180 ))
181 }
182 }
183
184 async fn send_register_request(
187 &self,
188 request: RegisterRequest,
189 ) -> NetworkResult<RegisterResponse>;
190
191 async fn send_unregister_request(
196 &self,
197 actor_id: ActrId,
198 credential: AIdCredential,
199 reason: Option<String>,
200 ) -> NetworkResult<UnregisterResponse>;
201
202 async fn send_heartbeat(
205 &self,
206 actor_id: ActrId,
207 credential: AIdCredential,
208 availability: ServiceAvailabilityState,
209 power_reserve: f32,
210 mailbox_backlog: f32,
211 ) -> NetworkResult<Pong>;
212
213 async fn send_route_candidates_request(
215 &self,
216 actor_id: ActrId,
217 credential: AIdCredential,
218 request: RouteCandidatesRequest,
219 ) -> NetworkResult<RouteCandidatesResponse>;
220
221 async fn get_signing_key(
226 &self,
227 actor_id: ActrId,
228 credential: AIdCredential,
229 key_id: u32,
230 ) -> NetworkResult<(u32, Vec<u8>)>;
231
232 async fn send_credential_update_request(
237 &self,
238 actor_id: ActrId,
239 credential: AIdCredential,
240 ) -> NetworkResult<RegisterResponse>;
241
242 async fn send_envelope(&self, envelope: SignalingEnvelope) -> NetworkResult<()>;
244
245 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>>;
247
248 fn is_connected(&self) -> bool;
250
251 fn get_stats(&self) -> SignalingStats;
253 fn subscribe_events(&self) -> broadcast::Receiver<SignalingEvent>;
255
256 async fn set_actor_id(&self, actor_id: ActrId);
258 async fn set_credential_state(&self, credential_state: CredentialState);
259
260 async fn clear_identity(&self);
267
268 fn set_hook_callback(&self, _cb: HookCallback) {}
272}
273
274#[derive(Debug, Clone, Copy, PartialEq, Eq)]
276pub enum ConnectionState {
277 Disconnected,
278 Connected,
279}
280
281#[derive(Debug, Clone)]
287pub enum SignalingEvent {
288 ConnectStart { attempt: u32 },
290 Connected,
292 Disconnected { reason: DisconnectReason },
294}
295
296#[derive(Debug, Clone)]
298pub enum DisconnectReason {
299 StreamEnded,
301 PongTimeout,
303 PingSendFailed,
305 CredentialExpired,
307 Manual,
309 ConnectionFailed(String),
311}
312
313#[derive(Clone, Debug)]
322pub enum HookEvent {
323 SignalingConnectStart {
325 attempt: u32,
326 },
327 SignalingConnected,
328 SignalingDisconnected,
329 WebRtcConnectStart {
331 peer_id: ActrId,
332 },
333 WebRtcConnected {
334 peer_id: ActrId,
335 relayed: bool,
336 },
337 WebRtcDisconnected {
338 peer_id: ActrId,
339 },
340 DataStreamDeliveryUncertain {
341 stream_id: String,
342 session_id: u64,
343 reason: String,
344 },
345 WebSocketConnectStart {
347 peer_id: ActrId,
348 },
349 WebSocketConnected {
350 peer_id: ActrId,
351 },
352 WebSocketDisconnected {
353 peer_id: ActrId,
354 },
355 CredentialRenewed {
357 new_expiry: std::time::SystemTime,
358 },
359 CredentialExpiring {
360 new_expiry: std::time::SystemTime,
361 },
362 MailboxBackpressure {
364 queue_len: usize,
365 threshold: usize,
366 },
367}
368
369pub type HookCallback =
374 Arc<dyn Fn(HookEvent) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
375
376pub struct WebSocketSignalingClient {
378 config: SignalingConfig,
379 actor_id: tokio::sync::Mutex<Option<ActrId>>,
380 credential_state: tokio::sync::Mutex<Option<CredentialState>>,
381 ws_sink: WsSink,
383 ws_stream: tokio::sync::Mutex<
385 Option<futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
386 >,
387 connected: Arc<AtomicBool>,
389 connecting: Arc<AtomicBool>,
391 stats: Arc<AtomicSignalingStats>,
393 envelope_counter: tokio::sync::Mutex<u64>,
395 pending_replies: Arc<tokio::sync::Mutex<HashMap<String, oneshot::Sender<SignalingEnvelope>>>>,
397 pending_pongs: Arc<tokio::sync::Mutex<HashMap<Vec<u8>, oneshot::Sender<()>>>>,
399 probe_counter: AtomicU64,
401 inbound_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<SignalingEnvelope>>>,
403 inbound_tx: tokio::sync::Mutex<mpsc::UnboundedSender<SignalingEnvelope>>,
404 receiver_task: Arc<tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>>,
406 ping_task: tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
408 event_tx: broadcast::Sender<SignalingEvent>,
410 last_pong: Arc<AtomicU64>,
412 reconnector_started: Arc<AtomicBool>,
414 reconnect_notify: Arc<tokio::sync::Notify>,
416 hook_callback: OnceLock<HookCallback>,
418}
419
420impl WebSocketSignalingClient {
421 pub fn new(config: SignalingConfig) -> Self {
423 let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
424 let (event_tx, _event_rx) = broadcast::channel(64);
425 Self {
426 config,
427 actor_id: tokio::sync::Mutex::new(None),
428 credential_state: tokio::sync::Mutex::new(None),
429 ws_sink: Arc::new(tokio::sync::Mutex::new(None)),
430 ws_stream: tokio::sync::Mutex::new(None),
431 connected: Arc::new(AtomicBool::new(false)),
432 connecting: Arc::new(AtomicBool::new(false)),
433 stats: Arc::new(AtomicSignalingStats::default()),
434 envelope_counter: tokio::sync::Mutex::new(0),
435 pending_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
436 pending_pongs: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
437 probe_counter: AtomicU64::new(0),
438 inbound_rx: Arc::new(tokio::sync::Mutex::new(inbound_rx)),
439 inbound_tx: tokio::sync::Mutex::new(inbound_tx),
440 receiver_task: Arc::new(tokio::sync::Mutex::new(None)),
441 ping_task: tokio::sync::Mutex::new(None),
442 event_tx,
443 last_pong: Arc::new(AtomicU64::new(0)),
444 reconnector_started: Arc::new(AtomicBool::new(false)),
445 reconnect_notify: Arc::new(tokio::sync::Notify::new()),
446 hook_callback: OnceLock::new(),
447 }
448 }
449
450 async fn invoke_hook(&self, event: HookEvent) {
457 if let Some(cb) = self.hook_callback.get() {
458 cb(event).await;
459 }
460 }
461
462 async fn publish_disconnected_transition(
463 was_connected: bool,
464 stats: &Arc<AtomicSignalingStats>,
465 event_tx: &broadcast::Sender<SignalingEvent>,
466 hook_callback: Option<HookCallback>,
467 reason: DisconnectReason,
468 reconnect_notify: Option<&Arc<tokio::sync::Notify>>,
469 ) -> bool {
470 if !was_connected {
471 return false;
472 }
473
474 stats.disconnections.fetch_add(1, Ordering::Relaxed);
475
476 if let Some(cb) = hook_callback {
477 cb(HookEvent::SignalingDisconnected).await;
478 }
479
480 let _ = event_tx.send(SignalingEvent::Disconnected { reason });
481
482 if let Some(notify) = reconnect_notify {
483 notify.notify_one();
484 }
485
486 true
487 }
488
489 pub fn start_reconnect_manager(self: &Arc<Self>) {
490 if !self.config.reconnect_config.enabled {
491 return;
492 }
493 if self
494 .reconnector_started
495 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
496 .is_err()
497 {
498 return; }
500
501 tracing::info!("🔄 Starting reconnect manager for signaling client");
502
503 let client = self.clone();
504 let notify = self.reconnect_notify.clone();
505
506 tokio::spawn(async move {
507 loop {
508 notify.notified().await;
510
511 if !client.config.reconnect_config.enabled {
512 break;
513 }
514
515 client.run_reconnect_cycle().await;
517 }
518 });
519 }
520
521 async fn run_reconnect_cycle(self: &Arc<Self>) {
523 use actr_framework::ExponentialBackoff;
524
525 let cfg = &self.config.reconnect_config;
526
527 tracing::debug!("🧹 Cleaning up old WebSocket resources before reconnect");
529 if let Err(e) = self.disconnect().await {
530 tracing::warn!("⚠️ Disconnect cleanup failed (non-fatal): {e}");
531 }
532
533 let backoff = ExponentialBackoff::builder()
534 .initial_delay(std::time::Duration::from_secs(cfg.initial_delay.max(1)))
535 .max_delay(std::time::Duration::from_secs(cfg.max_delay.max(1)))
536 .max_retries(cfg.max_attempts)
537 .with_jitter()
538 .build();
539
540 let mut attempt: u32 = 0;
541
542 for delay in backoff {
543 if self.connected.load(Ordering::Acquire) {
544 tracing::debug!("Already connected, aborting reconnect cycle");
545 return;
546 }
547
548 attempt += 1;
549 let _ = self.event_tx.send(SignalingEvent::ConnectStart { attempt });
550
551 match self.establish_connection_once().await {
552 Ok(()) => {
553 tracing::info!("✅ Signaling reconnect succeeded on attempt {attempt}");
554 self.start_receiver().await;
555 self.start_ping_task().await;
556 return;
557 }
558 Err(e) => {
559 tracing::warn!(
560 "❌ Reconnect attempt {attempt} failed: {e}, retrying in {delay:?}"
561 );
562 tokio::select! {
563 _ = tokio::time::sleep(delay) => {}
564 _ = self.reconnect_notify.notified() => {
565 tracing::debug!("Explicit reconnect request interrupted reconnect backoff");
566 }
567 }
568 }
569 }
570 }
571
572 tracing::error!("Reconnect failed after {attempt} attempts, entering cooldown");
574 let cooldown = std::time::Duration::from_secs(cfg.max_delay.max(1) * 2);
575 tokio::select! {
576 _ = tokio::time::sleep(cooldown) => {}
577 _ = self.reconnect_notify.notified() => {
578 tracing::debug!("Explicit reconnect request interrupted reconnect cooldown");
579 }
580 }
581 }
583
584 #[cfg(feature = "test-utils")]
592 pub async fn connect_to(url: &str) -> NetworkResult<Arc<Self>> {
593 let config = SignalingConfig {
594 server_url: url.parse()?,
595 connection_timeout: 5,
596 heartbeat_interval: 30,
597 reconnect_config: ReconnectConfig::default(),
598 auth_config: None,
599 webrtc_role: None,
600 };
601
602 let client = Arc::new(Self::new(config));
603 client.start_reconnect_manager();
604 client.connect().await?;
605 Ok(client)
606 }
607
608 #[cfg(feature = "test-utils")]
614 pub async fn connect_to_with_identity(
615 url: &str,
616 actor_id: ActrId,
617 credential_state: CredentialState,
618 ) -> NetworkResult<Arc<Self>> {
619 let config = SignalingConfig {
620 server_url: url.parse()?,
621 connection_timeout: 5,
622 heartbeat_interval: 30,
623 reconnect_config: ReconnectConfig::default(),
624 auth_config: None,
625 webrtc_role: None,
626 };
627
628 let client = Arc::new(Self::new(config));
629 client.set_actor_id(actor_id).await;
630 client.set_credential_state(credential_state).await;
631 client.start_reconnect_manager();
632 client.connect().await?;
633 Ok(client)
634 }
635
636 async fn next_envelope_id(&self) -> String {
638 let mut counter = self.envelope_counter.lock().await;
639 *counter += 1;
640 format!("env-{}", *counter)
641 }
642
643 async fn create_envelope(&self, flow: signaling_envelope::Flow) -> SignalingEnvelope {
645 SignalingEnvelope {
646 envelope_version: 1,
647 envelope_id: self.next_envelope_id().await,
648 reply_for: None,
649 timestamp: prost_types::Timestamp {
650 seconds: chrono::Utc::now().timestamp(),
651 nanos: 0,
652 },
653 traceparent: None,
654 tracestate: None,
655 flow: Some(flow),
656 }
657 }
658
659 async fn reset_inbound_channel(&self) {
661 self.drop_pending_replies("inbound channel reset").await;
662 self.drop_pending_pongs("inbound channel reset").await;
663
664 let (tx, rx) = mpsc::unbounded_channel();
665 *self.inbound_tx.lock().await = tx;
666 *self.inbound_rx.lock().await = rx;
667 }
668
669 async fn drop_pending_replies(&self, reason: &'static str) {
670 let dropped = {
671 let mut pending = self.pending_replies.lock().await;
672 let dropped = pending.len();
673 pending.clear();
674 dropped
675 };
676
677 if dropped > 0 {
678 tracing::debug!(reason, dropped, "Dropping pending signaling reply waiters");
679 }
680 }
681
682 async fn drop_pending_pongs(&self, reason: &'static str) {
683 let dropped = {
684 let mut pending = self.pending_pongs.lock().await;
685 let dropped = pending.len();
686 pending.clear();
687 dropped
688 };
689
690 if dropped > 0 {
691 tracing::debug!(reason, dropped, "Dropping pending signaling pong waiters");
692 }
693 }
694
695 async fn build_url_with_identity(&self) -> Url {
700 let mut url = self.config.server_url.clone();
701 let actor_id_opt = self.actor_id.lock().await.clone();
702 if let Some(actor_id) = actor_id_opt {
703 let actor_str = actr_protocol::ActrId::to_string_repr(&actor_id);
704 url.query_pairs_mut().append_pair("actor_id", &actor_str);
705 }
706
707 let cred_state_opt = self.credential_state.lock().await.clone();
709 if let Some(cred_state) = cred_state_opt {
710 let cred = cred_state.credential().await;
711 let claims_b64 = base64::engine::general_purpose::STANDARD.encode(&cred.claims);
712 let sig_b64 = base64::engine::general_purpose::STANDARD.encode(&cred.signature);
713 url.query_pairs_mut()
714 .append_pair("key_id", &cred.key_id.to_string())
715 .append_pair("claims", &claims_b64)
716 .append_pair("signature", &sig_b64);
717 }
718
719 if let Some(role) = &self.config.webrtc_role {
721 url.query_pairs_mut().append_pair("webrtc_role", role);
722 }
723
724 url
725 }
726
727 fn redact_signaling_url_for_log(url: &Url) -> String {
728 let mut redacted = url.clone();
729 let pairs: Vec<(String, String)> = redacted
730 .query_pairs()
731 .map(|(key, value)| {
732 let redacted_value = match key.to_ascii_lowercase().as_str() {
733 "claims" | "signature" | "token" | "authorization" | "bearer"
734 | "access_token" | "api_key" => "REDACTED".to_string(),
735 _ => value.into_owned(),
736 };
737 (key.into_owned(), redacted_value)
738 })
739 .collect();
740
741 redacted.set_query(None);
742 if !pairs.is_empty() {
743 let mut query = redacted.query_pairs_mut();
744 for (key, value) in pairs {
745 query.append_pair(&key, &value);
746 }
747 }
748
749 redacted.to_string()
750 }
751
752 async fn establish_connection_once(&self) -> NetworkResult<()> {
756 if self.connected.load(Ordering::Acquire) {
758 tracing::debug!("Connection already established, skipping establish_connection_once()");
759 return Ok(());
760 }
761
762 let url = self.build_url_with_identity().await;
763 let timeout_secs = self.config.connection_timeout;
764 tracing::debug!(
765 "Establishing connection to URL: {}",
766 Self::redact_signaling_url_for_log(&url)
767 );
768 let config = WebSocketConfig::default().write_buffer_size(0);
770 let connect_result = if timeout_secs == 0 {
772 connect_async_with_config(url.as_str(), Some(config), false).await
773 } else {
774 let timeout_duration = std::time::Duration::from_secs(timeout_secs);
775 tokio::time::timeout(
776 timeout_duration,
777 connect_async_with_config(url.as_str(), Some(config), false),
778 )
779 .await
780 .map_err(|_| {
781 NetworkError::ConnectionError(format!(
782 "Signaling connect timeout after {}s",
783 timeout_secs
784 ))
785 })?
786 }?;
787
788 let (ws_stream, _) = connect_result;
789
790 let (sink, stream) = ws_stream.split();
792
793 *self.ws_sink.lock().await = Some(sink);
794 *self.ws_stream.lock().await = Some(stream);
795 self.connected.store(true, Ordering::Release);
796 self.last_pong.store(current_unix_secs(), Ordering::Release);
797 self.invoke_hook(HookEvent::SignalingConnected).await;
799 let _ = self.event_tx.send(SignalingEvent::Connected);
800
801 self.stats.connections.fetch_add(1, Ordering::Relaxed);
802
803 Ok(())
804 }
805
806 async fn connect_with_retries(&self) -> NetworkResult<()> {
808 use actr_framework::ExponentialBackoff;
809
810 let cfg = &self.config.reconnect_config;
811
812 if !cfg.enabled {
814 return self.establish_connection_once().await;
815 }
816
817 let backoff = ExponentialBackoff::builder()
818 .initial_delay(std::time::Duration::from_secs(cfg.initial_delay.max(1)))
819 .max_delay(std::time::Duration::from_secs(cfg.max_delay.max(1)))
820 .max_retries(cfg.max_attempts)
821 .with_jitter()
822 .build();
823
824 let mut last_err = None;
825
826 for (attempt, delay) in std::iter::once(std::time::Duration::ZERO)
828 .chain(backoff)
829 .enumerate()
830 {
831 let attempt = attempt as u32 + 1;
832 self.invoke_hook(HookEvent::SignalingConnectStart { attempt })
833 .await;
834 if delay > std::time::Duration::ZERO {
835 tracing::info!("Retry signaling connect after {delay:?} (attempt {attempt})");
836 tokio::select! {
837 _ = tokio::time::sleep(delay) => {}
838 _ = self.reconnect_notify.notified() => {
839 tracing::debug!("Explicit reconnect request interrupted signaling connect backoff");
840 }
841 }
842 }
843
844 match self.establish_connection_once().await {
845 Ok(()) => return Ok(()),
846 Err(e) => {
847 tracing::warn!("Signaling connect attempt {attempt} failed: {e:?}");
848 last_err = Some(e);
849 }
850 }
851 }
852
853 let total = cfg.max_attempts + 1; tracing::error!("Signaling connect failed after {total} attempts, giving up");
855 Err(last_err.unwrap_or_else(|| {
856 NetworkError::ConnectionError("All connection attempts failed".to_string())
857 }))
858 }
859
860 #[cfg_attr(
862 feature = "opentelemetry",
863 tracing::instrument(skip_all, fields(envelope_id = %envelope.envelope_id))
864 )]
865 async fn send_envelope_and_wait_response(
866 &self,
867 envelope: SignalingEnvelope,
868 ) -> NetworkResult<SignalingEnvelope> {
869 let reply_for = envelope.envelope_id.clone();
870
871 let (tx, rx) = oneshot::channel();
873 self.pending_replies
874 .lock()
875 .await
876 .insert(reply_for.clone(), tx);
877
878 if let Err(e) = self.send_envelope(envelope).await {
879 self.pending_replies.lock().await.remove(&reply_for);
881 return Err(e);
882 }
883
884 let result =
885 tokio::time::timeout(std::time::Duration::from_secs(RESPONSE_TIMEOUT_SECS), rx).await;
886 if result.is_err() {
888 self.pending_replies.lock().await.remove(&reply_for);
889 }
890
891 let response_envelope = result
892 .map_err(|_| {
893 NetworkError::ConnectionError(
894 "Timed out waiting for signaling response".to_string(),
895 )
896 })?
897 .map_err(|_| {
898 NetworkError::ConnectionError(
899 "Receiver dropped while waiting for signaling response".to_string(),
900 )
901 })?;
902
903 Ok(response_envelope)
904 }
905
906 async fn start_receiver(&self) {
908 let mut stream_guard = self.ws_stream.lock().await;
909 if stream_guard.is_none() {
910 return;
911 }
912
913 let mut stream = stream_guard.take().expect("stream exists");
914 let pending = self.pending_replies.clone();
915 let inbound_tx = { self.inbound_tx.lock().await.clone() };
916 let stats = self.stats.clone();
917 let connected = self.connected.clone();
918 let event_tx = self.event_tx.clone();
919 let last_pong = self.last_pong.clone();
920 let pending_pongs = self.pending_pongs.clone();
921 let reconnect_notify = self.reconnect_notify.clone();
922 let reconnect_enabled = self.config.reconnect_config.enabled;
923 let hook_callback = self.hook_callback.get().cloned();
924 let handle = tokio::spawn(async move {
925 while let Some(msg) = stream.next().await {
926 match msg {
927 Ok(tokio_tungstenite::tungstenite::Message::Binary(data)) => {
928 last_pong.store(current_unix_secs(), Ordering::Release);
930 match SignalingEnvelope::decode(&data[..]) {
931 Ok(envelope) => {
932 #[cfg(feature = "opentelemetry")]
933 let span = {
934 let span = tracing::info_span!("signaling.receive_envelope", envelope_id = %envelope.envelope_id);
935 span.set_parent(extract_trace_context(&envelope));
936 span
937 };
938
939 stats.messages_received.fetch_add(1, Ordering::Relaxed);
940 tracing::debug!("Received message: {:?}", envelope);
941 if let Some(reply_for) = envelope.reply_for.clone() {
942 if let Some(sender) = pending.lock().await.remove(&reply_for) {
943 #[cfg(feature = "opentelemetry")]
944 let _ = span.enter();
945 if let Err(e) = sender.send(envelope) {
946 stats.errors.fetch_add(1, Ordering::Relaxed);
947 tracing::warn!(
948 "Failed to send reply envelope to waiter: {e:?}",
949 );
950 }
951 continue;
952 }
953 }
954 tracing::debug!(
955 "Unmatched or push message -> forward to inbound channel"
956 );
957 if let Err(e) = inbound_tx.send(envelope) {
959 stats.errors.fetch_add(1, Ordering::Relaxed);
960 tracing::warn!(
961 "Failed to send envelope to inbound channel: {e:?}"
962 );
963 }
964 }
965 Err(e) => {
966 stats.errors.fetch_add(1, Ordering::Relaxed);
967 tracing::warn!("Failed to decode SignalingEnvelope: {e}");
968 }
969 }
970 }
971 Ok(tokio_tungstenite::tungstenite::Message::Pong(payload)) => {
972 tracing::debug!("Received pong");
973 last_pong.store(current_unix_secs(), Ordering::Release);
974 if let Some(sender) = pending_pongs.lock().await.remove(&payload.to_vec()) {
975 let _ = sender.send(());
976 }
977 }
978 Ok(tokio_tungstenite::tungstenite::Message::Ping(_)) => {
979 tracing::debug!("Received ping");
980 last_pong.store(current_unix_secs(), Ordering::Release);
981 }
982 Ok(other) => {
983 tracing::warn!("Received non-binary frame, ignoring: {other:?}");
984 }
985 Err(e) => {
986 stats.errors.fetch_add(1, Ordering::Relaxed);
987 tracing::error!("Signaling receive error: {e}");
988 break;
989 }
990 }
991 }
992
993 tracing::warn!("Stream terminated");
994 let was_connected = connected.swap(false, Ordering::AcqRel);
998 Self::publish_disconnected_transition(
999 was_connected,
1000 &stats,
1001 &event_tx,
1002 hook_callback,
1003 DisconnectReason::StreamEnded,
1004 reconnect_enabled.then_some(&reconnect_notify),
1005 )
1006 .await;
1007 pending_pongs.lock().await.clear();
1008 });
1009
1010 *self.receiver_task.lock().await = Some(handle);
1011 }
1012
1013 async fn start_ping_task(&self) {
1016 let mut existing = self.ping_task.lock().await;
1017 if let Some(handle) = existing.as_ref() {
1018 if handle.is_finished() {
1019 existing.take();
1020 } else {
1021 return;
1022 }
1023 }
1024
1025 let sink = self.ws_sink.clone();
1026 let connected = self.connected.clone();
1027 let stats = self.stats.clone();
1028 let event_tx = self.event_tx.clone();
1029 let last_pong = self.last_pong.clone();
1030 let receiver_task_clone = Arc::clone(&self.receiver_task);
1031 let reconnect_notify = self.reconnect_notify.clone();
1032 let reconnect_enabled = self.config.reconnect_config.enabled;
1033 let hook_callback = self.hook_callback.get().cloned();
1034
1035 let handle = tokio::spawn(async move {
1036 loop {
1037 tokio::time::sleep(std::time::Duration::from_secs(PING_INTERVAL_SECS)).await;
1038
1039 if !connected.load(Ordering::Acquire) {
1040 break;
1041 }
1042
1043 let mut disconnect_reason = None;
1045 {
1046 let mut sink_guard = sink.lock().await;
1047 if let Some(sink) = sink_guard.as_mut() {
1048 match tokio::time::timeout(
1049 std::time::Duration::from_secs(SIGNALING_SEND_TIMEOUT_SECS),
1050 sink.send(tokio_tungstenite::tungstenite::Message::Ping(
1051 Vec::new().into(),
1052 )),
1053 )
1054 .await
1055 {
1056 Ok(Ok(())) => {}
1057 Ok(Err(e)) => {
1058 tracing::warn!("Signaling ping send failed: {e}");
1059 disconnect_reason = Some(DisconnectReason::PingSendFailed);
1060 }
1061 Err(_) => {
1062 tracing::warn!("Signaling ping send timed out");
1063 disconnect_reason = Some(DisconnectReason::PingSendFailed);
1064 }
1065 }
1066 } else {
1067 tracing::warn!("Signaling not connected");
1068 disconnect_reason = Some(DisconnectReason::PingSendFailed);
1069 }
1070 }
1071
1072 if let Some(reason) = disconnect_reason {
1073 let was_connected = connected.swap(false, Ordering::AcqRel);
1074 Self::publish_disconnected_transition(
1075 was_connected,
1076 &stats,
1077 &event_tx,
1078 hook_callback.clone(),
1079 reason,
1080 reconnect_enabled.then_some(&reconnect_notify),
1081 )
1082 .await;
1083 break;
1084 }
1085
1086 let now = current_unix_secs();
1088 let last = last_pong.load(Ordering::Acquire);
1089 if now.saturating_sub(last) > PONG_TIMEOUT_SECS {
1090 tracing::warn!(
1091 "Signaling pong timeout (last seen {}s ago), marking disconnected",
1092 now.saturating_sub(last)
1093 );
1094 if let Some(handle) = receiver_task_clone.lock().await.take() {
1095 handle.abort();
1096 }
1097 let was_connected = connected.swap(false, Ordering::AcqRel);
1098 Self::publish_disconnected_transition(
1099 was_connected,
1100 &stats,
1101 &event_tx,
1102 hook_callback.clone(),
1103 DisconnectReason::PongTimeout,
1104 reconnect_enabled.then_some(&reconnect_notify),
1105 )
1106 .await;
1107 break;
1108 }
1109 }
1110 });
1111
1112 *existing = Some(handle);
1113 }
1114
1115 async fn wait_for_connection_result(&self) -> NetworkResult<()> {
1119 let mut event_rx = self.event_tx.subscribe();
1120 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(30);
1121
1122 loop {
1123 tokio::select! {
1124 _ = tokio::time::sleep_until(deadline) => {
1125 if self.connected.load(Ordering::Acquire) {
1127 tracing::debug!("Connection succeeded just before timeout");
1128 return Ok(());
1129 }
1130 return Err(NetworkError::ConnectionError(
1131 "Timeout waiting for concurrent connection attempt".to_string(),
1132 ));
1133 }
1134 result = event_rx.recv() => {
1135 match result {
1136 Ok(SignalingEvent::Connected) => {
1137 tracing::debug!("Connection established by another task");
1138 return Ok(());
1139 }
1140 Ok(_) => continue, Err(broadcast::error::RecvError::Lagged(n)) => {
1142 tracing::warn!("Event receiver lagged by {n} events");
1143 if self.connected.load(Ordering::Acquire) {
1145 return Ok(());
1146 }
1147 continue;
1148 }
1149 Err(broadcast::error::RecvError::Closed) => {
1150 return Err(NetworkError::ConnectionError(
1151 "Event channel closed while waiting for connection".to_string(),
1152 ));
1153 }
1154 }
1155 }
1156 }
1157 }
1158 }
1159}
1160
1161#[async_trait]
1162impl SignalingClient for WebSocketSignalingClient {
1163 async fn connect(&self) -> NetworkResult<()> {
1164 match self
1169 .connecting
1170 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
1171 {
1172 Ok(_) => {
1173 }
1176 Err(_) => {
1177 if self.connected.load(Ordering::Acquire) {
1180 tracing::debug!("Already connected, skipping connect()");
1181 return Ok(());
1182 }
1183
1184 tracing::debug!(
1186 "Another connection attempt in progress, waiting for state change..."
1187 );
1188 return self.wait_for_connection_result().await;
1189 }
1190 }
1191
1192 if self.connected.load(Ordering::Acquire) {
1197 tracing::debug!("Connection completed by another task while acquiring lock");
1198 self.connecting.store(false, Ordering::Release);
1199 return Ok(());
1200 }
1201
1202 tracing::debug!("Acquired connection lock, establishing connection...");
1203
1204 let result = self.connect_with_retries().await;
1206
1207 self.connecting.store(false, Ordering::Release);
1209
1210 match result {
1212 Ok(()) => {
1213 self.start_receiver().await;
1214 self.start_ping_task().await;
1215 Ok(())
1216 }
1217 Err(e) => {
1218 let _ = self.event_tx.send(SignalingEvent::Disconnected {
1220 reason: DisconnectReason::ConnectionFailed(e.to_string()),
1221 });
1222 tracing::error!("Connection failed: {e}");
1223 Err(e)
1224 }
1225 }
1226 }
1227
1228 async fn connect_once(&self) -> NetworkResult<()> {
1229 if self.connected.load(Ordering::Acquire) {
1230 tracing::debug!("Already connected, skipping connect_once()");
1231 return Ok(());
1232 }
1233
1234 match self
1235 .connecting
1236 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
1237 {
1238 Ok(_) => {}
1239 Err(_) => {
1240 if self.connected.load(Ordering::Acquire) {
1241 tracing::debug!("Already connected, skipping connect_once()");
1242 return Ok(());
1243 }
1244
1245 tracing::debug!(
1246 "Another connection attempt in progress, waiting for state change..."
1247 );
1248 return self.wait_for_connection_result().await;
1249 }
1250 }
1251
1252 if self.connected.load(Ordering::Acquire) {
1253 tracing::debug!("Connection completed by another task while acquiring lock");
1254 self.connecting.store(false, Ordering::Release);
1255 return Ok(());
1256 }
1257
1258 tracing::debug!(
1259 "Acquired connection lock, establishing one signaling connection attempt..."
1260 );
1261
1262 let result = self.establish_connection_once().await;
1263 self.connecting.store(false, Ordering::Release);
1264
1265 match result {
1266 Ok(()) => {
1267 self.start_receiver().await;
1268 self.start_ping_task().await;
1269 Ok(())
1270 }
1271 Err(e) => {
1272 let _ = self.event_tx.send(SignalingEvent::Disconnected {
1273 reason: DisconnectReason::ConnectionFailed(e.to_string()),
1274 });
1275 tracing::error!("Connection attempt failed: {e}");
1276 Err(e)
1277 }
1278 }
1279 }
1280
1281 async fn disconnect(&self) -> NetworkResult<()> {
1282 self.drop_pending_replies("signaling disconnect").await;
1283 self.drop_pending_pongs("signaling disconnect").await;
1284 let was_connected = self.connected.swap(false, Ordering::AcqRel);
1285
1286 let ping_handle = match tokio::time::timeout(
1291 std::time::Duration::from_secs(DISCONNECT_LOCK_TIMEOUT_SECS),
1292 self.ping_task.lock(),
1293 )
1294 .await
1295 {
1296 Ok(mut task_guard) => task_guard.take(),
1297 Err(_) => {
1298 tracing::warn!("Timed out waiting for signaling ping task lock during disconnect");
1299 None
1300 }
1301 };
1302 if let Some(handle) = ping_handle {
1303 handle.abort();
1304 }
1305
1306 let receiver_handle = match tokio::time::timeout(
1307 std::time::Duration::from_secs(DISCONNECT_LOCK_TIMEOUT_SECS),
1308 self.receiver_task.lock(),
1309 )
1310 .await
1311 {
1312 Ok(mut task_guard) => task_guard.take(),
1313 Err(_) => {
1314 tracing::warn!(
1315 "Timed out waiting for signaling receiver task lock during disconnect"
1316 );
1317 None
1318 }
1319 };
1320 if let Some(handle) = receiver_handle {
1321 handle.abort();
1322 }
1323
1324 let sink = match tokio::time::timeout(
1328 std::time::Duration::from_secs(DISCONNECT_LOCK_TIMEOUT_SECS),
1329 self.ws_sink.lock(),
1330 )
1331 .await
1332 {
1333 Ok(mut sink_guard) => sink_guard.take(),
1334 Err(_) => {
1335 tracing::warn!(
1336 "Timed out waiting for signaling WebSocket sink lock during disconnect"
1337 );
1338 None
1339 }
1340 };
1341
1342 if let Some(mut sink) = sink {
1343 match tokio::time::timeout(
1344 std::time::Duration::from_secs(DISCONNECT_CLOSE_TIMEOUT_SECS),
1345 sink.close(),
1346 )
1347 .await
1348 {
1349 Ok(Ok(())) => {}
1350 Ok(Err(e)) => {
1351 tracing::warn!("Signaling WebSocket close failed during disconnect: {}", e);
1352 }
1353 Err(_) => {
1354 tracing::warn!(
1355 "Signaling WebSocket close timed out during disconnect; continuing cleanup"
1356 );
1357 }
1358 }
1359 }
1360
1361 match tokio::time::timeout(
1362 std::time::Duration::from_secs(DISCONNECT_LOCK_TIMEOUT_SECS),
1363 self.ws_stream.lock(),
1364 )
1365 .await
1366 {
1367 Ok(mut stream_guard) => {
1368 stream_guard.take();
1369 }
1370 Err(_) => {
1371 tracing::warn!(
1372 "Timed out waiting for signaling WebSocket stream lock during disconnect"
1373 );
1374 }
1375 }
1376
1377 self.reset_inbound_channel().await;
1378
1379 Self::publish_disconnected_transition(
1381 was_connected,
1382 &self.stats,
1383 &self.event_tx,
1384 self.hook_callback.get().cloned(),
1385 DisconnectReason::Manual,
1386 None,
1387 )
1388 .await;
1389
1390 Ok(())
1391 }
1392
1393 async fn probe_alive(&self, timeout: Duration) -> NetworkResult<()> {
1394 if !self.connected.load(Ordering::Acquire) {
1395 return Err(NetworkError::ConnectionError(
1396 "Signaling client is not connected".to_string(),
1397 ));
1398 }
1399
1400 let probe_id = self.probe_counter.fetch_add(1, Ordering::Relaxed) + 1;
1401 let payload =
1402 format!("actr-signaling-probe-{probe_id}-{}", current_unix_secs()).into_bytes();
1403 let (tx, rx) = oneshot::channel();
1404 self.pending_pongs.lock().await.insert(payload.clone(), tx);
1405
1406 let send_result = {
1407 let mut sink_guard = self.ws_sink.lock().await;
1408 match sink_guard.as_mut() {
1409 Some(sink) => sink
1410 .send(tokio_tungstenite::tungstenite::Message::Ping(
1411 payload.clone().into(),
1412 ))
1413 .await
1414 .map_err(|e| {
1415 NetworkError::ConnectionError(format!("Signaling probe ping failed: {e}"))
1416 }),
1417 None => Err(NetworkError::ConnectionError(
1418 "Signaling probe failed: WebSocket sink is not available".to_string(),
1419 )),
1420 }
1421 };
1422
1423 if let Err(e) = send_result {
1424 self.pending_pongs.lock().await.remove(&payload);
1425 let was_connected = self.connected.swap(false, Ordering::AcqRel);
1426 Self::publish_disconnected_transition(
1427 was_connected,
1428 &self.stats,
1429 &self.event_tx,
1430 self.hook_callback.get().cloned(),
1431 DisconnectReason::PingSendFailed,
1432 None,
1433 )
1434 .await;
1435 return Err(e);
1436 }
1437
1438 match tokio::time::timeout(timeout, rx).await {
1439 Ok(Ok(())) => {
1440 self.last_pong.store(current_unix_secs(), Ordering::Release);
1441 Ok(())
1442 }
1443 Ok(Err(_)) => {
1444 self.pending_pongs.lock().await.remove(&payload);
1445 Err(NetworkError::ConnectionError(
1446 "Signaling probe pong waiter dropped".to_string(),
1447 ))
1448 }
1449 Err(_) => {
1450 self.pending_pongs.lock().await.remove(&payload);
1451 Err(NetworkError::TimeoutError(format!(
1452 "Timed out waiting for signaling probe pong after {}ms",
1453 timeout.as_millis()
1454 )))
1455 }
1456 }
1457 }
1458
1459 #[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
1460 async fn send_register_request(
1461 &self,
1462 request: RegisterRequest,
1463 ) -> NetworkResult<RegisterResponse> {
1464 let flow = signaling_envelope::Flow::PeerToServer(PeerToSignaling {
1466 payload: Some(peer_to_signaling::Payload::RegisterRequest(request)),
1467 });
1468
1469 let envelope = self.create_envelope(flow).await;
1470 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
1471
1472 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
1473 {
1474 if let Some(signaling_to_actr::Payload::RegisterResponse(response)) =
1475 server_to_actr.payload
1476 {
1477 return Ok(response);
1478 }
1479 }
1480
1481 Err(NetworkError::ConnectionError(
1482 "Invalid registration response".to_string(),
1483 ))
1484 }
1485
1486 #[cfg_attr(
1487 feature = "opentelemetry",
1488 tracing::instrument(skip_all, fields(actor_id = %actor_id))
1489 )]
1490 async fn send_unregister_request(
1491 &self,
1492 actor_id: ActrId,
1493 credential: AIdCredential,
1494 reason: Option<String>,
1495 ) -> NetworkResult<UnregisterResponse> {
1496 let request = UnregisterRequest {
1498 actr_id: actor_id.clone(),
1499 reason,
1500 };
1501
1502 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
1504 source: actor_id,
1505 credential,
1506 payload: Some(actr_to_signaling::Payload::UnregisterRequest(request)),
1507 });
1508
1509 let envelope = self.create_envelope(flow).await;
1511 self.send_envelope(envelope).await?;
1512
1513 Ok(UnregisterResponse {
1518 result: Some(actr_protocol::unregister_response::Result::Success(
1519 actr_protocol::unregister_response::UnregisterOk {},
1520 )),
1521 })
1522 }
1523
1524 #[cfg_attr(
1525 feature = "opentelemetry",
1526 tracing::instrument(level = "debug", skip_all, fields(actor_id = %actor_id))
1527 )]
1528 async fn send_heartbeat(
1529 &self,
1530 actor_id: ActrId,
1531 credential: AIdCredential,
1532 availability: ServiceAvailabilityState,
1533 power_reserve: f32,
1534 mailbox_backlog: f32,
1535 ) -> NetworkResult<Pong> {
1536 let ping = Ping {
1537 availability: availability as i32,
1538 power_reserve,
1539 mailbox_backlog,
1540 sticky_client_ids: vec![], };
1542
1543 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
1544 source: actor_id,
1545 credential,
1546 payload: Some(actr_to_signaling::Payload::Ping(ping)),
1547 });
1548
1549 let envelope = self.create_envelope(flow).await;
1550 let reply_for = envelope.envelope_id.clone();
1551
1552 let (tx, rx) = oneshot::channel();
1554 self.pending_replies
1555 .lock()
1556 .await
1557 .insert(reply_for.clone(), tx);
1558
1559 if let Err(e) = self.send_envelope(envelope).await {
1560 self.pending_replies.lock().await.remove(&reply_for);
1562 return Err(e);
1563 }
1564
1565 let response_envelope = rx.await.map_err(|_| {
1567 NetworkError::ConnectionError(
1568 "Receiver dropped while waiting for heartbeat response".to_string(),
1569 )
1570 })?;
1571
1572 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
1574 {
1575 match server_to_actr.payload {
1576 Some(signaling_to_actr::Payload::Pong(pong)) => {
1577 return Ok(pong);
1578 }
1579 Some(signaling_to_actr::Payload::Error(err)) => {
1580 if err.code == 401 {
1582 return Err(NetworkError::CredentialExpired(err.message));
1583 }
1584 return Err(NetworkError::AuthenticationError(format!(
1585 "{} ({})",
1586 err.message, err.code
1587 )));
1588 }
1589 _ => {}
1590 }
1591 }
1592
1593 Err(NetworkError::ConnectionError(
1594 "Received response but not a Pong message".to_string(),
1595 ))
1596 }
1597
1598 #[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
1599 async fn send_route_candidates_request(
1600 &self,
1601 actor_id: ActrId,
1602 credential: AIdCredential,
1603 request: RouteCandidatesRequest,
1604 ) -> NetworkResult<RouteCandidatesResponse> {
1605 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
1606 source: actor_id,
1607 credential,
1608 payload: Some(actr_to_signaling::Payload::RouteCandidatesRequest(request)),
1609 });
1610
1611 let envelope = self.create_envelope(flow).await;
1612 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
1613
1614 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
1615 {
1616 match server_to_actr.payload {
1617 Some(signaling_to_actr::Payload::RouteCandidatesResponse(response)) => {
1618 return Ok(response);
1619 }
1620 Some(signaling_to_actr::Payload::Error(err)) => {
1621 return Err(NetworkError::ServiceDiscoveryError(format!(
1622 "{} ({})",
1623 err.message, err.code
1624 )));
1625 }
1626 _ => {}
1627 }
1628 }
1629
1630 Err(NetworkError::ConnectionError(
1631 "Invalid route candidates response".to_string(),
1632 ))
1633 }
1634
1635 async fn get_signing_key(
1636 &self,
1637 actor_id: ActrId,
1638 credential: AIdCredential,
1639 key_id: u32,
1640 ) -> NetworkResult<(u32, Vec<u8>)> {
1641 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
1642 source: actor_id,
1643 credential,
1644 payload: Some(actr_to_signaling::Payload::GetSigningKeyRequest(
1645 GetSigningKeyRequest { key_id },
1646 )),
1647 });
1648
1649 let envelope = self.create_envelope(flow).await;
1650 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
1651
1652 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
1653 {
1654 match server_to_actr.payload {
1655 Some(signaling_to_actr::Payload::GetSigningKeyResponse(resp)) => {
1656 return Ok((resp.key_id, resp.pubkey.to_vec()));
1657 }
1658 Some(signaling_to_actr::Payload::Error(err)) => {
1659 return Err(NetworkError::ConnectionError(format!(
1660 "get_signing_key failed: {} ({})",
1661 err.message, err.code
1662 )));
1663 }
1664 _ => {}
1665 }
1666 }
1667
1668 Err(NetworkError::ConnectionError(
1669 "get_signing_key: invalid response".to_string(),
1670 ))
1671 }
1672
1673 #[cfg_attr(
1674 feature = "opentelemetry",
1675 tracing::instrument(level = "debug", skip_all, fields(actor_id = %actor_id))
1676 )]
1677 async fn send_credential_update_request(
1678 &self,
1679 actor_id: ActrId,
1680 credential: AIdCredential,
1681 ) -> NetworkResult<RegisterResponse> {
1682 let request = CredentialUpdateRequest {
1683 actr_id: actor_id.clone(),
1684 };
1685
1686 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
1687 source: actor_id,
1688 credential,
1689 payload: Some(actr_to_signaling::Payload::CredentialUpdateRequest(request)),
1690 });
1691
1692 let envelope = self.create_envelope(flow).await;
1693 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
1694
1695 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
1696 {
1697 match server_to_actr.payload {
1698 Some(signaling_to_actr::Payload::RegisterResponse(response)) => {
1699 return Ok(response);
1700 }
1701 Some(signaling_to_actr::Payload::Error(err)) => {
1702 return Err(NetworkError::ConnectionError(format!(
1703 "Credential update failed: {} ({})",
1704 err.message, err.code
1705 )));
1706 }
1707 _ => {}
1708 }
1709 }
1710
1711 Err(NetworkError::ConnectionError(
1712 "Invalid credential update response".to_string(),
1713 ))
1714 }
1715
1716 #[cfg_attr(
1717 feature = "opentelemetry",
1718 tracing::instrument(level = "debug", skip_all, fields(envelope_id = %envelope.envelope_id))
1719 )]
1720 async fn send_envelope(&self, envelope: SignalingEnvelope) -> NetworkResult<()> {
1721 #[cfg(feature = "opentelemetry")]
1722 let envelope = {
1723 let mut envelope = envelope;
1724 trace::inject_span_context(&tracing::Span::current(), &mut envelope);
1725 envelope
1726 };
1727
1728 if !self.is_connected() {
1731 return Err(NetworkError::ConnectionError(
1732 "Cannot send: WebSocket not connected".to_string(),
1733 ));
1734 }
1735
1736 let mut sink_guard = self.ws_sink.lock().await;
1737
1738 if let Some(sink) = sink_guard.as_mut() {
1739 let mut buf = Vec::new();
1741 envelope.encode(&mut buf)?;
1742 let msg = tokio_tungstenite::tungstenite::Message::Binary(buf.into());
1743 match tokio::time::timeout(
1744 std::time::Duration::from_secs(SIGNALING_SEND_TIMEOUT_SECS),
1745 sink.send(msg),
1746 )
1747 .await
1748 {
1749 Ok(Ok(())) => {}
1750 Ok(Err(e)) => return Err(e.into()),
1751 Err(_) => {
1752 self.connected.store(false, Ordering::Release);
1753 return Err(NetworkError::ConnectionError(
1754 "Signaling WebSocket send timed out".to_string(),
1755 ));
1756 }
1757 }
1758
1759 self.stats.messages_sent.fetch_add(1, Ordering::Relaxed);
1760 tracing::debug!("Stats: {:?}", self.stats.snapshot());
1761 Ok(())
1762 } else {
1763 Err(NetworkError::ConnectionError("Not connected".to_string()))
1764 }
1765 }
1766
1767 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
1768 let mut rx = self.inbound_rx.lock().await;
1769 match rx.recv().await {
1770 Some(envelope) => Ok(Some(envelope)),
1771 None => {
1772 tracing::error!("Inbound channel closed");
1773 Err(NetworkError::ConnectionError(
1774 "Inbound channel closed".to_string(),
1775 ))
1776 }
1777 }
1778 }
1779
1780 fn is_connected(&self) -> bool {
1781 self.connected.load(Ordering::Acquire)
1782 }
1783
1784 fn get_stats(&self) -> SignalingStats {
1785 self.stats.snapshot()
1786 }
1787
1788 fn subscribe_events(&self) -> broadcast::Receiver<SignalingEvent> {
1789 self.event_tx.subscribe()
1790 }
1791
1792 async fn set_actor_id(&self, actor_id: ActrId) {
1793 *self.actor_id.lock().await = Some(actor_id);
1794 }
1795
1796 async fn set_credential_state(&self, credential_state: CredentialState) {
1797 *self.credential_state.lock().await = Some(credential_state);
1798 }
1799
1800 async fn clear_identity(&self) {
1801 *self.actor_id.lock().await = None;
1802 *self.credential_state.lock().await = None;
1803 }
1804
1805 fn set_hook_callback(&self, cb: HookCallback) {
1806 let _ = self.hook_callback.set(cb);
1807 }
1808}
1809
1810#[derive(Debug)]
1812pub(crate) struct AtomicSignalingStats {
1813 pub connections: AtomicU64,
1815
1816 pub disconnections: AtomicU64,
1818
1819 pub messages_sent: AtomicU64,
1821
1822 pub messages_received: AtomicU64,
1824
1825 pub heartbeats_sent: AtomicU64,
1828
1829 pub heartbeats_received: AtomicU64,
1832
1833 pub errors: AtomicU64,
1835}
1836
1837impl Default for AtomicSignalingStats {
1838 fn default() -> Self {
1839 Self {
1840 connections: AtomicU64::new(0),
1841 disconnections: AtomicU64::new(0),
1842 messages_sent: AtomicU64::new(0),
1843 messages_received: AtomicU64::new(0),
1844 heartbeats_sent: AtomicU64::new(0),
1845 heartbeats_received: AtomicU64::new(0),
1846 errors: AtomicU64::new(0),
1847 }
1848 }
1849}
1850
1851#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
1853pub struct SignalingStats {
1854 pub connections: u64,
1856
1857 pub disconnections: u64,
1859
1860 pub messages_sent: u64,
1862
1863 pub messages_received: u64,
1865
1866 pub heartbeats_sent: u64,
1868
1869 pub heartbeats_received: u64,
1871
1872 pub errors: u64,
1874}
1875
1876impl AtomicSignalingStats {
1877 pub fn snapshot(&self) -> SignalingStats {
1879 SignalingStats {
1880 connections: self.connections.load(Ordering::Relaxed),
1881 disconnections: self.disconnections.load(Ordering::Relaxed),
1882 messages_sent: self.messages_sent.load(Ordering::Relaxed),
1883 messages_received: self.messages_received.load(Ordering::Relaxed),
1884 heartbeats_sent: self.heartbeats_sent.load(Ordering::Relaxed),
1885 heartbeats_received: self.heartbeats_received.load(Ordering::Relaxed),
1886 errors: self.errors.load(Ordering::Relaxed),
1887 }
1888 }
1889}
1890
1891fn current_unix_secs() -> u64 {
1892 use std::time::{SystemTime, UNIX_EPOCH};
1893 SystemTime::now()
1894 .duration_since(UNIX_EPOCH)
1895 .unwrap_or_default()
1896 .as_secs()
1897}
1898
1899#[cfg(test)]
1900mod tests {
1901 use super::*;
1902 use std::future::Future;
1903 use std::pin::Pin;
1904 use std::sync::atomic::{AtomicUsize, Ordering as UsizeOrdering};
1905
1906 struct FakeSignalingClient {
1908 event_tx: broadcast::Sender<SignalingEvent>,
1909 connected: AtomicBool,
1910 connect_calls: Arc<AtomicUsize>,
1911 actor_id: tokio::sync::Mutex<Option<ActrId>>,
1912 credential_state: tokio::sync::Mutex<Option<CredentialState>>,
1913 }
1914
1915 #[async_trait]
1916 impl SignalingClient for FakeSignalingClient {
1917 async fn connect(&self) -> NetworkResult<()> {
1918 self.connect_calls.fetch_add(1, UsizeOrdering::SeqCst);
1919 Ok(())
1920 }
1921
1922 async fn disconnect(&self) -> NetworkResult<()> {
1923 Ok(())
1924 }
1925
1926 async fn send_register_request(
1927 &self,
1928 _request: RegisterRequest,
1929 ) -> NetworkResult<RegisterResponse> {
1930 unimplemented!("not needed in tests");
1931 }
1932
1933 async fn send_unregister_request(
1934 &self,
1935 _actor_id: ActrId,
1936 _credential: AIdCredential,
1937 _reason: Option<String>,
1938 ) -> NetworkResult<UnregisterResponse> {
1939 unimplemented!("not needed in tests");
1940 }
1941
1942 async fn send_heartbeat(
1943 &self,
1944 _actor_id: ActrId,
1945 _credential: AIdCredential,
1946 _availability: ServiceAvailabilityState,
1947 _power_reserve: f32,
1948 _mailbox_backlog: f32,
1949 ) -> NetworkResult<Pong> {
1950 unimplemented!("not needed in tests");
1951 }
1952
1953 async fn send_route_candidates_request(
1954 &self,
1955 _actor_id: ActrId,
1956 _credential: AIdCredential,
1957 _request: RouteCandidatesRequest,
1958 ) -> NetworkResult<RouteCandidatesResponse> {
1959 unimplemented!("not needed in tests");
1960 }
1961
1962 async fn get_signing_key(
1963 &self,
1964 _actor_id: ActrId,
1965 _credential: AIdCredential,
1966 _key_id: u32,
1967 ) -> NetworkResult<(u32, Vec<u8>)> {
1968 unimplemented!("not needed in tests");
1969 }
1970
1971 async fn send_credential_update_request(
1972 &self,
1973 _actor_id: ActrId,
1974 _credential: AIdCredential,
1975 ) -> NetworkResult<RegisterResponse> {
1976 unimplemented!("not needed in tests");
1977 }
1978
1979 async fn send_envelope(&self, _envelope: SignalingEnvelope) -> NetworkResult<()> {
1980 unimplemented!("not needed in tests");
1981 }
1982
1983 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
1984 unimplemented!("not needed in tests");
1985 }
1986
1987 fn is_connected(&self) -> bool {
1988 self.connected.load(Ordering::SeqCst)
1989 }
1990
1991 fn get_stats(&self) -> SignalingStats {
1992 SignalingStats::default()
1993 }
1994
1995 fn subscribe_events(&self) -> broadcast::Receiver<SignalingEvent> {
1996 self.event_tx.subscribe()
1997 }
1998
1999 async fn set_actor_id(&self, actor_id: ActrId) {
2000 *self.actor_id.lock().await = Some(actor_id);
2001 }
2002
2003 async fn set_credential_state(&self, credential_state: CredentialState) {
2004 *self.credential_state.lock().await = Some(credential_state);
2005 }
2006
2007 async fn clear_identity(&self) {
2008 *self.actor_id.lock().await = None;
2009 *self.credential_state.lock().await = None;
2010 }
2011 }
2012
2013 fn make_fake_client() -> Arc<FakeSignalingClient> {
2014 let (event_tx, _erx) = broadcast::channel(64);
2015 Arc::new(FakeSignalingClient {
2016 event_tx,
2017 connected: AtomicBool::new(false),
2018 connect_calls: Arc::new(AtomicUsize::new(0)),
2019 actor_id: tokio::sync::Mutex::new(None),
2020 credential_state: tokio::sync::Mutex::new(None),
2021 })
2022 }
2023
2024 fn make_config() -> SignalingConfig {
2026 SignalingConfig {
2027 server_url: Url::parse("ws://127.0.0.1:1/signaling/ws").unwrap(),
2028 connection_timeout: 2,
2029 heartbeat_interval: 30,
2030 reconnect_config: ReconnectConfig::default(),
2031 auth_config: None,
2032 webrtc_role: None,
2033 }
2034 }
2035
2036 fn make_ws_client(config: SignalingConfig) -> Arc<WebSocketSignalingClient> {
2038 Arc::new(WebSocketSignalingClient::new(config))
2039 }
2040
2041 #[tokio::test]
2042 async fn test_publish_disconnected_transition_fires_hook_once() {
2043 let stats = Arc::new(AtomicSignalingStats::default());
2044 let (event_tx, mut event_rx) = broadcast::channel(4);
2045 let hook_count = Arc::new(AtomicUsize::new(0));
2046 let hook_count_for_cb = hook_count.clone();
2047 let hook_callback: HookCallback = Arc::new(move |event| {
2048 let hook_count = hook_count_for_cb.clone();
2049 Box::pin(async move {
2050 if matches!(event, HookEvent::SignalingDisconnected) {
2051 hook_count.fetch_add(1, UsizeOrdering::SeqCst);
2052 }
2053 }) as Pin<Box<dyn Future<Output = ()> + Send>>
2054 });
2055
2056 let first = WebSocketSignalingClient::publish_disconnected_transition(
2057 true,
2058 &stats,
2059 &event_tx,
2060 Some(hook_callback.clone()),
2061 DisconnectReason::StreamEnded,
2062 None,
2063 )
2064 .await;
2065 assert!(
2066 first,
2067 "first connected->disconnected transition should publish"
2068 );
2069 assert_eq!(hook_count.load(UsizeOrdering::SeqCst), 1);
2070 assert_eq!(stats.snapshot().disconnections, 1);
2071 assert!(matches!(
2072 event_rx.recv().await,
2073 Ok(SignalingEvent::Disconnected {
2074 reason: DisconnectReason::StreamEnded
2075 })
2076 ));
2077
2078 let second = WebSocketSignalingClient::publish_disconnected_transition(
2079 false,
2080 &stats,
2081 &event_tx,
2082 Some(hook_callback),
2083 DisconnectReason::PongTimeout,
2084 None,
2085 )
2086 .await;
2087 assert!(
2088 !second,
2089 "stale duplicate disconnected transition should be ignored"
2090 );
2091 assert_eq!(hook_count.load(UsizeOrdering::SeqCst), 1);
2092 assert_eq!(stats.snapshot().disconnections, 1);
2093 assert!(event_rx.try_recv().is_err());
2094 }
2095
2096 #[test]
2101 fn test_reconnect_config_defaults() {
2102 let cfg = ReconnectConfig::default();
2103 assert!(cfg.enabled);
2104 assert_eq!(cfg.max_attempts, 10);
2105 assert_eq!(cfg.initial_delay, 1);
2106 assert_eq!(cfg.max_delay, 60);
2107 assert!((cfg.backoff_multiplier - 2.0).abs() < f64::EPSILON);
2108 }
2109
2110 #[test]
2115 fn test_websocket_signaling_client_initial_state_disconnected() {
2116 let client = WebSocketSignalingClient::new(make_config());
2117 assert!(
2118 !client.is_connected(),
2119 "newly created client should be Disconnected"
2120 );
2121 assert!(
2122 !client.connecting.load(Ordering::Acquire),
2123 "newly created client should not be in connecting state"
2124 );
2125 assert!(
2126 !client.reconnector_started.load(Ordering::Acquire),
2127 "reconnect manager should not be started automatically"
2128 );
2129 }
2130
2131 #[test]
2132 fn test_initial_stats_are_zero() {
2133 let client = WebSocketSignalingClient::new(make_config());
2134 let stats = client.get_stats();
2135 assert_eq!(stats.connections, 0);
2136 assert_eq!(stats.disconnections, 0);
2137 assert_eq!(stats.messages_sent, 0);
2138 assert_eq!(stats.messages_received, 0);
2139 assert_eq!(stats.errors, 0);
2140 }
2141
2142 #[test]
2143 fn test_signaling_url_log_redacts_credential_query_params() {
2144 let url = Url::parse(
2145 "wss://example.com/signaling?actor_id=abc&key_id=7&claims=claims-value&signature=signature-value&token=token-value",
2146 )
2147 .unwrap();
2148
2149 let redacted = WebSocketSignalingClient::redact_signaling_url_for_log(&url);
2150
2151 assert!(redacted.contains("actor_id=abc"));
2152 assert!(redacted.contains("key_id=7"));
2153 assert!(redacted.contains("claims=REDACTED"));
2154 assert!(redacted.contains("signature=REDACTED"));
2155 assert!(redacted.contains("token=REDACTED"));
2156 assert!(!redacted.contains("claims-value"));
2157 assert!(!redacted.contains("signature-value"));
2158 assert!(!redacted.contains("token-value"));
2159 }
2160
2161 #[tokio::test]
2166 async fn test_reconnect_manager_idempotent() {
2167 let client = make_ws_client(make_config());
2168
2169 client.start_reconnect_manager();
2171 assert!(
2172 client.reconnector_started.load(Ordering::Acquire),
2173 "reconnector_started should be true after first call"
2174 );
2175
2176 client.start_reconnect_manager();
2178 assert!(client.reconnector_started.load(Ordering::Acquire));
2180 }
2181
2182 #[tokio::test]
2183 async fn test_reconnect_manager_disabled_when_config_disabled() {
2184 let mut config = make_config();
2185 config.reconnect_config.enabled = false;
2186 let client = make_ws_client(config);
2187
2188 client.start_reconnect_manager();
2189 assert!(
2190 !client.reconnector_started.load(Ordering::Acquire),
2191 "reconnect manager should not start when reconnect config is disabled"
2192 );
2193 }
2194
2195 #[tokio::test]
2200 async fn test_connect_fast_path_when_already_connected() {
2201 let client = make_ws_client(make_config());
2202 client.connected.store(true, Ordering::Release);
2204
2205 let result = client.connect().await;
2207 assert!(
2208 result.is_ok(),
2209 "connect() should return Ok when already connected"
2210 );
2211 assert!(!client.connecting.load(Ordering::Acquire));
2213 }
2214
2215 #[tokio::test]
2216 async fn test_connect_sets_connecting_flag() {
2217 let mut config = make_config();
2218 config.reconnect_config.enabled = false; config.connection_timeout = 1;
2220 let client = make_ws_client(config);
2221
2222 let result = client.connect().await;
2224 assert!(
2225 result.is_err(),
2226 "connecting to unreachable address should fail"
2227 );
2228 assert!(
2229 !client.connecting.load(Ordering::Acquire),
2230 "connecting flag should be cleared after connection failure"
2231 );
2232 }
2233
2234 #[tokio::test]
2239 async fn test_event_subscribe_receives_events() {
2240 let client = make_ws_client(make_config());
2241 let mut rx = client.subscribe_events();
2242
2243 let _ = client.event_tx.send(SignalingEvent::Connected);
2245
2246 match tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv()).await {
2247 Ok(Ok(SignalingEvent::Connected)) => {} other => panic!("expected Connected event, but got {:?}", other),
2249 }
2250 }
2251
2252 #[tokio::test]
2253 async fn test_disconnect_event_on_connect_failure() {
2254 let mut config = make_config();
2255 config.reconnect_config.enabled = false;
2256 config.connection_timeout = 1;
2257 let client = make_ws_client(config);
2258 let mut rx = client.subscribe_events();
2259
2260 let _ = client.connect().await;
2262
2263 match tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()).await {
2265 Ok(Ok(SignalingEvent::Disconnected {
2266 reason: DisconnectReason::ConnectionFailed(_),
2267 })) => {} other => panic!(
2269 "expected Disconnected(ConnectionFailed) event, but got {:?}",
2270 other
2271 ),
2272 }
2273 }
2274
2275 #[tokio::test]
2280 async fn test_disconnect_clears_connected_flag() {
2281 let client = make_ws_client(make_config());
2282 client.connected.store(true, Ordering::Release);
2284 assert!(client.is_connected());
2285
2286 let result = client.disconnect().await;
2287 assert!(result.is_ok());
2288 assert!(
2289 !client.is_connected(),
2290 "should be Disconnected after disconnect()"
2291 );
2292 }
2293
2294 #[tokio::test]
2295 async fn test_disconnect_increments_disconnection_stat() {
2296 let client = make_ws_client(make_config());
2297 client.connected.store(true, Ordering::Release);
2298
2299 let stats_before = client.get_stats().disconnections;
2300 let _ = client.disconnect().await;
2301 let stats_after = client.get_stats().disconnections;
2302 assert_eq!(
2303 stats_after,
2304 stats_before + 1,
2305 "disconnect() should increment disconnection count"
2306 );
2307 }
2308
2309 #[tokio::test]
2310 async fn test_disconnect_idempotent() {
2311 let client = make_ws_client(make_config());
2312
2313 let r1 = client.disconnect().await;
2315 let r2 = client.disconnect().await;
2316 assert!(r1.is_ok());
2317 assert!(r2.is_ok());
2318 assert!(!client.is_connected());
2319 }
2320
2321 #[tokio::test]
2326 async fn test_reconnect_notify_wakes_waiter() {
2327 let notify = Arc::new(tokio::sync::Notify::new());
2328 let notify_clone = notify.clone();
2329 let woken = Arc::new(AtomicBool::new(false));
2330 let woken_clone = woken.clone();
2331
2332 let handle = tokio::spawn(async move {
2333 notify_clone.notified().await;
2334 woken_clone.store(true, Ordering::Release);
2335 });
2336
2337 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2339 assert!(
2340 !woken.load(Ordering::Acquire),
2341 "should not be woken before notification"
2342 );
2343
2344 notify.notify_one();
2346 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2347 assert!(
2348 woken.load(Ordering::Acquire),
2349 "should be woken after notification"
2350 );
2351
2352 handle.abort();
2353 }
2354
2355 #[tokio::test]
2360 async fn test_build_url_without_identity() {
2361 let config = make_config();
2362 let expected_base = config.server_url.to_string();
2363 let client = WebSocketSignalingClient::new(config);
2364
2365 let url = client.build_url_with_identity().await;
2366 assert_eq!(
2367 url.to_string(),
2368 expected_base,
2369 "URL should not contain identity parameters when actor_id is not set"
2370 );
2371 }
2372
2373 #[tokio::test]
2374 async fn test_build_url_with_webrtc_role() {
2375 let mut config = make_config();
2376 config.webrtc_role = Some("answer".to_string());
2377 let client = WebSocketSignalingClient::new(config);
2378
2379 let url = client.build_url_with_identity().await;
2380 assert!(
2381 url.query().unwrap_or("").contains("webrtc_role=answer"),
2382 "URL should contain webrtc_role parameter, actual URL: {}",
2383 url
2384 );
2385 }
2386
2387 #[tokio::test]
2392 async fn test_reset_inbound_channel_creates_fresh_channel() {
2393 let client = WebSocketSignalingClient::new(make_config());
2394
2395 {
2397 let tx = client.inbound_tx.lock().await;
2398 let _ = tx.send(SignalingEnvelope::default());
2399 }
2400
2401 client.reset_inbound_channel().await;
2403
2404 let mut rx = client.inbound_rx.lock().await;
2406 let result = rx.try_recv();
2407 assert!(
2408 result.is_err(),
2409 "old messages should not be visible in the new channel after reset"
2410 );
2411 }
2412
2413 #[tokio::test]
2418 async fn test_envelope_id_monotonically_increasing() {
2419 let client = WebSocketSignalingClient::new(make_config());
2420
2421 let id1 = client.next_envelope_id().await;
2422 let id2 = client.next_envelope_id().await;
2423 let id3 = client.next_envelope_id().await;
2424
2425 assert_eq!(id1, "env-1");
2426 assert_eq!(id2, "env-2");
2427 assert_eq!(id3, "env-3");
2428 }
2429
2430 #[tokio::test]
2435 async fn test_send_envelope_fails_when_not_connected() {
2436 let client = WebSocketSignalingClient::new(make_config());
2437 let envelope = SignalingEnvelope::default();
2438
2439 let result = client.send_envelope(envelope).await;
2440 assert!(
2441 result.is_err(),
2442 "send_envelope should return error when not connected"
2443 );
2444 match result {
2445 Err(NetworkError::ConnectionError(msg)) => {
2446 assert!(
2447 msg.contains("not connected") || msg.contains("Not connected"),
2448 "error message should contain 'not connected', actual: {}",
2449 msg
2450 );
2451 }
2452 other => panic!("expected ConnectionError, got {:?}", other),
2453 }
2454 }
2455
2456 #[tokio::test]
2461 async fn test_fake_client_tracks_connect_calls() {
2462 let client = make_fake_client();
2463 assert_eq!(client.connect_calls.load(UsizeOrdering::SeqCst), 0);
2464
2465 client.connect().await.unwrap();
2466 client.connect().await.unwrap();
2467 client.connect().await.unwrap();
2468
2469 assert_eq!(
2470 client.connect_calls.load(UsizeOrdering::SeqCst),
2471 3,
2472 "FakeSignalingClient should accurately track connect call count"
2473 );
2474 }
2475}