Skip to main content

purple_ssh/providers/
mod.rs

1pub mod aws;
2pub mod azure;
3pub mod config;
4mod digitalocean;
5pub mod gcp;
6mod hetzner;
7mod linode;
8mod proxmox;
9pub mod scaleway;
10pub mod sync;
11mod upcloud;
12mod vultr;
13
14use std::sync::atomic::AtomicBool;
15
16use thiserror::Error;
17
18/// A host discovered from a cloud provider API.
19#[derive(Debug, Clone)]
20#[allow(dead_code)]
21pub struct ProviderHost {
22    /// Provider-assigned server ID.
23    pub server_id: String,
24    /// Server name/label.
25    pub name: String,
26    /// Public IP address (IPv4 or IPv6).
27    pub ip: String,
28    /// Provider tags/labels.
29    pub tags: Vec<String>,
30    /// Provider metadata (region, plan, etc.) as key-value pairs.
31    pub metadata: Vec<(String, String)>,
32}
33
34impl ProviderHost {
35    /// Create a ProviderHost with no metadata.
36    #[allow(dead_code)]
37    pub fn new(server_id: String, name: String, ip: String, tags: Vec<String>) -> Self {
38        Self {
39            server_id,
40            name,
41            ip,
42            tags,
43            metadata: Vec::new(),
44        }
45    }
46}
47
48/// Errors from provider API calls.
49#[derive(Debug, Error)]
50pub enum ProviderError {
51    #[error("HTTP error: {0}")]
52    Http(String),
53    #[error("Failed to parse response: {0}")]
54    Parse(String),
55    #[error("Authentication failed. Check your API token.")]
56    AuthFailed,
57    #[error("Rate limited. Try again in a moment.")]
58    RateLimited,
59    #[error("Cancelled.")]
60    Cancelled,
61    /// Some hosts were fetched but others failed. The caller should use the
62    /// hosts but suppress destructive operations like --remove.
63    #[error("Partial result: {failures} of {total} failed")]
64    PartialResult {
65        hosts: Vec<ProviderHost>,
66        failures: usize,
67        total: usize,
68    },
69}
70
71/// Trait implemented by each cloud provider.
72pub trait Provider {
73    /// Full provider name (e.g. "digitalocean").
74    fn name(&self) -> &str;
75    /// Short label for aliases (e.g. "do").
76    fn short_label(&self) -> &str;
77    /// Fetch hosts with cancellation support.
78    fn fetch_hosts_cancellable(
79        &self,
80        token: &str,
81        cancel: &AtomicBool,
82    ) -> Result<Vec<ProviderHost>, ProviderError>;
83    /// Fetch all servers from the provider API.
84    #[allow(dead_code)]
85    fn fetch_hosts(&self, token: &str) -> Result<Vec<ProviderHost>, ProviderError> {
86        self.fetch_hosts_cancellable(token, &AtomicBool::new(false))
87    }
88    /// Fetch hosts with progress reporting. Default delegates to fetch_hosts_cancellable.
89    fn fetch_hosts_with_progress(
90        &self,
91        token: &str,
92        cancel: &AtomicBool,
93        _progress: &dyn Fn(&str),
94    ) -> Result<Vec<ProviderHost>, ProviderError> {
95        self.fetch_hosts_cancellable(token, cancel)
96    }
97}
98
99/// All known provider names.
100pub const PROVIDER_NAMES: &[&str] = &["digitalocean", "vultr", "linode", "hetzner", "upcloud", "proxmox", "aws", "scaleway", "gcp", "azure"];
101
102/// Get a provider implementation by name.
103pub fn get_provider(name: &str) -> Option<Box<dyn Provider>> {
104    match name {
105        "digitalocean" => Some(Box::new(digitalocean::DigitalOcean)),
106        "vultr" => Some(Box::new(vultr::Vultr)),
107        "linode" => Some(Box::new(linode::Linode)),
108        "hetzner" => Some(Box::new(hetzner::Hetzner)),
109        "upcloud" => Some(Box::new(upcloud::UpCloud)),
110        "proxmox" => Some(Box::new(proxmox::Proxmox {
111            base_url: String::new(),
112            verify_tls: true,
113        })),
114        "aws" => Some(Box::new(aws::Aws {
115            regions: Vec::new(),
116            profile: String::new(),
117        })),
118        "scaleway" => Some(Box::new(scaleway::Scaleway {
119            zones: Vec::new(),
120        })),
121        "gcp" => Some(Box::new(gcp::Gcp {
122            zones: Vec::new(),
123            project: String::new(),
124        })),
125        "azure" => Some(Box::new(azure::Azure {
126            subscriptions: Vec::new(),
127        })),
128        _ => None,
129    }
130}
131
132/// Get a provider implementation configured from a provider section.
133/// For providers that need extra config (e.g. Proxmox base URL), this
134/// creates a properly configured instance.
135pub fn get_provider_with_config(name: &str, section: &config::ProviderSection) -> Option<Box<dyn Provider>> {
136    match name {
137        "proxmox" => Some(Box::new(proxmox::Proxmox {
138            base_url: section.url.clone(),
139            verify_tls: section.verify_tls,
140        })),
141        "aws" => Some(Box::new(aws::Aws {
142            regions: section.regions.split(',')
143                .map(|s| s.trim().to_string())
144                .filter(|s| !s.is_empty())
145                .collect(),
146            profile: section.profile.clone(),
147        })),
148        "scaleway" => Some(Box::new(scaleway::Scaleway {
149            zones: section.regions.split(',')
150                .map(|s| s.trim().to_string())
151                .filter(|s| !s.is_empty())
152                .collect(),
153        })),
154        "gcp" => Some(Box::new(gcp::Gcp {
155            zones: section.regions.split(',')
156                .map(|s| s.trim().to_string())
157                .filter(|s| !s.is_empty())
158                .collect(),
159            project: section.project.clone(),
160        })),
161        "azure" => Some(Box::new(azure::Azure {
162            subscriptions: section.regions.split(',')
163                .map(|s| s.trim().to_string())
164                .filter(|s| !s.is_empty())
165                .collect(),
166        })),
167        _ => get_provider(name),
168    }
169}
170
171/// Display name for a provider (e.g. "digitalocean" -> "DigitalOcean").
172pub fn provider_display_name(name: &str) -> &str {
173    match name {
174        "digitalocean" => "DigitalOcean",
175        "vultr" => "Vultr",
176        "linode" => "Linode",
177        "hetzner" => "Hetzner",
178        "upcloud" => "UpCloud",
179        "proxmox" => "Proxmox VE",
180        "aws" => "AWS EC2",
181        "scaleway" => "Scaleway",
182        "gcp" => "GCP",
183        "azure" => "Azure",
184        other => other,
185    }
186}
187
188/// Create an HTTP agent with explicit timeouts.
189pub(crate) fn http_agent() -> ureq::Agent {
190    ureq::AgentBuilder::new()
191        .timeout(std::time::Duration::from_secs(30))
192        .redirects(0)
193        .build()
194}
195
196/// Create an HTTP agent that accepts invalid/self-signed TLS certificates.
197pub(crate) fn http_agent_insecure() -> Result<ureq::Agent, ProviderError> {
198    let tls = ureq::native_tls::TlsConnector::builder()
199        .danger_accept_invalid_certs(true)
200        .danger_accept_invalid_hostnames(true)
201        .build()
202        .map_err(|e| ProviderError::Http(format!("TLS setup failed: {}", e)))?;
203    Ok(ureq::AgentBuilder::new()
204        .timeout(std::time::Duration::from_secs(30))
205        .redirects(0)
206        .tls_connector(std::sync::Arc::new(tls))
207        .build())
208}
209
210/// Strip CIDR suffix (/64, /128, etc.) from an IP address.
211/// Some provider APIs return IPv6 addresses with prefix length (e.g. "2600:3c00::1/128").
212/// SSH requires bare addresses without CIDR notation.
213pub(crate) fn strip_cidr(ip: &str) -> &str {
214    // Only strip if it looks like a CIDR suffix (slash followed by digits)
215    if let Some(pos) = ip.rfind('/') {
216        if ip[pos + 1..].bytes().all(|b| b.is_ascii_digit()) && pos + 1 < ip.len() {
217            return &ip[..pos];
218        }
219    }
220    ip
221}
222
223/// Map a ureq error to a ProviderError.
224fn map_ureq_error(err: ureq::Error) -> ProviderError {
225    match err {
226        ureq::Error::Status(401, _) | ureq::Error::Status(403, _) => ProviderError::AuthFailed,
227        ureq::Error::Status(429, _) => ProviderError::RateLimited,
228        ureq::Error::Status(code, _) => ProviderError::Http(format!("HTTP {}", code)),
229        ureq::Error::Transport(t) => ProviderError::Http(t.to_string()),
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    // =========================================================================
238    // strip_cidr tests
239    // =========================================================================
240
241    #[test]
242    fn test_strip_cidr_ipv6_with_prefix() {
243        assert_eq!(strip_cidr("2600:3c00::1/128"), "2600:3c00::1");
244        assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
245    }
246
247    #[test]
248    fn test_strip_cidr_bare_ipv6() {
249        assert_eq!(strip_cidr("2600:3c00::1"), "2600:3c00::1");
250    }
251
252    #[test]
253    fn test_strip_cidr_ipv4_passthrough() {
254        assert_eq!(strip_cidr("1.2.3.4"), "1.2.3.4");
255        assert_eq!(strip_cidr("10.0.0.1/24"), "10.0.0.1");
256    }
257
258    #[test]
259    fn test_strip_cidr_empty() {
260        assert_eq!(strip_cidr(""), "");
261    }
262
263    #[test]
264    fn test_strip_cidr_slash_without_digits() {
265        // Shouldn't strip if after slash there are non-digits
266        assert_eq!(strip_cidr("path/to/something"), "path/to/something");
267    }
268
269    #[test]
270    fn test_strip_cidr_trailing_slash() {
271        // Trailing slash with nothing after: pos+1 == ip.len(), should NOT strip
272        assert_eq!(strip_cidr("1.2.3.4/"), "1.2.3.4/");
273    }
274
275    // =========================================================================
276    // get_provider factory tests
277    // =========================================================================
278
279    #[test]
280    fn test_get_provider_digitalocean() {
281        let p = get_provider("digitalocean").unwrap();
282        assert_eq!(p.name(), "digitalocean");
283        assert_eq!(p.short_label(), "do");
284    }
285
286    #[test]
287    fn test_get_provider_vultr() {
288        let p = get_provider("vultr").unwrap();
289        assert_eq!(p.name(), "vultr");
290        assert_eq!(p.short_label(), "vultr");
291    }
292
293    #[test]
294    fn test_get_provider_linode() {
295        let p = get_provider("linode").unwrap();
296        assert_eq!(p.name(), "linode");
297        assert_eq!(p.short_label(), "linode");
298    }
299
300    #[test]
301    fn test_get_provider_hetzner() {
302        let p = get_provider("hetzner").unwrap();
303        assert_eq!(p.name(), "hetzner");
304        assert_eq!(p.short_label(), "hetzner");
305    }
306
307    #[test]
308    fn test_get_provider_upcloud() {
309        let p = get_provider("upcloud").unwrap();
310        assert_eq!(p.name(), "upcloud");
311        assert_eq!(p.short_label(), "uc");
312    }
313
314    #[test]
315    fn test_get_provider_proxmox() {
316        let p = get_provider("proxmox").unwrap();
317        assert_eq!(p.name(), "proxmox");
318        assert_eq!(p.short_label(), "pve");
319    }
320
321    #[test]
322    fn test_get_provider_unknown_returns_none() {
323        assert!(get_provider("oracle").is_none());
324        assert!(get_provider("").is_none());
325        assert!(get_provider("DigitalOcean").is_none()); // case-sensitive
326    }
327
328    #[test]
329    fn test_get_provider_all_names_resolve() {
330        for name in PROVIDER_NAMES {
331            assert!(get_provider(name).is_some(), "Provider '{}' should resolve", name);
332        }
333    }
334
335    // =========================================================================
336    // get_provider_with_config tests
337    // =========================================================================
338
339    #[test]
340    fn test_get_provider_with_config_proxmox_uses_url() {
341        let section = config::ProviderSection {
342            provider: "proxmox".to_string(),
343            token: "user@pam!token=secret".to_string(),
344            alias_prefix: "pve-".to_string(),
345            user: String::new(),
346            identity_file: String::new(),
347            url: "https://pve.example.com:8006".to_string(),
348            verify_tls: false,
349            auto_sync: false,
350            profile: String::new(),
351            regions: String::new(),
352            project: String::new(),
353        };
354        let p = get_provider_with_config("proxmox", &section).unwrap();
355        assert_eq!(p.name(), "proxmox");
356    }
357
358    #[test]
359    fn test_get_provider_with_config_non_proxmox_delegates() {
360        let section = config::ProviderSection {
361            provider: "digitalocean".to_string(),
362            token: "do-token".to_string(),
363            alias_prefix: "do-".to_string(),
364            user: String::new(),
365            identity_file: String::new(),
366            url: String::new(),
367            verify_tls: true,
368            auto_sync: true,
369            profile: String::new(),
370            regions: String::new(),
371            project: String::new(),
372        };
373        let p = get_provider_with_config("digitalocean", &section).unwrap();
374        assert_eq!(p.name(), "digitalocean");
375    }
376
377    #[test]
378    fn test_get_provider_with_config_gcp_uses_project_and_zones() {
379        let section = config::ProviderSection {
380            provider: "gcp".to_string(),
381            token: "sa.json".to_string(),
382            alias_prefix: "gcp".to_string(),
383            user: String::new(),
384            identity_file: String::new(),
385            url: String::new(),
386            verify_tls: true,
387            auto_sync: true,
388            profile: String::new(),
389            regions: "us-central1-a, europe-west1-b".to_string(),
390            project: "my-project".to_string(),
391        };
392        let p = get_provider_with_config("gcp", &section).unwrap();
393        assert_eq!(p.name(), "gcp");
394    }
395
396    #[test]
397    fn test_get_provider_with_config_unknown_returns_none() {
398        let section = config::ProviderSection {
399            provider: "oracle".to_string(),
400            token: String::new(),
401            alias_prefix: String::new(),
402            user: String::new(),
403            identity_file: String::new(),
404            url: String::new(),
405            verify_tls: true,
406            auto_sync: true,
407            profile: String::new(),
408            regions: String::new(),
409            project: String::new(),
410        };
411        assert!(get_provider_with_config("oracle", &section).is_none());
412    }
413
414    // =========================================================================
415    // provider_display_name tests
416    // =========================================================================
417
418    #[test]
419    fn test_display_name_all_providers() {
420        assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
421        assert_eq!(provider_display_name("vultr"), "Vultr");
422        assert_eq!(provider_display_name("linode"), "Linode");
423        assert_eq!(provider_display_name("hetzner"), "Hetzner");
424        assert_eq!(provider_display_name("upcloud"), "UpCloud");
425        assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
426        assert_eq!(provider_display_name("aws"), "AWS EC2");
427        assert_eq!(provider_display_name("scaleway"), "Scaleway");
428        assert_eq!(provider_display_name("gcp"), "GCP");
429        assert_eq!(provider_display_name("azure"), "Azure");
430    }
431
432    #[test]
433    fn test_display_name_unknown_returns_input() {
434        assert_eq!(provider_display_name("oracle"), "oracle");
435        assert_eq!(provider_display_name(""), "");
436    }
437
438    // =========================================================================
439    // PROVIDER_NAMES constant tests
440    // =========================================================================
441
442    #[test]
443    fn test_provider_names_count() {
444        assert_eq!(PROVIDER_NAMES.len(), 10);
445    }
446
447    #[test]
448    fn test_provider_names_contains_all() {
449        assert!(PROVIDER_NAMES.contains(&"digitalocean"));
450        assert!(PROVIDER_NAMES.contains(&"vultr"));
451        assert!(PROVIDER_NAMES.contains(&"linode"));
452        assert!(PROVIDER_NAMES.contains(&"hetzner"));
453        assert!(PROVIDER_NAMES.contains(&"upcloud"));
454        assert!(PROVIDER_NAMES.contains(&"proxmox"));
455        assert!(PROVIDER_NAMES.contains(&"aws"));
456        assert!(PROVIDER_NAMES.contains(&"scaleway"));
457        assert!(PROVIDER_NAMES.contains(&"gcp"));
458        assert!(PROVIDER_NAMES.contains(&"azure"));
459    }
460
461    // =========================================================================
462    // ProviderError display tests
463    // =========================================================================
464
465    #[test]
466    fn test_provider_error_display_http() {
467        let err = ProviderError::Http("connection refused".to_string());
468        assert_eq!(format!("{}", err), "HTTP error: connection refused");
469    }
470
471    #[test]
472    fn test_provider_error_display_parse() {
473        let err = ProviderError::Parse("invalid JSON".to_string());
474        assert_eq!(format!("{}", err), "Failed to parse response: invalid JSON");
475    }
476
477    #[test]
478    fn test_provider_error_display_auth() {
479        let err = ProviderError::AuthFailed;
480        assert!(format!("{}", err).contains("Authentication failed"));
481    }
482
483    #[test]
484    fn test_provider_error_display_rate_limited() {
485        let err = ProviderError::RateLimited;
486        assert!(format!("{}", err).contains("Rate limited"));
487    }
488
489    #[test]
490    fn test_provider_error_display_cancelled() {
491        let err = ProviderError::Cancelled;
492        assert_eq!(format!("{}", err), "Cancelled.");
493    }
494
495    #[test]
496    fn test_provider_error_display_partial_result() {
497        let err = ProviderError::PartialResult {
498            hosts: vec![],
499            failures: 3,
500            total: 10,
501        };
502        assert!(format!("{}", err).contains("3 of 10 failed"));
503    }
504
505    // =========================================================================
506    // ProviderHost struct tests
507    // =========================================================================
508
509    #[test]
510    fn test_provider_host_construction() {
511        let host = ProviderHost::new("12345".to_string(), "web-01".to_string(), "1.2.3.4".to_string(), vec!["prod".to_string(), "web".to_string()]);
512        assert_eq!(host.server_id, "12345");
513        assert_eq!(host.name, "web-01");
514        assert_eq!(host.ip, "1.2.3.4");
515        assert_eq!(host.tags.len(), 2);
516    }
517
518    #[test]
519    fn test_provider_host_clone() {
520        let host = ProviderHost::new("1".to_string(), "a".to_string(), "1.1.1.1".to_string(), vec![]);
521        let cloned = host.clone();
522        assert_eq!(cloned.server_id, host.server_id);
523        assert_eq!(cloned.name, host.name);
524    }
525
526    // =========================================================================
527    // strip_cidr additional edge cases
528    // =========================================================================
529
530    #[test]
531    fn test_strip_cidr_ipv6_with_64() {
532        assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
533    }
534
535    #[test]
536    fn test_strip_cidr_ipv4_with_32() {
537        assert_eq!(strip_cidr("1.2.3.4/32"), "1.2.3.4");
538    }
539
540    #[test]
541    fn test_strip_cidr_ipv4_with_8() {
542        assert_eq!(strip_cidr("10.0.0.1/8"), "10.0.0.1");
543    }
544
545    #[test]
546    fn test_strip_cidr_just_slash() {
547        // "/" alone: pos=0, pos+1=1=len -> condition fails
548        assert_eq!(strip_cidr("/"), "/");
549    }
550
551    #[test]
552    fn test_strip_cidr_slash_with_letters() {
553        assert_eq!(strip_cidr("10.0.0.1/abc"), "10.0.0.1/abc");
554    }
555
556    #[test]
557    fn test_strip_cidr_multiple_slashes() {
558        // rfind gets last slash: "48" is digits, so it strips the last /48
559        assert_eq!(strip_cidr("10.0.0.1/24/48"), "10.0.0.1/24");
560    }
561
562    #[test]
563    fn test_strip_cidr_ipv6_full_notation() {
564        assert_eq!(
565            strip_cidr("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128"),
566            "2001:0db8:85a3:0000:0000:8a2e:0370:7334"
567        );
568    }
569
570    // =========================================================================
571    // ProviderError Debug
572    // =========================================================================
573
574    #[test]
575    fn test_provider_error_debug_http() {
576        let err = ProviderError::Http("timeout".to_string());
577        let debug = format!("{:?}", err);
578        assert!(debug.contains("Http"));
579        assert!(debug.contains("timeout"));
580    }
581
582    #[test]
583    fn test_provider_error_debug_partial_result() {
584        let err = ProviderError::PartialResult {
585            hosts: vec![ProviderHost::new("1".to_string(), "web".to_string(), "1.2.3.4".to_string(), vec![])],
586            failures: 2,
587            total: 5,
588        };
589        let debug = format!("{:?}", err);
590        assert!(debug.contains("PartialResult"));
591        assert!(debug.contains("failures: 2"));
592    }
593
594    // =========================================================================
595    // ProviderHost with empty fields
596    // =========================================================================
597
598    #[test]
599    fn test_provider_host_empty_fields() {
600        let host = ProviderHost::new(String::new(), String::new(), String::new(), vec![]);
601        assert!(host.server_id.is_empty());
602        assert!(host.name.is_empty());
603        assert!(host.ip.is_empty());
604    }
605
606    // =========================================================================
607    // get_provider_with_config for all non-proxmox providers
608    // =========================================================================
609
610    #[test]
611    fn test_get_provider_with_config_all_providers() {
612        for &name in PROVIDER_NAMES {
613            let section = config::ProviderSection {
614                provider: name.to_string(),
615                token: "tok".to_string(),
616                alias_prefix: "test".to_string(),
617                user: String::new(),
618                identity_file: String::new(),
619                url: if name == "proxmox" {
620                    "https://pve:8006".to_string()
621                } else {
622                    String::new()
623                },
624                verify_tls: true,
625                auto_sync: true,
626                profile: String::new(),
627                regions: String::new(),
628                project: String::new(),
629            };
630            let p = get_provider_with_config(name, &section);
631            assert!(p.is_some(), "get_provider_with_config({}) should return Some", name);
632            assert_eq!(p.unwrap().name(), name);
633        }
634    }
635
636    // =========================================================================
637    // Provider trait default methods
638    // =========================================================================
639
640    #[test]
641    fn test_provider_fetch_hosts_delegates_to_cancellable() {
642        let provider = get_provider("digitalocean").unwrap();
643        // fetch_hosts delegates to fetch_hosts_cancellable with AtomicBool(false)
644        // We can't actually test this without a server, but we verify the method exists
645        // by calling it (will fail with network error, which is fine for this test)
646        let result = provider.fetch_hosts("fake-token");
647        assert!(result.is_err()); // Expected: no network
648    }
649
650    // =========================================================================
651    // strip_cidr: suffix starts with digit but contains letters
652    // =========================================================================
653
654    #[test]
655    fn test_strip_cidr_digit_then_letters_not_stripped() {
656        assert_eq!(strip_cidr("10.0.0.1/24abc"), "10.0.0.1/24abc");
657    }
658
659    // =========================================================================
660    // provider_display_name: all known providers
661    // =========================================================================
662
663    #[test]
664    fn test_provider_display_name_all() {
665        assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
666        assert_eq!(provider_display_name("vultr"), "Vultr");
667        assert_eq!(provider_display_name("linode"), "Linode");
668        assert_eq!(provider_display_name("hetzner"), "Hetzner");
669        assert_eq!(provider_display_name("upcloud"), "UpCloud");
670        assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
671        assert_eq!(provider_display_name("aws"), "AWS EC2");
672        assert_eq!(provider_display_name("scaleway"), "Scaleway");
673        assert_eq!(provider_display_name("gcp"), "GCP");
674        assert_eq!(provider_display_name("azure"), "Azure");
675    }
676
677    #[test]
678    fn test_provider_display_name_unknown() {
679        assert_eq!(provider_display_name("oracle"), "oracle");
680    }
681
682    // =========================================================================
683    // get_provider: all known + unknown
684    // =========================================================================
685
686    #[test]
687    fn test_get_provider_all_known() {
688        for name in PROVIDER_NAMES {
689            assert!(get_provider(name).is_some(), "get_provider({}) should return Some", name);
690        }
691    }
692
693    #[test]
694    fn test_get_provider_case_sensitive_and_unknown() {
695        assert!(get_provider("oracle").is_none());
696        assert!(get_provider("DigitalOcean").is_none()); // Case-sensitive
697        assert!(get_provider("VULTR").is_none());
698        assert!(get_provider("").is_none());
699    }
700
701    // =========================================================================
702    // PROVIDER_NAMES constant
703    // =========================================================================
704
705    #[test]
706    fn test_provider_names_has_all_ten() {
707        assert_eq!(PROVIDER_NAMES.len(), 10);
708        assert!(PROVIDER_NAMES.contains(&"digitalocean"));
709        assert!(PROVIDER_NAMES.contains(&"proxmox"));
710        assert!(PROVIDER_NAMES.contains(&"aws"));
711        assert!(PROVIDER_NAMES.contains(&"scaleway"));
712        assert!(PROVIDER_NAMES.contains(&"azure"));
713    }
714
715    // =========================================================================
716    // Provider short_label via get_provider
717    // =========================================================================
718
719    #[test]
720    fn test_provider_short_labels() {
721        let cases = [
722            ("digitalocean", "do"),
723            ("vultr", "vultr"),
724            ("linode", "linode"),
725            ("hetzner", "hetzner"),
726            ("upcloud", "uc"),
727            ("proxmox", "pve"),
728            ("aws", "aws"),
729            ("scaleway", "scw"),
730            ("gcp", "gcp"),
731            ("azure", "az"),
732        ];
733        for (name, expected_label) in &cases {
734            let p = get_provider(name).unwrap();
735            assert_eq!(p.short_label(), *expected_label, "short_label for {}", name);
736        }
737    }
738}