Skip to main content

purple_ssh/providers/
mod.rs

1pub mod aws;
2pub mod config;
3mod digitalocean;
4mod hetzner;
5mod linode;
6mod proxmox;
7pub mod scaleway;
8pub mod sync;
9mod upcloud;
10mod vultr;
11
12use std::sync::atomic::AtomicBool;
13
14use thiserror::Error;
15
16/// A host discovered from a cloud provider API.
17#[derive(Debug, Clone)]
18#[allow(dead_code)]
19pub struct ProviderHost {
20    /// Provider-assigned server ID.
21    pub server_id: String,
22    /// Server name/label.
23    pub name: String,
24    /// Public IP address (IPv4 or IPv6).
25    pub ip: String,
26    /// Provider tags/labels.
27    pub tags: Vec<String>,
28    /// Provider metadata (region, plan, etc.) as key-value pairs.
29    pub metadata: Vec<(String, String)>,
30}
31
32impl ProviderHost {
33    /// Create a ProviderHost with no metadata.
34    #[allow(dead_code)]
35    pub fn new(server_id: String, name: String, ip: String, tags: Vec<String>) -> Self {
36        Self {
37            server_id,
38            name,
39            ip,
40            tags,
41            metadata: Vec::new(),
42        }
43    }
44}
45
46/// Errors from provider API calls.
47#[derive(Debug, Error)]
48pub enum ProviderError {
49    #[error("HTTP error: {0}")]
50    Http(String),
51    #[error("Failed to parse response: {0}")]
52    Parse(String),
53    #[error("Authentication failed. Check your API token.")]
54    AuthFailed,
55    #[error("Rate limited. Try again in a moment.")]
56    RateLimited,
57    #[error("Cancelled.")]
58    Cancelled,
59    /// Some hosts were fetched but others failed. The caller should use the
60    /// hosts but suppress destructive operations like --remove.
61    #[error("Partial result: {failures} of {total} failed")]
62    PartialResult {
63        hosts: Vec<ProviderHost>,
64        failures: usize,
65        total: usize,
66    },
67}
68
69/// Trait implemented by each cloud provider.
70pub trait Provider {
71    /// Full provider name (e.g. "digitalocean").
72    fn name(&self) -> &str;
73    /// Short label for aliases (e.g. "do").
74    fn short_label(&self) -> &str;
75    /// Fetch hosts with cancellation support.
76    fn fetch_hosts_cancellable(
77        &self,
78        token: &str,
79        cancel: &AtomicBool,
80    ) -> Result<Vec<ProviderHost>, ProviderError>;
81    /// Fetch all servers from the provider API.
82    #[allow(dead_code)]
83    fn fetch_hosts(&self, token: &str) -> Result<Vec<ProviderHost>, ProviderError> {
84        self.fetch_hosts_cancellable(token, &AtomicBool::new(false))
85    }
86    /// Fetch hosts with progress reporting. Default delegates to fetch_hosts_cancellable.
87    fn fetch_hosts_with_progress(
88        &self,
89        token: &str,
90        cancel: &AtomicBool,
91        _progress: &dyn Fn(&str),
92    ) -> Result<Vec<ProviderHost>, ProviderError> {
93        self.fetch_hosts_cancellable(token, cancel)
94    }
95}
96
97/// All known provider names.
98pub const PROVIDER_NAMES: &[&str] = &["digitalocean", "vultr", "linode", "hetzner", "upcloud", "proxmox", "aws", "scaleway"];
99
100/// Get a provider implementation by name.
101pub fn get_provider(name: &str) -> Option<Box<dyn Provider>> {
102    match name {
103        "digitalocean" => Some(Box::new(digitalocean::DigitalOcean)),
104        "vultr" => Some(Box::new(vultr::Vultr)),
105        "linode" => Some(Box::new(linode::Linode)),
106        "hetzner" => Some(Box::new(hetzner::Hetzner)),
107        "upcloud" => Some(Box::new(upcloud::UpCloud)),
108        "proxmox" => Some(Box::new(proxmox::Proxmox {
109            base_url: String::new(),
110            verify_tls: true,
111        })),
112        "aws" => Some(Box::new(aws::Aws {
113            regions: Vec::new(),
114            profile: String::new(),
115        })),
116        "scaleway" => Some(Box::new(scaleway::Scaleway {
117            zones: Vec::new(),
118        })),
119        _ => None,
120    }
121}
122
123/// Get a provider implementation configured from a provider section.
124/// For providers that need extra config (e.g. Proxmox base URL), this
125/// creates a properly configured instance.
126pub fn get_provider_with_config(name: &str, section: &config::ProviderSection) -> Option<Box<dyn Provider>> {
127    match name {
128        "proxmox" => Some(Box::new(proxmox::Proxmox {
129            base_url: section.url.clone(),
130            verify_tls: section.verify_tls,
131        })),
132        "aws" => Some(Box::new(aws::Aws {
133            regions: section.regions.split(',')
134                .map(|s| s.trim().to_string())
135                .filter(|s| !s.is_empty())
136                .collect(),
137            profile: section.profile.clone(),
138        })),
139        "scaleway" => Some(Box::new(scaleway::Scaleway {
140            zones: section.regions.split(',')
141                .map(|s| s.trim().to_string())
142                .filter(|s| !s.is_empty())
143                .collect(),
144        })),
145        _ => get_provider(name),
146    }
147}
148
149/// Display name for a provider (e.g. "digitalocean" -> "DigitalOcean").
150pub fn provider_display_name(name: &str) -> &str {
151    match name {
152        "digitalocean" => "DigitalOcean",
153        "vultr" => "Vultr",
154        "linode" => "Linode",
155        "hetzner" => "Hetzner",
156        "upcloud" => "UpCloud",
157        "proxmox" => "Proxmox VE",
158        "aws" => "AWS EC2",
159        "scaleway" => "Scaleway",
160        other => other,
161    }
162}
163
164/// Create an HTTP agent with explicit timeouts.
165pub(crate) fn http_agent() -> ureq::Agent {
166    ureq::AgentBuilder::new()
167        .timeout(std::time::Duration::from_secs(30))
168        .redirects(0)
169        .build()
170}
171
172/// Create an HTTP agent that accepts invalid/self-signed TLS certificates.
173pub(crate) fn http_agent_insecure() -> Result<ureq::Agent, ProviderError> {
174    let tls = ureq::native_tls::TlsConnector::builder()
175        .danger_accept_invalid_certs(true)
176        .danger_accept_invalid_hostnames(true)
177        .build()
178        .map_err(|e| ProviderError::Http(format!("TLS setup failed: {}", e)))?;
179    Ok(ureq::AgentBuilder::new()
180        .timeout(std::time::Duration::from_secs(30))
181        .redirects(0)
182        .tls_connector(std::sync::Arc::new(tls))
183        .build())
184}
185
186/// Strip CIDR suffix (/64, /128, etc.) from an IP address.
187/// Some provider APIs return IPv6 addresses with prefix length (e.g. "2600:3c00::1/128").
188/// SSH requires bare addresses without CIDR notation.
189pub(crate) fn strip_cidr(ip: &str) -> &str {
190    // Only strip if it looks like a CIDR suffix (slash followed by digits)
191    if let Some(pos) = ip.rfind('/') {
192        if ip[pos + 1..].bytes().all(|b| b.is_ascii_digit()) && pos + 1 < ip.len() {
193            return &ip[..pos];
194        }
195    }
196    ip
197}
198
199/// Map a ureq error to a ProviderError.
200fn map_ureq_error(err: ureq::Error) -> ProviderError {
201    match err {
202        ureq::Error::Status(401, _) | ureq::Error::Status(403, _) => ProviderError::AuthFailed,
203        ureq::Error::Status(429, _) => ProviderError::RateLimited,
204        ureq::Error::Status(code, _) => ProviderError::Http(format!("HTTP {}", code)),
205        ureq::Error::Transport(t) => ProviderError::Http(t.to_string()),
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    // =========================================================================
214    // strip_cidr tests
215    // =========================================================================
216
217    #[test]
218    fn test_strip_cidr_ipv6_with_prefix() {
219        assert_eq!(strip_cidr("2600:3c00::1/128"), "2600:3c00::1");
220        assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
221    }
222
223    #[test]
224    fn test_strip_cidr_bare_ipv6() {
225        assert_eq!(strip_cidr("2600:3c00::1"), "2600:3c00::1");
226    }
227
228    #[test]
229    fn test_strip_cidr_ipv4_passthrough() {
230        assert_eq!(strip_cidr("1.2.3.4"), "1.2.3.4");
231        assert_eq!(strip_cidr("10.0.0.1/24"), "10.0.0.1");
232    }
233
234    #[test]
235    fn test_strip_cidr_empty() {
236        assert_eq!(strip_cidr(""), "");
237    }
238
239    #[test]
240    fn test_strip_cidr_slash_without_digits() {
241        // Shouldn't strip if after slash there are non-digits
242        assert_eq!(strip_cidr("path/to/something"), "path/to/something");
243    }
244
245    #[test]
246    fn test_strip_cidr_trailing_slash() {
247        // Trailing slash with nothing after: pos+1 == ip.len(), should NOT strip
248        assert_eq!(strip_cidr("1.2.3.4/"), "1.2.3.4/");
249    }
250
251    // =========================================================================
252    // get_provider factory tests
253    // =========================================================================
254
255    #[test]
256    fn test_get_provider_digitalocean() {
257        let p = get_provider("digitalocean").unwrap();
258        assert_eq!(p.name(), "digitalocean");
259        assert_eq!(p.short_label(), "do");
260    }
261
262    #[test]
263    fn test_get_provider_vultr() {
264        let p = get_provider("vultr").unwrap();
265        assert_eq!(p.name(), "vultr");
266        assert_eq!(p.short_label(), "vultr");
267    }
268
269    #[test]
270    fn test_get_provider_linode() {
271        let p = get_provider("linode").unwrap();
272        assert_eq!(p.name(), "linode");
273        assert_eq!(p.short_label(), "linode");
274    }
275
276    #[test]
277    fn test_get_provider_hetzner() {
278        let p = get_provider("hetzner").unwrap();
279        assert_eq!(p.name(), "hetzner");
280        assert_eq!(p.short_label(), "hetzner");
281    }
282
283    #[test]
284    fn test_get_provider_upcloud() {
285        let p = get_provider("upcloud").unwrap();
286        assert_eq!(p.name(), "upcloud");
287        assert_eq!(p.short_label(), "uc");
288    }
289
290    #[test]
291    fn test_get_provider_proxmox() {
292        let p = get_provider("proxmox").unwrap();
293        assert_eq!(p.name(), "proxmox");
294        assert_eq!(p.short_label(), "pve");
295    }
296
297    #[test]
298    fn test_get_provider_unknown_returns_none() {
299        assert!(get_provider("gcp").is_none());
300        assert!(get_provider("").is_none());
301        assert!(get_provider("DigitalOcean").is_none()); // case-sensitive
302    }
303
304    #[test]
305    fn test_get_provider_all_names_resolve() {
306        for name in PROVIDER_NAMES {
307            assert!(get_provider(name).is_some(), "Provider '{}' should resolve", name);
308        }
309    }
310
311    // =========================================================================
312    // get_provider_with_config tests
313    // =========================================================================
314
315    #[test]
316    fn test_get_provider_with_config_proxmox_uses_url() {
317        let section = config::ProviderSection {
318            provider: "proxmox".to_string(),
319            token: "user@pam!token=secret".to_string(),
320            alias_prefix: "pve-".to_string(),
321            user: String::new(),
322            identity_file: String::new(),
323            url: "https://pve.example.com:8006".to_string(),
324            verify_tls: false,
325            auto_sync: false,
326            profile: String::new(),
327            regions: String::new(),
328        };
329        let p = get_provider_with_config("proxmox", &section).unwrap();
330        assert_eq!(p.name(), "proxmox");
331    }
332
333    #[test]
334    fn test_get_provider_with_config_non_proxmox_delegates() {
335        let section = config::ProviderSection {
336            provider: "digitalocean".to_string(),
337            token: "do-token".to_string(),
338            alias_prefix: "do-".to_string(),
339            user: String::new(),
340            identity_file: String::new(),
341            url: String::new(),
342            verify_tls: true,
343            auto_sync: true,
344            profile: String::new(),
345            regions: String::new(),
346        };
347        let p = get_provider_with_config("digitalocean", &section).unwrap();
348        assert_eq!(p.name(), "digitalocean");
349    }
350
351    #[test]
352    fn test_get_provider_with_config_unknown_returns_none() {
353        let section = config::ProviderSection {
354            provider: "gcp".to_string(),
355            token: String::new(),
356            alias_prefix: String::new(),
357            user: String::new(),
358            identity_file: String::new(),
359            url: String::new(),
360            verify_tls: true,
361            auto_sync: true,
362            profile: String::new(),
363            regions: String::new(),
364        };
365        assert!(get_provider_with_config("gcp", &section).is_none());
366    }
367
368    // =========================================================================
369    // provider_display_name tests
370    // =========================================================================
371
372    #[test]
373    fn test_display_name_all_providers() {
374        assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
375        assert_eq!(provider_display_name("vultr"), "Vultr");
376        assert_eq!(provider_display_name("linode"), "Linode");
377        assert_eq!(provider_display_name("hetzner"), "Hetzner");
378        assert_eq!(provider_display_name("upcloud"), "UpCloud");
379        assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
380        assert_eq!(provider_display_name("aws"), "AWS EC2");
381        assert_eq!(provider_display_name("scaleway"), "Scaleway");
382    }
383
384    #[test]
385    fn test_display_name_unknown_returns_input() {
386        assert_eq!(provider_display_name("gcp"), "gcp");
387        assert_eq!(provider_display_name(""), "");
388    }
389
390    // =========================================================================
391    // PROVIDER_NAMES constant tests
392    // =========================================================================
393
394    #[test]
395    fn test_provider_names_count() {
396        assert_eq!(PROVIDER_NAMES.len(), 8);
397    }
398
399    #[test]
400    fn test_provider_names_contains_all() {
401        assert!(PROVIDER_NAMES.contains(&"digitalocean"));
402        assert!(PROVIDER_NAMES.contains(&"vultr"));
403        assert!(PROVIDER_NAMES.contains(&"linode"));
404        assert!(PROVIDER_NAMES.contains(&"hetzner"));
405        assert!(PROVIDER_NAMES.contains(&"upcloud"));
406        assert!(PROVIDER_NAMES.contains(&"proxmox"));
407        assert!(PROVIDER_NAMES.contains(&"aws"));
408        assert!(PROVIDER_NAMES.contains(&"scaleway"));
409    }
410
411    // =========================================================================
412    // ProviderError display tests
413    // =========================================================================
414
415    #[test]
416    fn test_provider_error_display_http() {
417        let err = ProviderError::Http("connection refused".to_string());
418        assert_eq!(format!("{}", err), "HTTP error: connection refused");
419    }
420
421    #[test]
422    fn test_provider_error_display_parse() {
423        let err = ProviderError::Parse("invalid JSON".to_string());
424        assert_eq!(format!("{}", err), "Failed to parse response: invalid JSON");
425    }
426
427    #[test]
428    fn test_provider_error_display_auth() {
429        let err = ProviderError::AuthFailed;
430        assert!(format!("{}", err).contains("Authentication failed"));
431    }
432
433    #[test]
434    fn test_provider_error_display_rate_limited() {
435        let err = ProviderError::RateLimited;
436        assert!(format!("{}", err).contains("Rate limited"));
437    }
438
439    #[test]
440    fn test_provider_error_display_cancelled() {
441        let err = ProviderError::Cancelled;
442        assert_eq!(format!("{}", err), "Cancelled.");
443    }
444
445    #[test]
446    fn test_provider_error_display_partial_result() {
447        let err = ProviderError::PartialResult {
448            hosts: vec![],
449            failures: 3,
450            total: 10,
451        };
452        assert!(format!("{}", err).contains("3 of 10 failed"));
453    }
454
455    // =========================================================================
456    // ProviderHost struct tests
457    // =========================================================================
458
459    #[test]
460    fn test_provider_host_construction() {
461        let host = ProviderHost::new("12345".to_string(), "web-01".to_string(), "1.2.3.4".to_string(), vec!["prod".to_string(), "web".to_string()]);
462        assert_eq!(host.server_id, "12345");
463        assert_eq!(host.name, "web-01");
464        assert_eq!(host.ip, "1.2.3.4");
465        assert_eq!(host.tags.len(), 2);
466    }
467
468    #[test]
469    fn test_provider_host_clone() {
470        let host = ProviderHost::new("1".to_string(), "a".to_string(), "1.1.1.1".to_string(), vec![]);
471        let cloned = host.clone();
472        assert_eq!(cloned.server_id, host.server_id);
473        assert_eq!(cloned.name, host.name);
474    }
475
476    // =========================================================================
477    // strip_cidr additional edge cases
478    // =========================================================================
479
480    #[test]
481    fn test_strip_cidr_ipv6_with_64() {
482        assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
483    }
484
485    #[test]
486    fn test_strip_cidr_ipv4_with_32() {
487        assert_eq!(strip_cidr("1.2.3.4/32"), "1.2.3.4");
488    }
489
490    #[test]
491    fn test_strip_cidr_ipv4_with_8() {
492        assert_eq!(strip_cidr("10.0.0.1/8"), "10.0.0.1");
493    }
494
495    #[test]
496    fn test_strip_cidr_just_slash() {
497        // "/" alone: pos=0, pos+1=1=len -> condition fails
498        assert_eq!(strip_cidr("/"), "/");
499    }
500
501    #[test]
502    fn test_strip_cidr_slash_with_letters() {
503        assert_eq!(strip_cidr("10.0.0.1/abc"), "10.0.0.1/abc");
504    }
505
506    #[test]
507    fn test_strip_cidr_multiple_slashes() {
508        // rfind gets last slash: "48" is digits, so it strips the last /48
509        assert_eq!(strip_cidr("10.0.0.1/24/48"), "10.0.0.1/24");
510    }
511
512    #[test]
513    fn test_strip_cidr_ipv6_full_notation() {
514        assert_eq!(
515            strip_cidr("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128"),
516            "2001:0db8:85a3:0000:0000:8a2e:0370:7334"
517        );
518    }
519
520    // =========================================================================
521    // ProviderError Debug
522    // =========================================================================
523
524    #[test]
525    fn test_provider_error_debug_http() {
526        let err = ProviderError::Http("timeout".to_string());
527        let debug = format!("{:?}", err);
528        assert!(debug.contains("Http"));
529        assert!(debug.contains("timeout"));
530    }
531
532    #[test]
533    fn test_provider_error_debug_partial_result() {
534        let err = ProviderError::PartialResult {
535            hosts: vec![ProviderHost::new("1".to_string(), "web".to_string(), "1.2.3.4".to_string(), vec![])],
536            failures: 2,
537            total: 5,
538        };
539        let debug = format!("{:?}", err);
540        assert!(debug.contains("PartialResult"));
541        assert!(debug.contains("failures: 2"));
542    }
543
544    // =========================================================================
545    // ProviderHost with empty fields
546    // =========================================================================
547
548    #[test]
549    fn test_provider_host_empty_fields() {
550        let host = ProviderHost::new(String::new(), String::new(), String::new(), vec![]);
551        assert!(host.server_id.is_empty());
552        assert!(host.name.is_empty());
553        assert!(host.ip.is_empty());
554    }
555
556    // =========================================================================
557    // get_provider_with_config for all non-proxmox providers
558    // =========================================================================
559
560    #[test]
561    fn test_get_provider_with_config_all_providers() {
562        for &name in PROVIDER_NAMES {
563            let section = config::ProviderSection {
564                provider: name.to_string(),
565                token: "tok".to_string(),
566                alias_prefix: "test".to_string(),
567                user: String::new(),
568                identity_file: String::new(),
569                url: if name == "proxmox" {
570                    "https://pve:8006".to_string()
571                } else {
572                    String::new()
573                },
574                verify_tls: true,
575                auto_sync: true,
576                profile: String::new(),
577                regions: String::new(),
578            };
579            let p = get_provider_with_config(name, &section);
580            assert!(p.is_some(), "get_provider_with_config({}) should return Some", name);
581            assert_eq!(p.unwrap().name(), name);
582        }
583    }
584
585    // =========================================================================
586    // Provider trait default methods
587    // =========================================================================
588
589    #[test]
590    fn test_provider_fetch_hosts_delegates_to_cancellable() {
591        let provider = get_provider("digitalocean").unwrap();
592        // fetch_hosts delegates to fetch_hosts_cancellable with AtomicBool(false)
593        // We can't actually test this without a server, but we verify the method exists
594        // by calling it (will fail with network error, which is fine for this test)
595        let result = provider.fetch_hosts("fake-token");
596        assert!(result.is_err()); // Expected: no network
597    }
598
599    // =========================================================================
600    // strip_cidr: suffix starts with digit but contains letters
601    // =========================================================================
602
603    #[test]
604    fn test_strip_cidr_digit_then_letters_not_stripped() {
605        assert_eq!(strip_cidr("10.0.0.1/24abc"), "10.0.0.1/24abc");
606    }
607
608    // =========================================================================
609    // provider_display_name: all known providers
610    // =========================================================================
611
612    #[test]
613    fn test_provider_display_name_all() {
614        assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
615        assert_eq!(provider_display_name("vultr"), "Vultr");
616        assert_eq!(provider_display_name("linode"), "Linode");
617        assert_eq!(provider_display_name("hetzner"), "Hetzner");
618        assert_eq!(provider_display_name("upcloud"), "UpCloud");
619        assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
620        assert_eq!(provider_display_name("aws"), "AWS EC2");
621        assert_eq!(provider_display_name("scaleway"), "Scaleway");
622    }
623
624    #[test]
625    fn test_provider_display_name_unknown() {
626        assert_eq!(provider_display_name("gcp"), "gcp");
627    }
628
629    // =========================================================================
630    // get_provider: all known + unknown
631    // =========================================================================
632
633    #[test]
634    fn test_get_provider_all_known() {
635        for name in PROVIDER_NAMES {
636            assert!(get_provider(name).is_some(), "get_provider({}) should return Some", name);
637        }
638    }
639
640    #[test]
641    fn test_get_provider_case_sensitive_and_unknown() {
642        assert!(get_provider("gcp").is_none());
643        assert!(get_provider("DigitalOcean").is_none()); // Case-sensitive
644        assert!(get_provider("VULTR").is_none());
645        assert!(get_provider("").is_none());
646    }
647
648    // =========================================================================
649    // PROVIDER_NAMES constant
650    // =========================================================================
651
652    #[test]
653    fn test_provider_names_has_all_eight() {
654        assert_eq!(PROVIDER_NAMES.len(), 8);
655        assert!(PROVIDER_NAMES.contains(&"digitalocean"));
656        assert!(PROVIDER_NAMES.contains(&"proxmox"));
657        assert!(PROVIDER_NAMES.contains(&"aws"));
658        assert!(PROVIDER_NAMES.contains(&"scaleway"));
659    }
660
661    // =========================================================================
662    // Provider short_label via get_provider
663    // =========================================================================
664
665    #[test]
666    fn test_provider_short_labels() {
667        let cases = [
668            ("digitalocean", "do"),
669            ("vultr", "vultr"),
670            ("linode", "linode"),
671            ("hetzner", "hetzner"),
672            ("upcloud", "uc"),
673            ("proxmox", "pve"),
674            ("aws", "aws"),
675            ("scaleway", "scw"),
676        ];
677        for (name, expected_label) in &cases {
678            let p = get_provider(name).unwrap();
679            assert_eq!(p.short_label(), *expected_label, "short_label for {}", name);
680        }
681    }
682}