Skip to main content

purple_ssh/providers/
mod.rs

1pub mod config;
2mod digitalocean;
3mod hetzner;
4mod linode;
5mod proxmox;
6pub mod sync;
7mod upcloud;
8mod vultr;
9
10use std::sync::atomic::AtomicBool;
11
12use thiserror::Error;
13
14/// A host discovered from a cloud provider API.
15#[derive(Debug, Clone)]
16#[allow(dead_code)]
17pub struct ProviderHost {
18    /// Provider-assigned server ID.
19    pub server_id: String,
20    /// Server name/label.
21    pub name: String,
22    /// Public IP address (IPv4 or IPv6).
23    pub ip: String,
24    /// Provider tags/labels.
25    pub tags: Vec<String>,
26}
27
28/// Errors from provider API calls.
29#[derive(Debug, Error)]
30pub enum ProviderError {
31    #[error("HTTP error: {0}")]
32    Http(String),
33    #[error("Failed to parse response: {0}")]
34    Parse(String),
35    #[error("Authentication failed. Check your API token.")]
36    AuthFailed,
37    #[error("Rate limited. Try again in a moment.")]
38    RateLimited,
39    #[error("Cancelled.")]
40    Cancelled,
41    /// Some hosts were fetched but others failed. The caller should use the
42    /// hosts but suppress destructive operations like --remove.
43    #[error("Partial result: {failures} of {total} failed")]
44    PartialResult {
45        hosts: Vec<ProviderHost>,
46        failures: usize,
47        total: usize,
48    },
49}
50
51/// Trait implemented by each cloud provider.
52pub trait Provider {
53    /// Full provider name (e.g. "digitalocean").
54    fn name(&self) -> &str;
55    /// Short label for aliases (e.g. "do").
56    fn short_label(&self) -> &str;
57    /// Fetch hosts with cancellation support.
58    fn fetch_hosts_cancellable(
59        &self,
60        token: &str,
61        cancel: &AtomicBool,
62    ) -> Result<Vec<ProviderHost>, ProviderError>;
63    /// Fetch all servers from the provider API.
64    #[allow(dead_code)]
65    fn fetch_hosts(&self, token: &str) -> Result<Vec<ProviderHost>, ProviderError> {
66        self.fetch_hosts_cancellable(token, &AtomicBool::new(false))
67    }
68    /// Fetch hosts with progress reporting. Default delegates to fetch_hosts_cancellable.
69    fn fetch_hosts_with_progress(
70        &self,
71        token: &str,
72        cancel: &AtomicBool,
73        _progress: &dyn Fn(&str),
74    ) -> Result<Vec<ProviderHost>, ProviderError> {
75        self.fetch_hosts_cancellable(token, cancel)
76    }
77}
78
79/// All known provider names.
80pub const PROVIDER_NAMES: &[&str] = &["digitalocean", "vultr", "linode", "hetzner", "upcloud", "proxmox"];
81
82/// Get a provider implementation by name.
83pub fn get_provider(name: &str) -> Option<Box<dyn Provider>> {
84    match name {
85        "digitalocean" => Some(Box::new(digitalocean::DigitalOcean)),
86        "vultr" => Some(Box::new(vultr::Vultr)),
87        "linode" => Some(Box::new(linode::Linode)),
88        "hetzner" => Some(Box::new(hetzner::Hetzner)),
89        "upcloud" => Some(Box::new(upcloud::UpCloud)),
90        "proxmox" => Some(Box::new(proxmox::Proxmox {
91            base_url: String::new(),
92            verify_tls: true,
93        })),
94        _ => None,
95    }
96}
97
98/// Get a provider implementation configured from a provider section.
99/// For providers that need extra config (e.g. Proxmox base URL), this
100/// creates a properly configured instance.
101pub fn get_provider_with_config(name: &str, section: &config::ProviderSection) -> Option<Box<dyn Provider>> {
102    match name {
103        "proxmox" => Some(Box::new(proxmox::Proxmox {
104            base_url: section.url.clone(),
105            verify_tls: section.verify_tls,
106        })),
107        _ => get_provider(name),
108    }
109}
110
111/// Display name for a provider (e.g. "digitalocean" -> "DigitalOcean").
112pub fn provider_display_name(name: &str) -> &str {
113    match name {
114        "digitalocean" => "DigitalOcean",
115        "vultr" => "Vultr",
116        "linode" => "Linode",
117        "hetzner" => "Hetzner",
118        "upcloud" => "UpCloud",
119        "proxmox" => "Proxmox VE",
120        other => other,
121    }
122}
123
124/// Create an HTTP agent with explicit timeouts.
125pub(crate) fn http_agent() -> ureq::Agent {
126    ureq::AgentBuilder::new()
127        .timeout(std::time::Duration::from_secs(30))
128        .redirects(0)
129        .build()
130}
131
132/// Create an HTTP agent that accepts invalid/self-signed TLS certificates.
133pub(crate) fn http_agent_insecure() -> Result<ureq::Agent, ProviderError> {
134    let tls = ureq::native_tls::TlsConnector::builder()
135        .danger_accept_invalid_certs(true)
136        .danger_accept_invalid_hostnames(true)
137        .build()
138        .map_err(|e| ProviderError::Http(format!("TLS setup failed: {}", e)))?;
139    Ok(ureq::AgentBuilder::new()
140        .timeout(std::time::Duration::from_secs(30))
141        .redirects(0)
142        .tls_connector(std::sync::Arc::new(tls))
143        .build())
144}
145
146/// Strip CIDR suffix (/64, /128, etc.) from an IP address.
147/// Some provider APIs return IPv6 addresses with prefix length (e.g. "2600:3c00::1/128").
148/// SSH requires bare addresses without CIDR notation.
149pub(crate) fn strip_cidr(ip: &str) -> &str {
150    // Only strip if it looks like a CIDR suffix (slash followed by digits)
151    if let Some(pos) = ip.rfind('/') {
152        if ip[pos + 1..].bytes().all(|b| b.is_ascii_digit()) && pos + 1 < ip.len() {
153            return &ip[..pos];
154        }
155    }
156    ip
157}
158
159/// Map a ureq error to a ProviderError.
160fn map_ureq_error(err: ureq::Error) -> ProviderError {
161    match err {
162        ureq::Error::Status(401, _) | ureq::Error::Status(403, _) => ProviderError::AuthFailed,
163        ureq::Error::Status(429, _) => ProviderError::RateLimited,
164        ureq::Error::Status(code, _) => ProviderError::Http(format!("HTTP {}", code)),
165        ureq::Error::Transport(t) => ProviderError::Http(t.to_string()),
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    // =========================================================================
174    // strip_cidr tests
175    // =========================================================================
176
177    #[test]
178    fn test_strip_cidr_ipv6_with_prefix() {
179        assert_eq!(strip_cidr("2600:3c00::1/128"), "2600:3c00::1");
180        assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
181    }
182
183    #[test]
184    fn test_strip_cidr_bare_ipv6() {
185        assert_eq!(strip_cidr("2600:3c00::1"), "2600:3c00::1");
186    }
187
188    #[test]
189    fn test_strip_cidr_ipv4_passthrough() {
190        assert_eq!(strip_cidr("1.2.3.4"), "1.2.3.4");
191        assert_eq!(strip_cidr("10.0.0.1/24"), "10.0.0.1");
192    }
193
194    #[test]
195    fn test_strip_cidr_empty() {
196        assert_eq!(strip_cidr(""), "");
197    }
198
199    #[test]
200    fn test_strip_cidr_slash_without_digits() {
201        // Shouldn't strip if after slash there are non-digits
202        assert_eq!(strip_cidr("path/to/something"), "path/to/something");
203    }
204
205    #[test]
206    fn test_strip_cidr_trailing_slash() {
207        // Trailing slash with nothing after: pos+1 == ip.len(), should NOT strip
208        assert_eq!(strip_cidr("1.2.3.4/"), "1.2.3.4/");
209    }
210
211    // =========================================================================
212    // get_provider factory tests
213    // =========================================================================
214
215    #[test]
216    fn test_get_provider_digitalocean() {
217        let p = get_provider("digitalocean").unwrap();
218        assert_eq!(p.name(), "digitalocean");
219        assert_eq!(p.short_label(), "do");
220    }
221
222    #[test]
223    fn test_get_provider_vultr() {
224        let p = get_provider("vultr").unwrap();
225        assert_eq!(p.name(), "vultr");
226        assert_eq!(p.short_label(), "vultr");
227    }
228
229    #[test]
230    fn test_get_provider_linode() {
231        let p = get_provider("linode").unwrap();
232        assert_eq!(p.name(), "linode");
233        assert_eq!(p.short_label(), "linode");
234    }
235
236    #[test]
237    fn test_get_provider_hetzner() {
238        let p = get_provider("hetzner").unwrap();
239        assert_eq!(p.name(), "hetzner");
240        assert_eq!(p.short_label(), "hetzner");
241    }
242
243    #[test]
244    fn test_get_provider_upcloud() {
245        let p = get_provider("upcloud").unwrap();
246        assert_eq!(p.name(), "upcloud");
247        assert_eq!(p.short_label(), "uc");
248    }
249
250    #[test]
251    fn test_get_provider_proxmox() {
252        let p = get_provider("proxmox").unwrap();
253        assert_eq!(p.name(), "proxmox");
254        assert_eq!(p.short_label(), "pve");
255    }
256
257    #[test]
258    fn test_get_provider_unknown_returns_none() {
259        assert!(get_provider("aws").is_none());
260        assert!(get_provider("").is_none());
261        assert!(get_provider("DigitalOcean").is_none()); // case-sensitive
262    }
263
264    #[test]
265    fn test_get_provider_all_names_resolve() {
266        for name in PROVIDER_NAMES {
267            assert!(get_provider(name).is_some(), "Provider '{}' should resolve", name);
268        }
269    }
270
271    // =========================================================================
272    // get_provider_with_config tests
273    // =========================================================================
274
275    #[test]
276    fn test_get_provider_with_config_proxmox_uses_url() {
277        let section = config::ProviderSection {
278            provider: "proxmox".to_string(),
279            token: "user@pam!token=secret".to_string(),
280            alias_prefix: "pve-".to_string(),
281            user: String::new(),
282            identity_file: String::new(),
283            url: "https://pve.example.com:8006".to_string(),
284            verify_tls: false,
285            auto_sync: false,
286        };
287        let p = get_provider_with_config("proxmox", &section).unwrap();
288        assert_eq!(p.name(), "proxmox");
289    }
290
291    #[test]
292    fn test_get_provider_with_config_non_proxmox_delegates() {
293        let section = config::ProviderSection {
294            provider: "digitalocean".to_string(),
295            token: "do-token".to_string(),
296            alias_prefix: "do-".to_string(),
297            user: String::new(),
298            identity_file: String::new(),
299            url: String::new(),
300            verify_tls: true,
301            auto_sync: true,
302        };
303        let p = get_provider_with_config("digitalocean", &section).unwrap();
304        assert_eq!(p.name(), "digitalocean");
305    }
306
307    #[test]
308    fn test_get_provider_with_config_unknown_returns_none() {
309        let section = config::ProviderSection {
310            provider: "aws".to_string(),
311            token: String::new(),
312            alias_prefix: String::new(),
313            user: String::new(),
314            identity_file: String::new(),
315            url: String::new(),
316            verify_tls: true,
317            auto_sync: true,
318        };
319        assert!(get_provider_with_config("aws", &section).is_none());
320    }
321
322    // =========================================================================
323    // provider_display_name tests
324    // =========================================================================
325
326    #[test]
327    fn test_display_name_all_providers() {
328        assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
329        assert_eq!(provider_display_name("vultr"), "Vultr");
330        assert_eq!(provider_display_name("linode"), "Linode");
331        assert_eq!(provider_display_name("hetzner"), "Hetzner");
332        assert_eq!(provider_display_name("upcloud"), "UpCloud");
333        assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
334    }
335
336    #[test]
337    fn test_display_name_unknown_returns_input() {
338        assert_eq!(provider_display_name("aws"), "aws");
339        assert_eq!(provider_display_name(""), "");
340    }
341
342    // =========================================================================
343    // PROVIDER_NAMES constant tests
344    // =========================================================================
345
346    #[test]
347    fn test_provider_names_count() {
348        assert_eq!(PROVIDER_NAMES.len(), 6);
349    }
350
351    #[test]
352    fn test_provider_names_contains_all() {
353        assert!(PROVIDER_NAMES.contains(&"digitalocean"));
354        assert!(PROVIDER_NAMES.contains(&"vultr"));
355        assert!(PROVIDER_NAMES.contains(&"linode"));
356        assert!(PROVIDER_NAMES.contains(&"hetzner"));
357        assert!(PROVIDER_NAMES.contains(&"upcloud"));
358        assert!(PROVIDER_NAMES.contains(&"proxmox"));
359    }
360
361    // =========================================================================
362    // ProviderError display tests
363    // =========================================================================
364
365    #[test]
366    fn test_provider_error_display_http() {
367        let err = ProviderError::Http("connection refused".to_string());
368        assert_eq!(format!("{}", err), "HTTP error: connection refused");
369    }
370
371    #[test]
372    fn test_provider_error_display_parse() {
373        let err = ProviderError::Parse("invalid JSON".to_string());
374        assert_eq!(format!("{}", err), "Failed to parse response: invalid JSON");
375    }
376
377    #[test]
378    fn test_provider_error_display_auth() {
379        let err = ProviderError::AuthFailed;
380        assert!(format!("{}", err).contains("Authentication failed"));
381    }
382
383    #[test]
384    fn test_provider_error_display_rate_limited() {
385        let err = ProviderError::RateLimited;
386        assert!(format!("{}", err).contains("Rate limited"));
387    }
388
389    #[test]
390    fn test_provider_error_display_cancelled() {
391        let err = ProviderError::Cancelled;
392        assert_eq!(format!("{}", err), "Cancelled.");
393    }
394
395    #[test]
396    fn test_provider_error_display_partial_result() {
397        let err = ProviderError::PartialResult {
398            hosts: vec![],
399            failures: 3,
400            total: 10,
401        };
402        assert!(format!("{}", err).contains("3 of 10 failed"));
403    }
404
405    // =========================================================================
406    // ProviderHost struct tests
407    // =========================================================================
408
409    #[test]
410    fn test_provider_host_construction() {
411        let host = ProviderHost {
412            server_id: "12345".to_string(),
413            name: "web-01".to_string(),
414            ip: "1.2.3.4".to_string(),
415            tags: vec!["prod".to_string(), "web".to_string()],
416        };
417        assert_eq!(host.server_id, "12345");
418        assert_eq!(host.name, "web-01");
419        assert_eq!(host.ip, "1.2.3.4");
420        assert_eq!(host.tags.len(), 2);
421    }
422
423    #[test]
424    fn test_provider_host_clone() {
425        let host = ProviderHost {
426            server_id: "1".to_string(),
427            name: "a".to_string(),
428            ip: "1.1.1.1".to_string(),
429            tags: vec![],
430        };
431        let cloned = host.clone();
432        assert_eq!(cloned.server_id, host.server_id);
433        assert_eq!(cloned.name, host.name);
434    }
435
436    // =========================================================================
437    // strip_cidr additional edge cases
438    // =========================================================================
439
440    #[test]
441    fn test_strip_cidr_ipv6_with_64() {
442        assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
443    }
444
445    #[test]
446    fn test_strip_cidr_ipv4_with_32() {
447        assert_eq!(strip_cidr("1.2.3.4/32"), "1.2.3.4");
448    }
449
450    #[test]
451    fn test_strip_cidr_ipv4_with_8() {
452        assert_eq!(strip_cidr("10.0.0.1/8"), "10.0.0.1");
453    }
454
455    #[test]
456    fn test_strip_cidr_just_slash() {
457        // "/" alone: pos=0, pos+1=1=len -> condition fails
458        assert_eq!(strip_cidr("/"), "/");
459    }
460
461    #[test]
462    fn test_strip_cidr_slash_with_letters() {
463        assert_eq!(strip_cidr("10.0.0.1/abc"), "10.0.0.1/abc");
464    }
465
466    #[test]
467    fn test_strip_cidr_multiple_slashes() {
468        // rfind gets last slash: "48" is digits, so it strips the last /48
469        assert_eq!(strip_cidr("10.0.0.1/24/48"), "10.0.0.1/24");
470    }
471
472    #[test]
473    fn test_strip_cidr_ipv6_full_notation() {
474        assert_eq!(
475            strip_cidr("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128"),
476            "2001:0db8:85a3:0000:0000:8a2e:0370:7334"
477        );
478    }
479
480    // =========================================================================
481    // ProviderError Debug
482    // =========================================================================
483
484    #[test]
485    fn test_provider_error_debug_http() {
486        let err = ProviderError::Http("timeout".to_string());
487        let debug = format!("{:?}", err);
488        assert!(debug.contains("Http"));
489        assert!(debug.contains("timeout"));
490    }
491
492    #[test]
493    fn test_provider_error_debug_partial_result() {
494        let err = ProviderError::PartialResult {
495            hosts: vec![ProviderHost {
496                server_id: "1".to_string(),
497                name: "web".to_string(),
498                ip: "1.2.3.4".to_string(),
499                tags: vec![],
500            }],
501            failures: 2,
502            total: 5,
503        };
504        let debug = format!("{:?}", err);
505        assert!(debug.contains("PartialResult"));
506        assert!(debug.contains("failures: 2"));
507    }
508
509    // =========================================================================
510    // ProviderHost with empty fields
511    // =========================================================================
512
513    #[test]
514    fn test_provider_host_empty_fields() {
515        let host = ProviderHost {
516            server_id: String::new(),
517            name: String::new(),
518            ip: String::new(),
519            tags: vec![],
520        };
521        assert!(host.server_id.is_empty());
522        assert!(host.name.is_empty());
523        assert!(host.ip.is_empty());
524    }
525
526    // =========================================================================
527    // get_provider_with_config for all non-proxmox providers
528    // =========================================================================
529
530    #[test]
531    fn test_get_provider_with_config_all_providers() {
532        for &name in PROVIDER_NAMES {
533            let section = config::ProviderSection {
534                provider: name.to_string(),
535                token: "tok".to_string(),
536                alias_prefix: "test".to_string(),
537                user: String::new(),
538                identity_file: String::new(),
539                url: if name == "proxmox" {
540                    "https://pve:8006".to_string()
541                } else {
542                    String::new()
543                },
544                verify_tls: true,
545                auto_sync: true,
546            };
547            let p = get_provider_with_config(name, &section);
548            assert!(p.is_some(), "get_provider_with_config({}) should return Some", name);
549            assert_eq!(p.unwrap().name(), name);
550        }
551    }
552
553    // =========================================================================
554    // Provider trait default methods
555    // =========================================================================
556
557    #[test]
558    fn test_provider_fetch_hosts_delegates_to_cancellable() {
559        let provider = get_provider("digitalocean").unwrap();
560        // fetch_hosts delegates to fetch_hosts_cancellable with AtomicBool(false)
561        // We can't actually test this without a server, but we verify the method exists
562        // by calling it (will fail with network error, which is fine for this test)
563        let result = provider.fetch_hosts("fake-token");
564        assert!(result.is_err()); // Expected: no network
565    }
566
567    // =========================================================================
568    // strip_cidr: suffix starts with digit but contains letters
569    // =========================================================================
570
571    #[test]
572    fn test_strip_cidr_digit_then_letters_not_stripped() {
573        assert_eq!(strip_cidr("10.0.0.1/24abc"), "10.0.0.1/24abc");
574    }
575
576    // =========================================================================
577    // provider_display_name: all known providers
578    // =========================================================================
579
580    #[test]
581    fn test_provider_display_name_all() {
582        assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
583        assert_eq!(provider_display_name("vultr"), "Vultr");
584        assert_eq!(provider_display_name("linode"), "Linode");
585        assert_eq!(provider_display_name("hetzner"), "Hetzner");
586        assert_eq!(provider_display_name("upcloud"), "UpCloud");
587        assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
588    }
589
590    #[test]
591    fn test_provider_display_name_unknown() {
592        assert_eq!(provider_display_name("foobar"), "foobar");
593    }
594
595    // =========================================================================
596    // get_provider: all known + unknown
597    // =========================================================================
598
599    #[test]
600    fn test_get_provider_all_known() {
601        for name in PROVIDER_NAMES {
602            assert!(get_provider(name).is_some(), "get_provider({}) should return Some", name);
603        }
604    }
605
606    #[test]
607    fn test_get_provider_case_sensitive_and_unknown() {
608        assert!(get_provider("aws").is_none());
609        assert!(get_provider("DigitalOcean").is_none()); // Case-sensitive
610        assert!(get_provider("VULTR").is_none());
611        assert!(get_provider("").is_none());
612    }
613
614    // =========================================================================
615    // PROVIDER_NAMES constant
616    // =========================================================================
617
618    #[test]
619    fn test_provider_names_has_all_six() {
620        assert_eq!(PROVIDER_NAMES.len(), 6);
621        assert!(PROVIDER_NAMES.contains(&"digitalocean"));
622        assert!(PROVIDER_NAMES.contains(&"proxmox"));
623    }
624
625    // =========================================================================
626    // Provider short_label via get_provider
627    // =========================================================================
628
629    #[test]
630    fn test_provider_short_labels() {
631        let cases = [
632            ("digitalocean", "do"),
633            ("vultr", "vultr"),
634            ("linode", "linode"),
635            ("hetzner", "hetzner"),
636            ("upcloud", "uc"),
637            ("proxmox", "pve"),
638        ];
639        for (name, expected_label) in &cases {
640            let p = get_provider(name).unwrap();
641            assert_eq!(p.short_label(), *expected_label, "short_label for {}", name);
642        }
643    }
644}