Skip to main content

purple_ssh/providers/
azure.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use serde::Deserialize;
5
6use super::{Provider, ProviderError, ProviderHost, map_ureq_error};
7
8pub struct Azure {
9    pub subscriptions: Vec<String>,
10}
11
12// --- VM response models ---
13
14#[derive(Deserialize)]
15#[cfg_attr(not(test), allow(dead_code))]
16struct VmListResponse {
17    #[serde(default)]
18    value: Vec<VirtualMachine>,
19    #[serde(rename = "nextLink")]
20    next_link: Option<String>,
21}
22
23#[derive(Deserialize)]
24struct VirtualMachine {
25    name: String,
26    #[serde(default)]
27    location: String,
28    #[serde(default)]
29    tags: Option<HashMap<String, String>>,
30    #[serde(default)]
31    properties: VmProperties,
32}
33
34#[derive(Deserialize, Default)]
35struct VmProperties {
36    #[serde(rename = "vmId", default)]
37    vm_id: String,
38    #[serde(rename = "hardwareProfile")]
39    hardware_profile: Option<HardwareProfile>,
40    #[serde(rename = "storageProfile")]
41    storage_profile: Option<StorageProfile>,
42    #[serde(rename = "networkProfile")]
43    network_profile: Option<NetworkProfile>,
44    #[serde(rename = "instanceView")]
45    instance_view: Option<InstanceView>,
46}
47
48#[derive(Deserialize)]
49struct HardwareProfile {
50    #[serde(rename = "vmSize")]
51    vm_size: String,
52}
53
54#[derive(Deserialize)]
55struct StorageProfile {
56    #[serde(rename = "imageReference")]
57    image_reference: Option<ImageReference>,
58}
59
60#[derive(Deserialize)]
61struct ImageReference {
62    offer: Option<String>,
63    sku: Option<String>,
64    #[allow(dead_code)]
65    id: Option<String>,
66}
67
68#[derive(Deserialize)]
69struct NetworkProfile {
70    #[serde(rename = "networkInterfaces", default)]
71    network_interfaces: Vec<NetworkInterfaceRef>,
72}
73
74#[derive(Deserialize)]
75struct NetworkInterfaceRef {
76    id: String,
77    properties: Option<NicRefProperties>,
78}
79
80#[derive(Deserialize)]
81struct NicRefProperties {
82    primary: Option<bool>,
83}
84
85#[derive(Deserialize)]
86struct InstanceView {
87    #[serde(default)]
88    statuses: Vec<InstanceViewStatus>,
89}
90
91#[derive(Deserialize)]
92struct InstanceViewStatus {
93    code: String,
94}
95
96// --- NIC response models ---
97
98#[derive(Deserialize)]
99#[cfg_attr(not(test), allow(dead_code))]
100struct NicListResponse {
101    #[serde(default)]
102    value: Vec<Nic>,
103    #[serde(rename = "nextLink")]
104    #[allow(dead_code)]
105    next_link: Option<String>,
106}
107
108#[derive(Deserialize)]
109struct Nic {
110    id: String,
111    #[serde(default)]
112    properties: NicProperties,
113}
114
115#[derive(Deserialize, Default)]
116struct NicProperties {
117    #[serde(rename = "ipConfigurations", default)]
118    ip_configurations: Vec<IpConfiguration>,
119}
120
121#[derive(Deserialize)]
122struct IpConfiguration {
123    #[serde(default)]
124    properties: IpConfigProperties,
125}
126
127#[derive(Deserialize, Default)]
128struct IpConfigProperties {
129    #[serde(rename = "privateIPAddress")]
130    private_ip_address: Option<String>,
131    #[serde(rename = "publicIPAddress")]
132    public_ip_address: Option<PublicIpRef>,
133    primary: Option<bool>,
134}
135
136#[derive(Deserialize)]
137struct PublicIpRef {
138    id: String,
139}
140
141// --- Public IP response models ---
142
143#[derive(Deserialize)]
144#[cfg_attr(not(test), allow(dead_code))]
145struct PublicIpListResponse {
146    #[serde(default)]
147    value: Vec<PublicIp>,
148    #[serde(rename = "nextLink")]
149    #[allow(dead_code)]
150    next_link: Option<String>,
151}
152
153#[derive(Deserialize)]
154struct PublicIp {
155    id: String,
156    #[serde(default)]
157    properties: PublicIpProperties,
158}
159
160#[derive(Deserialize, Default)]
161struct PublicIpProperties {
162    #[serde(rename = "ipAddress")]
163    ip_address: Option<String>,
164}
165
166// --- Auth models ---
167
168/// Service principal credentials. Supports two JSON formats:
169/// - Azure CLI output (`az ad sp create-for-rbac`): `appId`, `password`, `tenant`
170/// - Manual/portal format: `clientId`, `clientSecret`, `tenantId`
171#[derive(Deserialize)]
172struct ServicePrincipal {
173    #[serde(alias = "tenantId", alias = "tenant")]
174    tenant_id: String,
175    #[serde(alias = "clientId", alias = "appId")]
176    client_id: String,
177    #[serde(alias = "clientSecret", alias = "password")]
178    client_secret: String,
179}
180
181#[derive(Deserialize)]
182struct TokenResponse {
183    access_token: String,
184}
185
186/// Validate that a subscription ID is a valid UUID (8-4-4-4-12 hex chars).
187pub fn is_valid_subscription_id(id: &str) -> bool {
188    let parts: Vec<&str> = id.split('-').collect();
189    if parts.len() != 5 {
190        return false;
191    }
192    let expected_lens = [8, 4, 4, 4, 12];
193    parts
194        .iter()
195        .zip(expected_lens.iter())
196        .all(|(part, &len)| part.len() == len && part.chars().all(|c| c.is_ascii_hexdigit()))
197}
198
199/// Detect whether a token string is a path to a service principal JSON file.
200fn is_sp_file(token: &str) -> bool {
201    token.to_ascii_lowercase().ends_with(".json")
202}
203
204/// Exchange service principal credentials for an access token.
205fn resolve_sp_token(path: &str) -> Result<String, ProviderError> {
206    let content = std::fs::read_to_string(path)
207        .map_err(|e| ProviderError::Http(format!("Failed to read SP file {}: {}", path, e)))?;
208    let sp: ServicePrincipal = serde_json::from_str(&content)
209        .map_err(|e| ProviderError::Http(format!(
210            "Failed to parse SP file: {}. Expected JSON with appId/password/tenant (az CLI) or clientId/clientSecret/tenantId.", e
211        )))?;
212
213    let agent = super::http_agent();
214    let url = format!(
215        "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
216        sp.tenant_id
217    );
218    let mut resp = agent
219        .post(&url)
220        .send_form([
221            ("grant_type", "client_credentials"),
222            ("client_id", sp.client_id.as_str()),
223            ("client_secret", sp.client_secret.as_str()),
224            ("scope", "https://management.azure.com/.default"),
225        ])
226        .map_err(map_ureq_error)?;
227
228    let token_resp: TokenResponse = resp
229        .body_mut()
230        .read_json()
231        .map_err(|e| ProviderError::Parse(format!("Token response: {}", e)))?;
232
233    Ok(token_resp.access_token)
234}
235
236/// Resolve token: if it's a path to a SP JSON file, exchange it for an access token.
237/// Otherwise, use it as a raw access token. Strips "Bearer " prefix if present.
238fn resolve_token(token: &str) -> Result<String, ProviderError> {
239    if is_sp_file(token) {
240        resolve_sp_token(token)
241    } else {
242        let t = token.strip_prefix("Bearer ").unwrap_or(token);
243        if t.is_empty() {
244            return Err(ProviderError::AuthFailed);
245        }
246        Ok(t.to_string())
247    }
248}
249
250/// Select the best IP for a VM by looking up its primary NIC and IP configuration.
251/// Priority: public IP > private IP > None.
252fn select_ip(
253    vm: &VirtualMachine,
254    nic_map: &HashMap<String, &Nic>,
255    public_ip_map: &HashMap<String, String>,
256) -> Option<String> {
257    let net_profile = vm.properties.network_profile.as_ref()?;
258    if net_profile.network_interfaces.is_empty() {
259        return None;
260    }
261
262    // Find primary NIC, fallback to first
263    let nic_ref = net_profile
264        .network_interfaces
265        .iter()
266        .find(|n| {
267            n.properties
268                .as_ref()
269                .and_then(|p| p.primary)
270                .unwrap_or(false)
271        })
272        .or_else(|| net_profile.network_interfaces.first())?;
273
274    let nic_id_lower = nic_ref.id.to_ascii_lowercase();
275    let nic = nic_map.get(&nic_id_lower)?;
276
277    // Find primary IP config, fallback to first
278    let ip_config = nic
279        .properties
280        .ip_configurations
281        .iter()
282        .find(|c| c.properties.primary.unwrap_or(false))
283        .or_else(|| nic.properties.ip_configurations.first())?;
284
285    // Try public IP first
286    if let Some(ref pub_ref) = ip_config.properties.public_ip_address {
287        let pub_id_lower = pub_ref.id.to_ascii_lowercase();
288        if let Some(addr) = public_ip_map.get(&pub_id_lower) {
289            if !addr.is_empty() {
290                return Some(addr.clone());
291            }
292        }
293    }
294
295    // Fallback to private IP
296    if let Some(ref private) = ip_config.properties.private_ip_address {
297        if !private.is_empty() {
298            return Some(private.clone());
299        }
300    }
301
302    None
303}
304
305/// Extract power state from instanceView statuses.
306fn extract_power_state(instance_view: &Option<InstanceView>) -> Option<String> {
307    let iv = instance_view.as_ref()?;
308    for status in &iv.statuses {
309        if let Some(suffix) = status.code.strip_prefix("PowerState/") {
310            return Some(suffix.to_string());
311        }
312    }
313    None
314}
315
316/// Build OS string from image reference: "{offer}-{sku}".
317fn build_os_string(image_ref: &Option<ImageReference>) -> Option<String> {
318    let img = image_ref.as_ref()?;
319    let offer = img.offer.as_deref()?;
320    let sku = img.sku.as_deref()?;
321    if offer.is_empty() || sku.is_empty() {
322        return None;
323    }
324    Some(format!("{}-{}", offer, sku))
325}
326
327/// Build metadata key-value pairs for a VM.
328fn build_metadata(vm: &VirtualMachine) -> Vec<(String, String)> {
329    let mut metadata = super::ProviderMetadata::new();
330    if !vm.location.is_empty() {
331        metadata.push("region", vm.location.to_ascii_lowercase());
332    }
333    if let Some(ref hw) = vm.properties.hardware_profile {
334        if !hw.vm_size.is_empty() {
335            metadata.push("vm_size", hw.vm_size.clone());
336        }
337    }
338    if let Some(ref sp) = vm.properties.storage_profile {
339        if let Some(os) = build_os_string(&sp.image_reference) {
340            metadata.push("image", os);
341        }
342    }
343    if let Some(state) = extract_power_state(&vm.properties.instance_view) {
344        metadata.push("status", state);
345    }
346    metadata.finish()
347}
348
349/// Build tags from Azure VM tags (key:value map).
350fn build_tags(vm: &VirtualMachine) -> Vec<String> {
351    let mut tags = Vec::new();
352    if let Some(ref vm_tags) = vm.tags {
353        for (k, v) in vm_tags {
354            if v.is_empty() {
355                tags.push(k.clone());
356            } else {
357                tags.push(format!("{}:{}", k, v));
358            }
359        }
360    }
361    tags
362}
363
364/// Fetch a paginated Azure API list endpoint. Returns the deserialized items.
365fn fetch_paginated<T: serde::de::DeserializeOwned>(
366    agent: &ureq::Agent,
367    initial_url: &str,
368    access_token: &str,
369    api_base: &str,
370    cancel: &AtomicBool,
371    resource_name: &str,
372    progress: &dyn Fn(&str),
373) -> Result<Vec<T>, ProviderError> {
374    // We need to deserialize a response that has `value: Vec<T>` and `nextLink: Option<String>`.
375    // Since we can't use generics with serde easily, we'll use serde_json::Value.
376    let mut all_items = Vec::new();
377    let mut next_url: Option<String> = Some(initial_url.to_string());
378
379    for page in 0u32.. {
380        if cancel.load(Ordering::Relaxed) {
381            return Err(ProviderError::Cancelled);
382        }
383        if page > 500 {
384            break;
385        }
386
387        let url = match next_url.take() {
388            Some(u) => u,
389            None => break,
390        };
391
392        progress(&format!(
393            "Fetching {} ({} so far)...",
394            resource_name,
395            all_items.len()
396        ));
397
398        let mut response = match agent
399            .get(&url)
400            .header("Authorization", &super::bearer_auth(access_token))
401            .call()
402        {
403            Ok(r) => r,
404            Err(e) => {
405                let err = map_ureq_error(e);
406                // AuthFailed and RateLimited always propagate immediately
407                if matches!(err, ProviderError::AuthFailed | ProviderError::RateLimited) {
408                    return Err(err);
409                }
410                // On later pages, return what we have so far instead of losing it
411                if !all_items.is_empty() {
412                    break;
413                }
414                return Err(err);
415            }
416        };
417
418        let body: serde_json::Value = match response.body_mut().read_json() {
419            Ok(v) => v,
420            Err(e) => {
421                if !all_items.is_empty() {
422                    break;
423                }
424                return Err(ProviderError::Parse(format!(
425                    "{} response: {}",
426                    resource_name, e
427                )));
428            }
429        };
430
431        if let Some(value_array) = body.get("value").and_then(|v| v.as_array()) {
432            for item in value_array {
433                match serde_json::from_value(item.clone()) {
434                    Ok(parsed) => all_items.push(parsed),
435                    Err(_) => continue, // skip malformed items
436                }
437            }
438        }
439
440        // Only follow nextLinks that point back at the same API host we were
441        // told to use. In production that pins follow-on requests to the real
442        // Azure management host; in tests it pins them to the mock server.
443        // The byte after `api_base` must be '/' so a look-alike host like
444        // `https://management.azure.com.evil/` cannot smuggle the Bearer token
445        // off to an attacker (`api_base` carries no trailing slash).
446        next_url = body
447            .get("nextLink")
448            .and_then(|v| v.as_str())
449            .filter(|s| !s.is_empty())
450            .filter(|s| {
451                s.strip_prefix(api_base)
452                    .is_some_and(|rest| rest.starts_with('/'))
453            })
454            .map(|s| s.to_string());
455    }
456
457    Ok(all_items)
458}
459
460impl Azure {
461    /// Real ARM management host. Overridable via `fetch_with_endpoint` so
462    /// tests can drive the VM/NIC/public-IP join against a mock server.
463    const API_BASE: &'static str = "https://management.azure.com";
464
465    /// Fetch VMs across subscriptions against an explicit management API base.
466    /// Production passes `API_BASE`; tests pass a mock URL (plus a plain bearer
467    /// token so the OAuth exchange is skipped), exercising URL construction,
468    /// the auth header, the VM/NIC/public-IP join and `ProviderHost` mapping.
469    fn fetch_with_endpoint(
470        &self,
471        api_base: &str,
472        token: &str,
473        cancel: &AtomicBool,
474        _env: &crate::runtime::env::Env,
475        progress: &dyn Fn(&str),
476    ) -> Result<Vec<ProviderHost>, ProviderError> {
477        if self.subscriptions.is_empty() {
478            return Err(ProviderError::Http(
479                "No Azure subscriptions configured. Set at least one subscription ID.".to_string(),
480            ));
481        }
482
483        // Validate subscription ID format (UUID: 8-4-4-4-12 hex chars)
484        for sub in &self.subscriptions {
485            if !is_valid_subscription_id(sub) {
486                return Err(ProviderError::Http(format!(
487                    "Invalid subscription ID '{}'. Expected UUID format (e.g. 12345678-1234-1234-1234-123456789012).",
488                    sub
489                )));
490            }
491        }
492
493        progress("Authenticating...");
494        let access_token = resolve_token(token)?;
495
496        if cancel.load(Ordering::Relaxed) {
497            return Err(ProviderError::Cancelled);
498        }
499
500        let agent = super::http_agent();
501        let mut all_hosts = Vec::new();
502        let mut failures = 0usize;
503        let total = self.subscriptions.len();
504
505        for (i, sub) in self.subscriptions.iter().enumerate() {
506            if cancel.load(Ordering::Relaxed) {
507                return Err(ProviderError::Cancelled);
508            }
509
510            progress(&format!("Subscription {}/{} ({})...", i + 1, total, sub));
511
512            match self.fetch_subscription(&agent, &access_token, sub, api_base, cancel, progress) {
513                Ok(hosts) => all_hosts.extend(hosts),
514                Err(ProviderError::Cancelled) => return Err(ProviderError::Cancelled),
515                Err(ProviderError::AuthFailed) => return Err(ProviderError::AuthFailed),
516                Err(ProviderError::RateLimited) => return Err(ProviderError::RateLimited),
517                Err(_) => {
518                    failures += 1;
519                }
520            }
521        }
522
523        if failures > 0 && !all_hosts.is_empty() {
524            return Err(ProviderError::PartialResult {
525                hosts: all_hosts,
526                failures,
527                total,
528            });
529        }
530        if failures > 0 && all_hosts.is_empty() {
531            return Err(ProviderError::Http(format!(
532                "All {} subscription(s) failed.",
533                total
534            )));
535        }
536
537        progress(&format!("{} VMs", all_hosts.len()));
538        Ok(all_hosts)
539    }
540
541    fn fetch_subscription(
542        &self,
543        agent: &ureq::Agent,
544        access_token: &str,
545        subscription_id: &str,
546        api_base: &str,
547        cancel: &AtomicBool,
548        progress: &dyn Fn(&str),
549    ) -> Result<Vec<ProviderHost>, ProviderError> {
550        // 1. Fetch all VMs (with instanceView expanded for power state)
551        let vm_url = format!(
552            "{}/subscriptions/{}/providers/Microsoft.Compute/virtualMachines?api-version=2024-07-01&$expand=instanceView",
553            api_base, subscription_id
554        );
555        let vms: Vec<VirtualMachine> = fetch_paginated(
556            agent,
557            &vm_url,
558            access_token,
559            api_base,
560            cancel,
561            "VMs",
562            progress,
563        )?;
564
565        if cancel.load(Ordering::Relaxed) {
566            return Err(ProviderError::Cancelled);
567        }
568
569        // 2. Fetch all NICs
570        let nic_url = format!(
571            "{}/subscriptions/{}/providers/Microsoft.Network/networkInterfaces?api-version=2024-05-01",
572            api_base, subscription_id
573        );
574        let nics: Vec<Nic> = fetch_paginated(
575            agent,
576            &nic_url,
577            access_token,
578            api_base,
579            cancel,
580            "NICs",
581            progress,
582        )?;
583
584        if cancel.load(Ordering::Relaxed) {
585            return Err(ProviderError::Cancelled);
586        }
587
588        // 3. Fetch all public IPs
589        let pip_url = format!(
590            "{}/subscriptions/{}/providers/Microsoft.Network/publicIPAddresses?api-version=2024-05-01",
591            api_base, subscription_id
592        );
593        let public_ips: Vec<PublicIp> = fetch_paginated(
594            agent,
595            &pip_url,
596            access_token,
597            api_base,
598            cancel,
599            "public IPs",
600            progress,
601        )?;
602
603        // Build lookup maps (case-insensitive Azure resource IDs)
604        let nic_map: HashMap<String, &Nic> = nics
605            .iter()
606            .map(|n| (n.id.to_ascii_lowercase(), n))
607            .collect();
608
609        let public_ip_map: HashMap<String, String> = public_ips
610            .iter()
611            .filter_map(|p| {
612                p.properties
613                    .ip_address
614                    .as_ref()
615                    .map(|addr| (p.id.to_ascii_lowercase(), addr.clone()))
616            })
617            .collect();
618
619        // 4. Join: VM -> NIC -> public IP
620        let mut hosts = Vec::new();
621        for vm in &vms {
622            // Skip VMs with empty vm_id (would collide in sync engine)
623            if vm.properties.vm_id.is_empty() {
624                continue;
625            }
626            if let Some(ip) = select_ip(vm, &nic_map, &public_ip_map) {
627                hosts.push(ProviderHost {
628                    server_id: vm.properties.vm_id.clone(),
629                    name: vm.name.clone(),
630                    ip,
631                    tags: build_tags(vm),
632                    metadata: build_metadata(vm),
633                });
634            }
635        }
636
637        Ok(hosts)
638    }
639}
640
641impl Provider for Azure {
642    fn name(&self) -> &str {
643        "azure"
644    }
645
646    fn short_label(&self) -> &str {
647        "az"
648    }
649
650    fn fetch_hosts_cancellable(
651        &self,
652        token: &str,
653        cancel: &AtomicBool,
654        env: &crate::runtime::env::Env,
655    ) -> Result<Vec<ProviderHost>, ProviderError> {
656        self.fetch_hosts_with_progress(token, cancel, env, &|_| {})
657    }
658
659    fn fetch_hosts_with_progress(
660        &self,
661        token: &str,
662        cancel: &AtomicBool,
663        env: &crate::runtime::env::Env,
664        progress: &dyn Fn(&str),
665    ) -> Result<Vec<ProviderHost>, ProviderError> {
666        self.fetch_with_endpoint(Self::API_BASE, token, cancel, env, progress)
667    }
668}
669
670#[cfg(test)]
671#[path = "azure_tests.rs"]
672mod tests;