1use dactor::{ClusterDiscovery, DiscoveryError};
8use std::fmt;
9
10#[derive(Debug)]
16pub enum AzureDiscoveryError {
17 ImdsError(String),
19 ArmApiError(String),
21 HttpError(reqwest::Error),
23 ParseError(String),
25 Config(String),
27}
28
29impl fmt::Display for AzureDiscoveryError {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 AzureDiscoveryError::ImdsError(e) => write!(f, "IMDS error: {e}"),
33 AzureDiscoveryError::ArmApiError(e) => write!(f, "ARM API error: {e}"),
34 AzureDiscoveryError::HttpError(e) => write!(f, "HTTP error: {e}"),
35 AzureDiscoveryError::ParseError(e) => write!(f, "parse error: {e}"),
36 AzureDiscoveryError::Config(e) => write!(f, "configuration error: {e}"),
37 }
38 }
39}
40
41impl std::error::Error for AzureDiscoveryError {
42 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
43 match self {
44 AzureDiscoveryError::HttpError(e) => Some(e),
45 _ => None,
46 }
47 }
48}
49
50impl From<reqwest::Error> for AzureDiscoveryError {
51 fn from(e: reqwest::Error) -> Self {
52 AzureDiscoveryError::HttpError(e)
53 }
54}
55
56#[derive(Debug, Clone, serde::Deserialize)]
62#[serde(rename_all = "camelCase")]
63struct ImdsResponse {
64 compute: ImdsCompute,
65}
66
67#[derive(Debug, Clone, serde::Deserialize)]
69#[serde(rename_all = "camelCase")]
70struct ImdsCompute {
71 subscription_id: String,
72 resource_group_name: String,
73 #[serde(default, rename = "vmScaleSetName")]
74 vmss_name: Option<String>,
75}
76
77#[derive(Debug, Clone, serde::Deserialize)]
83struct ArmListResponse<T> {
84 value: Vec<T>,
85 #[serde(default, rename = "nextLink")]
86 next_link: Option<String>,
87}
88
89#[derive(Debug, Clone, serde::Deserialize)]
91struct ArmNetworkInterface {
92 properties: ArmNicProperties,
93}
94
95#[derive(Debug, Clone, serde::Deserialize)]
97#[serde(rename_all = "camelCase")]
98struct ArmNicProperties {
99 ip_configurations: Vec<ArmIpConfiguration>,
100}
101
102#[derive(Debug, Clone, serde::Deserialize)]
104struct ArmIpConfiguration {
105 properties: ArmIpConfigProperties,
106}
107
108#[derive(Debug, Clone, serde::Deserialize)]
110#[serde(rename_all = "camelCase")]
111struct ArmIpConfigProperties {
112 private_ip_address: Option<String>,
113}
114
115#[derive(Debug, Clone, serde::Deserialize)]
121struct TokenResponse {
122 access_token: String,
123}
124
125async fn acquire_managed_identity_token(
127 client: &reqwest::Client,
128) -> Result<String, AzureDiscoveryError> {
129 let resp = client
130 .get("http://169.254.169.254/metadata/identity/oauth2/token")
131 .header("Metadata", "true")
132 .query(&[
133 ("api-version", "2019-08-01"),
134 ("resource", "https://management.azure.com/"),
135 ])
136 .send()
137 .await?;
138
139 if !resp.status().is_success() {
140 let status = resp.status();
141 let body = resp.text().await.unwrap_or_default();
142 return Err(AzureDiscoveryError::ImdsError(format!(
143 "token request failed ({status}): {body}"
144 )));
145 }
146
147 let token: TokenResponse = resp
148 .json()
149 .await
150 .map_err(|e| AzureDiscoveryError::ParseError(format!("token response: {e}")))?;
151
152 Ok(token.access_token)
153}
154
155const IMDS_BASE: &str = "http://169.254.169.254";
160const IMDS_API_VERSION: &str = "2021-02-01";
161const ARM_API_VERSION_NIC: &str = "2023-09-01";
162const ARM_API_VERSION_VM: &str = "2023-09-01";
163
164async fn query_imds(client: &reqwest::Client) -> Result<ImdsResponse, AzureDiscoveryError> {
166 let resp = client
167 .get(format!("{IMDS_BASE}/metadata/instance"))
168 .header("Metadata", "true")
169 .query(&[("api-version", IMDS_API_VERSION)])
170 .send()
171 .await?;
172
173 if !resp.status().is_success() {
174 let status = resp.status();
175 let body = resp.text().await.unwrap_or_default();
176 return Err(AzureDiscoveryError::ImdsError(format!(
177 "IMDS returned {status}: {body}"
178 )));
179 }
180
181 resp.json()
182 .await
183 .map_err(|e| AzureDiscoveryError::ParseError(format!("IMDS response: {e}")))
184}
185
186pub async fn current_subscription_id() -> Option<String> {
188 let client = reqwest::Client::builder().timeout(std::time::Duration::from_secs(10)).build().unwrap_or_default();
189 query_imds(&client)
190 .await
191 .ok()
192 .map(|r| r.compute.subscription_id)
193}
194
195pub async fn current_resource_group() -> Option<String> {
197 let client = reqwest::Client::builder().timeout(std::time::Duration::from_secs(10)).build().unwrap_or_default();
198 query_imds(&client)
199 .await
200 .ok()
201 .map(|r| r.compute.resource_group_name)
202}
203
204pub fn imds_instance_url() -> String {
206 format!(
207 "{IMDS_BASE}/metadata/instance?api-version={IMDS_API_VERSION}"
208 )
209}
210
211#[derive(Debug, Clone)]
217pub struct VmssDiscoveryConfig {
218 pub port: u16,
220 pub use_imds: bool,
223 pub subscription_id: Option<String>,
225 pub resource_group: Option<String>,
227 pub vmss_name: Option<String>,
229}
230
231impl Default for VmssDiscoveryConfig {
232 fn default() -> Self {
233 Self {
234 port: 9000,
235 use_imds: true,
236 subscription_id: None,
237 resource_group: None,
238 vmss_name: None,
239 }
240 }
241}
242
243pub struct VmssDiscovery {
255 config: VmssDiscoveryConfig,
256 client: reqwest::Client,
257}
258
259impl VmssDiscovery {
260 pub fn builder() -> VmssDiscoveryBuilder {
262 VmssDiscoveryBuilder {
263 config: VmssDiscoveryConfig::default(),
264 }
265 }
266
267 pub fn config(&self) -> &VmssDiscoveryConfig {
269 &self.config
270 }
271
272 async fn resolve_vmss_info(
275 &self,
276 ) -> Result<(String, String, String), AzureDiscoveryError> {
277 if let (Some(sub), Some(rg), Some(vmss)) = (
278 self.config.subscription_id.clone(),
279 self.config.resource_group.clone(),
280 self.config.vmss_name.clone(),
281 ) {
282 return Ok((sub, rg, vmss));
283 }
284
285 if !self.config.use_imds {
286 return Err(AzureDiscoveryError::Config(
287 "use_imds is false but subscription_id, resource_group, or vmss_name is missing"
288 .to_string(),
289 ));
290 }
291
292 let imds = query_imds(&self.client).await?;
293 let sub = self
294 .config
295 .subscription_id
296 .clone()
297 .unwrap_or(imds.compute.subscription_id);
298 let rg = self
299 .config
300 .resource_group
301 .clone()
302 .unwrap_or(imds.compute.resource_group_name);
303 let vmss = self.config.vmss_name.clone().or(imds.compute.vmss_name).ok_or_else(
304 || {
305 AzureDiscoveryError::ImdsError(
306 "current VM is not part of a VMSS".to_string(),
307 )
308 },
309 )?;
310
311 Ok((sub, rg, vmss))
312 }
313
314 pub async fn discover_instances(&self) -> Result<Vec<String>, AzureDiscoveryError> {
316 let (subscription_id, resource_group, vmss_name) =
317 self.resolve_vmss_info().await?;
318
319 let token = acquire_managed_identity_token(&self.client).await?;
320
321 let url = format!(
322 "https://management.azure.com/subscriptions/{subscription_id}\
323 /resourceGroups/{resource_group}\
324 /providers/Microsoft.Compute/virtualMachineScaleSets/{vmss_name}\
325 /networkInterfaces?api-version={ARM_API_VERSION_NIC}"
326 );
327
328 let mut addresses = Vec::new();
329 let mut next_url: Option<String> = Some(url);
330
331 while let Some(page_url) = next_url.take() {
332 let resp = self
333 .client
334 .get(&page_url)
335 .bearer_auth(&token)
336 .send()
337 .await?;
338
339 if !resp.status().is_success() {
340 let status = resp.status();
341 let body = resp.text().await.unwrap_or_default();
342 return Err(AzureDiscoveryError::ArmApiError(format!(
343 "list NICs failed ({status}): {body}"
344 )));
345 }
346
347 let page: ArmListResponse<ArmNetworkInterface> = resp
348 .json()
349 .await
350 .map_err(|e| AzureDiscoveryError::ParseError(format!("NIC list: {e}")))?;
351
352 for nic in &page.value {
353 for ip_config in &nic.properties.ip_configurations {
354 if let Some(ip) = &ip_config.properties.private_ip_address {
355 addresses.push(format!("{ip}:{}", self.config.port));
356 }
357 }
358 }
359
360 next_url = page.next_link;
361 }
362
363 tracing::debug!(count = addresses.len(), "VMSS discovery complete");
364 Ok(addresses)
365 }
366}
367
368#[async_trait::async_trait]
369impl ClusterDiscovery for VmssDiscovery {
370 async fn discover(&self) -> Result<Vec<dactor::DiscoveredPeer>, DiscoveryError> {
371 self.discover_instances()
372 .await
373 .map(|addrs| addrs.into_iter().map(dactor::DiscoveredPeer::from_address).collect())
374 .map_err(|e| DiscoveryError::new(e.to_string()))
375 }
376}
377
378pub struct VmssDiscoveryBuilder {
384 config: VmssDiscoveryConfig,
385}
386
387impl VmssDiscoveryBuilder {
388 pub fn port(mut self, port: u16) -> Self {
390 self.config.port = port;
391 self
392 }
393
394 pub fn use_imds(mut self, yes: bool) -> Self {
396 self.config.use_imds = yes;
397 self
398 }
399
400 pub fn subscription_id(mut self, id: &str) -> Self {
402 self.config.subscription_id = Some(id.to_string());
403 self
404 }
405
406 pub fn resource_group(mut self, rg: &str) -> Self {
408 self.config.resource_group = Some(rg.to_string());
409 self
410 }
411
412 pub fn vmss_name(mut self, name: &str) -> Self {
414 self.config.vmss_name = Some(name.to_string());
415 self
416 }
417
418 pub fn build(self) -> VmssDiscovery {
420 VmssDiscovery {
421 config: self.config,
422 client: reqwest::Client::builder().timeout(std::time::Duration::from_secs(10)).build().unwrap_or_default(),
423 }
424 }
425}
426
427#[derive(Debug, Clone)]
433pub struct AzureTagConfig {
434 pub tag_key: String,
436 pub tag_value: String,
438 pub port: u16,
440 pub subscription_id: Option<String>,
442 pub resource_group: Option<String>,
444}
445
446impl Default for AzureTagConfig {
447 fn default() -> Self {
448 Self {
449 tag_key: String::new(),
450 tag_value: String::new(),
451 port: 9000,
452 subscription_id: None,
453 resource_group: None,
454 }
455 }
456}
457
458pub struct AzureTagDiscovery {
470 config: AzureTagConfig,
471 client: reqwest::Client,
472}
473
474impl AzureTagDiscovery {
475 pub fn builder() -> AzureTagDiscoveryBuilder {
477 AzureTagDiscoveryBuilder {
478 config: AzureTagConfig::default(),
479 }
480 }
481
482 pub fn config(&self) -> &AzureTagConfig {
484 &self.config
485 }
486
487 async fn resolve_subscription(&self) -> Result<String, AzureDiscoveryError> {
489 if let Some(sub) = &self.config.subscription_id {
490 return Ok(sub.clone());
491 }
492
493 let imds = query_imds(&self.client).await?;
494 Ok(imds.compute.subscription_id)
495 }
496
497 pub async fn discover_by_tag(&self) -> Result<Vec<String>, AzureDiscoveryError> {
499 if self.config.tag_key.is_empty() {
500 return Err(AzureDiscoveryError::Config(
501 "tag_key must not be empty".to_string(),
502 ));
503 }
504
505 let subscription_id = self.resolve_subscription().await?;
506 let token = acquire_managed_identity_token(&self.client).await?;
507
508 let base_url = if let Some(rg) = &self.config.resource_group {
510 format!(
511 "https://management.azure.com/subscriptions/{subscription_id}\
512 /resourceGroups/{rg}\
513 /providers/Microsoft.Compute/virtualMachines\
514 ?api-version={ARM_API_VERSION_VM}"
515 )
516 } else {
517 format!(
518 "https://management.azure.com/subscriptions/{subscription_id}\
519 /providers/Microsoft.Compute/virtualMachines\
520 ?api-version={ARM_API_VERSION_VM}"
521 )
522 };
523
524 let mut addresses = Vec::new();
525 let mut next_url: Option<String> = Some(base_url);
526
527 while let Some(page_url) = next_url.take() {
528 let resp = self
529 .client
530 .get(&page_url)
531 .bearer_auth(&token)
532 .send()
533 .await?;
534
535 if !resp.status().is_success() {
536 let status = resp.status();
537 let body = resp.text().await.unwrap_or_default();
538 return Err(AzureDiscoveryError::ArmApiError(format!(
539 "list VMs failed ({status}): {body}"
540 )));
541 }
542
543 let page: ArmListResponse<serde_json::Value> = resp
544 .json()
545 .await
546 .map_err(|e| AzureDiscoveryError::ParseError(format!("VM list: {e}")))?;
547
548 for vm in &page.value {
549 let tags = vm.get("tags").and_then(|t| t.as_object());
551 let matches = tags
552 .and_then(|t| t.get(&self.config.tag_key))
553 .and_then(|v| v.as_str())
554 .map(|v| v == self.config.tag_value)
555 .unwrap_or(false);
556
557 if !matches {
558 continue;
559 }
560
561 if let Some(nic_id) = vm
563 .pointer("/properties/networkProfile/networkInterfaces/0/id")
564 .and_then(|v| v.as_str())
565 {
566 if let Ok(ip) = self.fetch_nic_private_ip(nic_id, &token).await {
567 addresses.push(format!("{ip}:{}", self.config.port));
568 }
569 }
570 }
571
572 next_url = page.next_link;
573 }
574
575 tracing::debug!(count = addresses.len(), "Azure tag discovery complete");
576 Ok(addresses)
577 }
578
579 async fn fetch_nic_private_ip(
581 &self,
582 nic_id: &str,
583 token: &str,
584 ) -> Result<String, AzureDiscoveryError> {
585 let url = format!(
586 "https://management.azure.com{nic_id}?api-version={ARM_API_VERSION_NIC}"
587 );
588
589 let resp = self
590 .client
591 .get(&url)
592 .bearer_auth(token)
593 .send()
594 .await?;
595
596 if !resp.status().is_success() {
597 let status = resp.status();
598 let body = resp.text().await.unwrap_or_default();
599 return Err(AzureDiscoveryError::ArmApiError(format!(
600 "get NIC failed ({status}): {body}"
601 )));
602 }
603
604 let nic: ArmNetworkInterface = resp
605 .json()
606 .await
607 .map_err(|e| AzureDiscoveryError::ParseError(format!("NIC details: {e}")))?;
608
609 nic.properties
610 .ip_configurations
611 .first()
612 .and_then(|c| c.properties.private_ip_address.clone())
613 .ok_or_else(|| {
614 AzureDiscoveryError::ArmApiError(
615 "NIC has no private IP configuration".to_string(),
616 )
617 })
618 }
619}
620
621#[async_trait::async_trait]
622impl ClusterDiscovery for AzureTagDiscovery {
623 async fn discover(&self) -> Result<Vec<dactor::DiscoveredPeer>, DiscoveryError> {
624 self.discover_by_tag()
625 .await
626 .map(|addrs| addrs.into_iter().map(dactor::DiscoveredPeer::from_address).collect())
627 .map_err(|e| DiscoveryError::new(e.to_string()))
628 }
629}
630
631pub struct AzureTagDiscoveryBuilder {
637 config: AzureTagConfig,
638}
639
640impl AzureTagDiscoveryBuilder {
641 pub fn tag_key(mut self, key: &str) -> Self {
643 self.config.tag_key = key.to_string();
644 self
645 }
646
647 pub fn tag_value(mut self, value: &str) -> Self {
649 self.config.tag_value = value.to_string();
650 self
651 }
652
653 pub fn port(mut self, port: u16) -> Self {
655 self.config.port = port;
656 self
657 }
658
659 pub fn subscription_id(mut self, id: &str) -> Self {
661 self.config.subscription_id = Some(id.to_string());
662 self
663 }
664
665 pub fn resource_group(mut self, rg: &str) -> Self {
667 self.config.resource_group = Some(rg.to_string());
668 self
669 }
670
671 pub fn build(self) -> AzureTagDiscovery {
673 AzureTagDiscovery {
674 config: self.config,
675 client: reqwest::Client::builder().timeout(std::time::Duration::from_secs(10)).build().unwrap_or_default(),
676 }
677 }
678}
679
680pub fn vm_private_ip() -> Option<String> {
686 std::env::var("DACTOR_VM_IP").ok()
687}
688
689pub fn subscription_id() -> Option<String> {
691 std::env::var("AZURE_SUBSCRIPTION_ID").ok()
692}
693
694pub fn resource_group() -> Option<String> {
696 std::env::var("AZURE_RESOURCE_GROUP").ok()
697}
698
699#[cfg(test)]
704mod tests {
705 use super::*;
706
707 #[test]
710 fn vmss_builder_creates_valid_config() {
711 let discovery = VmssDiscovery::builder()
712 .port(8080)
713 .use_imds(false)
714 .subscription_id("sub-123")
715 .resource_group("my-rg")
716 .vmss_name("my-vmss")
717 .build();
718
719 assert_eq!(discovery.config().port, 8080);
720 assert!(!discovery.config().use_imds);
721 assert_eq!(
722 discovery.config().subscription_id.as_deref(),
723 Some("sub-123")
724 );
725 assert_eq!(
726 discovery.config().resource_group.as_deref(),
727 Some("my-rg")
728 );
729 assert_eq!(
730 discovery.config().vmss_name.as_deref(),
731 Some("my-vmss")
732 );
733 }
734
735 #[test]
736 fn vmss_builder_default_values() {
737 let discovery = VmssDiscovery::builder().build();
738
739 assert_eq!(discovery.config().port, 9000);
740 assert!(discovery.config().use_imds);
741 assert!(discovery.config().subscription_id.is_none());
742 assert!(discovery.config().resource_group.is_none());
743 assert!(discovery.config().vmss_name.is_none());
744 }
745
746 #[test]
747 fn vmss_default_config() {
748 let cfg = VmssDiscoveryConfig::default();
749 assert_eq!(cfg.port, 9000);
750 assert!(cfg.use_imds);
751 assert!(cfg.subscription_id.is_none());
752 assert!(cfg.resource_group.is_none());
753 assert!(cfg.vmss_name.is_none());
754 }
755
756 #[test]
759 fn tag_builder_creates_valid_config() {
760 let discovery = AzureTagDiscovery::builder()
761 .tag_key("dactor-cluster")
762 .tag_value("production")
763 .port(7000)
764 .subscription_id("sub-456")
765 .resource_group("prod-rg")
766 .build();
767
768 assert_eq!(discovery.config().tag_key, "dactor-cluster");
769 assert_eq!(discovery.config().tag_value, "production");
770 assert_eq!(discovery.config().port, 7000);
771 assert_eq!(
772 discovery.config().subscription_id.as_deref(),
773 Some("sub-456")
774 );
775 assert_eq!(
776 discovery.config().resource_group.as_deref(),
777 Some("prod-rg")
778 );
779 }
780
781 #[test]
782 fn tag_builder_default_values() {
783 let discovery = AzureTagDiscovery::builder()
784 .tag_key("cluster")
785 .tag_value("dev")
786 .build();
787
788 assert_eq!(discovery.config().port, 9000);
789 assert!(discovery.config().subscription_id.is_none());
790 assert!(discovery.config().resource_group.is_none());
791 }
792
793 #[test]
794 fn tag_default_config() {
795 let cfg = AzureTagConfig::default();
796 assert!(cfg.tag_key.is_empty());
797 assert!(cfg.tag_value.is_empty());
798 assert_eq!(cfg.port, 9000);
799 assert!(cfg.subscription_id.is_none());
800 assert!(cfg.resource_group.is_none());
801 }
802
803 #[test]
806 fn vm_private_ip_returns_none_outside_azure() {
807 std::env::remove_var("DACTOR_VM_IP");
808 assert!(vm_private_ip().is_none());
809 }
810
811 #[test]
812 fn subscription_id_returns_none_outside_azure() {
813 std::env::remove_var("AZURE_SUBSCRIPTION_ID");
814 assert!(subscription_id().is_none());
815 }
816
817 #[test]
818 fn resource_group_returns_none_outside_azure() {
819 std::env::remove_var("AZURE_RESOURCE_GROUP");
820 assert!(resource_group().is_none());
821 }
822
823 #[test]
826 fn error_display_imds() {
827 let err = AzureDiscoveryError::ImdsError("timeout".to_string());
828 assert_eq!(err.to_string(), "IMDS error: timeout");
829 }
830
831 #[test]
832 fn error_display_arm_api() {
833 let err = AzureDiscoveryError::ArmApiError("403 forbidden".to_string());
834 assert_eq!(err.to_string(), "ARM API error: 403 forbidden");
835 }
836
837 #[test]
838 fn error_display_parse() {
839 let err = AzureDiscoveryError::ParseError("invalid json".to_string());
840 assert_eq!(err.to_string(), "parse error: invalid json");
841 }
842
843 #[test]
844 fn error_display_config() {
845 let err = AzureDiscoveryError::Config("missing subscription".to_string());
846 assert_eq!(err.to_string(), "configuration error: missing subscription");
847 }
848
849 #[test]
852 fn imds_url_contains_api_version() {
853 let url = imds_instance_url();
854 assert!(url.starts_with("http://169.254.169.254/metadata/instance"));
855 assert!(url.contains("api-version=2021-02-01"));
856 }
857}