Skip to main content

purple_ssh/providers/
azure.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use serde::Deserialize;
5
6use super::{Provider, ProviderError, ProviderHost, map_ureq_error};
7
8pub struct Azure {
9    pub subscriptions: Vec<String>,
10}
11
12// --- VM response models ---
13
14#[derive(Deserialize)]
15#[cfg_attr(not(test), allow(dead_code))]
16struct VmListResponse {
17    #[serde(default)]
18    value: Vec<VirtualMachine>,
19    #[serde(rename = "nextLink")]
20    next_link: Option<String>,
21}
22
23#[derive(Deserialize)]
24struct VirtualMachine {
25    name: String,
26    #[serde(default)]
27    location: String,
28    #[serde(default)]
29    tags: Option<HashMap<String, String>>,
30    #[serde(default)]
31    properties: VmProperties,
32}
33
34#[derive(Deserialize, Default)]
35struct VmProperties {
36    #[serde(rename = "vmId", default)]
37    vm_id: String,
38    #[serde(rename = "hardwareProfile")]
39    hardware_profile: Option<HardwareProfile>,
40    #[serde(rename = "storageProfile")]
41    storage_profile: Option<StorageProfile>,
42    #[serde(rename = "networkProfile")]
43    network_profile: Option<NetworkProfile>,
44    #[serde(rename = "instanceView")]
45    instance_view: Option<InstanceView>,
46}
47
48#[derive(Deserialize)]
49struct HardwareProfile {
50    #[serde(rename = "vmSize")]
51    vm_size: String,
52}
53
54#[derive(Deserialize)]
55struct StorageProfile {
56    #[serde(rename = "imageReference")]
57    image_reference: Option<ImageReference>,
58}
59
60#[derive(Deserialize)]
61struct ImageReference {
62    offer: Option<String>,
63    sku: Option<String>,
64    #[allow(dead_code)]
65    id: Option<String>,
66}
67
68#[derive(Deserialize)]
69struct NetworkProfile {
70    #[serde(rename = "networkInterfaces", default)]
71    network_interfaces: Vec<NetworkInterfaceRef>,
72}
73
74#[derive(Deserialize)]
75struct NetworkInterfaceRef {
76    id: String,
77    properties: Option<NicRefProperties>,
78}
79
80#[derive(Deserialize)]
81struct NicRefProperties {
82    primary: Option<bool>,
83}
84
85#[derive(Deserialize)]
86struct InstanceView {
87    #[serde(default)]
88    statuses: Vec<InstanceViewStatus>,
89}
90
91#[derive(Deserialize)]
92struct InstanceViewStatus {
93    code: String,
94}
95
96// --- NIC response models ---
97
98#[derive(Deserialize)]
99#[cfg_attr(not(test), allow(dead_code))]
100struct NicListResponse {
101    #[serde(default)]
102    value: Vec<Nic>,
103    #[serde(rename = "nextLink")]
104    #[allow(dead_code)]
105    next_link: Option<String>,
106}
107
108#[derive(Deserialize)]
109struct Nic {
110    id: String,
111    #[serde(default)]
112    properties: NicProperties,
113}
114
115#[derive(Deserialize, Default)]
116struct NicProperties {
117    #[serde(rename = "ipConfigurations", default)]
118    ip_configurations: Vec<IpConfiguration>,
119}
120
121#[derive(Deserialize)]
122struct IpConfiguration {
123    #[serde(default)]
124    properties: IpConfigProperties,
125}
126
127#[derive(Deserialize, Default)]
128struct IpConfigProperties {
129    #[serde(rename = "privateIPAddress")]
130    private_ip_address: Option<String>,
131    #[serde(rename = "publicIPAddress")]
132    public_ip_address: Option<PublicIpRef>,
133    primary: Option<bool>,
134}
135
136#[derive(Deserialize)]
137struct PublicIpRef {
138    id: String,
139}
140
141// --- Public IP response models ---
142
143#[derive(Deserialize)]
144#[cfg_attr(not(test), allow(dead_code))]
145struct PublicIpListResponse {
146    #[serde(default)]
147    value: Vec<PublicIp>,
148    #[serde(rename = "nextLink")]
149    #[allow(dead_code)]
150    next_link: Option<String>,
151}
152
153#[derive(Deserialize)]
154struct PublicIp {
155    id: String,
156    #[serde(default)]
157    properties: PublicIpProperties,
158}
159
160#[derive(Deserialize, Default)]
161struct PublicIpProperties {
162    #[serde(rename = "ipAddress")]
163    ip_address: Option<String>,
164}
165
166// --- Auth models ---
167
168/// Service principal credentials. Supports two JSON formats:
169/// - Azure CLI output (`az ad sp create-for-rbac`): `appId`, `password`, `tenant`
170/// - Manual/portal format: `clientId`, `clientSecret`, `tenantId`
171#[derive(Deserialize)]
172struct ServicePrincipal {
173    #[serde(alias = "tenantId", alias = "tenant")]
174    tenant_id: String,
175    #[serde(alias = "clientId", alias = "appId")]
176    client_id: String,
177    #[serde(alias = "clientSecret", alias = "password")]
178    client_secret: String,
179}
180
181#[derive(Deserialize)]
182struct TokenResponse {
183    access_token: String,
184}
185
186/// Validate that a subscription ID is a valid UUID (8-4-4-4-12 hex chars).
187pub fn is_valid_subscription_id(id: &str) -> bool {
188    let parts: Vec<&str> = id.split('-').collect();
189    if parts.len() != 5 {
190        return false;
191    }
192    let expected_lens = [8, 4, 4, 4, 12];
193    parts
194        .iter()
195        .zip(expected_lens.iter())
196        .all(|(part, &len)| part.len() == len && part.chars().all(|c| c.is_ascii_hexdigit()))
197}
198
199/// Detect whether a token string is a path to a service principal JSON file.
200fn is_sp_file(token: &str) -> bool {
201    token.to_ascii_lowercase().ends_with(".json")
202}
203
204/// Exchange service principal credentials for an access token.
205fn resolve_sp_token(path: &str) -> Result<String, ProviderError> {
206    let content = std::fs::read_to_string(path)
207        .map_err(|e| ProviderError::Http(format!("Failed to read SP file {}: {}", path, e)))?;
208    let sp: ServicePrincipal = serde_json::from_str(&content)
209        .map_err(|e| ProviderError::Http(format!(
210            "Failed to parse SP file: {}. Expected JSON with appId/password/tenant (az CLI) or clientId/clientSecret/tenantId.", e
211        )))?;
212
213    let agent = super::http_agent();
214    let url = format!(
215        "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
216        sp.tenant_id
217    );
218    let mut resp = agent
219        .post(&url)
220        .send_form([
221            ("grant_type", "client_credentials"),
222            ("client_id", sp.client_id.as_str()),
223            ("client_secret", sp.client_secret.as_str()),
224            ("scope", "https://management.azure.com/.default"),
225        ])
226        .map_err(map_ureq_error)?;
227
228    let token_resp: TokenResponse = resp
229        .body_mut()
230        .read_json()
231        .map_err(|e| ProviderError::Parse(format!("Token response: {}", e)))?;
232
233    Ok(token_resp.access_token)
234}
235
236/// Resolve token: if it's a path to a SP JSON file, exchange it for an access token.
237/// Otherwise, use it as a raw access token. Strips "Bearer " prefix if present.
238fn resolve_token(token: &str) -> Result<String, ProviderError> {
239    if is_sp_file(token) {
240        resolve_sp_token(token)
241    } else {
242        let t = token.strip_prefix("Bearer ").unwrap_or(token);
243        if t.is_empty() {
244            return Err(ProviderError::AuthFailed);
245        }
246        Ok(t.to_string())
247    }
248}
249
250/// Select the best IP for a VM by looking up its primary NIC and IP configuration.
251/// Priority: public IP > private IP > None.
252fn select_ip(
253    vm: &VirtualMachine,
254    nic_map: &HashMap<String, &Nic>,
255    public_ip_map: &HashMap<String, String>,
256) -> Option<String> {
257    let net_profile = vm.properties.network_profile.as_ref()?;
258    if net_profile.network_interfaces.is_empty() {
259        return None;
260    }
261
262    // Find primary NIC, fallback to first
263    let nic_ref = net_profile
264        .network_interfaces
265        .iter()
266        .find(|n| {
267            n.properties
268                .as_ref()
269                .and_then(|p| p.primary)
270                .unwrap_or(false)
271        })
272        .or_else(|| net_profile.network_interfaces.first())?;
273
274    let nic_id_lower = nic_ref.id.to_ascii_lowercase();
275    let nic = nic_map.get(&nic_id_lower)?;
276
277    // Find primary IP config, fallback to first
278    let ip_config = nic
279        .properties
280        .ip_configurations
281        .iter()
282        .find(|c| c.properties.primary.unwrap_or(false))
283        .or_else(|| nic.properties.ip_configurations.first())?;
284
285    // Try public IP first
286    if let Some(ref pub_ref) = ip_config.properties.public_ip_address {
287        let pub_id_lower = pub_ref.id.to_ascii_lowercase();
288        if let Some(addr) = public_ip_map.get(&pub_id_lower) {
289            if !addr.is_empty() {
290                return Some(addr.clone());
291            }
292        }
293    }
294
295    // Fallback to private IP
296    if let Some(ref private) = ip_config.properties.private_ip_address {
297        if !private.is_empty() {
298            return Some(private.clone());
299        }
300    }
301
302    None
303}
304
305/// Extract power state from instanceView statuses.
306fn extract_power_state(instance_view: &Option<InstanceView>) -> Option<String> {
307    let iv = instance_view.as_ref()?;
308    for status in &iv.statuses {
309        if let Some(suffix) = status.code.strip_prefix("PowerState/") {
310            return Some(suffix.to_string());
311        }
312    }
313    None
314}
315
316/// Build OS string from image reference: "{offer}-{sku}".
317fn build_os_string(image_ref: &Option<ImageReference>) -> Option<String> {
318    let img = image_ref.as_ref()?;
319    let offer = img.offer.as_deref()?;
320    let sku = img.sku.as_deref()?;
321    if offer.is_empty() || sku.is_empty() {
322        return None;
323    }
324    Some(format!("{}-{}", offer, sku))
325}
326
327/// Build metadata key-value pairs for a VM.
328fn build_metadata(vm: &VirtualMachine) -> Vec<(String, String)> {
329    let mut metadata = Vec::new();
330    if !vm.location.is_empty() {
331        metadata.push(("region".to_string(), vm.location.to_ascii_lowercase()));
332    }
333    if let Some(ref hw) = vm.properties.hardware_profile {
334        if !hw.vm_size.is_empty() {
335            metadata.push(("vm_size".to_string(), hw.vm_size.clone()));
336        }
337    }
338    if let Some(ref sp) = vm.properties.storage_profile {
339        if let Some(os) = build_os_string(&sp.image_reference) {
340            metadata.push(("image".to_string(), os));
341        }
342    }
343    if let Some(state) = extract_power_state(&vm.properties.instance_view) {
344        metadata.push(("status".to_string(), state));
345    }
346    metadata
347}
348
349/// Build tags from Azure VM tags (key:value map).
350fn build_tags(vm: &VirtualMachine) -> Vec<String> {
351    let mut tags = Vec::new();
352    if let Some(ref vm_tags) = vm.tags {
353        for (k, v) in vm_tags {
354            if v.is_empty() {
355                tags.push(k.clone());
356            } else {
357                tags.push(format!("{}:{}", k, v));
358            }
359        }
360    }
361    tags
362}
363
364/// Fetch a paginated Azure API list endpoint. Returns the deserialized items.
365fn fetch_paginated<T: serde::de::DeserializeOwned>(
366    agent: &ureq::Agent,
367    initial_url: &str,
368    access_token: &str,
369    cancel: &AtomicBool,
370    resource_name: &str,
371    progress: &dyn Fn(&str),
372) -> Result<Vec<T>, ProviderError> {
373    // We need to deserialize a response that has `value: Vec<T>` and `nextLink: Option<String>`.
374    // Since we can't use generics with serde easily, we'll use serde_json::Value.
375    let mut all_items = Vec::new();
376    let mut next_url: Option<String> = Some(initial_url.to_string());
377
378    for page in 0u32.. {
379        if cancel.load(Ordering::Relaxed) {
380            return Err(ProviderError::Cancelled);
381        }
382        if page > 500 {
383            break;
384        }
385
386        let url = match next_url.take() {
387            Some(u) => u,
388            None => break,
389        };
390
391        progress(&format!(
392            "Fetching {} ({} so far)...",
393            resource_name,
394            all_items.len()
395        ));
396
397        let mut response = match agent
398            .get(&url)
399            .header("Authorization", &format!("Bearer {}", access_token))
400            .call()
401        {
402            Ok(r) => r,
403            Err(e) => {
404                let err = map_ureq_error(e);
405                // AuthFailed and RateLimited always propagate immediately
406                if matches!(err, ProviderError::AuthFailed | ProviderError::RateLimited) {
407                    return Err(err);
408                }
409                // On later pages, return what we have so far instead of losing it
410                if !all_items.is_empty() {
411                    break;
412                }
413                return Err(err);
414            }
415        };
416
417        let body: serde_json::Value = match response.body_mut().read_json() {
418            Ok(v) => v,
419            Err(e) => {
420                if !all_items.is_empty() {
421                    break;
422                }
423                return Err(ProviderError::Parse(format!(
424                    "{} response: {}",
425                    resource_name, e
426                )));
427            }
428        };
429
430        if let Some(value_array) = body.get("value").and_then(|v| v.as_array()) {
431            for item in value_array {
432                match serde_json::from_value(item.clone()) {
433                    Ok(parsed) => all_items.push(parsed),
434                    Err(_) => continue, // skip malformed items
435                }
436            }
437        }
438
439        next_url = body
440            .get("nextLink")
441            .and_then(|v| v.as_str())
442            .filter(|s| !s.is_empty())
443            .filter(|s| s.starts_with("https://management.azure.com/"))
444            .map(|s| s.to_string());
445    }
446
447    Ok(all_items)
448}
449
450impl Provider for Azure {
451    fn name(&self) -> &str {
452        "azure"
453    }
454
455    fn short_label(&self) -> &str {
456        "az"
457    }
458
459    fn fetch_hosts_cancellable(
460        &self,
461        token: &str,
462        cancel: &AtomicBool,
463    ) -> Result<Vec<ProviderHost>, ProviderError> {
464        self.fetch_hosts_with_progress(token, cancel, &|_| {})
465    }
466
467    fn fetch_hosts_with_progress(
468        &self,
469        token: &str,
470        cancel: &AtomicBool,
471        progress: &dyn Fn(&str),
472    ) -> Result<Vec<ProviderHost>, ProviderError> {
473        if self.subscriptions.is_empty() {
474            return Err(ProviderError::Http(
475                "No Azure subscriptions configured. Set at least one subscription ID.".to_string(),
476            ));
477        }
478
479        // Validate subscription ID format (UUID: 8-4-4-4-12 hex chars)
480        for sub in &self.subscriptions {
481            if !is_valid_subscription_id(sub) {
482                return Err(ProviderError::Http(format!(
483                    "Invalid subscription ID '{}'. Expected UUID format (e.g. 12345678-1234-1234-1234-123456789012).",
484                    sub
485                )));
486            }
487        }
488
489        progress("Authenticating...");
490        let access_token = resolve_token(token)?;
491
492        if cancel.load(Ordering::Relaxed) {
493            return Err(ProviderError::Cancelled);
494        }
495
496        let agent = super::http_agent();
497        let mut all_hosts = Vec::new();
498        let mut failures = 0usize;
499        let total = self.subscriptions.len();
500
501        for (i, sub) in self.subscriptions.iter().enumerate() {
502            if cancel.load(Ordering::Relaxed) {
503                return Err(ProviderError::Cancelled);
504            }
505
506            progress(&format!("Subscription {}/{} ({})...", i + 1, total, sub));
507
508            match self.fetch_subscription(&agent, &access_token, sub, cancel, progress) {
509                Ok(hosts) => all_hosts.extend(hosts),
510                Err(ProviderError::Cancelled) => return Err(ProviderError::Cancelled),
511                Err(ProviderError::AuthFailed) => return Err(ProviderError::AuthFailed),
512                Err(ProviderError::RateLimited) => return Err(ProviderError::RateLimited),
513                Err(_) => {
514                    failures += 1;
515                }
516            }
517        }
518
519        if failures > 0 && !all_hosts.is_empty() {
520            return Err(ProviderError::PartialResult {
521                hosts: all_hosts,
522                failures,
523                total,
524            });
525        }
526        if failures > 0 && all_hosts.is_empty() {
527            return Err(ProviderError::Http(format!(
528                "All {} subscription(s) failed.",
529                total
530            )));
531        }
532
533        progress(&format!("{} VMs", all_hosts.len()));
534        Ok(all_hosts)
535    }
536}
537
538impl Azure {
539    fn fetch_subscription(
540        &self,
541        agent: &ureq::Agent,
542        access_token: &str,
543        subscription_id: &str,
544        cancel: &AtomicBool,
545        progress: &dyn Fn(&str),
546    ) -> Result<Vec<ProviderHost>, ProviderError> {
547        // 1. Fetch all VMs (with instanceView expanded for power state)
548        let vm_url = format!(
549            "https://management.azure.com/subscriptions/{}/providers/Microsoft.Compute/virtualMachines?api-version=2024-07-01&$expand=instanceView",
550            subscription_id
551        );
552        let vms: Vec<VirtualMachine> =
553            fetch_paginated(agent, &vm_url, access_token, cancel, "VMs", progress)?;
554
555        if cancel.load(Ordering::Relaxed) {
556            return Err(ProviderError::Cancelled);
557        }
558
559        // 2. Fetch all NICs
560        let nic_url = format!(
561            "https://management.azure.com/subscriptions/{}/providers/Microsoft.Network/networkInterfaces?api-version=2024-05-01",
562            subscription_id
563        );
564        let nics: Vec<Nic> =
565            fetch_paginated(agent, &nic_url, access_token, cancel, "NICs", progress)?;
566
567        if cancel.load(Ordering::Relaxed) {
568            return Err(ProviderError::Cancelled);
569        }
570
571        // 3. Fetch all public IPs
572        let pip_url = format!(
573            "https://management.azure.com/subscriptions/{}/providers/Microsoft.Network/publicIPAddresses?api-version=2024-05-01",
574            subscription_id
575        );
576        let public_ips: Vec<PublicIp> = fetch_paginated(
577            agent,
578            &pip_url,
579            access_token,
580            cancel,
581            "public IPs",
582            progress,
583        )?;
584
585        // Build lookup maps (case-insensitive Azure resource IDs)
586        let nic_map: HashMap<String, &Nic> = nics
587            .iter()
588            .map(|n| (n.id.to_ascii_lowercase(), n))
589            .collect();
590
591        let public_ip_map: HashMap<String, String> = public_ips
592            .iter()
593            .filter_map(|p| {
594                p.properties
595                    .ip_address
596                    .as_ref()
597                    .map(|addr| (p.id.to_ascii_lowercase(), addr.clone()))
598            })
599            .collect();
600
601        // 4. Join: VM -> NIC -> public IP
602        let mut hosts = Vec::new();
603        for vm in &vms {
604            // Skip VMs with empty vm_id (would collide in sync engine)
605            if vm.properties.vm_id.is_empty() {
606                continue;
607            }
608            if let Some(ip) = select_ip(vm, &nic_map, &public_ip_map) {
609                hosts.push(ProviderHost {
610                    server_id: vm.properties.vm_id.clone(),
611                    name: vm.name.clone(),
612                    ip,
613                    tags: build_tags(vm),
614                    metadata: build_metadata(vm),
615                });
616            }
617        }
618
619        Ok(hosts)
620    }
621}
622
623#[cfg(test)]
624#[path = "azure_tests.rs"]
625mod tests;