Skip to main content

purple_ssh/providers/
mod.rs

1pub mod aws;
2pub mod azure;
3pub mod config;
4mod digitalocean;
5pub mod gcp;
6mod hetzner;
7mod linode;
8mod proxmox;
9pub mod scaleway;
10pub mod sync;
11mod tailscale;
12mod upcloud;
13mod vultr;
14
15use std::sync::atomic::AtomicBool;
16
17use thiserror::Error;
18
19/// A host discovered from a cloud provider API.
20#[derive(Debug, Clone)]
21#[allow(dead_code)]
22pub struct ProviderHost {
23    /// Provider-assigned server ID.
24    pub server_id: String,
25    /// Server name/label.
26    pub name: String,
27    /// Public IP address (IPv4 or IPv6).
28    pub ip: String,
29    /// Provider tags/labels.
30    pub tags: Vec<String>,
31    /// Provider metadata (region, plan, etc.) as key-value pairs.
32    pub metadata: Vec<(String, String)>,
33}
34
35impl ProviderHost {
36    /// Create a ProviderHost with no metadata.
37    #[allow(dead_code)]
38    pub fn new(server_id: String, name: String, ip: String, tags: Vec<String>) -> Self {
39        Self {
40            server_id,
41            name,
42            ip,
43            tags,
44            metadata: Vec::new(),
45        }
46    }
47}
48
49/// Errors from provider API calls.
50#[derive(Debug, Error)]
51pub enum ProviderError {
52    #[error("HTTP error: {0}")]
53    Http(String),
54    #[error("Failed to parse response: {0}")]
55    Parse(String),
56    #[error("Authentication failed. Check your API token.")]
57    AuthFailed,
58    #[error("Rate limited. Try again in a moment.")]
59    RateLimited,
60    #[error("{0}")]
61    Execute(String),
62    #[error("Cancelled.")]
63    Cancelled,
64    /// Some hosts were fetched but others failed. The caller should use the
65    /// hosts but suppress destructive operations like --remove.
66    #[error("Partial result: {failures} of {total} failed")]
67    PartialResult {
68        hosts: Vec<ProviderHost>,
69        failures: usize,
70        total: usize,
71    },
72}
73
74/// Trait implemented by each cloud provider.
75pub trait Provider {
76    /// Full provider name (e.g. "digitalocean").
77    fn name(&self) -> &str;
78    /// Short label for aliases (e.g. "do").
79    fn short_label(&self) -> &str;
80    /// Fetch hosts with cancellation support.
81    fn fetch_hosts_cancellable(
82        &self,
83        token: &str,
84        cancel: &AtomicBool,
85    ) -> Result<Vec<ProviderHost>, ProviderError>;
86    /// Fetch all servers from the provider API.
87    #[allow(dead_code)]
88    fn fetch_hosts(&self, token: &str) -> Result<Vec<ProviderHost>, ProviderError> {
89        self.fetch_hosts_cancellable(token, &AtomicBool::new(false))
90    }
91    /// Fetch hosts with progress reporting. Default delegates to fetch_hosts_cancellable.
92    fn fetch_hosts_with_progress(
93        &self,
94        token: &str,
95        cancel: &AtomicBool,
96        _progress: &dyn Fn(&str),
97    ) -> Result<Vec<ProviderHost>, ProviderError> {
98        self.fetch_hosts_cancellable(token, cancel)
99    }
100}
101
102/// All known provider names.
103pub const PROVIDER_NAMES: &[&str] = &["digitalocean", "vultr", "linode", "hetzner", "upcloud", "proxmox", "aws", "scaleway", "gcp", "azure", "tailscale"];
104
105/// Get a provider implementation by name.
106pub fn get_provider(name: &str) -> Option<Box<dyn Provider>> {
107    match name {
108        "digitalocean" => Some(Box::new(digitalocean::DigitalOcean)),
109        "vultr" => Some(Box::new(vultr::Vultr)),
110        "linode" => Some(Box::new(linode::Linode)),
111        "hetzner" => Some(Box::new(hetzner::Hetzner)),
112        "upcloud" => Some(Box::new(upcloud::UpCloud)),
113        "proxmox" => Some(Box::new(proxmox::Proxmox {
114            base_url: String::new(),
115            verify_tls: true,
116        })),
117        "aws" => Some(Box::new(aws::Aws {
118            regions: Vec::new(),
119            profile: String::new(),
120        })),
121        "scaleway" => Some(Box::new(scaleway::Scaleway {
122            zones: Vec::new(),
123        })),
124        "gcp" => Some(Box::new(gcp::Gcp {
125            zones: Vec::new(),
126            project: String::new(),
127        })),
128        "azure" => Some(Box::new(azure::Azure {
129            subscriptions: Vec::new(),
130        })),
131        "tailscale" => Some(Box::new(tailscale::Tailscale)),
132        _ => None,
133    }
134}
135
136/// Get a provider implementation configured from a provider section.
137/// For providers that need extra config (e.g. Proxmox base URL), this
138/// creates a properly configured instance.
139pub fn get_provider_with_config(name: &str, section: &config::ProviderSection) -> Option<Box<dyn Provider>> {
140    match name {
141        "proxmox" => Some(Box::new(proxmox::Proxmox {
142            base_url: section.url.clone(),
143            verify_tls: section.verify_tls,
144        })),
145        "aws" => Some(Box::new(aws::Aws {
146            regions: section.regions.split(',')
147                .map(|s| s.trim().to_string())
148                .filter(|s| !s.is_empty())
149                .collect(),
150            profile: section.profile.clone(),
151        })),
152        "scaleway" => Some(Box::new(scaleway::Scaleway {
153            zones: section.regions.split(',')
154                .map(|s| s.trim().to_string())
155                .filter(|s| !s.is_empty())
156                .collect(),
157        })),
158        "gcp" => Some(Box::new(gcp::Gcp {
159            zones: section.regions.split(',')
160                .map(|s| s.trim().to_string())
161                .filter(|s| !s.is_empty())
162                .collect(),
163            project: section.project.clone(),
164        })),
165        "azure" => Some(Box::new(azure::Azure {
166            subscriptions: section.regions.split(',')
167                .map(|s| s.trim().to_string())
168                .filter(|s| !s.is_empty())
169                .collect(),
170        })),
171        _ => get_provider(name),
172    }
173}
174
175/// Display name for a provider (e.g. "digitalocean" -> "DigitalOcean").
176pub fn provider_display_name(name: &str) -> &str {
177    match name {
178        "digitalocean" => "DigitalOcean",
179        "vultr" => "Vultr",
180        "linode" => "Linode",
181        "hetzner" => "Hetzner",
182        "upcloud" => "UpCloud",
183        "proxmox" => "Proxmox VE",
184        "aws" => "AWS EC2",
185        "scaleway" => "Scaleway",
186        "gcp" => "GCP",
187        "azure" => "Azure",
188        "tailscale" => "Tailscale",
189        other => other,
190    }
191}
192
193/// Create an HTTP agent with explicit timeouts.
194pub(crate) fn http_agent() -> ureq::Agent {
195    ureq::AgentBuilder::new()
196        .timeout(std::time::Duration::from_secs(30))
197        .redirects(0)
198        .build()
199}
200
201/// Create an HTTP agent that accepts invalid/self-signed TLS certificates.
202pub(crate) fn http_agent_insecure() -> Result<ureq::Agent, ProviderError> {
203    let tls = ureq::native_tls::TlsConnector::builder()
204        .danger_accept_invalid_certs(true)
205        .danger_accept_invalid_hostnames(true)
206        .build()
207        .map_err(|e| ProviderError::Http(format!("TLS setup failed: {}", e)))?;
208    Ok(ureq::AgentBuilder::new()
209        .timeout(std::time::Duration::from_secs(30))
210        .redirects(0)
211        .tls_connector(std::sync::Arc::new(tls))
212        .build())
213}
214
215/// Strip CIDR suffix (/64, /128, etc.) from an IP address.
216/// Some provider APIs return IPv6 addresses with prefix length (e.g. "2600:3c00::1/128").
217/// SSH requires bare addresses without CIDR notation.
218pub(crate) fn strip_cidr(ip: &str) -> &str {
219    // Only strip if it looks like a CIDR suffix (slash followed by digits)
220    if let Some(pos) = ip.rfind('/') {
221        if ip[pos + 1..].bytes().all(|b| b.is_ascii_digit()) && pos + 1 < ip.len() {
222            return &ip[..pos];
223        }
224    }
225    ip
226}
227
228/// Map a ureq error to a ProviderError.
229fn map_ureq_error(err: ureq::Error) -> ProviderError {
230    match err {
231        ureq::Error::Status(401, _) | ureq::Error::Status(403, _) => ProviderError::AuthFailed,
232        ureq::Error::Status(429, _) => ProviderError::RateLimited,
233        ureq::Error::Status(code, _) => ProviderError::Http(format!("HTTP {}", code)),
234        ureq::Error::Transport(t) => ProviderError::Http(t.to_string()),
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    // =========================================================================
243    // strip_cidr tests
244    // =========================================================================
245
246    #[test]
247    fn test_strip_cidr_ipv6_with_prefix() {
248        assert_eq!(strip_cidr("2600:3c00::1/128"), "2600:3c00::1");
249        assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
250    }
251
252    #[test]
253    fn test_strip_cidr_bare_ipv6() {
254        assert_eq!(strip_cidr("2600:3c00::1"), "2600:3c00::1");
255    }
256
257    #[test]
258    fn test_strip_cidr_ipv4_passthrough() {
259        assert_eq!(strip_cidr("1.2.3.4"), "1.2.3.4");
260        assert_eq!(strip_cidr("10.0.0.1/24"), "10.0.0.1");
261    }
262
263    #[test]
264    fn test_strip_cidr_empty() {
265        assert_eq!(strip_cidr(""), "");
266    }
267
268    #[test]
269    fn test_strip_cidr_slash_without_digits() {
270        // Shouldn't strip if after slash there are non-digits
271        assert_eq!(strip_cidr("path/to/something"), "path/to/something");
272    }
273
274    #[test]
275    fn test_strip_cidr_trailing_slash() {
276        // Trailing slash with nothing after: pos+1 == ip.len(), should NOT strip
277        assert_eq!(strip_cidr("1.2.3.4/"), "1.2.3.4/");
278    }
279
280    // =========================================================================
281    // get_provider factory tests
282    // =========================================================================
283
284    #[test]
285    fn test_get_provider_digitalocean() {
286        let p = get_provider("digitalocean").unwrap();
287        assert_eq!(p.name(), "digitalocean");
288        assert_eq!(p.short_label(), "do");
289    }
290
291    #[test]
292    fn test_get_provider_vultr() {
293        let p = get_provider("vultr").unwrap();
294        assert_eq!(p.name(), "vultr");
295        assert_eq!(p.short_label(), "vultr");
296    }
297
298    #[test]
299    fn test_get_provider_linode() {
300        let p = get_provider("linode").unwrap();
301        assert_eq!(p.name(), "linode");
302        assert_eq!(p.short_label(), "linode");
303    }
304
305    #[test]
306    fn test_get_provider_hetzner() {
307        let p = get_provider("hetzner").unwrap();
308        assert_eq!(p.name(), "hetzner");
309        assert_eq!(p.short_label(), "hetzner");
310    }
311
312    #[test]
313    fn test_get_provider_upcloud() {
314        let p = get_provider("upcloud").unwrap();
315        assert_eq!(p.name(), "upcloud");
316        assert_eq!(p.short_label(), "uc");
317    }
318
319    #[test]
320    fn test_get_provider_proxmox() {
321        let p = get_provider("proxmox").unwrap();
322        assert_eq!(p.name(), "proxmox");
323        assert_eq!(p.short_label(), "pve");
324    }
325
326    #[test]
327    fn test_get_provider_unknown_returns_none() {
328        assert!(get_provider("oracle").is_none());
329        assert!(get_provider("").is_none());
330        assert!(get_provider("DigitalOcean").is_none()); // case-sensitive
331    }
332
333    #[test]
334    fn test_get_provider_all_names_resolve() {
335        for name in PROVIDER_NAMES {
336            assert!(get_provider(name).is_some(), "Provider '{}' should resolve", name);
337        }
338    }
339
340    // =========================================================================
341    // get_provider_with_config tests
342    // =========================================================================
343
344    #[test]
345    fn test_get_provider_with_config_proxmox_uses_url() {
346        let section = config::ProviderSection {
347            provider: "proxmox".to_string(),
348            token: "user@pam!token=secret".to_string(),
349            alias_prefix: "pve-".to_string(),
350            user: String::new(),
351            identity_file: String::new(),
352            url: "https://pve.example.com:8006".to_string(),
353            verify_tls: false,
354            auto_sync: false,
355            profile: String::new(),
356            regions: String::new(),
357            project: String::new(),
358        };
359        let p = get_provider_with_config("proxmox", &section).unwrap();
360        assert_eq!(p.name(), "proxmox");
361    }
362
363    #[test]
364    fn test_get_provider_with_config_non_proxmox_delegates() {
365        let section = config::ProviderSection {
366            provider: "digitalocean".to_string(),
367            token: "do-token".to_string(),
368            alias_prefix: "do-".to_string(),
369            user: String::new(),
370            identity_file: String::new(),
371            url: String::new(),
372            verify_tls: true,
373            auto_sync: true,
374            profile: String::new(),
375            regions: String::new(),
376            project: String::new(),
377        };
378        let p = get_provider_with_config("digitalocean", &section).unwrap();
379        assert_eq!(p.name(), "digitalocean");
380    }
381
382    #[test]
383    fn test_get_provider_with_config_gcp_uses_project_and_zones() {
384        let section = config::ProviderSection {
385            provider: "gcp".to_string(),
386            token: "sa.json".to_string(),
387            alias_prefix: "gcp".to_string(),
388            user: String::new(),
389            identity_file: String::new(),
390            url: String::new(),
391            verify_tls: true,
392            auto_sync: true,
393            profile: String::new(),
394            regions: "us-central1-a, europe-west1-b".to_string(),
395            project: "my-project".to_string(),
396        };
397        let p = get_provider_with_config("gcp", &section).unwrap();
398        assert_eq!(p.name(), "gcp");
399    }
400
401    #[test]
402    fn test_get_provider_with_config_unknown_returns_none() {
403        let section = config::ProviderSection {
404            provider: "oracle".to_string(),
405            token: String::new(),
406            alias_prefix: String::new(),
407            user: String::new(),
408            identity_file: String::new(),
409            url: String::new(),
410            verify_tls: true,
411            auto_sync: true,
412            profile: String::new(),
413            regions: String::new(),
414            project: String::new(),
415        };
416        assert!(get_provider_with_config("oracle", &section).is_none());
417    }
418
419    // =========================================================================
420    // provider_display_name tests
421    // =========================================================================
422
423    #[test]
424    fn test_display_name_all_providers() {
425        assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
426        assert_eq!(provider_display_name("vultr"), "Vultr");
427        assert_eq!(provider_display_name("linode"), "Linode");
428        assert_eq!(provider_display_name("hetzner"), "Hetzner");
429        assert_eq!(provider_display_name("upcloud"), "UpCloud");
430        assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
431        assert_eq!(provider_display_name("aws"), "AWS EC2");
432        assert_eq!(provider_display_name("scaleway"), "Scaleway");
433        assert_eq!(provider_display_name("gcp"), "GCP");
434        assert_eq!(provider_display_name("azure"), "Azure");
435        assert_eq!(provider_display_name("tailscale"), "Tailscale");
436    }
437
438    #[test]
439    fn test_display_name_unknown_returns_input() {
440        assert_eq!(provider_display_name("oracle"), "oracle");
441        assert_eq!(provider_display_name(""), "");
442    }
443
444    // =========================================================================
445    // PROVIDER_NAMES constant tests
446    // =========================================================================
447
448    #[test]
449    fn test_provider_names_count() {
450        assert_eq!(PROVIDER_NAMES.len(), 11);
451    }
452
453    #[test]
454    fn test_provider_names_contains_all() {
455        assert!(PROVIDER_NAMES.contains(&"digitalocean"));
456        assert!(PROVIDER_NAMES.contains(&"vultr"));
457        assert!(PROVIDER_NAMES.contains(&"linode"));
458        assert!(PROVIDER_NAMES.contains(&"hetzner"));
459        assert!(PROVIDER_NAMES.contains(&"upcloud"));
460        assert!(PROVIDER_NAMES.contains(&"proxmox"));
461        assert!(PROVIDER_NAMES.contains(&"aws"));
462        assert!(PROVIDER_NAMES.contains(&"scaleway"));
463        assert!(PROVIDER_NAMES.contains(&"gcp"));
464        assert!(PROVIDER_NAMES.contains(&"azure"));
465        assert!(PROVIDER_NAMES.contains(&"tailscale"));
466    }
467
468    // =========================================================================
469    // ProviderError display tests
470    // =========================================================================
471
472    #[test]
473    fn test_provider_error_display_http() {
474        let err = ProviderError::Http("connection refused".to_string());
475        assert_eq!(format!("{}", err), "HTTP error: connection refused");
476    }
477
478    #[test]
479    fn test_provider_error_display_parse() {
480        let err = ProviderError::Parse("invalid JSON".to_string());
481        assert_eq!(format!("{}", err), "Failed to parse response: invalid JSON");
482    }
483
484    #[test]
485    fn test_provider_error_display_auth() {
486        let err = ProviderError::AuthFailed;
487        assert!(format!("{}", err).contains("Authentication failed"));
488    }
489
490    #[test]
491    fn test_provider_error_display_rate_limited() {
492        let err = ProviderError::RateLimited;
493        assert!(format!("{}", err).contains("Rate limited"));
494    }
495
496    #[test]
497    fn test_provider_error_display_cancelled() {
498        let err = ProviderError::Cancelled;
499        assert_eq!(format!("{}", err), "Cancelled.");
500    }
501
502    #[test]
503    fn test_provider_error_display_partial_result() {
504        let err = ProviderError::PartialResult {
505            hosts: vec![],
506            failures: 3,
507            total: 10,
508        };
509        assert!(format!("{}", err).contains("3 of 10 failed"));
510    }
511
512    // =========================================================================
513    // ProviderHost struct tests
514    // =========================================================================
515
516    #[test]
517    fn test_provider_host_construction() {
518        let host = ProviderHost::new("12345".to_string(), "web-01".to_string(), "1.2.3.4".to_string(), vec!["prod".to_string(), "web".to_string()]);
519        assert_eq!(host.server_id, "12345");
520        assert_eq!(host.name, "web-01");
521        assert_eq!(host.ip, "1.2.3.4");
522        assert_eq!(host.tags.len(), 2);
523    }
524
525    #[test]
526    fn test_provider_host_clone() {
527        let host = ProviderHost::new("1".to_string(), "a".to_string(), "1.1.1.1".to_string(), vec![]);
528        let cloned = host.clone();
529        assert_eq!(cloned.server_id, host.server_id);
530        assert_eq!(cloned.name, host.name);
531    }
532
533    // =========================================================================
534    // strip_cidr additional edge cases
535    // =========================================================================
536
537    #[test]
538    fn test_strip_cidr_ipv6_with_64() {
539        assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
540    }
541
542    #[test]
543    fn test_strip_cidr_ipv4_with_32() {
544        assert_eq!(strip_cidr("1.2.3.4/32"), "1.2.3.4");
545    }
546
547    #[test]
548    fn test_strip_cidr_ipv4_with_8() {
549        assert_eq!(strip_cidr("10.0.0.1/8"), "10.0.0.1");
550    }
551
552    #[test]
553    fn test_strip_cidr_just_slash() {
554        // "/" alone: pos=0, pos+1=1=len -> condition fails
555        assert_eq!(strip_cidr("/"), "/");
556    }
557
558    #[test]
559    fn test_strip_cidr_slash_with_letters() {
560        assert_eq!(strip_cidr("10.0.0.1/abc"), "10.0.0.1/abc");
561    }
562
563    #[test]
564    fn test_strip_cidr_multiple_slashes() {
565        // rfind gets last slash: "48" is digits, so it strips the last /48
566        assert_eq!(strip_cidr("10.0.0.1/24/48"), "10.0.0.1/24");
567    }
568
569    #[test]
570    fn test_strip_cidr_ipv6_full_notation() {
571        assert_eq!(
572            strip_cidr("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128"),
573            "2001:0db8:85a3:0000:0000:8a2e:0370:7334"
574        );
575    }
576
577    // =========================================================================
578    // ProviderError Debug
579    // =========================================================================
580
581    #[test]
582    fn test_provider_error_debug_http() {
583        let err = ProviderError::Http("timeout".to_string());
584        let debug = format!("{:?}", err);
585        assert!(debug.contains("Http"));
586        assert!(debug.contains("timeout"));
587    }
588
589    #[test]
590    fn test_provider_error_debug_partial_result() {
591        let err = ProviderError::PartialResult {
592            hosts: vec![ProviderHost::new("1".to_string(), "web".to_string(), "1.2.3.4".to_string(), vec![])],
593            failures: 2,
594            total: 5,
595        };
596        let debug = format!("{:?}", err);
597        assert!(debug.contains("PartialResult"));
598        assert!(debug.contains("failures: 2"));
599    }
600
601    // =========================================================================
602    // ProviderHost with empty fields
603    // =========================================================================
604
605    #[test]
606    fn test_provider_host_empty_fields() {
607        let host = ProviderHost::new(String::new(), String::new(), String::new(), vec![]);
608        assert!(host.server_id.is_empty());
609        assert!(host.name.is_empty());
610        assert!(host.ip.is_empty());
611    }
612
613    // =========================================================================
614    // get_provider_with_config for all non-proxmox providers
615    // =========================================================================
616
617    #[test]
618    fn test_get_provider_with_config_all_providers() {
619        for &name in PROVIDER_NAMES {
620            let section = config::ProviderSection {
621                provider: name.to_string(),
622                token: "tok".to_string(),
623                alias_prefix: "test".to_string(),
624                user: String::new(),
625                identity_file: String::new(),
626                url: if name == "proxmox" {
627                    "https://pve:8006".to_string()
628                } else {
629                    String::new()
630                },
631                verify_tls: true,
632                auto_sync: true,
633                profile: String::new(),
634                regions: String::new(),
635                project: String::new(),
636            };
637            let p = get_provider_with_config(name, &section);
638            assert!(p.is_some(), "get_provider_with_config({}) should return Some", name);
639            assert_eq!(p.unwrap().name(), name);
640        }
641    }
642
643    // =========================================================================
644    // Provider trait default methods
645    // =========================================================================
646
647    #[test]
648    fn test_provider_fetch_hosts_delegates_to_cancellable() {
649        let provider = get_provider("digitalocean").unwrap();
650        // fetch_hosts delegates to fetch_hosts_cancellable with AtomicBool(false)
651        // We can't actually test this without a server, but we verify the method exists
652        // by calling it (will fail with network error, which is fine for this test)
653        let result = provider.fetch_hosts("fake-token");
654        assert!(result.is_err()); // Expected: no network
655    }
656
657    // =========================================================================
658    // strip_cidr: suffix starts with digit but contains letters
659    // =========================================================================
660
661    #[test]
662    fn test_strip_cidr_digit_then_letters_not_stripped() {
663        assert_eq!(strip_cidr("10.0.0.1/24abc"), "10.0.0.1/24abc");
664    }
665
666    // =========================================================================
667    // provider_display_name: all known providers
668    // =========================================================================
669
670    #[test]
671    fn test_provider_display_name_all() {
672        assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
673        assert_eq!(provider_display_name("vultr"), "Vultr");
674        assert_eq!(provider_display_name("linode"), "Linode");
675        assert_eq!(provider_display_name("hetzner"), "Hetzner");
676        assert_eq!(provider_display_name("upcloud"), "UpCloud");
677        assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
678        assert_eq!(provider_display_name("aws"), "AWS EC2");
679        assert_eq!(provider_display_name("scaleway"), "Scaleway");
680        assert_eq!(provider_display_name("gcp"), "GCP");
681        assert_eq!(provider_display_name("azure"), "Azure");
682        assert_eq!(provider_display_name("tailscale"), "Tailscale");
683    }
684
685    #[test]
686    fn test_provider_display_name_unknown() {
687        assert_eq!(provider_display_name("oracle"), "oracle");
688    }
689
690    // =========================================================================
691    // get_provider: all known + unknown
692    // =========================================================================
693
694    #[test]
695    fn test_get_provider_all_known() {
696        for name in PROVIDER_NAMES {
697            assert!(get_provider(name).is_some(), "get_provider({}) should return Some", name);
698        }
699    }
700
701    #[test]
702    fn test_get_provider_case_sensitive_and_unknown() {
703        assert!(get_provider("oracle").is_none());
704        assert!(get_provider("DigitalOcean").is_none()); // Case-sensitive
705        assert!(get_provider("VULTR").is_none());
706        assert!(get_provider("").is_none());
707    }
708
709    // =========================================================================
710    // PROVIDER_NAMES constant
711    // =========================================================================
712
713    #[test]
714    fn test_provider_names_has_all_eleven() {
715        assert_eq!(PROVIDER_NAMES.len(), 11);
716        assert!(PROVIDER_NAMES.contains(&"digitalocean"));
717        assert!(PROVIDER_NAMES.contains(&"proxmox"));
718        assert!(PROVIDER_NAMES.contains(&"aws"));
719        assert!(PROVIDER_NAMES.contains(&"scaleway"));
720        assert!(PROVIDER_NAMES.contains(&"azure"));
721        assert!(PROVIDER_NAMES.contains(&"tailscale"));
722    }
723
724    // =========================================================================
725    // Provider short_label via get_provider
726    // =========================================================================
727
728    #[test]
729    fn test_provider_short_labels() {
730        let cases = [
731            ("digitalocean", "do"),
732            ("vultr", "vultr"),
733            ("linode", "linode"),
734            ("hetzner", "hetzner"),
735            ("upcloud", "uc"),
736            ("proxmox", "pve"),
737            ("aws", "aws"),
738            ("scaleway", "scw"),
739            ("gcp", "gcp"),
740            ("azure", "az"),
741            ("tailscale", "ts"),
742        ];
743        for (name, expected_label) in &cases {
744            let p = get_provider(name).unwrap();
745            assert_eq!(p.short_label(), *expected_label, "short_label for {}", name);
746        }
747    }
748}