1#[cfg(feature = "opentelemetry")]
6use super::trace;
7use crate::lifecycle::CredentialState;
8use crate::transport::error::{NetworkError, NetworkResult};
9#[cfg(feature = "opentelemetry")]
10use crate::wire::webrtc::trace::extract_trace_context;
11#[cfg(feature = "opentelemetry")]
12use actr_protocol::ActrIdExt;
13use actr_protocol::prost::Message as ProstMessage;
14use actr_protocol::{
15 AIdCredential, ActrId, ActrToSignaling, CredentialUpdateRequest, PeerToSignaling, Ping, Pong,
16 RegisterRequest, RegisterResponse, RouteCandidatesRequest, RouteCandidatesResponse,
17 ServiceAvailabilityState, SignalingEnvelope, UnregisterRequest, UnregisterResponse,
18 actr_to_signaling, peer_to_signaling, signaling_envelope, signaling_to_actr,
19};
20use async_trait::async_trait;
21use base64::Engine as _;
22use futures_util::{SinkExt, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::sync::{
26 Arc,
27 atomic::{AtomicBool, AtomicU64, Ordering},
28};
29use tokio::net::TcpStream;
30use tokio::sync::{mpsc, oneshot, watch};
31use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
32use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config};
33use tokio_util::sync::CancellationToken;
34#[cfg(feature = "opentelemetry")]
35use tracing_opentelemetry::OpenTelemetrySpanExt;
36use url::Url;
37
38const RESPONSE_TIMEOUT_SECS: u64 = 15;
44const PING_INTERVAL_SECS: u64 = 5;
46const PONG_TIMEOUT_SECS: u64 = 10;
47
48#[derive(Debug, Clone)]
54pub struct SignalingConfig {
55 pub server_url: Url,
57
58 pub connection_timeout: u64,
60
61 pub heartbeat_interval: u64,
63
64 pub reconnect_config: ReconnectConfig,
66
67 pub auth_config: Option<AuthConfig>,
69}
70
71#[derive(Debug, Clone)]
73pub struct ReconnectConfig {
74 pub enabled: bool,
76
77 pub max_attempts: u32,
79
80 pub initial_delay: u64,
82
83 pub max_delay: u64,
85
86 pub backoff_multiplier: f64,
88}
89
90impl Default for ReconnectConfig {
91 fn default() -> Self {
92 Self {
93 enabled: true,
94 max_attempts: 10,
95 initial_delay: 1,
96 max_delay: 60,
97 backoff_multiplier: 2.0,
98 }
99 }
100}
101
102#[derive(Debug, Clone)]
104pub struct AuthConfig {
105 pub auth_type: AuthType,
107
108 pub credentials: HashMap<String, String>,
110}
111
112#[derive(Debug, Clone)]
114pub enum AuthType {
115 None,
117 BearerToken,
119 ApiKey,
121 Jwt,
123}
124
125#[async_trait]
135pub trait SignalingClient: Send + Sync {
136 async fn connect(&self) -> NetworkResult<()>;
138
139 async fn disconnect(&self) -> NetworkResult<()>;
141
142 async fn send_register_request(
144 &self,
145 request: RegisterRequest,
146 ) -> NetworkResult<RegisterResponse>;
147
148 async fn send_unregister_request(
153 &self,
154 actor_id: ActrId,
155 credential: AIdCredential,
156 reason: Option<String>,
157 ) -> NetworkResult<UnregisterResponse>;
158
159 async fn send_heartbeat(
162 &self,
163 actor_id: ActrId,
164 credential: AIdCredential,
165 availability: ServiceAvailabilityState,
166 power_reserve: f32,
167 mailbox_backlog: f32,
168 ) -> NetworkResult<Pong>;
169
170 async fn send_route_candidates_request(
172 &self,
173 actor_id: ActrId,
174 credential: AIdCredential,
175 request: RouteCandidatesRequest,
176 ) -> NetworkResult<RouteCandidatesResponse>;
177
178 async fn send_credential_update_request(
183 &self,
184 actor_id: ActrId,
185 credential: AIdCredential,
186 ) -> NetworkResult<RegisterResponse>;
187
188 async fn send_envelope(&self, envelope: SignalingEnvelope) -> NetworkResult<()>;
190
191 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>>;
193
194 fn is_connected(&self) -> bool;
196
197 fn get_stats(&self) -> SignalingStats;
199 fn subscribe_state(&self) -> watch::Receiver<ConnectionState>;
201
202 async fn set_actor_id(&self, actor_id: ActrId);
204 async fn set_credential_state(&self, credential_state: CredentialState);
205}
206
207#[derive(Debug, Clone, Copy, PartialEq, Eq)]
209pub enum ConnectionState {
210 Disconnected,
211 Connected,
212}
213
214pub struct WebSocketSignalingClient {
216 config: SignalingConfig,
217 actor_id: tokio::sync::Mutex<Option<ActrId>>,
218 credential_state: tokio::sync::Mutex<Option<CredentialState>>,
219 ws_sink: Arc<
221 tokio::sync::Mutex<
222 Option<
223 futures_util::stream::SplitSink<
224 WebSocketStream<MaybeTlsStream<TcpStream>>,
225 tokio_tungstenite::tungstenite::Message,
226 >,
227 >,
228 >,
229 >,
230 ws_stream: tokio::sync::Mutex<
232 Option<futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
233 >,
234 connected: Arc<AtomicBool>,
236 stats: Arc<AtomicSignalingStats>,
238 envelope_counter: tokio::sync::Mutex<u64>,
240 pending_replies: Arc<tokio::sync::Mutex<HashMap<String, oneshot::Sender<SignalingEnvelope>>>>,
242 inbound_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<SignalingEnvelope>>>,
244 inbound_tx: tokio::sync::Mutex<mpsc::UnboundedSender<SignalingEnvelope>>,
245 receiver_task: Arc<tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>>,
247 ping_task: tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
249 state_tx: watch::Sender<ConnectionState>,
251 last_pong: Arc<AtomicU64>,
253}
254
255impl WebSocketSignalingClient {
256 pub fn new(config: SignalingConfig) -> Self {
258 let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
259 let (state_tx, _state_rx) = watch::channel(ConnectionState::Disconnected);
260 Self {
261 config,
262 actor_id: tokio::sync::Mutex::new(None),
263 credential_state: tokio::sync::Mutex::new(None),
264 ws_sink: Arc::new(tokio::sync::Mutex::new(None)),
265 ws_stream: tokio::sync::Mutex::new(None),
266 connected: Arc::new(AtomicBool::new(false)),
267 stats: Arc::new(AtomicSignalingStats::default()),
268 envelope_counter: tokio::sync::Mutex::new(0),
269 pending_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
270 inbound_rx: Arc::new(tokio::sync::Mutex::new(inbound_rx)),
271 inbound_tx: tokio::sync::Mutex::new(inbound_tx),
272 receiver_task: Arc::new(tokio::sync::Mutex::new(None)),
273 ping_task: tokio::sync::Mutex::new(None),
274 state_tx,
275 last_pong: Arc::new(AtomicU64::new(0)),
276 }
277 }
278
279 pub async fn connect_to(url: &str) -> NetworkResult<Self> {
281 let config = SignalingConfig {
282 server_url: url.parse()?,
283 connection_timeout: 30,
284 heartbeat_interval: 30,
285 reconnect_config: ReconnectConfig::default(),
286 auth_config: None,
287 };
288
289 let client = Self::new(config);
290 client.connect().await?;
291 Ok(client)
292 }
293
294 async fn next_envelope_id(&self) -> String {
296 let mut counter = self.envelope_counter.lock().await;
297 *counter += 1;
298 format!("env-{}", *counter)
299 }
300
301 async fn create_envelope(&self, flow: signaling_envelope::Flow) -> SignalingEnvelope {
303 SignalingEnvelope {
304 envelope_version: 1,
305 envelope_id: self.next_envelope_id().await,
306 reply_for: None,
307 timestamp: prost_types::Timestamp {
308 seconds: chrono::Utc::now().timestamp(),
309 nanos: 0,
310 },
311 traceparent: None,
312 tracestate: None,
313 flow: Some(flow),
314 }
315 }
316
317 async fn reset_inbound_channel(&self) {
319 let (tx, rx) = mpsc::unbounded_channel();
320 *self.inbound_tx.lock().await = tx;
321 *self.inbound_rx.lock().await = rx;
322 }
323
324 async fn build_url_with_identity(&self) -> Url {
326 let mut url = self.config.server_url.clone();
327 let actor_id_opt = self.actor_id.lock().await.clone();
328 let credential_state_opt = self.credential_state.lock().await.clone();
329 if let (Some(actor_id), Some(credential_state)) = (actor_id_opt, credential_state_opt) {
330 let credential = credential_state.credential().await;
331 let actor_str = actr_protocol::ActrIdExt::to_string_repr(&actor_id);
332 let token_b64 =
333 base64::engine::general_purpose::STANDARD.encode(&credential.encrypted_token);
334 {
335 let mut pairs = url.query_pairs_mut();
336 pairs.append_pair("actor_id", &actor_str);
337 pairs.append_pair("token", &token_b64);
338 pairs.append_pair("token_key_id", &credential.token_key_id.to_string());
339 }
340 }
341 url
342 }
343
344 async fn establish_connection_once(&self) -> NetworkResult<()> {
348 let url = self.build_url_with_identity().await;
349 let timeout_secs = self.config.connection_timeout;
350 tracing::debug!("Establishing connection to URL: {}", url.as_str());
351 let config = WebSocketConfig::default().write_buffer_size(0);
353 let connect_result = if timeout_secs == 0 {
355 connect_async_with_config(url.as_str(), Some(config), false).await
356 } else {
357 let timeout_duration = std::time::Duration::from_secs(timeout_secs);
358 tokio::time::timeout(
359 timeout_duration,
360 connect_async_with_config(url.as_str(), Some(config), false),
361 )
362 .await
363 .map_err(|_| {
364 NetworkError::ConnectionError(format!(
365 "Signaling connect timeout after {}s",
366 timeout_secs
367 ))
368 })?
369 }?;
370
371 let (ws_stream, _) = connect_result;
372
373 let (sink, stream) = ws_stream.split();
375
376 *self.ws_sink.lock().await = Some(sink);
377 *self.ws_stream.lock().await = Some(stream);
378 self.connected.store(true, Ordering::Release);
379 self.last_pong.store(current_unix_secs(), Ordering::Release);
380 let _ = self.state_tx.send(ConnectionState::Connected);
382
383 self.stats.connections.fetch_add(1, Ordering::Relaxed);
384
385 Ok(())
386 }
387
388 async fn connect_with_retries(&self) -> NetworkResult<()> {
390 let cfg = &self.config.reconnect_config;
391
392 if !cfg.enabled {
394 return self.establish_connection_once().await;
395 }
396
397 let mut attempt: u32 = 0;
398 let mut delay_secs = cfg.initial_delay.max(1);
399
400 loop {
401 attempt += 1;
402
403 match self.establish_connection_once().await {
404 Ok(()) => {
405 return Ok(());
406 }
407 Err(e) => {
408 tracing::warn!("Signaling connect attempt {} failed: {e:?}", attempt);
409
410 if attempt >= cfg.max_attempts {
411 tracing::error!(
412 "Signaling connect failed after {} attempts, giving up",
413 attempt
414 );
415 return Err(e);
416 }
417
418 let sleep_secs = delay_secs.min(cfg.max_delay.max(1));
419 tracing::info!("Retry signaling connect after {}s", sleep_secs);
420 tokio::time::sleep(std::time::Duration::from_secs(sleep_secs)).await;
421
422 delay_secs = ((delay_secs as f64) * cfg.backoff_multiplier)
424 .round()
425 .max(1.0) as u64;
426 }
427 }
428 }
429 }
430
431 #[cfg_attr(
433 feature = "opentelemetry",
434 tracing::instrument(skip_all, fields(envelope_id = %envelope.envelope_id))
435 )]
436 async fn send_envelope_and_wait_response(
437 &self,
438 envelope: SignalingEnvelope,
439 ) -> NetworkResult<SignalingEnvelope> {
440 let reply_for = envelope.envelope_id.clone();
441
442 let (tx, rx) = oneshot::channel();
444 self.pending_replies
445 .lock()
446 .await
447 .insert(reply_for.clone(), tx);
448
449 if let Err(e) = self.send_envelope(envelope).await {
450 self.pending_replies.lock().await.remove(&reply_for);
452 return Err(e);
453 }
454
455 let result =
456 tokio::time::timeout(std::time::Duration::from_secs(RESPONSE_TIMEOUT_SECS), rx).await;
457 if result.is_err() {
459 self.pending_replies.lock().await.remove(&reply_for);
460 }
461
462 let response_envelope = result
463 .map_err(|_| {
464 NetworkError::ConnectionError(
465 "Timed out waiting for signaling response".to_string(),
466 )
467 })?
468 .map_err(|_| {
469 NetworkError::ConnectionError(
470 "Receiver dropped while waiting for signaling response".to_string(),
471 )
472 })?;
473
474 Ok(response_envelope)
475 }
476
477 async fn start_receiver(&self) {
479 let mut stream_guard = self.ws_stream.lock().await;
480 if stream_guard.is_none() {
481 return;
482 }
483
484 let mut stream = stream_guard.take().expect("stream exists");
485 let pending = self.pending_replies.clone();
486 let inbound_tx = { self.inbound_tx.lock().await.clone() };
487 let stats = self.stats.clone();
488 let connected = self.connected.clone();
489 let state_tx = self.state_tx.clone();
490 let last_pong = self.last_pong.clone();
491 tracing::debug!("Start receiver");
492 let handle = tokio::spawn(async move {
493 while let Some(msg) = stream.next().await {
494 match msg {
495 Ok(tokio_tungstenite::tungstenite::Message::Binary(data)) => {
496 last_pong.store(current_unix_secs(), Ordering::Release);
498 match SignalingEnvelope::decode(&data[..]) {
499 Ok(envelope) => {
500 #[cfg(feature = "opentelemetry")]
501 let span = {
502 let span = tracing::info_span!("signaling.receive_envelope", envelope_id = %envelope.envelope_id);
503 span.set_parent(extract_trace_context(&envelope));
504 span
505 };
506
507 stats.messages_received.fetch_add(1, Ordering::Relaxed);
508 tracing::debug!("Received message: {:?}", envelope);
509 if let Some(reply_for) = envelope.reply_for.clone() {
510 if let Some(sender) = pending.lock().await.remove(&reply_for) {
511 #[cfg(feature = "opentelemetry")]
512 let _ = span.enter();
513 if let Err(e) = sender.send(envelope) {
514 stats.errors.fetch_add(1, Ordering::Relaxed);
515 tracing::warn!(
516 "Failed to send reply envelope to waiter: {e:?}",
517 );
518 }
519 continue;
520 }
521 }
522 tracing::debug!(
523 "Unmatched or push message -> forward to inbound channel"
524 );
525 if let Err(e) = inbound_tx.send(envelope) {
527 stats.errors.fetch_add(1, Ordering::Relaxed);
528 tracing::warn!(
529 "Failed to send envelope to inbound channel: {e:?}"
530 );
531 }
532 }
533 Err(e) => {
534 stats.errors.fetch_add(1, Ordering::Relaxed);
535 tracing::warn!("Failed to decode SignalingEnvelope: {e}");
536 }
537 }
538 }
539 Ok(tokio_tungstenite::tungstenite::Message::Pong(_)) => {
540 tracing::debug!("Received pong");
541 last_pong.store(current_unix_secs(), Ordering::Release);
542 }
543 Ok(tokio_tungstenite::tungstenite::Message::Ping(_)) => {
544 tracing::debug!("Received ping");
545 last_pong.store(current_unix_secs(), Ordering::Release);
546 }
547 Ok(_) => {
548 tracing::warn!("Received non-binary frame, ignoring");
549 }
550 Err(e) => {
551 stats.errors.fetch_add(1, Ordering::Relaxed);
552 tracing::error!("Signaling receive error: {e}");
553 break;
554 }
555 }
556 }
557
558 connected.store(false, Ordering::Release);
560 stats.disconnections.fetch_add(1, Ordering::Relaxed);
561 let _ = state_tx.send(ConnectionState::Disconnected);
562 });
563
564 *self.receiver_task.lock().await = Some(handle);
565 }
566
567 async fn start_ping_task(&self) {
569 let mut existing = self.ping_task.lock().await;
570 if let Some(handle) = existing.as_ref() {
571 if handle.is_finished() {
572 existing.take();
573 } else {
574 return;
575 }
576 }
577
578 let sink = self.ws_sink.clone();
579 let connected = self.connected.clone();
580 let state_tx = self.state_tx.clone();
581 let last_pong = self.last_pong.clone();
582 let receiver_task_clone = Arc::clone(&self.receiver_task);
583
584 let handle = tokio::spawn(async move {
585 loop {
586 tokio::time::sleep(std::time::Duration::from_secs(PING_INTERVAL_SECS)).await;
587
588 if !connected.load(Ordering::Acquire) {
589 break;
590 }
591
592 let mut sink_guard = sink.lock().await;
594 if let Some(sink) = sink_guard.as_mut() {
595 if let Err(e) = sink
596 .send(tokio_tungstenite::tungstenite::Message::Ping(
597 Vec::new().into(),
598 ))
599 .await
600 {
601 tracing::warn!("Signaling ping send failed: {e}");
602 connected.store(false, Ordering::Release);
603 let _ = state_tx.send(ConnectionState::Disconnected);
604 break;
605 }
606 } else {
607 tracing::warn!("Signaling not connected");
608 connected.store(false, Ordering::Release);
609 let _ = state_tx.send(ConnectionState::Disconnected);
610 break;
611 }
612 drop(sink_guard);
613
614 let now = current_unix_secs();
616 let last = last_pong.load(Ordering::Acquire);
617 if now.saturating_sub(last) > PONG_TIMEOUT_SECS {
618 tracing::warn!(
619 "Signaling pong timeout (last seen {}s ago), marking disconnected",
620 now.saturating_sub(last)
621 );
622 connected.store(false, Ordering::Release);
623 let _ = state_tx.send(ConnectionState::Disconnected);
624 if let Some(handle) = receiver_task_clone.lock().await.take() {
625 handle.abort();
626 }
627 break;
628 }
629 }
630 });
631
632 *existing = Some(handle);
633 }
634}
635
636#[async_trait]
637impl SignalingClient for WebSocketSignalingClient {
638 async fn connect(&self) -> NetworkResult<()> {
639 self.connect_with_retries().await?;
640 self.start_receiver().await;
641 self.start_ping_task().await;
642 Ok(())
643 }
644
645 async fn disconnect(&self) -> NetworkResult<()> {
646 let mut sink_guard = self.ws_sink.lock().await;
648 let mut stream_guard = self.ws_stream.lock().await;
649
650 if let Some(mut sink) = sink_guard.take() {
652 let _ = sink.close().await;
653 }
654
655 stream_guard.take();
657
658 if let Some(handle) = self.receiver_task.lock().await.take() {
660 handle.abort();
661 }
662 if let Some(handle) = self.ping_task.lock().await.take() {
664 handle.abort();
665 }
666
667 self.reset_inbound_channel().await;
668
669 self.connected.store(false, Ordering::Release);
670 self.stats.disconnections.fetch_add(1, Ordering::Relaxed);
671
672 Ok(())
673 }
674
675 #[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
676 async fn send_register_request(
677 &self,
678 request: RegisterRequest,
679 ) -> NetworkResult<RegisterResponse> {
680 let flow = signaling_envelope::Flow::PeerToServer(PeerToSignaling {
682 payload: Some(peer_to_signaling::Payload::RegisterRequest(request)),
683 });
684
685 let envelope = self.create_envelope(flow).await;
686 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
687
688 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
689 {
690 if let Some(signaling_to_actr::Payload::RegisterResponse(response)) =
691 server_to_actr.payload
692 {
693 return Ok(response);
694 }
695 }
696
697 Err(NetworkError::ConnectionError(
698 "Invalid registration response".to_string(),
699 ))
700 }
701
702 #[cfg_attr(
703 feature = "opentelemetry",
704 tracing::instrument(skip_all, fields(actor_id = %actor_id.to_string_repr()))
705 )]
706 async fn send_unregister_request(
707 &self,
708 actor_id: ActrId,
709 credential: AIdCredential,
710 reason: Option<String>,
711 ) -> NetworkResult<UnregisterResponse> {
712 let request = UnregisterRequest {
714 actr_id: actor_id.clone(),
715 reason,
716 };
717
718 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
720 source: actor_id,
721 credential,
722 payload: Some(actr_to_signaling::Payload::UnregisterRequest(request)),
723 });
724
725 let envelope = self.create_envelope(flow).await;
727 self.send_envelope(envelope).await?;
728
729 Ok(UnregisterResponse {
734 result: Some(actr_protocol::unregister_response::Result::Success(
735 actr_protocol::unregister_response::UnregisterOk {},
736 )),
737 })
738 }
739
740 #[cfg_attr(
741 feature = "opentelemetry",
742 tracing::instrument(level = "debug", skip_all, fields(actor_id = %actor_id.to_string_repr()))
743 )]
744 async fn send_heartbeat(
745 &self,
746 actor_id: ActrId,
747 credential: AIdCredential,
748 availability: ServiceAvailabilityState,
749 power_reserve: f32,
750 mailbox_backlog: f32,
751 ) -> NetworkResult<Pong> {
752 let ping = Ping {
753 availability: availability as i32,
754 power_reserve,
755 mailbox_backlog,
756 sticky_client_ids: vec![], };
758
759 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
760 source: actor_id,
761 credential,
762 payload: Some(actr_to_signaling::Payload::Ping(ping)),
763 });
764
765 let envelope = self.create_envelope(flow).await;
766 let reply_for = envelope.envelope_id.clone();
767
768 let (tx, rx) = oneshot::channel();
770 self.pending_replies
771 .lock()
772 .await
773 .insert(reply_for.clone(), tx);
774
775 if let Err(e) = self.send_envelope(envelope).await {
776 self.pending_replies.lock().await.remove(&reply_for);
778 return Err(e);
779 }
780
781 let response_envelope = rx.await.map_err(|_| {
783 NetworkError::ConnectionError(
784 "Receiver dropped while waiting for heartbeat response".to_string(),
785 )
786 })?;
787
788 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
790 {
791 if let Some(signaling_to_actr::Payload::Pong(pong)) = server_to_actr.payload {
792 return Ok(pong);
793 }
794 }
795
796 Err(NetworkError::ConnectionError(
797 "Received response but not a Pong message".to_string(),
798 ))
799 }
800
801 #[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
802 async fn send_route_candidates_request(
803 &self,
804 actor_id: ActrId,
805 credential: AIdCredential,
806 request: RouteCandidatesRequest,
807 ) -> NetworkResult<RouteCandidatesResponse> {
808 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
809 source: actor_id,
810 credential,
811 payload: Some(actr_to_signaling::Payload::RouteCandidatesRequest(request)),
812 });
813
814 let envelope = self.create_envelope(flow).await;
815 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
816
817 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
818 {
819 match server_to_actr.payload {
820 Some(signaling_to_actr::Payload::RouteCandidatesResponse(response)) => {
821 return Ok(response);
822 }
823 Some(signaling_to_actr::Payload::Error(err)) => {
824 return Err(NetworkError::ServiceDiscoveryError(format!(
825 "{} ({})",
826 err.message, err.code
827 )));
828 }
829 _ => {}
830 }
831 }
832
833 Err(NetworkError::ConnectionError(
834 "Invalid route candidates response".to_string(),
835 ))
836 }
837
838 #[cfg_attr(
839 feature = "opentelemetry",
840 tracing::instrument(level = "debug", skip_all, fields(actor_id = %actor_id.to_string_repr()))
841 )]
842 async fn send_credential_update_request(
843 &self,
844 actor_id: ActrId,
845 credential: AIdCredential,
846 ) -> NetworkResult<RegisterResponse> {
847 let request = CredentialUpdateRequest {
848 actr_id: actor_id.clone(),
849 };
850
851 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
852 source: actor_id,
853 credential,
854 payload: Some(actr_to_signaling::Payload::CredentialUpdateRequest(request)),
855 });
856
857 let envelope = self.create_envelope(flow).await;
858 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
859
860 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
861 {
862 match server_to_actr.payload {
863 Some(signaling_to_actr::Payload::RegisterResponse(response)) => {
864 return Ok(response);
865 }
866 Some(signaling_to_actr::Payload::Error(err)) => {
867 return Err(NetworkError::ConnectionError(format!(
868 "Credential update failed: {} ({})",
869 err.message, err.code
870 )));
871 }
872 _ => {}
873 }
874 }
875
876 Err(NetworkError::ConnectionError(
877 "Invalid credential update response".to_string(),
878 ))
879 }
880
881 #[cfg_attr(
882 feature = "opentelemetry",
883 tracing::instrument(level = "debug", skip_all, fields(envelope_id = %envelope.envelope_id))
884 )]
885 async fn send_envelope(&self, envelope: SignalingEnvelope) -> NetworkResult<()> {
886 #[cfg(feature = "opentelemetry")]
887 let envelope = {
888 let mut envelope = envelope;
889 trace::inject_span_context(&tracing::Span::current(), &mut envelope);
890 envelope
891 };
892
893 let mut sink_guard = self.ws_sink.lock().await;
894
895 if let Some(sink) = sink_guard.as_mut() {
896 let mut buf = Vec::new();
898 envelope.encode(&mut buf)?;
899 let msg = tokio_tungstenite::tungstenite::Message::Binary(buf.into());
900 sink.send(msg).await?;
901
902 self.stats.messages_sent.fetch_add(1, Ordering::Relaxed);
903 tracing::debug!("Stats: {:?}", self.stats.snapshot());
904 Ok(())
905 } else {
906 Err(NetworkError::ConnectionError("Not connected".to_string()))
907 }
908 }
909
910 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
911 let mut rx = self.inbound_rx.lock().await;
912 match rx.recv().await {
913 Some(envelope) => Ok(Some(envelope)),
914 None => {
915 tracing::error!("Inbound channel closed");
916 Err(NetworkError::ConnectionError(
917 "Inbound channel closed".to_string(),
918 ))
919 }
920 }
921 }
922
923 fn is_connected(&self) -> bool {
924 self.connected.load(Ordering::Acquire)
925 }
926
927 fn get_stats(&self) -> SignalingStats {
928 self.stats.snapshot()
929 }
930
931 fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
932 self.state_tx.subscribe()
933 }
934
935 async fn set_actor_id(&self, actor_id: ActrId) {
936 *self.actor_id.lock().await = Some(actor_id);
937 }
938
939 async fn set_credential_state(&self, credential_state: CredentialState) {
940 *self.credential_state.lock().await = Some(credential_state);
941 }
942}
943
944pub fn spawn_signaling_reconnector(client: Arc<dyn SignalingClient>, shutdown: CancellationToken) {
947 let mut state_rx = client.subscribe_state();
948
949 tokio::spawn(async move {
950 loop {
951 tokio::select! {
952 _ = shutdown.cancelled() => {
953 tracing::info!("🛑 Stopping signaling reconnect helper due to shutdown");
954 break;
955 }
956 changed = state_rx.changed() => {
957 if changed.is_err() {
958 tracing::info!("Signaling state channel closed, stopping reconnect helper");
959 break;
960 }
961
962 if *state_rx.borrow() == ConnectionState::Disconnected {
963 if shutdown.is_cancelled() {
965 tracing::info!(
966 "Shutdown already requested when disconnect event observed; skipping reconnect"
967 );
968 break;
969 }
970
971 tracing::warn!("📡 Signaling state is DISCONNECTED, attempting reconnect");
972 if let Err(e) = client.connect().await {
973 tracing::error!("❌ Signaling reconnect failed: {e}");
974 } else {
975 tracing::info!("✅ Signaling reconnect succeeded");
976 }
977
978 }
979 }
980 }
981 }
982 });
983}
984
985#[derive(Debug)]
987pub(crate) struct AtomicSignalingStats {
988 pub connections: AtomicU64,
990
991 pub disconnections: AtomicU64,
993
994 pub messages_sent: AtomicU64,
996
997 pub messages_received: AtomicU64,
999
1000 pub heartbeats_sent: AtomicU64,
1003
1004 pub heartbeats_received: AtomicU64,
1007
1008 pub errors: AtomicU64,
1010}
1011
1012impl Default for AtomicSignalingStats {
1013 fn default() -> Self {
1014 Self {
1015 connections: AtomicU64::new(0),
1016 disconnections: AtomicU64::new(0),
1017 messages_sent: AtomicU64::new(0),
1018 messages_received: AtomicU64::new(0),
1019 heartbeats_sent: AtomicU64::new(0),
1020 heartbeats_received: AtomicU64::new(0),
1021 errors: AtomicU64::new(0),
1022 }
1023 }
1024}
1025
1026#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
1028pub struct SignalingStats {
1029 pub connections: u64,
1031
1032 pub disconnections: u64,
1034
1035 pub messages_sent: u64,
1037
1038 pub messages_received: u64,
1040
1041 pub heartbeats_sent: u64,
1043
1044 pub heartbeats_received: u64,
1046
1047 pub errors: u64,
1049}
1050
1051impl AtomicSignalingStats {
1052 pub fn snapshot(&self) -> SignalingStats {
1054 SignalingStats {
1055 connections: self.connections.load(Ordering::Relaxed),
1056 disconnections: self.disconnections.load(Ordering::Relaxed),
1057 messages_sent: self.messages_sent.load(Ordering::Relaxed),
1058 messages_received: self.messages_received.load(Ordering::Relaxed),
1059 heartbeats_sent: self.heartbeats_sent.load(Ordering::Relaxed),
1060 heartbeats_received: self.heartbeats_received.load(Ordering::Relaxed),
1061 errors: self.errors.load(Ordering::Relaxed),
1062 }
1063 }
1064}
1065
1066fn current_unix_secs() -> u64 {
1067 use std::time::{SystemTime, UNIX_EPOCH};
1068 SystemTime::now()
1069 .duration_since(UNIX_EPOCH)
1070 .unwrap_or_default()
1071 .as_secs()
1072}
1073
1074#[cfg(test)]
1075mod tests {
1076 use super::*;
1077 use std::sync::atomic::{AtomicUsize, Ordering as UsizeOrdering};
1078 use tokio_util::sync::CancellationToken;
1079
1080 struct FakeSignalingClient {
1082 state_tx: watch::Sender<ConnectionState>,
1083 connect_calls: Arc<AtomicUsize>,
1084 actor_id: tokio::sync::Mutex<Option<ActrId>>,
1085 credential_state: tokio::sync::Mutex<Option<CredentialState>>,
1086 }
1087
1088 #[async_trait]
1089 impl SignalingClient for FakeSignalingClient {
1090 async fn connect(&self) -> NetworkResult<()> {
1091 self.connect_calls.fetch_add(1, UsizeOrdering::SeqCst);
1092 Ok(())
1093 }
1094
1095 async fn disconnect(&self) -> NetworkResult<()> {
1096 Ok(())
1097 }
1098
1099 async fn send_register_request(
1100 &self,
1101 _request: RegisterRequest,
1102 ) -> NetworkResult<RegisterResponse> {
1103 unimplemented!("not needed in tests");
1104 }
1105
1106 async fn send_unregister_request(
1107 &self,
1108 _actor_id: ActrId,
1109 _credential: AIdCredential,
1110 _reason: Option<String>,
1111 ) -> NetworkResult<UnregisterResponse> {
1112 unimplemented!("not needed in tests");
1113 }
1114
1115 async fn send_heartbeat(
1116 &self,
1117 _actor_id: ActrId,
1118 _credential: AIdCredential,
1119 _availability: ServiceAvailabilityState,
1120 _power_reserve: f32,
1121 _mailbox_backlog: f32,
1122 ) -> NetworkResult<Pong> {
1123 unimplemented!("not needed in tests");
1124 }
1125
1126 async fn send_route_candidates_request(
1127 &self,
1128 _actor_id: ActrId,
1129 _credential: AIdCredential,
1130 _request: RouteCandidatesRequest,
1131 ) -> NetworkResult<RouteCandidatesResponse> {
1132 unimplemented!("not needed in tests");
1133 }
1134
1135 async fn send_credential_update_request(
1136 &self,
1137 _actor_id: ActrId,
1138 _credential: AIdCredential,
1139 ) -> NetworkResult<RegisterResponse> {
1140 unimplemented!("not needed in tests");
1141 }
1142
1143 async fn send_envelope(&self, _envelope: SignalingEnvelope) -> NetworkResult<()> {
1144 unimplemented!("not needed in tests");
1145 }
1146
1147 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
1148 unimplemented!("not needed in tests");
1149 }
1150
1151 fn is_connected(&self) -> bool {
1152 *self.state_tx.borrow() == ConnectionState::Connected
1154 }
1155
1156 fn get_stats(&self) -> SignalingStats {
1157 SignalingStats::default()
1158 }
1159
1160 fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
1161 self.state_tx.subscribe()
1162 }
1163
1164 async fn set_actor_id(&self, actor_id: ActrId) {
1165 *self.actor_id.lock().await = Some(actor_id);
1166 }
1167
1168 async fn set_credential_state(&self, credential_state: CredentialState) {
1169 *self.credential_state.lock().await = Some(credential_state);
1170 }
1171 }
1172
1173 fn make_fake_client() -> (Arc<FakeSignalingClient>, watch::Sender<ConnectionState>) {
1174 let (state_tx, _rx) = watch::channel(ConnectionState::Disconnected);
1175 let client = Arc::new(FakeSignalingClient {
1176 state_tx: state_tx.clone(),
1177 connect_calls: Arc::new(AtomicUsize::new(0)),
1178 actor_id: tokio::sync::Mutex::new(None),
1179 credential_state: tokio::sync::Mutex::new(None),
1180 });
1181 (client, state_tx)
1182 }
1183
1184 #[tokio::test]
1185 async fn test_spawn_signaling_reconnector_does_not_trigger_on_connected() {
1186 let (fake_client, state_tx) = make_fake_client();
1187 let shutdown = CancellationToken::new();
1188
1189 let client_trait: Arc<dyn SignalingClient> = fake_client.clone();
1191 spawn_signaling_reconnector(client_trait, shutdown.clone());
1192
1193 let _ = state_tx.send(ConnectionState::Connected);
1195
1196 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1197
1198 let calls = fake_client.connect_calls.load(UsizeOrdering::SeqCst);
1199 assert_eq!(
1200 calls, 0,
1201 "connect() should not be called on Connected state"
1202 );
1203
1204 shutdown.cancel();
1205 }
1206
1207 #[tokio::test]
1208 async fn test_spawn_signaling_reconnector_triggers_connect_on_disconnect() {
1209 let (fake_client, state_tx) = make_fake_client();
1210 let shutdown = CancellationToken::new();
1211
1212 let client_trait: Arc<dyn SignalingClient> = fake_client.clone();
1214 spawn_signaling_reconnector(client_trait, shutdown.clone());
1215
1216 let _ = state_tx.send(ConnectionState::Disconnected);
1218
1219 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1221
1222 let calls = fake_client.connect_calls.load(UsizeOrdering::SeqCst);
1223 assert!(
1224 calls >= 1,
1225 "expected at least one reconnect attempt, got {}",
1226 calls
1227 );
1228
1229 shutdown.cancel();
1230 }
1231
1232 #[tokio::test]
1233 async fn test_spawn_signaling_reconnector_stops_on_shutdown_before_disconnect() {
1234 let (fake_client, state_tx) = make_fake_client();
1235 let shutdown = CancellationToken::new();
1236
1237 let client_trait: Arc<dyn SignalingClient> = fake_client.clone();
1238 spawn_signaling_reconnector(client_trait, shutdown.clone());
1239
1240 shutdown.cancel();
1242
1243 let _ = state_tx.send(ConnectionState::Disconnected);
1245
1246 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1247
1248 let calls = fake_client.connect_calls.load(UsizeOrdering::SeqCst);
1249 assert_eq!(calls, 0, "reconnect helper should not run after shutdown");
1250 }
1251
1252 #[test]
1253 fn test_websocket_signaling_client_initial_state_disconnected() {
1254 let config = SignalingConfig {
1256 server_url: Url::parse("ws://example.com/signaling/ws").unwrap(),
1257 connection_timeout: 30,
1258 heartbeat_interval: 30,
1259 reconnect_config: ReconnectConfig::default(),
1260 auth_config: None,
1261 };
1262
1263 let client = WebSocketSignalingClient::new(config);
1264 let state_rx = client.subscribe_state();
1265 assert_eq!(*state_rx.borrow(), ConnectionState::Disconnected);
1266 }
1267}