1use crate::core::{
2 AvailabilityStatus, HealthStatus, ProtoFile, ServiceDetails, ServiceDiscovery, ServiceFilter,
3 ServiceInfo,
4};
5use actr_hyper::AisClient;
6use actr_protocol::{
7 AIdCredential, ActrId, ActrToSignaling, ActrType, DiscoveryRequest, ErrorResponse,
8 GetServiceSpecRequest, Realm, RegisterAuthMode, RegisterRequest, SignalingEnvelope,
9 actr_to_signaling, discovery_response, get_service_spec_response, register_response,
10 signaling_envelope, signaling_to_actr,
11};
12use anyhow::{Context, Result, anyhow};
13use async_trait::async_trait;
14use base64::Engine as _;
15use futures_util::{SinkExt, StreamExt};
16use prost::Message;
17use std::path::PathBuf;
18use std::time::SystemTime;
19use tokio::{
20 sync::Mutex,
21 time::{Duration, sleep},
22};
23use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};
24use url::Url;
25
26type SignalingSocket =
27 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
28
29struct SignalingState {
30 socket: SignalingSocket,
31 actr_id: ActrId,
32 credential: AIdCredential,
33}
34
35#[derive(Debug, Clone)]
40pub struct DiscoveryContext {
41 pub package_actr_type: ActrType,
43
44 pub signaling_url: Url,
46
47 pub ais_endpoint: String,
49
50 pub realm: Realm,
52
53 pub realm_secret: Option<String>,
55}
56
57pub struct NetworkServiceDiscovery {
58 context: DiscoveryContext,
59 state: Mutex<Option<SignalingState>>,
60}
61
62impl NetworkServiceDiscovery {
63 const LOOKUP_RETRY_ATTEMPTS: usize = 45;
64 const LOOKUP_RETRY_DELAY: Duration = Duration::from_secs(2);
65
66 pub fn new(context: DiscoveryContext) -> Self {
67 Self {
68 context,
69 state: Mutex::new(None),
70 }
71 }
72
73 fn format_actr_type(actr_type: &ActrType) -> String {
74 actr_type.to_string_repr()
75 }
76
77 async fn ensure_connected(&self) -> Result<()> {
78 let mut state_guard = self.state.lock().await;
79 if state_guard.is_some() {
80 return Ok(());
81 }
82
83 let state = self.connect_and_register().await?;
84 *state_guard = Some(state);
85 Ok(())
86 }
87
88 async fn discover_entries(
90 &self,
91 _filter: Option<&ServiceFilter>,
92 ) -> Result<Vec<discovery_response::TypeEntry>> {
93 self.ensure_connected().await?;
94 let mut state_guard = self.state.lock().await;
95 let state = state_guard
96 .as_mut()
97 .context("Signaling state not initialized")?;
98
99 let request = DiscoveryRequest {
101 manufacturer: None,
102 limit: None,
103 };
104 let payload = actr_to_signaling::Payload::DiscoveryRequest(request);
105 let envelope =
106 Self::build_envelope(signaling_envelope::Flow::ActrToServer(ActrToSignaling {
107 source: state.actr_id.clone(),
108 credential: state.credential.clone(),
109 payload: Some(payload),
110 }))?;
111
112 let result = match Self::send_envelope(&mut state.socket, envelope).await {
113 Ok(()) => loop {
114 let envelope = Self::read_envelope(&mut state.socket).await?;
115 match envelope.flow {
116 Some(signaling_envelope::Flow::ServerToActr(server)) => match server.payload {
117 Some(signaling_to_actr::Payload::DiscoveryResponse(response)) => {
118 break Self::handle_discovery_response(response);
119 }
120 Some(signaling_to_actr::Payload::Error(error)) => {
121 break Err(Self::as_error("Discovery failed", &error));
122 }
123 _ => {}
124 },
125 Some(signaling_envelope::Flow::EnvelopeError(error)) => {
126 break Err(Self::as_error("Discovery failed", &error));
127 }
128 _ => {}
129 }
130 },
131 Err(err) => Err(err),
132 };
133 if result.is_err() {
134 *state_guard = None;
135 }
136 result
137 }
138
139 fn handle_discovery_response(
140 response: actr_protocol::DiscoveryResponse,
141 ) -> Result<Vec<discovery_response::TypeEntry>> {
142 match response.result {
143 Some(discovery_response::Result::Success(success)) => Ok(success.entries),
144 Some(discovery_response::Result::Error(error)) => {
145 Err(Self::as_error("Discovery failed", &error))
146 }
147 None => Err(anyhow!("Discovery response is missing result")),
148 }
149 }
150
151 async fn connect_and_register(&self) -> Result<SignalingState> {
152 let realm_secret = self.required_realm_secret()?.to_string();
153 let register_request = self.build_linked_register_request();
154
155 let ais_client = AisClient::new(&self.context.ais_endpoint).with_realm_secret(realm_secret);
156
157 let register_response = ais_client
158 .register_linked(register_request)
159 .await
160 .map_err(|err| anyhow!("AIS HTTP registration failed: {err}"))?;
161
162 let (actr_id, credential) = match register_response.result {
163 Some(register_response::Result::Success(success)) => {
164 (success.actr_id, success.credential)
165 }
166 Some(register_response::Result::Error(error)) => {
167 return Err(Self::as_error("AIS registration failed", &error));
168 }
169 None => return Err(anyhow!("AIS registration response is missing result")),
170 };
171
172 let signaling_url = Self::build_signaling_url_with_identity(
173 &self.context.signaling_url,
174 &actr_id,
175 &credential,
176 );
177 let (socket, _) = connect_async(signaling_url.as_str())
178 .await
179 .with_context(|| format!("Failed to connect to signaling: {signaling_url}"))?;
180
181 Ok(SignalingState {
182 socket,
183 actr_id,
184 credential,
185 })
186 }
187
188 fn build_signaling_url_with_identity(
189 signaling_url: &Url,
190 actr_id: &ActrId,
191 credential: &AIdCredential,
192 ) -> Url {
193 let mut url = signaling_url.clone();
194 let claims_b64 = base64::engine::general_purpose::STANDARD.encode(&credential.claims);
195 let signature_b64 = base64::engine::general_purpose::STANDARD.encode(&credential.signature);
196
197 url.query_pairs_mut()
198 .append_pair("actor_id", &actr_id.to_string_repr())
199 .append_pair("key_id", &credential.key_id.to_string())
200 .append_pair("claims", &claims_b64)
201 .append_pair("signature", &signature_b64);
202
203 url
204 }
205
206 fn as_error(context: &str, error: &ErrorResponse) -> anyhow::Error {
207 anyhow!("{context}: {} ({})", error.message, error.code)
208 }
209
210 async fn retry_lookup<T, F, Fut>(&self, context: &str, mut lookup: F) -> Result<T>
211 where
212 F: FnMut() -> Fut,
213 Fut: std::future::Future<Output = Result<Option<T>>>,
214 {
215 let mut last_error = None;
216
217 for attempt in 0..Self::LOOKUP_RETRY_ATTEMPTS {
218 match lookup().await {
219 Ok(Some(value)) => return Ok(value),
220 Ok(None) => last_error = Some(anyhow!("{context}")),
221 Err(err) => last_error = Some(err),
222 }
223
224 if attempt + 1 < Self::LOOKUP_RETRY_ATTEMPTS {
225 sleep(Self::LOOKUP_RETRY_DELAY).await;
226 }
227 }
228
229 Err(last_error.unwrap_or_else(|| anyhow!("{context}")))
230 }
231
232 async fn send_envelope(
233 socket: &mut SignalingSocket,
234 envelope: SignalingEnvelope,
235 ) -> Result<()> {
236 let mut buf = Vec::new();
237 envelope
238 .encode(&mut buf)
239 .context("Failed to encode signaling envelope")?;
240 socket
241 .send(WsMessage::Binary(buf.into()))
242 .await
243 .context("Failed to send signaling envelope")?;
244 Ok(())
245 }
246
247 async fn read_envelope(socket: &mut SignalingSocket) -> Result<SignalingEnvelope> {
248 while let Some(message) = socket.next().await {
249 match message.context("Failed to read signaling response")? {
250 WsMessage::Binary(bytes) => {
251 return SignalingEnvelope::decode(bytes)
252 .context("Failed to decode signaling envelope");
253 }
254 WsMessage::Close(_) => {
255 return Err(anyhow!("Signaling connection closed"));
256 }
257 WsMessage::Ping(_) | WsMessage::Pong(_) => {}
258 WsMessage::Text(text) => {
259 return Err(anyhow!("Unexpected text message from signaling: {text}"));
260 }
261 WsMessage::Frame(_) => {}
262 }
263 }
264
265 Err(anyhow!("Signaling connection closed"))
266 }
267
268 fn build_envelope(flow: signaling_envelope::Flow) -> Result<SignalingEnvelope> {
269 Ok(SignalingEnvelope {
270 envelope_version: 1,
271 envelope_id: uuid::Uuid::new_v4().to_string(),
272 reply_for: None,
273 timestamp: prost_types::Timestamp {
274 seconds: chrono::Utc::now().timestamp(),
275 nanos: 0,
276 },
277 traceparent: None,
278 tracestate: None,
279 flow: Some(flow),
280 })
281 }
282
283 fn select_version(entry: &discovery_response::TypeEntry) -> String {
284 entry
285 .tags
286 .iter()
287 .find(|tag| tag.as_str() == "latest")
288 .cloned()
289 .or_else(|| entry.tags.first().cloned())
290 .unwrap_or_else(|| "unknown".to_string())
291 }
292
293 fn matches_filter(entry: &discovery_response::TypeEntry, filter: &ServiceFilter) -> bool {
294 if let Some(pattern) = &filter.name_pattern {
295 let full_name = Self::format_actr_type(&entry.actr_type);
296 let matches = Self::matches_pattern(&entry.name, pattern)
297 || Self::matches_pattern(&full_name, pattern);
298 if !matches {
299 return false;
300 }
301 }
302
303 if let Some(version_range) = &filter.version_range
304 && Self::select_version(entry) != *version_range
305 && !entry.tags.iter().any(|tag| tag == version_range)
306 {
307 return false;
308 }
309
310 if let Some(tags) = &filter.tags {
311 let has_all = tags.iter().all(|tag| entry.tags.iter().any(|t| t == tag));
312 if !has_all {
313 return false;
314 }
315 }
316
317 true
318 }
319
320 fn matches_pattern(value: &str, pattern: &str) -> bool {
321 if pattern == "*" {
322 return true;
323 }
324
325 let segments: Vec<&str> = pattern.split('*').collect();
326 if segments.len() == 1 {
327 return value == pattern;
328 }
329
330 if !pattern.starts_with('*')
331 && let Some(first) = segments.first()
332 && !value.starts_with(first)
333 {
334 return false;
335 }
336
337 if !pattern.ends_with('*')
338 && let Some(last) = segments.last()
339 && !value.ends_with(last)
340 {
341 return false;
342 }
343
344 let mut search_start = 0;
345 let end_limit = if !pattern.ends_with('*') {
346 value
347 .len()
348 .saturating_sub(segments.last().unwrap_or(&"").len())
349 } else {
350 value.len()
351 };
352
353 for (index, segment) in segments.iter().enumerate() {
354 if segment.is_empty() {
355 continue;
356 }
357 if index == 0 && !pattern.starts_with('*') {
358 search_start = segment.len();
359 continue;
360 }
361 if index == segments.len() - 1 && !pattern.ends_with('*') {
362 continue;
363 }
364 if let Some(found) = value[search_start..end_limit].find(segment) {
365 search_start += found + segment.len();
366 } else {
367 return false;
368 }
369 }
370
371 true
372 }
373
374 fn matches_lookup_name(entry: &discovery_response::TypeEntry, name: &str) -> bool {
375 if entry.name == name || Self::format_actr_type(&entry.actr_type) == name {
376 return true;
377 }
378
379 let Ok(lookup_type) = ActrType::from_string_repr(name) else {
380 return false;
381 };
382
383 entry.actr_type == lookup_type
384 }
385
386 fn required_realm_secret(&self) -> Result<&str> {
387 self.context
388 .realm_secret
389 .as_deref()
390 .map(str::trim)
391 .filter(|secret| !secret.is_empty())
392 .ok_or_else(|| {
393 anyhow!("network.realm_secret is required for CLI service discovery registration")
394 })
395 }
396
397 fn build_linked_register_request(&self) -> RegisterRequest {
398 RegisterRequest {
399 actr_type: self.context.package_actr_type.clone(),
400 realm: self.context.realm,
401 service_spec: None,
402 service: None,
403 acl: None,
404 ws_address: None,
405 manifest_raw: None,
406 mfr_signature: None,
407 psk_token: None,
408 target: None,
409 auth_mode: Some(RegisterAuthMode::Linked as i32),
410 }
411 }
412}
413
414#[async_trait]
415impl ServiceDiscovery for NetworkServiceDiscovery {
416 async fn discover_services(&self, filter: Option<&ServiceFilter>) -> Result<Vec<ServiceInfo>> {
417 let entries = self.discover_entries(filter).await?;
418 let services = entries
419 .into_iter()
420 .filter(|entry| match filter {
421 Some(filter) => Self::matches_filter(entry, filter),
422 None => true,
423 })
424 .map(ServiceInfo::from)
425 .collect();
426 Ok(services)
427 }
428
429 async fn get_service_details(&self, name: &str) -> Result<ServiceDetails> {
430 let entry = self
431 .retry_lookup(&format!("Service not found: {name}"), || async {
432 let entries = self.discover_entries(None).await?;
433 Ok(entries
434 .into_iter()
435 .find(|entry| Self::matches_lookup_name(entry, name)))
436 })
437 .await?;
438 let info = ServiceInfo::from(entry.clone());
439
440 let spec_lookup_name = &entry.actr_type.name;
444 let proto_files = match self.get_service_proto(spec_lookup_name).await {
445 Ok(proto_files) => proto_files,
446 Err(e) => {
447 tracing::warn!("Failed to get ServiceSpec for {name}: {e}");
448 Vec::new()
449 }
450 };
451
452 Ok(ServiceDetails {
453 info,
454 proto_files,
455 dependencies: Vec::new(),
456 })
457 }
458
459 async fn check_service_availability(&self, name: &str) -> Result<AvailabilityStatus> {
461 let available = self
462 .retry_lookup(&format!("Service not found: {name}"), || async {
463 let entries = self.discover_entries(None).await?;
464 Ok(entries
465 .into_iter()
466 .any(|entry| Self::matches_lookup_name(&entry, name))
467 .then_some(true))
468 })
469 .await
470 .unwrap_or(false);
471
472 Ok(AvailabilityStatus {
473 is_available: available,
474 last_seen: available.then(SystemTime::now),
475 health: if available {
476 HealthStatus::Healthy
477 } else {
478 HealthStatus::Unknown
479 },
480 })
481 }
482
483 async fn get_service_proto(&self, name: &str) -> Result<Vec<ProtoFile>> {
484 self.retry_lookup(&format!("Get service spec failed: {name}"), || async {
485 self.ensure_connected().await?;
486 let mut state_guard = self.state.lock().await;
487 let state = state_guard
488 .as_mut()
489 .context("Signaling state not initialized")?;
490
491 let request = GetServiceSpecRequest {
492 name: name.to_string(),
493 };
494 let payload = actr_to_signaling::Payload::GetServiceSpecRequest(request);
495 let envelope =
496 Self::build_envelope(signaling_envelope::Flow::ActrToServer(ActrToSignaling {
497 source: state.actr_id.clone(),
498 credential: state.credential.clone(),
499 payload: Some(payload),
500 }))?;
501
502 let result = match Self::send_envelope(&mut state.socket, envelope).await {
503 Ok(()) => loop {
504 let envelope = Self::read_envelope(&mut state.socket).await?;
505 match envelope.flow {
506 Some(signaling_envelope::Flow::ServerToActr(server)) => {
507 match server.payload {
508 Some(signaling_to_actr::Payload::GetServiceSpecResponse(
509 response,
510 )) => {
511 let proto_files = match response.result {
512 Some(get_service_spec_response::Result::Success(
513 success,
514 )) => success
515 .protobufs
516 .into_iter()
517 .map(|p| ProtoFile {
518 name: format!("{}.proto", p.package),
519 path: PathBuf::new(),
520 content: p.content,
521 services: Vec::new(),
522 })
523 .collect::<Vec<_>>(),
524 Some(get_service_spec_response::Result::Error(error)) => {
525 break Err(Self::as_error(
526 "Get service spec failed",
527 &error,
528 ));
529 }
530 None => {
531 break Err(anyhow!(
532 "Get service spec response is missing result"
533 ));
534 }
535 };
536 break Ok(Some(proto_files));
537 }
538 Some(signaling_to_actr::Payload::Error(error)) => {
539 break Err(Self::as_error("Get service spec failed", &error));
540 }
541 _ => {}
542 }
543 }
544 Some(signaling_envelope::Flow::EnvelopeError(error)) => {
545 break Err(Self::as_error("Get service spec failed", &error));
546 }
547 _ => {}
548 }
549 },
550 Err(err) => Err(err),
551 };
552
553 if result.is_err() {
554 *state_guard = None;
555 }
556
557 result
558 })
559 .await
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566 use actr_protocol::Realm;
567
568 fn sample_context(realm_secret: Option<&str>) -> DiscoveryContext {
569 DiscoveryContext {
570 package_actr_type: ActrType {
571 manufacturer: "acme".to_string(),
572 name: "cli-client".to_string(),
573 version: "1.0.0".to_string(),
574 },
575 signaling_url: Url::parse("ws://localhost:8081/signaling/ws").unwrap(),
576 ais_endpoint: "http://localhost:8081/ais".to_string(),
577 realm: Realm { realm_id: 1001 },
578 realm_secret: realm_secret.map(str::to_string),
579 }
580 }
581
582 fn sample_actor_id() -> ActrId {
583 ActrId {
584 serial_number: 42,
585 r#type: ActrType {
586 manufacturer: "acme".to_string(),
587 name: "echo".to_string(),
588 version: "1.0.0".to_string(),
589 },
590 realm: Realm { realm_id: 1001 },
591 }
592 }
593
594 fn sample_credential() -> AIdCredential {
595 AIdCredential {
596 key_id: 7,
597 claims: vec![1, 2, 3, 4].into(),
598 signature: vec![5, 6, 7, 8].into(),
599 }
600 }
601
602 #[test]
603 fn build_signaling_url_with_identity_appends_auth_query() {
604 let signaling_url = Url::parse("ws://localhost:8081/signaling/ws?existing=1").unwrap();
605 let actor_id = sample_actor_id();
606 let credential = sample_credential();
607
608 let authenticated_url = NetworkServiceDiscovery::build_signaling_url_with_identity(
609 &signaling_url,
610 &actor_id,
611 &credential,
612 );
613 let query_pairs: std::collections::HashMap<_, _> =
614 authenticated_url.query_pairs().into_owned().collect();
615
616 assert_eq!(query_pairs.get("existing"), Some(&"1".to_string()));
617 assert_eq!(
618 query_pairs.get("actor_id"),
619 Some(&actor_id.to_string_repr())
620 );
621 assert_eq!(query_pairs.get("key_id"), Some(&"7".to_string()));
622 assert_eq!(
623 query_pairs.get("claims"),
624 Some(&base64::engine::general_purpose::STANDARD.encode([1, 2, 3, 4]))
625 );
626 assert_eq!(
627 query_pairs.get("signature"),
628 Some(&base64::engine::general_purpose::STANDARD.encode([5, 6, 7, 8]))
629 );
630 }
631
632 #[test]
633 fn cli_discovery_register_request_uses_linked_auth_mode() {
634 let discovery = NetworkServiceDiscovery::new(sample_context(Some("rs_test_secret")));
635 let request = discovery.build_linked_register_request();
636
637 assert_eq!(request.auth_mode, Some(RegisterAuthMode::Linked as i32));
638 assert_eq!(request.manifest_raw, None);
639 assert_eq!(request.mfr_signature, None);
640 assert_eq!(request.psk_token, None);
641 assert_eq!(request.target, None);
642 assert_eq!(request.actr_type.name, "cli-client");
643 assert_eq!(request.realm.realm_id, 1001);
644 }
645
646 #[test]
647 fn cli_discovery_requires_realm_secret() {
648 let missing = NetworkServiceDiscovery::new(sample_context(None));
649 let err = missing.required_realm_secret().unwrap_err();
650 assert!(err.to_string().contains("network.realm_secret is required"));
651
652 let blank = NetworkServiceDiscovery::new(sample_context(Some(" ")));
653 let err = blank.required_realm_secret().unwrap_err();
654 assert!(err.to_string().contains("network.realm_secret is required"));
655 }
656}