1#[cfg(feature = "opentelemetry")]
6use super::trace;
7use crate::lifecycle::CredentialState;
8use crate::transport::error::{NetworkError, NetworkResult};
9#[cfg(feature = "opentelemetry")]
10use crate::wire::webrtc::trace::extract_trace_context;
11#[cfg(feature = "opentelemetry")]
12use actr_protocol::ActrIdExt;
13use actr_protocol::prost::Message as ProstMessage;
14use actr_protocol::{
15 AIdCredential, ActrId, ActrToSignaling, CredentialUpdateRequest, PeerToSignaling, Ping, Pong,
16 RegisterRequest, RegisterResponse, RouteCandidatesRequest, RouteCandidatesResponse,
17 ServiceAvailabilityState, SignalingEnvelope, UnregisterRequest, UnregisterResponse,
18 actr_to_signaling, peer_to_signaling, signaling_envelope, signaling_to_actr,
19};
20use async_trait::async_trait;
21use base64::Engine as _;
22use futures_util::{SinkExt, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::sync::{
26 Arc,
27 atomic::{AtomicBool, AtomicU64, Ordering},
28};
29use tokio::net::TcpStream;
30use tokio::sync::{mpsc, oneshot, watch};
31use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
32use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config};
33#[cfg(feature = "opentelemetry")]
34use tracing_opentelemetry::OpenTelemetrySpanExt;
35use url::Url;
36
37const RESPONSE_TIMEOUT_SECS: u64 = 15;
43const PING_INTERVAL_SECS: u64 = 5;
45const PONG_TIMEOUT_SECS: u64 = 10;
46
47#[derive(Debug, Clone)]
53pub struct SignalingConfig {
54 pub server_url: Url,
56
57 pub connection_timeout: u64,
59
60 pub heartbeat_interval: u64,
62
63 pub reconnect_config: ReconnectConfig,
65
66 pub auth_config: Option<AuthConfig>,
68}
69
70#[derive(Debug, Clone)]
72pub struct ReconnectConfig {
73 pub enabled: bool,
75
76 pub max_attempts: u32,
78
79 pub initial_delay: u64,
81
82 pub max_delay: u64,
84
85 pub backoff_multiplier: f64,
87}
88
89impl Default for ReconnectConfig {
90 fn default() -> Self {
91 Self {
92 enabled: true,
93 max_attempts: 10,
94 initial_delay: 1,
95 max_delay: 60,
96 backoff_multiplier: 2.0,
97 }
98 }
99}
100
101#[derive(Debug, Clone)]
103pub struct AuthConfig {
104 pub auth_type: AuthType,
106
107 pub credentials: HashMap<String, String>,
109}
110
111#[derive(Debug, Clone)]
113pub enum AuthType {
114 None,
116 BearerToken,
118 ApiKey,
120 Jwt,
122}
123
124#[async_trait]
134pub trait SignalingClient: Send + Sync {
135 async fn connect(&self) -> NetworkResult<()>;
137
138 async fn disconnect(&self) -> NetworkResult<()>;
140
141 async fn send_register_request(
143 &self,
144 request: RegisterRequest,
145 ) -> NetworkResult<RegisterResponse>;
146
147 async fn send_unregister_request(
152 &self,
153 actor_id: ActrId,
154 credential: AIdCredential,
155 reason: Option<String>,
156 ) -> NetworkResult<UnregisterResponse>;
157
158 async fn send_heartbeat(
161 &self,
162 actor_id: ActrId,
163 credential: AIdCredential,
164 availability: ServiceAvailabilityState,
165 power_reserve: f32,
166 mailbox_backlog: f32,
167 ) -> NetworkResult<Pong>;
168
169 async fn send_route_candidates_request(
171 &self,
172 actor_id: ActrId,
173 credential: AIdCredential,
174 request: RouteCandidatesRequest,
175 ) -> NetworkResult<RouteCandidatesResponse>;
176
177 async fn send_credential_update_request(
182 &self,
183 actor_id: ActrId,
184 credential: AIdCredential,
185 ) -> NetworkResult<RegisterResponse>;
186
187 async fn send_envelope(&self, envelope: SignalingEnvelope) -> NetworkResult<()>;
189
190 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>>;
192
193 fn is_connected(&self) -> bool;
195
196 fn get_stats(&self) -> SignalingStats;
198 fn subscribe_state(&self) -> watch::Receiver<ConnectionState>;
200
201 async fn set_actor_id(&self, actor_id: ActrId);
203 async fn set_credential_state(&self, credential_state: CredentialState);
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq)]
208pub enum ConnectionState {
209 Disconnected,
210 Connected,
211}
212
213pub struct WebSocketSignalingClient {
215 config: SignalingConfig,
216 actor_id: tokio::sync::Mutex<Option<ActrId>>,
217 credential_state: tokio::sync::Mutex<Option<CredentialState>>,
218 ws_sink: Arc<
220 tokio::sync::Mutex<
221 Option<
222 futures_util::stream::SplitSink<
223 WebSocketStream<MaybeTlsStream<TcpStream>>,
224 tokio_tungstenite::tungstenite::Message,
225 >,
226 >,
227 >,
228 >,
229 ws_stream: tokio::sync::Mutex<
231 Option<futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
232 >,
233 connected: Arc<AtomicBool>,
235 connecting: Arc<AtomicBool>,
237 stats: Arc<AtomicSignalingStats>,
239 envelope_counter: tokio::sync::Mutex<u64>,
241 pending_replies: Arc<tokio::sync::Mutex<HashMap<String, oneshot::Sender<SignalingEnvelope>>>>,
243 inbound_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<SignalingEnvelope>>>,
245 inbound_tx: tokio::sync::Mutex<mpsc::UnboundedSender<SignalingEnvelope>>,
246 receiver_task: Arc<tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>>,
248 ping_task: tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
250 state_tx: watch::Sender<ConnectionState>,
252 last_pong: Arc<AtomicU64>,
254 reconnector_started: Arc<AtomicBool>,
256}
257
258impl WebSocketSignalingClient {
259 pub fn new(config: SignalingConfig) -> Self {
261 let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
262 let (state_tx, _state_rx) = watch::channel(ConnectionState::Disconnected);
263 Self {
264 config,
265 actor_id: tokio::sync::Mutex::new(None),
266 credential_state: tokio::sync::Mutex::new(None),
267 ws_sink: Arc::new(tokio::sync::Mutex::new(None)),
268 ws_stream: tokio::sync::Mutex::new(None),
269 connected: Arc::new(AtomicBool::new(false)),
270 connecting: Arc::new(AtomicBool::new(false)),
271 stats: Arc::new(AtomicSignalingStats::default()),
272 envelope_counter: tokio::sync::Mutex::new(0),
273 pending_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
274 inbound_rx: Arc::new(tokio::sync::Mutex::new(inbound_rx)),
275 inbound_tx: tokio::sync::Mutex::new(inbound_tx),
276 receiver_task: Arc::new(tokio::sync::Mutex::new(None)),
277 ping_task: tokio::sync::Mutex::new(None),
278 state_tx,
279 last_pong: Arc::new(AtomicU64::new(0)),
280 reconnector_started: Arc::new(AtomicBool::new(false)),
281 }
282 }
283
284 pub fn start_auto_reconnector(self: &Arc<Self>) {
289 if self.config.reconnect_config.enabled
291 && self
292 .reconnector_started
293 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
294 .is_ok()
295 {
296 tracing::info!("🔄 Starting auto-reconnector for signaling client");
297
298 let self_clone = self.clone();
299 let mut state_rx = self.subscribe_state();
300
301 tokio::spawn(async move {
302 loop {
303 match state_rx.changed().await {
304 Err(_) => {
305 tracing::info!("� Signaling client dropped, stopping reconnect helper");
307 break;
308 }
309 Ok(_) => {
310 if *state_rx.borrow() == ConnectionState::Disconnected {
311 tracing::debug!(
313 "🧹 Cleaning up old WebSocket resources before reconnect"
314 );
315 if let Err(e) = self_clone.disconnect().await {
316 tracing::warn!("⚠️ Disconnect cleanup failed (non-fatal): {e}");
317 }
318
319 tracing::warn!(
320 "📡 Signaling state is DISCONNECTED, attempting reconnect"
321 );
322 if let Err(e) = self_clone.connect().await {
323 tracing::error!("❌ Signaling reconnect failed: {e}");
324 } else {
325 tracing::info!("✅ Signaling reconnect succeeded");
326 }
327 }
328 }
329 }
330 }
331 });
332 }
333 }
334
335 pub async fn connect_to(url: &str) -> NetworkResult<Arc<Self>> {
337 let config = SignalingConfig {
338 server_url: url.parse()?,
339 connection_timeout: 5,
340 heartbeat_interval: 30,
341 reconnect_config: ReconnectConfig::default(),
342 auth_config: None,
343 };
344
345 let client = Arc::new(Self::new(config));
346 client.start_auto_reconnector();
347 client.connect().await?;
348 Ok(client)
349 }
350
351 async fn next_envelope_id(&self) -> String {
353 let mut counter = self.envelope_counter.lock().await;
354 *counter += 1;
355 format!("env-{}", *counter)
356 }
357
358 async fn create_envelope(&self, flow: signaling_envelope::Flow) -> SignalingEnvelope {
360 SignalingEnvelope {
361 envelope_version: 1,
362 envelope_id: self.next_envelope_id().await,
363 reply_for: None,
364 timestamp: prost_types::Timestamp {
365 seconds: chrono::Utc::now().timestamp(),
366 nanos: 0,
367 },
368 traceparent: None,
369 tracestate: None,
370 flow: Some(flow),
371 }
372 }
373
374 async fn reset_inbound_channel(&self) {
376 let (tx, rx) = mpsc::unbounded_channel();
377 *self.inbound_tx.lock().await = tx;
378 *self.inbound_rx.lock().await = rx;
379 }
380
381 async fn build_url_with_identity(&self) -> Url {
383 let mut url = self.config.server_url.clone();
384 let actor_id_opt = self.actor_id.lock().await.clone();
385 let credential_state_opt = self.credential_state.lock().await.clone();
386 if let (Some(actor_id), Some(credential_state)) = (actor_id_opt, credential_state_opt) {
387 let credential = credential_state.credential().await;
388 let actor_str = actr_protocol::ActrIdExt::to_string_repr(&actor_id);
389 let token_b64 =
390 base64::engine::general_purpose::STANDARD.encode(&credential.encrypted_token);
391 {
392 let mut pairs = url.query_pairs_mut();
393 pairs.append_pair("actor_id", &actor_str);
394 pairs.append_pair("token", &token_b64);
395 pairs.append_pair("token_key_id", &credential.token_key_id.to_string());
396 }
397 }
398 url
399 }
400
401 async fn establish_connection_once(&self) -> NetworkResult<()> {
405 let url = self.build_url_with_identity().await;
406 let timeout_secs = self.config.connection_timeout;
407 tracing::debug!("Establishing connection to URL: {}", url.as_str());
408 let config = WebSocketConfig::default().write_buffer_size(0);
410 let connect_result = if timeout_secs == 0 {
412 connect_async_with_config(url.as_str(), Some(config), false).await
413 } else {
414 let timeout_duration = std::time::Duration::from_secs(timeout_secs);
415 tokio::time::timeout(
416 timeout_duration,
417 connect_async_with_config(url.as_str(), Some(config), false),
418 )
419 .await
420 .map_err(|_| {
421 NetworkError::ConnectionError(format!(
422 "Signaling connect timeout after {}s",
423 timeout_secs
424 ))
425 })?
426 }?;
427
428 let (ws_stream, _) = connect_result;
429
430 let (sink, stream) = ws_stream.split();
432
433 *self.ws_sink.lock().await = Some(sink);
434 *self.ws_stream.lock().await = Some(stream);
435 self.connected.store(true, Ordering::Release);
436 self.last_pong.store(current_unix_secs(), Ordering::Release);
437 let _ = self.state_tx.send(ConnectionState::Connected);
439
440 self.stats.connections.fetch_add(1, Ordering::Relaxed);
441
442 Ok(())
443 }
444
445 async fn connect_with_retries(&self) -> NetworkResult<()> {
447 let cfg = &self.config.reconnect_config;
448
449 if !cfg.enabled {
451 return self.establish_connection_once().await;
452 }
453
454 let mut attempt: u32 = 0;
455 let mut delay_secs = cfg.initial_delay.max(1);
456
457 loop {
458 attempt += 1;
459
460 match self.establish_connection_once().await {
461 Ok(()) => {
462 return Ok(());
463 }
464 Err(e) => {
465 tracing::warn!("Signaling connect attempt {} failed: {e:?}", attempt);
466
467 if attempt >= cfg.max_attempts {
468 tracing::error!(
469 "Signaling connect failed after {} attempts, giving up",
470 attempt
471 );
472 return Err(e);
473 }
474
475 let sleep_secs = delay_secs.min(cfg.max_delay.max(1));
476 tracing::info!("Retry signaling connect after {}s", sleep_secs);
477 tokio::time::sleep(std::time::Duration::from_secs(sleep_secs)).await;
478
479 delay_secs = ((delay_secs as f64) * cfg.backoff_multiplier)
481 .round()
482 .max(1.0) as u64;
483 }
484 }
485 }
486 }
487
488 #[cfg_attr(
490 feature = "opentelemetry",
491 tracing::instrument(skip_all, fields(envelope_id = %envelope.envelope_id))
492 )]
493 async fn send_envelope_and_wait_response(
494 &self,
495 envelope: SignalingEnvelope,
496 ) -> NetworkResult<SignalingEnvelope> {
497 let reply_for = envelope.envelope_id.clone();
498
499 let (tx, rx) = oneshot::channel();
501 self.pending_replies
502 .lock()
503 .await
504 .insert(reply_for.clone(), tx);
505
506 if let Err(e) = self.send_envelope(envelope).await {
507 self.pending_replies.lock().await.remove(&reply_for);
509 return Err(e);
510 }
511
512 let result =
513 tokio::time::timeout(std::time::Duration::from_secs(RESPONSE_TIMEOUT_SECS), rx).await;
514 if result.is_err() {
516 self.pending_replies.lock().await.remove(&reply_for);
517 }
518
519 let response_envelope = result
520 .map_err(|_| {
521 NetworkError::ConnectionError(
522 "Timed out waiting for signaling response".to_string(),
523 )
524 })?
525 .map_err(|_| {
526 NetworkError::ConnectionError(
527 "Receiver dropped while waiting for signaling response".to_string(),
528 )
529 })?;
530
531 Ok(response_envelope)
532 }
533
534 async fn start_receiver(&self) {
536 let mut stream_guard = self.ws_stream.lock().await;
537 if stream_guard.is_none() {
538 return;
539 }
540
541 let mut stream = stream_guard.take().expect("stream exists");
542 let pending = self.pending_replies.clone();
543 let inbound_tx = { self.inbound_tx.lock().await.clone() };
544 let stats = self.stats.clone();
545 let connected = self.connected.clone();
546 let state_tx = self.state_tx.clone();
547 let last_pong = self.last_pong.clone();
548 let handle = tokio::spawn(async move {
549 while let Some(msg) = stream.next().await {
550 match msg {
551 Ok(tokio_tungstenite::tungstenite::Message::Binary(data)) => {
552 last_pong.store(current_unix_secs(), Ordering::Release);
554 match SignalingEnvelope::decode(&data[..]) {
555 Ok(envelope) => {
556 #[cfg(feature = "opentelemetry")]
557 let span = {
558 let span = tracing::info_span!("signaling.receive_envelope", envelope_id = %envelope.envelope_id);
559 span.set_parent(extract_trace_context(&envelope));
560 span
561 };
562
563 stats.messages_received.fetch_add(1, Ordering::Relaxed);
564 tracing::debug!("Received message: {:?}", envelope);
565 if let Some(reply_for) = envelope.reply_for.clone() {
566 if let Some(sender) = pending.lock().await.remove(&reply_for) {
567 #[cfg(feature = "opentelemetry")]
568 let _ = span.enter();
569 if let Err(e) = sender.send(envelope) {
570 stats.errors.fetch_add(1, Ordering::Relaxed);
571 tracing::warn!(
572 "Failed to send reply envelope to waiter: {e:?}",
573 );
574 }
575 continue;
576 }
577 }
578 tracing::debug!(
579 "Unmatched or push message -> forward to inbound channel"
580 );
581 if let Err(e) = inbound_tx.send(envelope) {
583 stats.errors.fetch_add(1, Ordering::Relaxed);
584 tracing::warn!(
585 "Failed to send envelope to inbound channel: {e:?}"
586 );
587 }
588 }
589 Err(e) => {
590 stats.errors.fetch_add(1, Ordering::Relaxed);
591 tracing::warn!("Failed to decode SignalingEnvelope: {e}");
592 }
593 }
594 }
595 Ok(tokio_tungstenite::tungstenite::Message::Pong(_)) => {
596 tracing::debug!("Received pong");
597 last_pong.store(current_unix_secs(), Ordering::Release);
598 }
599 Ok(tokio_tungstenite::tungstenite::Message::Ping(_)) => {
600 tracing::debug!("Received ping");
601 last_pong.store(current_unix_secs(), Ordering::Release);
602 }
603 Ok(_) => {
604 tracing::warn!("Received non-binary frame, ignoring");
605 }
606 Err(e) => {
607 stats.errors.fetch_add(1, Ordering::Relaxed);
608 tracing::error!("Signaling receive error: {e}");
609 break;
610 }
611 }
612 }
613
614 connected.store(false, Ordering::Release);
616 stats.disconnections.fetch_add(1, Ordering::Relaxed);
617 let _ = state_tx.send(ConnectionState::Disconnected);
618 });
619
620 *self.receiver_task.lock().await = Some(handle);
621 }
622
623 async fn start_ping_task(&self) {
626 let mut existing = self.ping_task.lock().await;
627 if let Some(handle) = existing.as_ref() {
628 if handle.is_finished() {
629 existing.take();
630 } else {
631 return;
632 }
633 }
634
635 let sink = self.ws_sink.clone();
636 let connected = self.connected.clone();
637 let state_tx = self.state_tx.clone();
638 let last_pong = self.last_pong.clone();
639 let receiver_task_clone = Arc::clone(&self.receiver_task);
640
641 let handle = tokio::spawn(async move {
642 loop {
643 tokio::time::sleep(std::time::Duration::from_secs(PING_INTERVAL_SECS)).await;
644
645 if !connected.load(Ordering::Acquire) {
646 break;
647 }
648
649 let mut sink_guard = sink.lock().await;
651 if let Some(sink) = sink_guard.as_mut() {
652 if let Err(e) = sink
653 .send(tokio_tungstenite::tungstenite::Message::Ping(
654 Vec::new().into(),
655 ))
656 .await
657 {
658 tracing::warn!("Signaling ping send failed: {e}");
659 connected.store(false, Ordering::Release);
660 let _ = state_tx.send(ConnectionState::Disconnected);
661 break;
662 }
663 } else {
664 tracing::warn!("Signaling not connected");
665 connected.store(false, Ordering::Release);
666 let _ = state_tx.send(ConnectionState::Disconnected);
667 break;
668 }
669 drop(sink_guard);
670
671 let now = current_unix_secs();
673 let last = last_pong.load(Ordering::Acquire);
674 if now.saturating_sub(last) > PONG_TIMEOUT_SECS {
675 tracing::warn!(
676 "Signaling pong timeout (last seen {}s ago), marking disconnected",
677 now.saturating_sub(last)
678 );
679 connected.store(false, Ordering::Release);
680 let _ = state_tx.send(ConnectionState::Disconnected);
681 if let Some(handle) = receiver_task_clone.lock().await.take() {
682 handle.abort();
683 }
684 break;
685 }
686 }
687 });
688
689 *existing = Some(handle);
690 }
691
692 async fn wait_for_connection_result(&self) -> NetworkResult<()> {
696 let mut state_rx = self.subscribe_state();
697 let timeout = tokio::time::sleep(std::time::Duration::from_secs(30));
698 tokio::pin!(timeout);
699
700 loop {
701 tokio::select! {
702 _ = &mut timeout => {
703 if self.connected.load(Ordering::Acquire) {
705 tracing::debug!("Connection succeeded just before timeout");
706 return Ok(());
707 }
708
709 if !self.connecting.load(Ordering::Acquire) {
711 tracing::warn!("Other connection attempt failed/timed out, will retry");
712 return self.connect().await;
714 }
715
716 return Err(NetworkError::ConnectionError(
717 "Timeout waiting for concurrent connection attempt".to_string(),
718 ));
719 }
720
721 result = state_rx.changed() => {
722 if result.is_err() {
723 return Err(NetworkError::ConnectionError(
724 "State channel closed while waiting for connection".to_string(),
725 ));
726 }
727
728 let state = *state_rx.borrow();
729 match state {
730 ConnectionState::Connected => {
731 tracing::debug!("Connection established by another task");
732 return Ok(());
733 }
734 ConnectionState::Disconnected => {
735 if !self.connecting.load(Ordering::Acquire) {
737 tracing::warn!("Other connection attempt failed, will retry");
738 return self.connect().await;
740 }
741 }
743 }
744 }
745 }
746 }
747 }
748}
749
750#[async_trait]
751impl SignalingClient for WebSocketSignalingClient {
752 async fn connect(&self) -> NetworkResult<()> {
753 if self.connected.load(Ordering::Acquire) {
755 tracing::debug!("Already connected, skipping connect()");
756 return Ok(());
757 }
758
759 if self
761 .connecting
762 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
763 .is_err()
764 {
765 tracing::debug!("Another connection attempt in progress, waiting for state change...");
767
768 return self.wait_for_connection_result().await;
769 }
770
771 tracing::debug!("Acquired connection lock, establishing connection...");
773
774 let result = self.connect_with_retries().await;
776
777 self.connecting.store(false, Ordering::Release);
779
780 match result {
782 Ok(()) => {
783 self.start_receiver().await;
784 self.start_ping_task().await;
785 Ok(())
786 }
787 Err(e) => {
788 let _ = self.state_tx.send(ConnectionState::Disconnected);
791 tracing::error!("Connection failed: {e}");
792 Err(e)
793 }
794 }
795 }
796
797 async fn disconnect(&self) -> NetworkResult<()> {
798 let mut sink_guard = self.ws_sink.lock().await;
800 let mut stream_guard = self.ws_stream.lock().await;
801
802 if let Some(mut sink) = sink_guard.take() {
804 let _ = sink.close().await;
805 }
806
807 stream_guard.take();
809
810 if let Some(handle) = self.receiver_task.lock().await.take() {
812 handle.abort();
813 }
814 if let Some(handle) = self.ping_task.lock().await.take() {
816 handle.abort();
817 }
818
819 self.reset_inbound_channel().await;
820
821 self.connected.store(false, Ordering::Release);
822 self.stats.disconnections.fetch_add(1, Ordering::Relaxed);
823
824 Ok(())
825 }
826
827 #[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
828 async fn send_register_request(
829 &self,
830 request: RegisterRequest,
831 ) -> NetworkResult<RegisterResponse> {
832 let flow = signaling_envelope::Flow::PeerToServer(PeerToSignaling {
834 payload: Some(peer_to_signaling::Payload::RegisterRequest(request)),
835 });
836
837 let envelope = self.create_envelope(flow).await;
838 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
839
840 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
841 {
842 if let Some(signaling_to_actr::Payload::RegisterResponse(response)) =
843 server_to_actr.payload
844 {
845 return Ok(response);
846 }
847 }
848
849 Err(NetworkError::ConnectionError(
850 "Invalid registration response".to_string(),
851 ))
852 }
853
854 #[cfg_attr(
855 feature = "opentelemetry",
856 tracing::instrument(skip_all, fields(actor_id = %actor_id.to_string_repr()))
857 )]
858 async fn send_unregister_request(
859 &self,
860 actor_id: ActrId,
861 credential: AIdCredential,
862 reason: Option<String>,
863 ) -> NetworkResult<UnregisterResponse> {
864 let request = UnregisterRequest {
866 actr_id: actor_id.clone(),
867 reason,
868 };
869
870 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
872 source: actor_id,
873 credential,
874 payload: Some(actr_to_signaling::Payload::UnregisterRequest(request)),
875 });
876
877 let envelope = self.create_envelope(flow).await;
879 self.send_envelope(envelope).await?;
880
881 Ok(UnregisterResponse {
886 result: Some(actr_protocol::unregister_response::Result::Success(
887 actr_protocol::unregister_response::UnregisterOk {},
888 )),
889 })
890 }
891
892 #[cfg_attr(
893 feature = "opentelemetry",
894 tracing::instrument(level = "debug", skip_all, fields(actor_id = %actor_id.to_string_repr()))
895 )]
896 async fn send_heartbeat(
897 &self,
898 actor_id: ActrId,
899 credential: AIdCredential,
900 availability: ServiceAvailabilityState,
901 power_reserve: f32,
902 mailbox_backlog: f32,
903 ) -> NetworkResult<Pong> {
904 let ping = Ping {
905 availability: availability as i32,
906 power_reserve,
907 mailbox_backlog,
908 sticky_client_ids: vec![], };
910
911 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
912 source: actor_id,
913 credential,
914 payload: Some(actr_to_signaling::Payload::Ping(ping)),
915 });
916
917 let envelope = self.create_envelope(flow).await;
918 let reply_for = envelope.envelope_id.clone();
919
920 let (tx, rx) = oneshot::channel();
922 self.pending_replies
923 .lock()
924 .await
925 .insert(reply_for.clone(), tx);
926
927 if let Err(e) = self.send_envelope(envelope).await {
928 self.pending_replies.lock().await.remove(&reply_for);
930 return Err(e);
931 }
932
933 let response_envelope = rx.await.map_err(|_| {
935 NetworkError::ConnectionError(
936 "Receiver dropped while waiting for heartbeat response".to_string(),
937 )
938 })?;
939
940 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
942 {
943 if let Some(signaling_to_actr::Payload::Pong(pong)) = server_to_actr.payload {
944 return Ok(pong);
945 }
946 }
947
948 Err(NetworkError::ConnectionError(
949 "Received response but not a Pong message".to_string(),
950 ))
951 }
952
953 #[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
954 async fn send_route_candidates_request(
955 &self,
956 actor_id: ActrId,
957 credential: AIdCredential,
958 request: RouteCandidatesRequest,
959 ) -> NetworkResult<RouteCandidatesResponse> {
960 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
961 source: actor_id,
962 credential,
963 payload: Some(actr_to_signaling::Payload::RouteCandidatesRequest(request)),
964 });
965
966 let envelope = self.create_envelope(flow).await;
967 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
968
969 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
970 {
971 match server_to_actr.payload {
972 Some(signaling_to_actr::Payload::RouteCandidatesResponse(response)) => {
973 return Ok(response);
974 }
975 Some(signaling_to_actr::Payload::Error(err)) => {
976 return Err(NetworkError::ServiceDiscoveryError(format!(
977 "{} ({})",
978 err.message, err.code
979 )));
980 }
981 _ => {}
982 }
983 }
984
985 Err(NetworkError::ConnectionError(
986 "Invalid route candidates response".to_string(),
987 ))
988 }
989
990 #[cfg_attr(
991 feature = "opentelemetry",
992 tracing::instrument(level = "debug", skip_all, fields(actor_id = %actor_id.to_string_repr()))
993 )]
994 async fn send_credential_update_request(
995 &self,
996 actor_id: ActrId,
997 credential: AIdCredential,
998 ) -> NetworkResult<RegisterResponse> {
999 let request = CredentialUpdateRequest {
1000 actr_id: actor_id.clone(),
1001 };
1002
1003 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
1004 source: actor_id,
1005 credential,
1006 payload: Some(actr_to_signaling::Payload::CredentialUpdateRequest(request)),
1007 });
1008
1009 let envelope = self.create_envelope(flow).await;
1010 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
1011
1012 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
1013 {
1014 match server_to_actr.payload {
1015 Some(signaling_to_actr::Payload::RegisterResponse(response)) => {
1016 return Ok(response);
1017 }
1018 Some(signaling_to_actr::Payload::Error(err)) => {
1019 return Err(NetworkError::ConnectionError(format!(
1020 "Credential update failed: {} ({})",
1021 err.message, err.code
1022 )));
1023 }
1024 _ => {}
1025 }
1026 }
1027
1028 Err(NetworkError::ConnectionError(
1029 "Invalid credential update response".to_string(),
1030 ))
1031 }
1032
1033 #[cfg_attr(
1034 feature = "opentelemetry",
1035 tracing::instrument(level = "debug", skip_all, fields(envelope_id = %envelope.envelope_id))
1036 )]
1037 async fn send_envelope(&self, envelope: SignalingEnvelope) -> NetworkResult<()> {
1038 #[cfg(feature = "opentelemetry")]
1039 let envelope = {
1040 let mut envelope = envelope;
1041 trace::inject_span_context(&tracing::Span::current(), &mut envelope);
1042 envelope
1043 };
1044
1045 if !self.is_connected() {
1048 return Err(NetworkError::ConnectionError(
1049 "Cannot send: WebSocket not connected".to_string(),
1050 ));
1051 }
1052
1053 let mut sink_guard = self.ws_sink.lock().await;
1054
1055 if let Some(sink) = sink_guard.as_mut() {
1056 let mut buf = Vec::new();
1058 envelope.encode(&mut buf)?;
1059 let msg = tokio_tungstenite::tungstenite::Message::Binary(buf.into());
1060 sink.send(msg).await?;
1061
1062 self.stats.messages_sent.fetch_add(1, Ordering::Relaxed);
1063 tracing::debug!("Stats: {:?}", self.stats.snapshot());
1064 Ok(())
1065 } else {
1066 Err(NetworkError::ConnectionError("Not connected".to_string()))
1067 }
1068 }
1069
1070 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
1071 let mut rx = self.inbound_rx.lock().await;
1072 match rx.recv().await {
1073 Some(envelope) => Ok(Some(envelope)),
1074 None => {
1075 tracing::error!("Inbound channel closed");
1076 Err(NetworkError::ConnectionError(
1077 "Inbound channel closed".to_string(),
1078 ))
1079 }
1080 }
1081 }
1082
1083 fn is_connected(&self) -> bool {
1084 self.connected.load(Ordering::Acquire)
1085 }
1086
1087 fn get_stats(&self) -> SignalingStats {
1088 self.stats.snapshot()
1089 }
1090
1091 fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
1092 self.state_tx.subscribe()
1093 }
1094
1095 async fn set_actor_id(&self, actor_id: ActrId) {
1096 *self.actor_id.lock().await = Some(actor_id);
1097 }
1098
1099 async fn set_credential_state(&self, credential_state: CredentialState) {
1100 *self.credential_state.lock().await = Some(credential_state);
1101 }
1102}
1103
1104#[derive(Debug)]
1106pub(crate) struct AtomicSignalingStats {
1107 pub connections: AtomicU64,
1109
1110 pub disconnections: AtomicU64,
1112
1113 pub messages_sent: AtomicU64,
1115
1116 pub messages_received: AtomicU64,
1118
1119 pub heartbeats_sent: AtomicU64,
1122
1123 pub heartbeats_received: AtomicU64,
1126
1127 pub errors: AtomicU64,
1129}
1130
1131impl Default for AtomicSignalingStats {
1132 fn default() -> Self {
1133 Self {
1134 connections: AtomicU64::new(0),
1135 disconnections: AtomicU64::new(0),
1136 messages_sent: AtomicU64::new(0),
1137 messages_received: AtomicU64::new(0),
1138 heartbeats_sent: AtomicU64::new(0),
1139 heartbeats_received: AtomicU64::new(0),
1140 errors: AtomicU64::new(0),
1141 }
1142 }
1143}
1144
1145#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
1147pub struct SignalingStats {
1148 pub connections: u64,
1150
1151 pub disconnections: u64,
1153
1154 pub messages_sent: u64,
1156
1157 pub messages_received: u64,
1159
1160 pub heartbeats_sent: u64,
1162
1163 pub heartbeats_received: u64,
1165
1166 pub errors: u64,
1168}
1169
1170impl AtomicSignalingStats {
1171 pub fn snapshot(&self) -> SignalingStats {
1173 SignalingStats {
1174 connections: self.connections.load(Ordering::Relaxed),
1175 disconnections: self.disconnections.load(Ordering::Relaxed),
1176 messages_sent: self.messages_sent.load(Ordering::Relaxed),
1177 messages_received: self.messages_received.load(Ordering::Relaxed),
1178 heartbeats_sent: self.heartbeats_sent.load(Ordering::Relaxed),
1179 heartbeats_received: self.heartbeats_received.load(Ordering::Relaxed),
1180 errors: self.errors.load(Ordering::Relaxed),
1181 }
1182 }
1183}
1184
1185fn current_unix_secs() -> u64 {
1186 use std::time::{SystemTime, UNIX_EPOCH};
1187 SystemTime::now()
1188 .duration_since(UNIX_EPOCH)
1189 .unwrap_or_default()
1190 .as_secs()
1191}
1192
1193#[cfg(test)]
1194mod tests {
1195 use super::*;
1196 use std::sync::atomic::{AtomicUsize, Ordering as UsizeOrdering};
1197 use tokio_util::sync::CancellationToken;
1198
1199 struct FakeSignalingClient {
1201 state_tx: watch::Sender<ConnectionState>,
1202 connect_calls: Arc<AtomicUsize>,
1203 actor_id: tokio::sync::Mutex<Option<ActrId>>,
1204 credential_state: tokio::sync::Mutex<Option<CredentialState>>,
1205 }
1206
1207 #[async_trait]
1208 impl SignalingClient for FakeSignalingClient {
1209 async fn connect(&self) -> NetworkResult<()> {
1210 self.connect_calls.fetch_add(1, UsizeOrdering::SeqCst);
1211 Ok(())
1212 }
1213
1214 async fn disconnect(&self) -> NetworkResult<()> {
1215 Ok(())
1216 }
1217
1218 async fn send_register_request(
1219 &self,
1220 _request: RegisterRequest,
1221 ) -> NetworkResult<RegisterResponse> {
1222 unimplemented!("not needed in tests");
1223 }
1224
1225 async fn send_unregister_request(
1226 &self,
1227 _actor_id: ActrId,
1228 _credential: AIdCredential,
1229 _reason: Option<String>,
1230 ) -> NetworkResult<UnregisterResponse> {
1231 unimplemented!("not needed in tests");
1232 }
1233
1234 async fn send_heartbeat(
1235 &self,
1236 _actor_id: ActrId,
1237 _credential: AIdCredential,
1238 _availability: ServiceAvailabilityState,
1239 _power_reserve: f32,
1240 _mailbox_backlog: f32,
1241 ) -> NetworkResult<Pong> {
1242 unimplemented!("not needed in tests");
1243 }
1244
1245 async fn send_route_candidates_request(
1246 &self,
1247 _actor_id: ActrId,
1248 _credential: AIdCredential,
1249 _request: RouteCandidatesRequest,
1250 ) -> NetworkResult<RouteCandidatesResponse> {
1251 unimplemented!("not needed in tests");
1252 }
1253
1254 async fn send_credential_update_request(
1255 &self,
1256 _actor_id: ActrId,
1257 _credential: AIdCredential,
1258 ) -> NetworkResult<RegisterResponse> {
1259 unimplemented!("not needed in tests");
1260 }
1261
1262 async fn send_envelope(&self, _envelope: SignalingEnvelope) -> NetworkResult<()> {
1263 unimplemented!("not needed in tests");
1264 }
1265
1266 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
1267 unimplemented!("not needed in tests");
1268 }
1269
1270 fn is_connected(&self) -> bool {
1271 *self.state_tx.borrow() == ConnectionState::Connected
1273 }
1274
1275 fn get_stats(&self) -> SignalingStats {
1276 SignalingStats::default()
1277 }
1278
1279 fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
1280 self.state_tx.subscribe()
1281 }
1282
1283 async fn set_actor_id(&self, actor_id: ActrId) {
1284 *self.actor_id.lock().await = Some(actor_id);
1285 }
1286
1287 async fn set_credential_state(&self, credential_state: CredentialState) {
1288 *self.credential_state.lock().await = Some(credential_state);
1289 }
1290 }
1291
1292 fn make_fake_client() -> (Arc<FakeSignalingClient>, watch::Sender<ConnectionState>) {
1293 let (state_tx, _rx) = watch::channel(ConnectionState::Disconnected);
1294 let client = Arc::new(FakeSignalingClient {
1295 state_tx: state_tx.clone(),
1296 connect_calls: Arc::new(AtomicUsize::new(0)),
1297 actor_id: tokio::sync::Mutex::new(None),
1298 credential_state: tokio::sync::Mutex::new(None),
1299 });
1300 (client, state_tx)
1301 }
1302
1303 #[test]
1304 fn test_websocket_signaling_client_initial_state_disconnected() {
1305 let config = SignalingConfig {
1307 server_url: Url::parse("ws://example.com/signaling/ws").unwrap(),
1308 connection_timeout: 30,
1309 heartbeat_interval: 30,
1310 reconnect_config: ReconnectConfig::default(),
1311 auth_config: None,
1312 };
1313
1314 let client = WebSocketSignalingClient::new(config);
1315 let state_rx = client.subscribe_state();
1316 assert_eq!(*state_rx.borrow(), ConnectionState::Disconnected);
1317 }
1318}