1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use serde::Deserialize;
5
6use super::{Provider, ProviderError, ProviderHost, map_ureq_error};
7
8pub struct Azure {
9 pub subscriptions: Vec<String>,
10}
11
12#[derive(Deserialize)]
15#[cfg_attr(not(test), allow(dead_code))]
16struct VmListResponse {
17 #[serde(default)]
18 value: Vec<VirtualMachine>,
19 #[serde(rename = "nextLink")]
20 next_link: Option<String>,
21}
22
23#[derive(Deserialize)]
24struct VirtualMachine {
25 name: String,
26 #[serde(default)]
27 location: String,
28 #[serde(default)]
29 tags: Option<HashMap<String, String>>,
30 #[serde(default)]
31 properties: VmProperties,
32}
33
34#[derive(Deserialize, Default)]
35struct VmProperties {
36 #[serde(rename = "vmId", default)]
37 vm_id: String,
38 #[serde(rename = "hardwareProfile")]
39 hardware_profile: Option<HardwareProfile>,
40 #[serde(rename = "storageProfile")]
41 storage_profile: Option<StorageProfile>,
42 #[serde(rename = "networkProfile")]
43 network_profile: Option<NetworkProfile>,
44 #[serde(rename = "instanceView")]
45 instance_view: Option<InstanceView>,
46}
47
48#[derive(Deserialize)]
49struct HardwareProfile {
50 #[serde(rename = "vmSize")]
51 vm_size: String,
52}
53
54#[derive(Deserialize)]
55struct StorageProfile {
56 #[serde(rename = "imageReference")]
57 image_reference: Option<ImageReference>,
58}
59
60#[derive(Deserialize)]
61struct ImageReference {
62 offer: Option<String>,
63 sku: Option<String>,
64 #[allow(dead_code)]
65 id: Option<String>,
66}
67
68#[derive(Deserialize)]
69struct NetworkProfile {
70 #[serde(rename = "networkInterfaces", default)]
71 network_interfaces: Vec<NetworkInterfaceRef>,
72}
73
74#[derive(Deserialize)]
75struct NetworkInterfaceRef {
76 id: String,
77 properties: Option<NicRefProperties>,
78}
79
80#[derive(Deserialize)]
81struct NicRefProperties {
82 primary: Option<bool>,
83}
84
85#[derive(Deserialize)]
86struct InstanceView {
87 #[serde(default)]
88 statuses: Vec<InstanceViewStatus>,
89}
90
91#[derive(Deserialize)]
92struct InstanceViewStatus {
93 code: String,
94}
95
96#[derive(Deserialize)]
99#[cfg_attr(not(test), allow(dead_code))]
100struct NicListResponse {
101 #[serde(default)]
102 value: Vec<Nic>,
103 #[serde(rename = "nextLink")]
104 #[allow(dead_code)]
105 next_link: Option<String>,
106}
107
108#[derive(Deserialize)]
109struct Nic {
110 id: String,
111 #[serde(default)]
112 properties: NicProperties,
113}
114
115#[derive(Deserialize, Default)]
116struct NicProperties {
117 #[serde(rename = "ipConfigurations", default)]
118 ip_configurations: Vec<IpConfiguration>,
119}
120
121#[derive(Deserialize)]
122struct IpConfiguration {
123 #[serde(default)]
124 properties: IpConfigProperties,
125}
126
127#[derive(Deserialize, Default)]
128struct IpConfigProperties {
129 #[serde(rename = "privateIPAddress")]
130 private_ip_address: Option<String>,
131 #[serde(rename = "publicIPAddress")]
132 public_ip_address: Option<PublicIpRef>,
133 primary: Option<bool>,
134}
135
136#[derive(Deserialize)]
137struct PublicIpRef {
138 id: String,
139}
140
141#[derive(Deserialize)]
144#[cfg_attr(not(test), allow(dead_code))]
145struct PublicIpListResponse {
146 #[serde(default)]
147 value: Vec<PublicIp>,
148 #[serde(rename = "nextLink")]
149 #[allow(dead_code)]
150 next_link: Option<String>,
151}
152
153#[derive(Deserialize)]
154struct PublicIp {
155 id: String,
156 #[serde(default)]
157 properties: PublicIpProperties,
158}
159
160#[derive(Deserialize, Default)]
161struct PublicIpProperties {
162 #[serde(rename = "ipAddress")]
163 ip_address: Option<String>,
164}
165
166#[derive(Deserialize)]
172struct ServicePrincipal {
173 #[serde(alias = "tenantId", alias = "tenant")]
174 tenant_id: String,
175 #[serde(alias = "clientId", alias = "appId")]
176 client_id: String,
177 #[serde(alias = "clientSecret", alias = "password")]
178 client_secret: String,
179}
180
181#[derive(Deserialize)]
182struct TokenResponse {
183 access_token: String,
184}
185
186pub fn is_valid_subscription_id(id: &str) -> bool {
188 let parts: Vec<&str> = id.split('-').collect();
189 if parts.len() != 5 {
190 return false;
191 }
192 let expected_lens = [8, 4, 4, 4, 12];
193 parts
194 .iter()
195 .zip(expected_lens.iter())
196 .all(|(part, &len)| part.len() == len && part.chars().all(|c| c.is_ascii_hexdigit()))
197}
198
199fn is_sp_file(token: &str) -> bool {
201 token.to_ascii_lowercase().ends_with(".json")
202}
203
204fn resolve_sp_token(path: &str) -> Result<String, ProviderError> {
206 let content = std::fs::read_to_string(path)
207 .map_err(|e| ProviderError::Http(format!("Failed to read SP file {}: {}", path, e)))?;
208 let sp: ServicePrincipal = serde_json::from_str(&content)
209 .map_err(|e| ProviderError::Http(format!(
210 "Failed to parse SP file: {}. Expected JSON with appId/password/tenant (az CLI) or clientId/clientSecret/tenantId.", e
211 )))?;
212
213 let agent = super::http_agent();
214 let url = format!(
215 "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
216 sp.tenant_id
217 );
218 let mut resp = agent
219 .post(&url)
220 .send_form([
221 ("grant_type", "client_credentials"),
222 ("client_id", sp.client_id.as_str()),
223 ("client_secret", sp.client_secret.as_str()),
224 ("scope", "https://management.azure.com/.default"),
225 ])
226 .map_err(map_ureq_error)?;
227
228 let token_resp: TokenResponse = resp
229 .body_mut()
230 .read_json()
231 .map_err(|e| ProviderError::Parse(format!("Token response: {}", e)))?;
232
233 Ok(token_resp.access_token)
234}
235
236fn resolve_token(token: &str) -> Result<String, ProviderError> {
239 if is_sp_file(token) {
240 resolve_sp_token(token)
241 } else {
242 let t = token.strip_prefix("Bearer ").unwrap_or(token);
243 if t.is_empty() {
244 return Err(ProviderError::AuthFailed);
245 }
246 Ok(t.to_string())
247 }
248}
249
250fn select_ip(
253 vm: &VirtualMachine,
254 nic_map: &HashMap<String, &Nic>,
255 public_ip_map: &HashMap<String, String>,
256) -> Option<String> {
257 let net_profile = vm.properties.network_profile.as_ref()?;
258 if net_profile.network_interfaces.is_empty() {
259 return None;
260 }
261
262 let nic_ref = net_profile
264 .network_interfaces
265 .iter()
266 .find(|n| {
267 n.properties
268 .as_ref()
269 .and_then(|p| p.primary)
270 .unwrap_or(false)
271 })
272 .or_else(|| net_profile.network_interfaces.first())?;
273
274 let nic_id_lower = nic_ref.id.to_ascii_lowercase();
275 let nic = nic_map.get(&nic_id_lower)?;
276
277 let ip_config = nic
279 .properties
280 .ip_configurations
281 .iter()
282 .find(|c| c.properties.primary.unwrap_or(false))
283 .or_else(|| nic.properties.ip_configurations.first())?;
284
285 if let Some(ref pub_ref) = ip_config.properties.public_ip_address {
287 let pub_id_lower = pub_ref.id.to_ascii_lowercase();
288 if let Some(addr) = public_ip_map.get(&pub_id_lower) {
289 if !addr.is_empty() {
290 return Some(addr.clone());
291 }
292 }
293 }
294
295 if let Some(ref private) = ip_config.properties.private_ip_address {
297 if !private.is_empty() {
298 return Some(private.clone());
299 }
300 }
301
302 None
303}
304
305fn extract_power_state(instance_view: &Option<InstanceView>) -> Option<String> {
307 let iv = instance_view.as_ref()?;
308 for status in &iv.statuses {
309 if let Some(suffix) = status.code.strip_prefix("PowerState/") {
310 return Some(suffix.to_string());
311 }
312 }
313 None
314}
315
316fn build_os_string(image_ref: &Option<ImageReference>) -> Option<String> {
318 let img = image_ref.as_ref()?;
319 let offer = img.offer.as_deref()?;
320 let sku = img.sku.as_deref()?;
321 if offer.is_empty() || sku.is_empty() {
322 return None;
323 }
324 Some(format!("{}-{}", offer, sku))
325}
326
327fn build_metadata(vm: &VirtualMachine) -> Vec<(String, String)> {
329 let mut metadata = Vec::new();
330 if !vm.location.is_empty() {
331 metadata.push(("region".to_string(), vm.location.to_ascii_lowercase()));
332 }
333 if let Some(ref hw) = vm.properties.hardware_profile {
334 if !hw.vm_size.is_empty() {
335 metadata.push(("vm_size".to_string(), hw.vm_size.clone()));
336 }
337 }
338 if let Some(ref sp) = vm.properties.storage_profile {
339 if let Some(os) = build_os_string(&sp.image_reference) {
340 metadata.push(("image".to_string(), os));
341 }
342 }
343 if let Some(state) = extract_power_state(&vm.properties.instance_view) {
344 metadata.push(("status".to_string(), state));
345 }
346 metadata
347}
348
349fn build_tags(vm: &VirtualMachine) -> Vec<String> {
351 let mut tags = Vec::new();
352 if let Some(ref vm_tags) = vm.tags {
353 for (k, v) in vm_tags {
354 if v.is_empty() {
355 tags.push(k.clone());
356 } else {
357 tags.push(format!("{}:{}", k, v));
358 }
359 }
360 }
361 tags
362}
363
364fn fetch_paginated<T: serde::de::DeserializeOwned>(
366 agent: &ureq::Agent,
367 initial_url: &str,
368 access_token: &str,
369 cancel: &AtomicBool,
370 resource_name: &str,
371 progress: &dyn Fn(&str),
372) -> Result<Vec<T>, ProviderError> {
373 let mut all_items = Vec::new();
376 let mut next_url: Option<String> = Some(initial_url.to_string());
377
378 for page in 0u32.. {
379 if cancel.load(Ordering::Relaxed) {
380 return Err(ProviderError::Cancelled);
381 }
382 if page > 500 {
383 break;
384 }
385
386 let url = match next_url.take() {
387 Some(u) => u,
388 None => break,
389 };
390
391 progress(&format!(
392 "Fetching {} ({} so far)...",
393 resource_name,
394 all_items.len()
395 ));
396
397 let mut response = match agent
398 .get(&url)
399 .header("Authorization", &format!("Bearer {}", access_token))
400 .call()
401 {
402 Ok(r) => r,
403 Err(e) => {
404 let err = map_ureq_error(e);
405 if matches!(err, ProviderError::AuthFailed | ProviderError::RateLimited) {
407 return Err(err);
408 }
409 if !all_items.is_empty() {
411 break;
412 }
413 return Err(err);
414 }
415 };
416
417 let body: serde_json::Value = match response.body_mut().read_json() {
418 Ok(v) => v,
419 Err(e) => {
420 if !all_items.is_empty() {
421 break;
422 }
423 return Err(ProviderError::Parse(format!(
424 "{} response: {}",
425 resource_name, e
426 )));
427 }
428 };
429
430 if let Some(value_array) = body.get("value").and_then(|v| v.as_array()) {
431 for item in value_array {
432 match serde_json::from_value(item.clone()) {
433 Ok(parsed) => all_items.push(parsed),
434 Err(_) => continue, }
436 }
437 }
438
439 next_url = body
440 .get("nextLink")
441 .and_then(|v| v.as_str())
442 .filter(|s| !s.is_empty())
443 .filter(|s| s.starts_with("https://management.azure.com/"))
444 .map(|s| s.to_string());
445 }
446
447 Ok(all_items)
448}
449
450impl Provider for Azure {
451 fn name(&self) -> &str {
452 "azure"
453 }
454
455 fn short_label(&self) -> &str {
456 "az"
457 }
458
459 fn fetch_hosts_cancellable(
460 &self,
461 token: &str,
462 cancel: &AtomicBool,
463 ) -> Result<Vec<ProviderHost>, ProviderError> {
464 self.fetch_hosts_with_progress(token, cancel, &|_| {})
465 }
466
467 fn fetch_hosts_with_progress(
468 &self,
469 token: &str,
470 cancel: &AtomicBool,
471 progress: &dyn Fn(&str),
472 ) -> Result<Vec<ProviderHost>, ProviderError> {
473 if self.subscriptions.is_empty() {
474 return Err(ProviderError::Http(
475 "No Azure subscriptions configured. Set at least one subscription ID.".to_string(),
476 ));
477 }
478
479 for sub in &self.subscriptions {
481 if !is_valid_subscription_id(sub) {
482 return Err(ProviderError::Http(format!(
483 "Invalid subscription ID '{}'. Expected UUID format (e.g. 12345678-1234-1234-1234-123456789012).",
484 sub
485 )));
486 }
487 }
488
489 progress("Authenticating...");
490 let access_token = resolve_token(token)?;
491
492 if cancel.load(Ordering::Relaxed) {
493 return Err(ProviderError::Cancelled);
494 }
495
496 let agent = super::http_agent();
497 let mut all_hosts = Vec::new();
498 let mut failures = 0usize;
499 let total = self.subscriptions.len();
500
501 for (i, sub) in self.subscriptions.iter().enumerate() {
502 if cancel.load(Ordering::Relaxed) {
503 return Err(ProviderError::Cancelled);
504 }
505
506 progress(&format!("Subscription {}/{} ({})...", i + 1, total, sub));
507
508 match self.fetch_subscription(&agent, &access_token, sub, cancel, progress) {
509 Ok(hosts) => all_hosts.extend(hosts),
510 Err(ProviderError::Cancelled) => return Err(ProviderError::Cancelled),
511 Err(ProviderError::AuthFailed) => return Err(ProviderError::AuthFailed),
512 Err(ProviderError::RateLimited) => return Err(ProviderError::RateLimited),
513 Err(_) => {
514 failures += 1;
515 }
516 }
517 }
518
519 if failures > 0 && !all_hosts.is_empty() {
520 return Err(ProviderError::PartialResult {
521 hosts: all_hosts,
522 failures,
523 total,
524 });
525 }
526 if failures > 0 && all_hosts.is_empty() {
527 return Err(ProviderError::Http(format!(
528 "All {} subscription(s) failed.",
529 total
530 )));
531 }
532
533 progress(&format!("{} VMs", all_hosts.len()));
534 Ok(all_hosts)
535 }
536}
537
538impl Azure {
539 fn fetch_subscription(
540 &self,
541 agent: &ureq::Agent,
542 access_token: &str,
543 subscription_id: &str,
544 cancel: &AtomicBool,
545 progress: &dyn Fn(&str),
546 ) -> Result<Vec<ProviderHost>, ProviderError> {
547 let vm_url = format!(
549 "https://management.azure.com/subscriptions/{}/providers/Microsoft.Compute/virtualMachines?api-version=2024-07-01&$expand=instanceView",
550 subscription_id
551 );
552 let vms: Vec<VirtualMachine> =
553 fetch_paginated(agent, &vm_url, access_token, cancel, "VMs", progress)?;
554
555 if cancel.load(Ordering::Relaxed) {
556 return Err(ProviderError::Cancelled);
557 }
558
559 let nic_url = format!(
561 "https://management.azure.com/subscriptions/{}/providers/Microsoft.Network/networkInterfaces?api-version=2024-05-01",
562 subscription_id
563 );
564 let nics: Vec<Nic> =
565 fetch_paginated(agent, &nic_url, access_token, cancel, "NICs", progress)?;
566
567 if cancel.load(Ordering::Relaxed) {
568 return Err(ProviderError::Cancelled);
569 }
570
571 let pip_url = format!(
573 "https://management.azure.com/subscriptions/{}/providers/Microsoft.Network/publicIPAddresses?api-version=2024-05-01",
574 subscription_id
575 );
576 let public_ips: Vec<PublicIp> = fetch_paginated(
577 agent,
578 &pip_url,
579 access_token,
580 cancel,
581 "public IPs",
582 progress,
583 )?;
584
585 let nic_map: HashMap<String, &Nic> = nics
587 .iter()
588 .map(|n| (n.id.to_ascii_lowercase(), n))
589 .collect();
590
591 let public_ip_map: HashMap<String, String> = public_ips
592 .iter()
593 .filter_map(|p| {
594 p.properties
595 .ip_address
596 .as_ref()
597 .map(|addr| (p.id.to_ascii_lowercase(), addr.clone()))
598 })
599 .collect();
600
601 let mut hosts = Vec::new();
603 for vm in &vms {
604 if vm.properties.vm_id.is_empty() {
606 continue;
607 }
608 if let Some(ip) = select_ip(vm, &nic_map, &public_ip_map) {
609 hosts.push(ProviderHost {
610 server_id: vm.properties.vm_id.clone(),
611 name: vm.name.clone(),
612 ip,
613 tags: build_tags(vm),
614 metadata: build_metadata(vm),
615 });
616 }
617 }
618
619 Ok(hosts)
620 }
621}
622
623#[cfg(test)]
624#[path = "azure_tests.rs"]
625mod tests;