1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use serde::Deserialize;
5
6use super::{Provider, ProviderError, ProviderHost, map_ureq_error};
7
8pub struct Azure {
9 pub subscriptions: Vec<String>,
10}
11
12#[derive(Deserialize)]
15#[cfg_attr(not(test), allow(dead_code))]
16struct VmListResponse {
17 #[serde(default)]
18 value: Vec<VirtualMachine>,
19 #[serde(rename = "nextLink")]
20 next_link: Option<String>,
21}
22
23#[derive(Deserialize)]
24struct VirtualMachine {
25 name: String,
26 #[serde(default)]
27 location: String,
28 #[serde(default)]
29 tags: Option<HashMap<String, String>>,
30 #[serde(default)]
31 properties: VmProperties,
32}
33
34#[derive(Deserialize, Default)]
35struct VmProperties {
36 #[serde(rename = "vmId", default)]
37 vm_id: String,
38 #[serde(rename = "hardwareProfile")]
39 hardware_profile: Option<HardwareProfile>,
40 #[serde(rename = "storageProfile")]
41 storage_profile: Option<StorageProfile>,
42 #[serde(rename = "networkProfile")]
43 network_profile: Option<NetworkProfile>,
44 #[serde(rename = "instanceView")]
45 instance_view: Option<InstanceView>,
46}
47
48#[derive(Deserialize)]
49struct HardwareProfile {
50 #[serde(rename = "vmSize")]
51 vm_size: String,
52}
53
54#[derive(Deserialize)]
55struct StorageProfile {
56 #[serde(rename = "imageReference")]
57 image_reference: Option<ImageReference>,
58}
59
60#[derive(Deserialize)]
61struct ImageReference {
62 offer: Option<String>,
63 sku: Option<String>,
64 #[allow(dead_code)]
65 id: Option<String>,
66}
67
68#[derive(Deserialize)]
69struct NetworkProfile {
70 #[serde(rename = "networkInterfaces", default)]
71 network_interfaces: Vec<NetworkInterfaceRef>,
72}
73
74#[derive(Deserialize)]
75struct NetworkInterfaceRef {
76 id: String,
77 properties: Option<NicRefProperties>,
78}
79
80#[derive(Deserialize)]
81struct NicRefProperties {
82 primary: Option<bool>,
83}
84
85#[derive(Deserialize)]
86struct InstanceView {
87 #[serde(default)]
88 statuses: Vec<InstanceViewStatus>,
89}
90
91#[derive(Deserialize)]
92struct InstanceViewStatus {
93 code: String,
94}
95
96#[derive(Deserialize)]
99#[cfg_attr(not(test), allow(dead_code))]
100struct NicListResponse {
101 #[serde(default)]
102 value: Vec<Nic>,
103 #[serde(rename = "nextLink")]
104 #[allow(dead_code)]
105 next_link: Option<String>,
106}
107
108#[derive(Deserialize)]
109struct Nic {
110 id: String,
111 #[serde(default)]
112 properties: NicProperties,
113}
114
115#[derive(Deserialize, Default)]
116struct NicProperties {
117 #[serde(rename = "ipConfigurations", default)]
118 ip_configurations: Vec<IpConfiguration>,
119}
120
121#[derive(Deserialize)]
122struct IpConfiguration {
123 #[serde(default)]
124 properties: IpConfigProperties,
125}
126
127#[derive(Deserialize, Default)]
128struct IpConfigProperties {
129 #[serde(rename = "privateIPAddress")]
130 private_ip_address: Option<String>,
131 #[serde(rename = "publicIPAddress")]
132 public_ip_address: Option<PublicIpRef>,
133 primary: Option<bool>,
134}
135
136#[derive(Deserialize)]
137struct PublicIpRef {
138 id: String,
139}
140
141#[derive(Deserialize)]
144#[cfg_attr(not(test), allow(dead_code))]
145struct PublicIpListResponse {
146 #[serde(default)]
147 value: Vec<PublicIp>,
148 #[serde(rename = "nextLink")]
149 #[allow(dead_code)]
150 next_link: Option<String>,
151}
152
153#[derive(Deserialize)]
154struct PublicIp {
155 id: String,
156 #[serde(default)]
157 properties: PublicIpProperties,
158}
159
160#[derive(Deserialize, Default)]
161struct PublicIpProperties {
162 #[serde(rename = "ipAddress")]
163 ip_address: Option<String>,
164}
165
166#[derive(Deserialize)]
172struct ServicePrincipal {
173 #[serde(alias = "tenantId", alias = "tenant")]
174 tenant_id: String,
175 #[serde(alias = "clientId", alias = "appId")]
176 client_id: String,
177 #[serde(alias = "clientSecret", alias = "password")]
178 client_secret: String,
179}
180
181#[derive(Deserialize)]
182struct TokenResponse {
183 access_token: String,
184}
185
186pub fn is_valid_subscription_id(id: &str) -> bool {
188 let parts: Vec<&str> = id.split('-').collect();
189 if parts.len() != 5 {
190 return false;
191 }
192 let expected_lens = [8, 4, 4, 4, 12];
193 parts
194 .iter()
195 .zip(expected_lens.iter())
196 .all(|(part, &len)| part.len() == len && part.chars().all(|c| c.is_ascii_hexdigit()))
197}
198
199fn is_sp_file(token: &str) -> bool {
201 token.to_ascii_lowercase().ends_with(".json")
202}
203
204fn resolve_sp_token(path: &str) -> Result<String, ProviderError> {
206 let content = std::fs::read_to_string(path)
207 .map_err(|e| ProviderError::Http(format!("Failed to read SP file {}: {}", path, e)))?;
208 let sp: ServicePrincipal = serde_json::from_str(&content)
209 .map_err(|e| ProviderError::Http(format!(
210 "Failed to parse SP file: {}. Expected JSON with appId/password/tenant (az CLI) or clientId/clientSecret/tenantId.", e
211 )))?;
212
213 let agent = super::http_agent();
214 let url = format!(
215 "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
216 sp.tenant_id
217 );
218 let mut resp = agent
219 .post(&url)
220 .send_form([
221 ("grant_type", "client_credentials"),
222 ("client_id", sp.client_id.as_str()),
223 ("client_secret", sp.client_secret.as_str()),
224 ("scope", "https://management.azure.com/.default"),
225 ])
226 .map_err(map_ureq_error)?;
227
228 let token_resp: TokenResponse = resp
229 .body_mut()
230 .read_json()
231 .map_err(|e| ProviderError::Parse(format!("Token response: {}", e)))?;
232
233 Ok(token_resp.access_token)
234}
235
236fn resolve_token(token: &str) -> Result<String, ProviderError> {
239 if is_sp_file(token) {
240 resolve_sp_token(token)
241 } else {
242 let t = token.strip_prefix("Bearer ").unwrap_or(token);
243 if t.is_empty() {
244 return Err(ProviderError::AuthFailed);
245 }
246 Ok(t.to_string())
247 }
248}
249
250fn select_ip(
253 vm: &VirtualMachine,
254 nic_map: &HashMap<String, &Nic>,
255 public_ip_map: &HashMap<String, String>,
256) -> Option<String> {
257 let net_profile = vm.properties.network_profile.as_ref()?;
258 if net_profile.network_interfaces.is_empty() {
259 return None;
260 }
261
262 let nic_ref = net_profile
264 .network_interfaces
265 .iter()
266 .find(|n| {
267 n.properties
268 .as_ref()
269 .and_then(|p| p.primary)
270 .unwrap_or(false)
271 })
272 .or_else(|| net_profile.network_interfaces.first())?;
273
274 let nic_id_lower = nic_ref.id.to_ascii_lowercase();
275 let nic = nic_map.get(&nic_id_lower)?;
276
277 let ip_config = nic
279 .properties
280 .ip_configurations
281 .iter()
282 .find(|c| c.properties.primary.unwrap_or(false))
283 .or_else(|| nic.properties.ip_configurations.first())?;
284
285 if let Some(ref pub_ref) = ip_config.properties.public_ip_address {
287 let pub_id_lower = pub_ref.id.to_ascii_lowercase();
288 if let Some(addr) = public_ip_map.get(&pub_id_lower) {
289 if !addr.is_empty() {
290 return Some(addr.clone());
291 }
292 }
293 }
294
295 if let Some(ref private) = ip_config.properties.private_ip_address {
297 if !private.is_empty() {
298 return Some(private.clone());
299 }
300 }
301
302 None
303}
304
305fn extract_power_state(instance_view: &Option<InstanceView>) -> Option<String> {
307 let iv = instance_view.as_ref()?;
308 for status in &iv.statuses {
309 if let Some(suffix) = status.code.strip_prefix("PowerState/") {
310 return Some(suffix.to_string());
311 }
312 }
313 None
314}
315
316fn build_os_string(image_ref: &Option<ImageReference>) -> Option<String> {
318 let img = image_ref.as_ref()?;
319 let offer = img.offer.as_deref()?;
320 let sku = img.sku.as_deref()?;
321 if offer.is_empty() || sku.is_empty() {
322 return None;
323 }
324 Some(format!("{}-{}", offer, sku))
325}
326
327fn build_metadata(vm: &VirtualMachine) -> Vec<(String, String)> {
329 let mut metadata = super::ProviderMetadata::new();
330 if !vm.location.is_empty() {
331 metadata.push("region", vm.location.to_ascii_lowercase());
332 }
333 if let Some(ref hw) = vm.properties.hardware_profile {
334 if !hw.vm_size.is_empty() {
335 metadata.push("vm_size", hw.vm_size.clone());
336 }
337 }
338 if let Some(ref sp) = vm.properties.storage_profile {
339 if let Some(os) = build_os_string(&sp.image_reference) {
340 metadata.push("image", os);
341 }
342 }
343 if let Some(state) = extract_power_state(&vm.properties.instance_view) {
344 metadata.push("status", state);
345 }
346 metadata.finish()
347}
348
349fn build_tags(vm: &VirtualMachine) -> Vec<String> {
351 let mut tags = Vec::new();
352 if let Some(ref vm_tags) = vm.tags {
353 for (k, v) in vm_tags {
354 if v.is_empty() {
355 tags.push(k.clone());
356 } else {
357 tags.push(format!("{}:{}", k, v));
358 }
359 }
360 }
361 tags
362}
363
364fn fetch_paginated<T: serde::de::DeserializeOwned>(
366 agent: &ureq::Agent,
367 initial_url: &str,
368 access_token: &str,
369 api_base: &str,
370 cancel: &AtomicBool,
371 resource_name: &str,
372 progress: &dyn Fn(&str),
373) -> Result<Vec<T>, ProviderError> {
374 let mut all_items = Vec::new();
377 let mut next_url: Option<String> = Some(initial_url.to_string());
378
379 for page in 0u32.. {
380 if cancel.load(Ordering::Relaxed) {
381 return Err(ProviderError::Cancelled);
382 }
383 if page > 500 {
384 break;
385 }
386
387 let url = match next_url.take() {
388 Some(u) => u,
389 None => break,
390 };
391
392 progress(&format!(
393 "Fetching {} ({} so far)...",
394 resource_name,
395 all_items.len()
396 ));
397
398 let mut response = match agent
399 .get(&url)
400 .header("Authorization", &super::bearer_auth(access_token))
401 .call()
402 {
403 Ok(r) => r,
404 Err(e) => {
405 let err = map_ureq_error(e);
406 if matches!(err, ProviderError::AuthFailed | ProviderError::RateLimited) {
408 return Err(err);
409 }
410 if !all_items.is_empty() {
412 break;
413 }
414 return Err(err);
415 }
416 };
417
418 let body: serde_json::Value = match response.body_mut().read_json() {
419 Ok(v) => v,
420 Err(e) => {
421 if !all_items.is_empty() {
422 break;
423 }
424 return Err(ProviderError::Parse(format!(
425 "{} response: {}",
426 resource_name, e
427 )));
428 }
429 };
430
431 if let Some(value_array) = body.get("value").and_then(|v| v.as_array()) {
432 for item in value_array {
433 match serde_json::from_value(item.clone()) {
434 Ok(parsed) => all_items.push(parsed),
435 Err(_) => continue, }
437 }
438 }
439
440 next_url = body
447 .get("nextLink")
448 .and_then(|v| v.as_str())
449 .filter(|s| !s.is_empty())
450 .filter(|s| {
451 s.strip_prefix(api_base)
452 .is_some_and(|rest| rest.starts_with('/'))
453 })
454 .map(|s| s.to_string());
455 }
456
457 Ok(all_items)
458}
459
460impl Azure {
461 const API_BASE: &'static str = "https://management.azure.com";
464
465 fn fetch_with_endpoint(
470 &self,
471 api_base: &str,
472 token: &str,
473 cancel: &AtomicBool,
474 _env: &crate::runtime::env::Env,
475 progress: &dyn Fn(&str),
476 ) -> Result<Vec<ProviderHost>, ProviderError> {
477 if self.subscriptions.is_empty() {
478 return Err(ProviderError::Http(
479 "No Azure subscriptions configured. Set at least one subscription ID.".to_string(),
480 ));
481 }
482
483 for sub in &self.subscriptions {
485 if !is_valid_subscription_id(sub) {
486 return Err(ProviderError::Http(format!(
487 "Invalid subscription ID '{}'. Expected UUID format (e.g. 12345678-1234-1234-1234-123456789012).",
488 sub
489 )));
490 }
491 }
492
493 progress("Authenticating...");
494 let access_token = resolve_token(token)?;
495
496 if cancel.load(Ordering::Relaxed) {
497 return Err(ProviderError::Cancelled);
498 }
499
500 let agent = super::http_agent();
501 let mut all_hosts = Vec::new();
502 let mut failures = 0usize;
503 let total = self.subscriptions.len();
504
505 for (i, sub) in self.subscriptions.iter().enumerate() {
506 if cancel.load(Ordering::Relaxed) {
507 return Err(ProviderError::Cancelled);
508 }
509
510 progress(&format!("Subscription {}/{} ({})...", i + 1, total, sub));
511
512 match self.fetch_subscription(&agent, &access_token, sub, api_base, cancel, progress) {
513 Ok(hosts) => all_hosts.extend(hosts),
514 Err(ProviderError::Cancelled) => return Err(ProviderError::Cancelled),
515 Err(ProviderError::AuthFailed) => return Err(ProviderError::AuthFailed),
516 Err(ProviderError::RateLimited) => return Err(ProviderError::RateLimited),
517 Err(_) => {
518 failures += 1;
519 }
520 }
521 }
522
523 if failures > 0 && !all_hosts.is_empty() {
524 return Err(ProviderError::PartialResult {
525 hosts: all_hosts,
526 failures,
527 total,
528 });
529 }
530 if failures > 0 && all_hosts.is_empty() {
531 return Err(ProviderError::Http(format!(
532 "All {} subscription(s) failed.",
533 total
534 )));
535 }
536
537 progress(&format!("{} VMs", all_hosts.len()));
538 Ok(all_hosts)
539 }
540
541 fn fetch_subscription(
542 &self,
543 agent: &ureq::Agent,
544 access_token: &str,
545 subscription_id: &str,
546 api_base: &str,
547 cancel: &AtomicBool,
548 progress: &dyn Fn(&str),
549 ) -> Result<Vec<ProviderHost>, ProviderError> {
550 let vm_url = format!(
552 "{}/subscriptions/{}/providers/Microsoft.Compute/virtualMachines?api-version=2024-07-01&$expand=instanceView",
553 api_base, subscription_id
554 );
555 let vms: Vec<VirtualMachine> = fetch_paginated(
556 agent,
557 &vm_url,
558 access_token,
559 api_base,
560 cancel,
561 "VMs",
562 progress,
563 )?;
564
565 if cancel.load(Ordering::Relaxed) {
566 return Err(ProviderError::Cancelled);
567 }
568
569 let nic_url = format!(
571 "{}/subscriptions/{}/providers/Microsoft.Network/networkInterfaces?api-version=2024-05-01",
572 api_base, subscription_id
573 );
574 let nics: Vec<Nic> = fetch_paginated(
575 agent,
576 &nic_url,
577 access_token,
578 api_base,
579 cancel,
580 "NICs",
581 progress,
582 )?;
583
584 if cancel.load(Ordering::Relaxed) {
585 return Err(ProviderError::Cancelled);
586 }
587
588 let pip_url = format!(
590 "{}/subscriptions/{}/providers/Microsoft.Network/publicIPAddresses?api-version=2024-05-01",
591 api_base, subscription_id
592 );
593 let public_ips: Vec<PublicIp> = fetch_paginated(
594 agent,
595 &pip_url,
596 access_token,
597 api_base,
598 cancel,
599 "public IPs",
600 progress,
601 )?;
602
603 let nic_map: HashMap<String, &Nic> = nics
605 .iter()
606 .map(|n| (n.id.to_ascii_lowercase(), n))
607 .collect();
608
609 let public_ip_map: HashMap<String, String> = public_ips
610 .iter()
611 .filter_map(|p| {
612 p.properties
613 .ip_address
614 .as_ref()
615 .map(|addr| (p.id.to_ascii_lowercase(), addr.clone()))
616 })
617 .collect();
618
619 let mut hosts = Vec::new();
621 for vm in &vms {
622 if vm.properties.vm_id.is_empty() {
624 continue;
625 }
626 if let Some(ip) = select_ip(vm, &nic_map, &public_ip_map) {
627 hosts.push(ProviderHost {
628 server_id: vm.properties.vm_id.clone(),
629 name: vm.name.clone(),
630 ip,
631 tags: build_tags(vm),
632 metadata: build_metadata(vm),
633 });
634 }
635 }
636
637 Ok(hosts)
638 }
639}
640
641impl Provider for Azure {
642 fn name(&self) -> &str {
643 "azure"
644 }
645
646 fn short_label(&self) -> &str {
647 "az"
648 }
649
650 fn fetch_hosts_cancellable(
651 &self,
652 token: &str,
653 cancel: &AtomicBool,
654 env: &crate::runtime::env::Env,
655 ) -> Result<Vec<ProviderHost>, ProviderError> {
656 self.fetch_hosts_with_progress(token, cancel, env, &|_| {})
657 }
658
659 fn fetch_hosts_with_progress(
660 &self,
661 token: &str,
662 cancel: &AtomicBool,
663 env: &crate::runtime::env::Env,
664 progress: &dyn Fn(&str),
665 ) -> Result<Vec<ProviderHost>, ProviderError> {
666 self.fetch_with_endpoint(Self::API_BASE, token, cancel, env, progress)
667 }
668}
669
670#[cfg(test)]
671#[path = "azure_tests.rs"]
672mod tests;