Skip to main content

actr_cli/core/components/
service_discovery.rs

1use crate::core::{
2    AvailabilityStatus, HealthStatus, ProtoFile, ServiceDetails, ServiceDiscovery, ServiceFilter,
3    ServiceInfo,
4};
5use actr_hyper::AisClient;
6use actr_protocol::{
7    AIdCredential, ActrId, ActrToSignaling, ActrType, DiscoveryRequest, ErrorResponse,
8    GetServiceSpecRequest, Realm, RegisterAuthMode, RegisterRequest, SignalingEnvelope,
9    actr_to_signaling, discovery_response, get_service_spec_response, register_response,
10    signaling_envelope, signaling_to_actr,
11};
12use anyhow::{Context, Result, anyhow};
13use async_trait::async_trait;
14use base64::Engine as _;
15use futures_util::{SinkExt, StreamExt};
16use prost::Message;
17use std::path::PathBuf;
18use std::time::SystemTime;
19use tokio::{
20    sync::Mutex,
21    time::{Duration, sleep},
22};
23use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};
24use url::Url;
25
26type SignalingSocket =
27    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
28
29struct SignalingState {
30    socket: SignalingSocket,
31    actr_id: ActrId,
32    credential: AIdCredential,
33}
34
35/// Discovery context for CLI service discovery operations
36///
37/// This contains the minimal information needed for CLI to perform service discovery,
38/// separate from runtime configuration (actr.toml).
39#[derive(Debug, Clone)]
40pub struct DiscoveryContext {
41    /// Package actor type (from manifest.toml)
42    pub package_actr_type: ActrType,
43
44    /// Signaling server URL
45    pub signaling_url: Url,
46
47    /// AIS endpoint URL
48    pub ais_endpoint: String,
49
50    /// Realm for temporary actor registration
51    pub realm: Realm,
52
53    /// Optional realm secret for authentication
54    pub realm_secret: Option<String>,
55}
56
57pub struct NetworkServiceDiscovery {
58    context: DiscoveryContext,
59    state: Mutex<Option<SignalingState>>,
60}
61
62impl NetworkServiceDiscovery {
63    const LOOKUP_RETRY_ATTEMPTS: usize = 45;
64    const LOOKUP_RETRY_DELAY: Duration = Duration::from_secs(2);
65
66    pub fn new(context: DiscoveryContext) -> Self {
67        Self {
68            context,
69            state: Mutex::new(None),
70        }
71    }
72
73    fn format_actr_type(actr_type: &ActrType) -> String {
74        actr_type.to_string_repr()
75    }
76
77    async fn ensure_connected(&self) -> Result<()> {
78        let mut state_guard = self.state.lock().await;
79        if state_guard.is_some() {
80            return Ok(());
81        }
82
83        let state = self.connect_and_register().await?;
84        *state_guard = Some(state);
85        Ok(())
86    }
87
88    // TODO: add filter support
89    async fn discover_entries(
90        &self,
91        _filter: Option<&ServiceFilter>,
92    ) -> Result<Vec<discovery_response::TypeEntry>> {
93        self.ensure_connected().await?;
94        let mut state_guard = self.state.lock().await;
95        let state = state_guard
96            .as_mut()
97            .context("Signaling state not initialized")?;
98
99        // TODO: add filter support
100        let request = DiscoveryRequest {
101            manufacturer: None,
102            limit: None,
103        };
104        let payload = actr_to_signaling::Payload::DiscoveryRequest(request);
105        let envelope =
106            Self::build_envelope(signaling_envelope::Flow::ActrToServer(ActrToSignaling {
107                source: state.actr_id.clone(),
108                credential: state.credential.clone(),
109                payload: Some(payload),
110            }))?;
111
112        let result = match Self::send_envelope(&mut state.socket, envelope).await {
113            Ok(()) => loop {
114                let envelope = Self::read_envelope(&mut state.socket).await?;
115                match envelope.flow {
116                    Some(signaling_envelope::Flow::ServerToActr(server)) => match server.payload {
117                        Some(signaling_to_actr::Payload::DiscoveryResponse(response)) => {
118                            break Self::handle_discovery_response(response);
119                        }
120                        Some(signaling_to_actr::Payload::Error(error)) => {
121                            break Err(Self::as_error("Discovery failed", &error));
122                        }
123                        _ => {}
124                    },
125                    Some(signaling_envelope::Flow::EnvelopeError(error)) => {
126                        break Err(Self::as_error("Discovery failed", &error));
127                    }
128                    _ => {}
129                }
130            },
131            Err(err) => Err(err),
132        };
133        if result.is_err() {
134            *state_guard = None;
135        }
136        result
137    }
138
139    fn handle_discovery_response(
140        response: actr_protocol::DiscoveryResponse,
141    ) -> Result<Vec<discovery_response::TypeEntry>> {
142        match response.result {
143            Some(discovery_response::Result::Success(success)) => Ok(success.entries),
144            Some(discovery_response::Result::Error(error)) => {
145                Err(Self::as_error("Discovery failed", &error))
146            }
147            None => Err(anyhow!("Discovery response is missing result")),
148        }
149    }
150
151    async fn connect_and_register(&self) -> Result<SignalingState> {
152        let realm_secret = self.required_realm_secret()?.to_string();
153        let register_request = self.build_linked_register_request();
154
155        let ais_client = AisClient::new(&self.context.ais_endpoint).with_realm_secret(realm_secret);
156
157        let register_response = ais_client
158            .register_linked(register_request)
159            .await
160            .map_err(|err| anyhow!("AIS HTTP registration failed: {err}"))?;
161
162        let (actr_id, credential) = match register_response.result {
163            Some(register_response::Result::Success(success)) => {
164                (success.actr_id, success.credential)
165            }
166            Some(register_response::Result::Error(error)) => {
167                return Err(Self::as_error("AIS registration failed", &error));
168            }
169            None => return Err(anyhow!("AIS registration response is missing result")),
170        };
171
172        let signaling_url = Self::build_signaling_url_with_identity(
173            &self.context.signaling_url,
174            &actr_id,
175            &credential,
176        );
177        let (socket, _) = connect_async(signaling_url.as_str())
178            .await
179            .with_context(|| format!("Failed to connect to signaling: {signaling_url}"))?;
180
181        Ok(SignalingState {
182            socket,
183            actr_id,
184            credential,
185        })
186    }
187
188    fn build_signaling_url_with_identity(
189        signaling_url: &Url,
190        actr_id: &ActrId,
191        credential: &AIdCredential,
192    ) -> Url {
193        let mut url = signaling_url.clone();
194        let claims_b64 = base64::engine::general_purpose::STANDARD.encode(&credential.claims);
195        let signature_b64 = base64::engine::general_purpose::STANDARD.encode(&credential.signature);
196
197        url.query_pairs_mut()
198            .append_pair("actor_id", &actr_id.to_string_repr())
199            .append_pair("key_id", &credential.key_id.to_string())
200            .append_pair("claims", &claims_b64)
201            .append_pair("signature", &signature_b64);
202
203        url
204    }
205
206    fn as_error(context: &str, error: &ErrorResponse) -> anyhow::Error {
207        anyhow!("{context}: {} ({})", error.message, error.code)
208    }
209
210    async fn retry_lookup<T, F, Fut>(&self, context: &str, mut lookup: F) -> Result<T>
211    where
212        F: FnMut() -> Fut,
213        Fut: std::future::Future<Output = Result<Option<T>>>,
214    {
215        let mut last_error = None;
216
217        for attempt in 0..Self::LOOKUP_RETRY_ATTEMPTS {
218            match lookup().await {
219                Ok(Some(value)) => return Ok(value),
220                Ok(None) => last_error = Some(anyhow!("{context}")),
221                Err(err) => last_error = Some(err),
222            }
223
224            if attempt + 1 < Self::LOOKUP_RETRY_ATTEMPTS {
225                sleep(Self::LOOKUP_RETRY_DELAY).await;
226            }
227        }
228
229        Err(last_error.unwrap_or_else(|| anyhow!("{context}")))
230    }
231
232    async fn send_envelope(
233        socket: &mut SignalingSocket,
234        envelope: SignalingEnvelope,
235    ) -> Result<()> {
236        let mut buf = Vec::new();
237        envelope
238            .encode(&mut buf)
239            .context("Failed to encode signaling envelope")?;
240        socket
241            .send(WsMessage::Binary(buf.into()))
242            .await
243            .context("Failed to send signaling envelope")?;
244        Ok(())
245    }
246
247    async fn read_envelope(socket: &mut SignalingSocket) -> Result<SignalingEnvelope> {
248        while let Some(message) = socket.next().await {
249            match message.context("Failed to read signaling response")? {
250                WsMessage::Binary(bytes) => {
251                    return SignalingEnvelope::decode(bytes)
252                        .context("Failed to decode signaling envelope");
253                }
254                WsMessage::Close(_) => {
255                    return Err(anyhow!("Signaling connection closed"));
256                }
257                WsMessage::Ping(_) | WsMessage::Pong(_) => {}
258                WsMessage::Text(text) => {
259                    return Err(anyhow!("Unexpected text message from signaling: {text}"));
260                }
261                WsMessage::Frame(_) => {}
262            }
263        }
264
265        Err(anyhow!("Signaling connection closed"))
266    }
267
268    fn build_envelope(flow: signaling_envelope::Flow) -> Result<SignalingEnvelope> {
269        Ok(SignalingEnvelope {
270            envelope_version: 1,
271            envelope_id: uuid::Uuid::new_v4().to_string(),
272            reply_for: None,
273            timestamp: prost_types::Timestamp {
274                seconds: chrono::Utc::now().timestamp(),
275                nanos: 0,
276            },
277            traceparent: None,
278            tracestate: None,
279            flow: Some(flow),
280        })
281    }
282
283    fn select_version(entry: &discovery_response::TypeEntry) -> String {
284        entry
285            .tags
286            .iter()
287            .find(|tag| tag.as_str() == "latest")
288            .cloned()
289            .or_else(|| entry.tags.first().cloned())
290            .unwrap_or_else(|| "unknown".to_string())
291    }
292
293    fn matches_filter(entry: &discovery_response::TypeEntry, filter: &ServiceFilter) -> bool {
294        if let Some(pattern) = &filter.name_pattern {
295            let full_name = Self::format_actr_type(&entry.actr_type);
296            let matches = Self::matches_pattern(&entry.name, pattern)
297                || Self::matches_pattern(&full_name, pattern);
298            if !matches {
299                return false;
300            }
301        }
302
303        if let Some(version_range) = &filter.version_range
304            && Self::select_version(entry) != *version_range
305            && !entry.tags.iter().any(|tag| tag == version_range)
306        {
307            return false;
308        }
309
310        if let Some(tags) = &filter.tags {
311            let has_all = tags.iter().all(|tag| entry.tags.iter().any(|t| t == tag));
312            if !has_all {
313                return false;
314            }
315        }
316
317        true
318    }
319
320    fn matches_pattern(value: &str, pattern: &str) -> bool {
321        if pattern == "*" {
322            return true;
323        }
324
325        let segments: Vec<&str> = pattern.split('*').collect();
326        if segments.len() == 1 {
327            return value == pattern;
328        }
329
330        if !pattern.starts_with('*')
331            && let Some(first) = segments.first()
332            && !value.starts_with(first)
333        {
334            return false;
335        }
336
337        if !pattern.ends_with('*')
338            && let Some(last) = segments.last()
339            && !value.ends_with(last)
340        {
341            return false;
342        }
343
344        let mut search_start = 0;
345        let end_limit = if !pattern.ends_with('*') {
346            value
347                .len()
348                .saturating_sub(segments.last().unwrap_or(&"").len())
349        } else {
350            value.len()
351        };
352
353        for (index, segment) in segments.iter().enumerate() {
354            if segment.is_empty() {
355                continue;
356            }
357            if index == 0 && !pattern.starts_with('*') {
358                search_start = segment.len();
359                continue;
360            }
361            if index == segments.len() - 1 && !pattern.ends_with('*') {
362                continue;
363            }
364            if let Some(found) = value[search_start..end_limit].find(segment) {
365                search_start += found + segment.len();
366            } else {
367                return false;
368            }
369        }
370
371        true
372    }
373
374    fn matches_lookup_name(entry: &discovery_response::TypeEntry, name: &str) -> bool {
375        if entry.name == name || Self::format_actr_type(&entry.actr_type) == name {
376            return true;
377        }
378
379        let Ok(lookup_type) = ActrType::from_string_repr(name) else {
380            return false;
381        };
382
383        entry.actr_type == lookup_type
384    }
385
386    fn required_realm_secret(&self) -> Result<&str> {
387        self.context
388            .realm_secret
389            .as_deref()
390            .map(str::trim)
391            .filter(|secret| !secret.is_empty())
392            .ok_or_else(|| {
393                anyhow!("network.realm_secret is required for CLI service discovery registration")
394            })
395    }
396
397    fn build_linked_register_request(&self) -> RegisterRequest {
398        RegisterRequest {
399            actr_type: self.context.package_actr_type.clone(),
400            realm: self.context.realm,
401            service_spec: None,
402            service: None,
403            acl: None,
404            ws_address: None,
405            manifest_raw: None,
406            mfr_signature: None,
407            psk_token: None,
408            target: None,
409            auth_mode: Some(RegisterAuthMode::Linked as i32),
410        }
411    }
412}
413
414#[async_trait]
415impl ServiceDiscovery for NetworkServiceDiscovery {
416    async fn discover_services(&self, filter: Option<&ServiceFilter>) -> Result<Vec<ServiceInfo>> {
417        let entries = self.discover_entries(filter).await?;
418        let services = entries
419            .into_iter()
420            .filter(|entry| match filter {
421                Some(filter) => Self::matches_filter(entry, filter),
422                None => true,
423            })
424            .map(ServiceInfo::from)
425            .collect();
426        Ok(services)
427    }
428
429    async fn get_service_details(&self, name: &str) -> Result<ServiceDetails> {
430        let entry = self
431            .retry_lookup(&format!("Service not found: {name}"), || async {
432                let entries = self.discover_entries(None).await?;
433                Ok(entries
434                    .into_iter()
435                    .find(|entry| Self::matches_lookup_name(entry, name)))
436            })
437            .await?;
438        let info = ServiceInfo::from(entry.clone());
439
440        // Try to get ServiceSpec with proto files
441        // Use actr_type.name (e.g., "EchoService") as the lookup key,
442        // matching package ServiceSpec.name = package.name
443        let spec_lookup_name = &entry.actr_type.name;
444        let proto_files = match self.get_service_proto(spec_lookup_name).await {
445            Ok(proto_files) => proto_files,
446            Err(e) => {
447                tracing::warn!("Failed to get ServiceSpec for {name}: {e}");
448                Vec::new()
449            }
450        };
451
452        Ok(ServiceDetails {
453            info,
454            proto_files,
455            dependencies: Vec::new(),
456        })
457    }
458
459    // TODO: improve the performance of this method
460    async fn check_service_availability(&self, name: &str) -> Result<AvailabilityStatus> {
461        let available = self
462            .retry_lookup(&format!("Service not found: {name}"), || async {
463                let entries = self.discover_entries(None).await?;
464                Ok(entries
465                    .into_iter()
466                    .any(|entry| Self::matches_lookup_name(&entry, name))
467                    .then_some(true))
468            })
469            .await
470            .unwrap_or(false);
471
472        Ok(AvailabilityStatus {
473            is_available: available,
474            last_seen: available.then(SystemTime::now),
475            health: if available {
476                HealthStatus::Healthy
477            } else {
478                HealthStatus::Unknown
479            },
480        })
481    }
482
483    async fn get_service_proto(&self, name: &str) -> Result<Vec<ProtoFile>> {
484        self.retry_lookup(&format!("Get service spec failed: {name}"), || async {
485            self.ensure_connected().await?;
486            let mut state_guard = self.state.lock().await;
487            let state = state_guard
488                .as_mut()
489                .context("Signaling state not initialized")?;
490
491            let request = GetServiceSpecRequest {
492                name: name.to_string(),
493            };
494            let payload = actr_to_signaling::Payload::GetServiceSpecRequest(request);
495            let envelope =
496                Self::build_envelope(signaling_envelope::Flow::ActrToServer(ActrToSignaling {
497                    source: state.actr_id.clone(),
498                    credential: state.credential.clone(),
499                    payload: Some(payload),
500                }))?;
501
502            let result = match Self::send_envelope(&mut state.socket, envelope).await {
503                Ok(()) => loop {
504                    let envelope = Self::read_envelope(&mut state.socket).await?;
505                    match envelope.flow {
506                        Some(signaling_envelope::Flow::ServerToActr(server)) => {
507                            match server.payload {
508                                Some(signaling_to_actr::Payload::GetServiceSpecResponse(
509                                    response,
510                                )) => {
511                                    let proto_files = match response.result {
512                                        Some(get_service_spec_response::Result::Success(
513                                            success,
514                                        )) => success
515                                            .protobufs
516                                            .into_iter()
517                                            .map(|p| ProtoFile {
518                                                name: format!("{}.proto", p.package),
519                                                path: PathBuf::new(),
520                                                content: p.content,
521                                                services: Vec::new(),
522                                            })
523                                            .collect::<Vec<_>>(),
524                                        Some(get_service_spec_response::Result::Error(error)) => {
525                                            break Err(Self::as_error(
526                                                "Get service spec failed",
527                                                &error,
528                                            ));
529                                        }
530                                        None => {
531                                            break Err(anyhow!(
532                                                "Get service spec response is missing result"
533                                            ));
534                                        }
535                                    };
536                                    break Ok(Some(proto_files));
537                                }
538                                Some(signaling_to_actr::Payload::Error(error)) => {
539                                    break Err(Self::as_error("Get service spec failed", &error));
540                                }
541                                _ => {}
542                            }
543                        }
544                        Some(signaling_envelope::Flow::EnvelopeError(error)) => {
545                            break Err(Self::as_error("Get service spec failed", &error));
546                        }
547                        _ => {}
548                    }
549                },
550                Err(err) => Err(err),
551            };
552
553            if result.is_err() {
554                *state_guard = None;
555            }
556
557            result
558        })
559        .await
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566    use actr_protocol::Realm;
567
568    fn sample_context(realm_secret: Option<&str>) -> DiscoveryContext {
569        DiscoveryContext {
570            package_actr_type: ActrType {
571                manufacturer: "acme".to_string(),
572                name: "cli-client".to_string(),
573                version: "1.0.0".to_string(),
574            },
575            signaling_url: Url::parse("ws://localhost:8081/signaling/ws").unwrap(),
576            ais_endpoint: "http://localhost:8081/ais".to_string(),
577            realm: Realm { realm_id: 1001 },
578            realm_secret: realm_secret.map(str::to_string),
579        }
580    }
581
582    fn sample_actor_id() -> ActrId {
583        ActrId {
584            serial_number: 42,
585            r#type: ActrType {
586                manufacturer: "acme".to_string(),
587                name: "echo".to_string(),
588                version: "1.0.0".to_string(),
589            },
590            realm: Realm { realm_id: 1001 },
591        }
592    }
593
594    fn sample_credential() -> AIdCredential {
595        AIdCredential {
596            key_id: 7,
597            claims: vec![1, 2, 3, 4].into(),
598            signature: vec![5, 6, 7, 8].into(),
599        }
600    }
601
602    #[test]
603    fn build_signaling_url_with_identity_appends_auth_query() {
604        let signaling_url = Url::parse("ws://localhost:8081/signaling/ws?existing=1").unwrap();
605        let actor_id = sample_actor_id();
606        let credential = sample_credential();
607
608        let authenticated_url = NetworkServiceDiscovery::build_signaling_url_with_identity(
609            &signaling_url,
610            &actor_id,
611            &credential,
612        );
613        let query_pairs: std::collections::HashMap<_, _> =
614            authenticated_url.query_pairs().into_owned().collect();
615
616        assert_eq!(query_pairs.get("existing"), Some(&"1".to_string()));
617        assert_eq!(
618            query_pairs.get("actor_id"),
619            Some(&actor_id.to_string_repr())
620        );
621        assert_eq!(query_pairs.get("key_id"), Some(&"7".to_string()));
622        assert_eq!(
623            query_pairs.get("claims"),
624            Some(&base64::engine::general_purpose::STANDARD.encode([1, 2, 3, 4]))
625        );
626        assert_eq!(
627            query_pairs.get("signature"),
628            Some(&base64::engine::general_purpose::STANDARD.encode([5, 6, 7, 8]))
629        );
630    }
631
632    #[test]
633    fn cli_discovery_register_request_uses_linked_auth_mode() {
634        let discovery = NetworkServiceDiscovery::new(sample_context(Some("rs_test_secret")));
635        let request = discovery.build_linked_register_request();
636
637        assert_eq!(request.auth_mode, Some(RegisterAuthMode::Linked as i32));
638        assert_eq!(request.manifest_raw, None);
639        assert_eq!(request.mfr_signature, None);
640        assert_eq!(request.psk_token, None);
641        assert_eq!(request.target, None);
642        assert_eq!(request.actr_type.name, "cli-client");
643        assert_eq!(request.realm.realm_id, 1001);
644    }
645
646    #[test]
647    fn cli_discovery_requires_realm_secret() {
648        let missing = NetworkServiceDiscovery::new(sample_context(None));
649        let err = missing.required_realm_secret().unwrap_err();
650        assert!(err.to_string().contains("network.realm_secret is required"));
651
652        let blank = NetworkServiceDiscovery::new(sample_context(Some("   ")));
653        let err = blank.required_realm_secret().unwrap_err();
654        assert!(err.to_string().contains("network.realm_secret is required"));
655    }
656}