Skip to main content

purple_ssh/providers/
mod.rs

1pub mod config;
2mod digitalocean;
3mod hetzner;
4mod linode;
5mod proxmox;
6pub mod sync;
7mod upcloud;
8mod vultr;
9
10use std::sync::atomic::AtomicBool;
11
12use thiserror::Error;
13
14/// A host discovered from a cloud provider API.
15#[derive(Debug, Clone)]
16#[allow(dead_code)]
17pub struct ProviderHost {
18    /// Provider-assigned server ID.
19    pub server_id: String,
20    /// Server name/label.
21    pub name: String,
22    /// Public IP address (IPv4 or IPv6).
23    pub ip: String,
24    /// Provider tags/labels.
25    pub tags: Vec<String>,
26    /// Provider metadata (region, plan, etc.) as key-value pairs.
27    pub metadata: Vec<(String, String)>,
28}
29
30impl ProviderHost {
31    /// Create a ProviderHost with no metadata.
32    #[allow(dead_code)]
33    pub fn new(server_id: String, name: String, ip: String, tags: Vec<String>) -> Self {
34        Self {
35            server_id,
36            name,
37            ip,
38            tags,
39            metadata: Vec::new(),
40        }
41    }
42}
43
44/// Errors from provider API calls.
45#[derive(Debug, Error)]
46pub enum ProviderError {
47    #[error("HTTP error: {0}")]
48    Http(String),
49    #[error("Failed to parse response: {0}")]
50    Parse(String),
51    #[error("Authentication failed. Check your API token.")]
52    AuthFailed,
53    #[error("Rate limited. Try again in a moment.")]
54    RateLimited,
55    #[error("Cancelled.")]
56    Cancelled,
57    /// Some hosts were fetched but others failed. The caller should use the
58    /// hosts but suppress destructive operations like --remove.
59    #[error("Partial result: {failures} of {total} failed")]
60    PartialResult {
61        hosts: Vec<ProviderHost>,
62        failures: usize,
63        total: usize,
64    },
65}
66
67/// Trait implemented by each cloud provider.
68pub trait Provider {
69    /// Full provider name (e.g. "digitalocean").
70    fn name(&self) -> &str;
71    /// Short label for aliases (e.g. "do").
72    fn short_label(&self) -> &str;
73    /// Fetch hosts with cancellation support.
74    fn fetch_hosts_cancellable(
75        &self,
76        token: &str,
77        cancel: &AtomicBool,
78    ) -> Result<Vec<ProviderHost>, ProviderError>;
79    /// Fetch all servers from the provider API.
80    #[allow(dead_code)]
81    fn fetch_hosts(&self, token: &str) -> Result<Vec<ProviderHost>, ProviderError> {
82        self.fetch_hosts_cancellable(token, &AtomicBool::new(false))
83    }
84    /// Fetch hosts with progress reporting. Default delegates to fetch_hosts_cancellable.
85    fn fetch_hosts_with_progress(
86        &self,
87        token: &str,
88        cancel: &AtomicBool,
89        _progress: &dyn Fn(&str),
90    ) -> Result<Vec<ProviderHost>, ProviderError> {
91        self.fetch_hosts_cancellable(token, cancel)
92    }
93}
94
95/// All known provider names.
96pub const PROVIDER_NAMES: &[&str] = &["digitalocean", "vultr", "linode", "hetzner", "upcloud", "proxmox"];
97
98/// Get a provider implementation by name.
99pub fn get_provider(name: &str) -> Option<Box<dyn Provider>> {
100    match name {
101        "digitalocean" => Some(Box::new(digitalocean::DigitalOcean)),
102        "vultr" => Some(Box::new(vultr::Vultr)),
103        "linode" => Some(Box::new(linode::Linode)),
104        "hetzner" => Some(Box::new(hetzner::Hetzner)),
105        "upcloud" => Some(Box::new(upcloud::UpCloud)),
106        "proxmox" => Some(Box::new(proxmox::Proxmox {
107            base_url: String::new(),
108            verify_tls: true,
109        })),
110        _ => None,
111    }
112}
113
114/// Get a provider implementation configured from a provider section.
115/// For providers that need extra config (e.g. Proxmox base URL), this
116/// creates a properly configured instance.
117pub fn get_provider_with_config(name: &str, section: &config::ProviderSection) -> Option<Box<dyn Provider>> {
118    match name {
119        "proxmox" => Some(Box::new(proxmox::Proxmox {
120            base_url: section.url.clone(),
121            verify_tls: section.verify_tls,
122        })),
123        _ => get_provider(name),
124    }
125}
126
127/// Display name for a provider (e.g. "digitalocean" -> "DigitalOcean").
128pub fn provider_display_name(name: &str) -> &str {
129    match name {
130        "digitalocean" => "DigitalOcean",
131        "vultr" => "Vultr",
132        "linode" => "Linode",
133        "hetzner" => "Hetzner",
134        "upcloud" => "UpCloud",
135        "proxmox" => "Proxmox VE",
136        other => other,
137    }
138}
139
140/// Create an HTTP agent with explicit timeouts.
141pub(crate) fn http_agent() -> ureq::Agent {
142    ureq::AgentBuilder::new()
143        .timeout(std::time::Duration::from_secs(30))
144        .redirects(0)
145        .build()
146}
147
148/// Create an HTTP agent that accepts invalid/self-signed TLS certificates.
149pub(crate) fn http_agent_insecure() -> Result<ureq::Agent, ProviderError> {
150    let tls = ureq::native_tls::TlsConnector::builder()
151        .danger_accept_invalid_certs(true)
152        .danger_accept_invalid_hostnames(true)
153        .build()
154        .map_err(|e| ProviderError::Http(format!("TLS setup failed: {}", e)))?;
155    Ok(ureq::AgentBuilder::new()
156        .timeout(std::time::Duration::from_secs(30))
157        .redirects(0)
158        .tls_connector(std::sync::Arc::new(tls))
159        .build())
160}
161
162/// Strip CIDR suffix (/64, /128, etc.) from an IP address.
163/// Some provider APIs return IPv6 addresses with prefix length (e.g. "2600:3c00::1/128").
164/// SSH requires bare addresses without CIDR notation.
165pub(crate) fn strip_cidr(ip: &str) -> &str {
166    // Only strip if it looks like a CIDR suffix (slash followed by digits)
167    if let Some(pos) = ip.rfind('/') {
168        if ip[pos + 1..].bytes().all(|b| b.is_ascii_digit()) && pos + 1 < ip.len() {
169            return &ip[..pos];
170        }
171    }
172    ip
173}
174
175/// Map a ureq error to a ProviderError.
176fn map_ureq_error(err: ureq::Error) -> ProviderError {
177    match err {
178        ureq::Error::Status(401, _) | ureq::Error::Status(403, _) => ProviderError::AuthFailed,
179        ureq::Error::Status(429, _) => ProviderError::RateLimited,
180        ureq::Error::Status(code, _) => ProviderError::Http(format!("HTTP {}", code)),
181        ureq::Error::Transport(t) => ProviderError::Http(t.to_string()),
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    // =========================================================================
190    // strip_cidr tests
191    // =========================================================================
192
193    #[test]
194    fn test_strip_cidr_ipv6_with_prefix() {
195        assert_eq!(strip_cidr("2600:3c00::1/128"), "2600:3c00::1");
196        assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
197    }
198
199    #[test]
200    fn test_strip_cidr_bare_ipv6() {
201        assert_eq!(strip_cidr("2600:3c00::1"), "2600:3c00::1");
202    }
203
204    #[test]
205    fn test_strip_cidr_ipv4_passthrough() {
206        assert_eq!(strip_cidr("1.2.3.4"), "1.2.3.4");
207        assert_eq!(strip_cidr("10.0.0.1/24"), "10.0.0.1");
208    }
209
210    #[test]
211    fn test_strip_cidr_empty() {
212        assert_eq!(strip_cidr(""), "");
213    }
214
215    #[test]
216    fn test_strip_cidr_slash_without_digits() {
217        // Shouldn't strip if after slash there are non-digits
218        assert_eq!(strip_cidr("path/to/something"), "path/to/something");
219    }
220
221    #[test]
222    fn test_strip_cidr_trailing_slash() {
223        // Trailing slash with nothing after: pos+1 == ip.len(), should NOT strip
224        assert_eq!(strip_cidr("1.2.3.4/"), "1.2.3.4/");
225    }
226
227    // =========================================================================
228    // get_provider factory tests
229    // =========================================================================
230
231    #[test]
232    fn test_get_provider_digitalocean() {
233        let p = get_provider("digitalocean").unwrap();
234        assert_eq!(p.name(), "digitalocean");
235        assert_eq!(p.short_label(), "do");
236    }
237
238    #[test]
239    fn test_get_provider_vultr() {
240        let p = get_provider("vultr").unwrap();
241        assert_eq!(p.name(), "vultr");
242        assert_eq!(p.short_label(), "vultr");
243    }
244
245    #[test]
246    fn test_get_provider_linode() {
247        let p = get_provider("linode").unwrap();
248        assert_eq!(p.name(), "linode");
249        assert_eq!(p.short_label(), "linode");
250    }
251
252    #[test]
253    fn test_get_provider_hetzner() {
254        let p = get_provider("hetzner").unwrap();
255        assert_eq!(p.name(), "hetzner");
256        assert_eq!(p.short_label(), "hetzner");
257    }
258
259    #[test]
260    fn test_get_provider_upcloud() {
261        let p = get_provider("upcloud").unwrap();
262        assert_eq!(p.name(), "upcloud");
263        assert_eq!(p.short_label(), "uc");
264    }
265
266    #[test]
267    fn test_get_provider_proxmox() {
268        let p = get_provider("proxmox").unwrap();
269        assert_eq!(p.name(), "proxmox");
270        assert_eq!(p.short_label(), "pve");
271    }
272
273    #[test]
274    fn test_get_provider_unknown_returns_none() {
275        assert!(get_provider("aws").is_none());
276        assert!(get_provider("").is_none());
277        assert!(get_provider("DigitalOcean").is_none()); // case-sensitive
278    }
279
280    #[test]
281    fn test_get_provider_all_names_resolve() {
282        for name in PROVIDER_NAMES {
283            assert!(get_provider(name).is_some(), "Provider '{}' should resolve", name);
284        }
285    }
286
287    // =========================================================================
288    // get_provider_with_config tests
289    // =========================================================================
290
291    #[test]
292    fn test_get_provider_with_config_proxmox_uses_url() {
293        let section = config::ProviderSection {
294            provider: "proxmox".to_string(),
295            token: "user@pam!token=secret".to_string(),
296            alias_prefix: "pve-".to_string(),
297            user: String::new(),
298            identity_file: String::new(),
299            url: "https://pve.example.com:8006".to_string(),
300            verify_tls: false,
301            auto_sync: false,
302        };
303        let p = get_provider_with_config("proxmox", &section).unwrap();
304        assert_eq!(p.name(), "proxmox");
305    }
306
307    #[test]
308    fn test_get_provider_with_config_non_proxmox_delegates() {
309        let section = config::ProviderSection {
310            provider: "digitalocean".to_string(),
311            token: "do-token".to_string(),
312            alias_prefix: "do-".to_string(),
313            user: String::new(),
314            identity_file: String::new(),
315            url: String::new(),
316            verify_tls: true,
317            auto_sync: true,
318        };
319        let p = get_provider_with_config("digitalocean", &section).unwrap();
320        assert_eq!(p.name(), "digitalocean");
321    }
322
323    #[test]
324    fn test_get_provider_with_config_unknown_returns_none() {
325        let section = config::ProviderSection {
326            provider: "aws".to_string(),
327            token: String::new(),
328            alias_prefix: String::new(),
329            user: String::new(),
330            identity_file: String::new(),
331            url: String::new(),
332            verify_tls: true,
333            auto_sync: true,
334        };
335        assert!(get_provider_with_config("aws", &section).is_none());
336    }
337
338    // =========================================================================
339    // provider_display_name tests
340    // =========================================================================
341
342    #[test]
343    fn test_display_name_all_providers() {
344        assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
345        assert_eq!(provider_display_name("vultr"), "Vultr");
346        assert_eq!(provider_display_name("linode"), "Linode");
347        assert_eq!(provider_display_name("hetzner"), "Hetzner");
348        assert_eq!(provider_display_name("upcloud"), "UpCloud");
349        assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
350    }
351
352    #[test]
353    fn test_display_name_unknown_returns_input() {
354        assert_eq!(provider_display_name("aws"), "aws");
355        assert_eq!(provider_display_name(""), "");
356    }
357
358    // =========================================================================
359    // PROVIDER_NAMES constant tests
360    // =========================================================================
361
362    #[test]
363    fn test_provider_names_count() {
364        assert_eq!(PROVIDER_NAMES.len(), 6);
365    }
366
367    #[test]
368    fn test_provider_names_contains_all() {
369        assert!(PROVIDER_NAMES.contains(&"digitalocean"));
370        assert!(PROVIDER_NAMES.contains(&"vultr"));
371        assert!(PROVIDER_NAMES.contains(&"linode"));
372        assert!(PROVIDER_NAMES.contains(&"hetzner"));
373        assert!(PROVIDER_NAMES.contains(&"upcloud"));
374        assert!(PROVIDER_NAMES.contains(&"proxmox"));
375    }
376
377    // =========================================================================
378    // ProviderError display tests
379    // =========================================================================
380
381    #[test]
382    fn test_provider_error_display_http() {
383        let err = ProviderError::Http("connection refused".to_string());
384        assert_eq!(format!("{}", err), "HTTP error: connection refused");
385    }
386
387    #[test]
388    fn test_provider_error_display_parse() {
389        let err = ProviderError::Parse("invalid JSON".to_string());
390        assert_eq!(format!("{}", err), "Failed to parse response: invalid JSON");
391    }
392
393    #[test]
394    fn test_provider_error_display_auth() {
395        let err = ProviderError::AuthFailed;
396        assert!(format!("{}", err).contains("Authentication failed"));
397    }
398
399    #[test]
400    fn test_provider_error_display_rate_limited() {
401        let err = ProviderError::RateLimited;
402        assert!(format!("{}", err).contains("Rate limited"));
403    }
404
405    #[test]
406    fn test_provider_error_display_cancelled() {
407        let err = ProviderError::Cancelled;
408        assert_eq!(format!("{}", err), "Cancelled.");
409    }
410
411    #[test]
412    fn test_provider_error_display_partial_result() {
413        let err = ProviderError::PartialResult {
414            hosts: vec![],
415            failures: 3,
416            total: 10,
417        };
418        assert!(format!("{}", err).contains("3 of 10 failed"));
419    }
420
421    // =========================================================================
422    // ProviderHost struct tests
423    // =========================================================================
424
425    #[test]
426    fn test_provider_host_construction() {
427        let host = ProviderHost::new("12345".to_string(), "web-01".to_string(), "1.2.3.4".to_string(), vec!["prod".to_string(), "web".to_string()]);
428        assert_eq!(host.server_id, "12345");
429        assert_eq!(host.name, "web-01");
430        assert_eq!(host.ip, "1.2.3.4");
431        assert_eq!(host.tags.len(), 2);
432    }
433
434    #[test]
435    fn test_provider_host_clone() {
436        let host = ProviderHost::new("1".to_string(), "a".to_string(), "1.1.1.1".to_string(), vec![]);
437        let cloned = host.clone();
438        assert_eq!(cloned.server_id, host.server_id);
439        assert_eq!(cloned.name, host.name);
440    }
441
442    // =========================================================================
443    // strip_cidr additional edge cases
444    // =========================================================================
445
446    #[test]
447    fn test_strip_cidr_ipv6_with_64() {
448        assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
449    }
450
451    #[test]
452    fn test_strip_cidr_ipv4_with_32() {
453        assert_eq!(strip_cidr("1.2.3.4/32"), "1.2.3.4");
454    }
455
456    #[test]
457    fn test_strip_cidr_ipv4_with_8() {
458        assert_eq!(strip_cidr("10.0.0.1/8"), "10.0.0.1");
459    }
460
461    #[test]
462    fn test_strip_cidr_just_slash() {
463        // "/" alone: pos=0, pos+1=1=len -> condition fails
464        assert_eq!(strip_cidr("/"), "/");
465    }
466
467    #[test]
468    fn test_strip_cidr_slash_with_letters() {
469        assert_eq!(strip_cidr("10.0.0.1/abc"), "10.0.0.1/abc");
470    }
471
472    #[test]
473    fn test_strip_cidr_multiple_slashes() {
474        // rfind gets last slash: "48" is digits, so it strips the last /48
475        assert_eq!(strip_cidr("10.0.0.1/24/48"), "10.0.0.1/24");
476    }
477
478    #[test]
479    fn test_strip_cidr_ipv6_full_notation() {
480        assert_eq!(
481            strip_cidr("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128"),
482            "2001:0db8:85a3:0000:0000:8a2e:0370:7334"
483        );
484    }
485
486    // =========================================================================
487    // ProviderError Debug
488    // =========================================================================
489
490    #[test]
491    fn test_provider_error_debug_http() {
492        let err = ProviderError::Http("timeout".to_string());
493        let debug = format!("{:?}", err);
494        assert!(debug.contains("Http"));
495        assert!(debug.contains("timeout"));
496    }
497
498    #[test]
499    fn test_provider_error_debug_partial_result() {
500        let err = ProviderError::PartialResult {
501            hosts: vec![ProviderHost::new("1".to_string(), "web".to_string(), "1.2.3.4".to_string(), vec![])],
502            failures: 2,
503            total: 5,
504        };
505        let debug = format!("{:?}", err);
506        assert!(debug.contains("PartialResult"));
507        assert!(debug.contains("failures: 2"));
508    }
509
510    // =========================================================================
511    // ProviderHost with empty fields
512    // =========================================================================
513
514    #[test]
515    fn test_provider_host_empty_fields() {
516        let host = ProviderHost::new(String::new(), String::new(), String::new(), vec![]);
517        assert!(host.server_id.is_empty());
518        assert!(host.name.is_empty());
519        assert!(host.ip.is_empty());
520    }
521
522    // =========================================================================
523    // get_provider_with_config for all non-proxmox providers
524    // =========================================================================
525
526    #[test]
527    fn test_get_provider_with_config_all_providers() {
528        for &name in PROVIDER_NAMES {
529            let section = config::ProviderSection {
530                provider: name.to_string(),
531                token: "tok".to_string(),
532                alias_prefix: "test".to_string(),
533                user: String::new(),
534                identity_file: String::new(),
535                url: if name == "proxmox" {
536                    "https://pve:8006".to_string()
537                } else {
538                    String::new()
539                },
540                verify_tls: true,
541                auto_sync: true,
542            };
543            let p = get_provider_with_config(name, &section);
544            assert!(p.is_some(), "get_provider_with_config({}) should return Some", name);
545            assert_eq!(p.unwrap().name(), name);
546        }
547    }
548
549    // =========================================================================
550    // Provider trait default methods
551    // =========================================================================
552
553    #[test]
554    fn test_provider_fetch_hosts_delegates_to_cancellable() {
555        let provider = get_provider("digitalocean").unwrap();
556        // fetch_hosts delegates to fetch_hosts_cancellable with AtomicBool(false)
557        // We can't actually test this without a server, but we verify the method exists
558        // by calling it (will fail with network error, which is fine for this test)
559        let result = provider.fetch_hosts("fake-token");
560        assert!(result.is_err()); // Expected: no network
561    }
562
563    // =========================================================================
564    // strip_cidr: suffix starts with digit but contains letters
565    // =========================================================================
566
567    #[test]
568    fn test_strip_cidr_digit_then_letters_not_stripped() {
569        assert_eq!(strip_cidr("10.0.0.1/24abc"), "10.0.0.1/24abc");
570    }
571
572    // =========================================================================
573    // provider_display_name: all known providers
574    // =========================================================================
575
576    #[test]
577    fn test_provider_display_name_all() {
578        assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
579        assert_eq!(provider_display_name("vultr"), "Vultr");
580        assert_eq!(provider_display_name("linode"), "Linode");
581        assert_eq!(provider_display_name("hetzner"), "Hetzner");
582        assert_eq!(provider_display_name("upcloud"), "UpCloud");
583        assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
584    }
585
586    #[test]
587    fn test_provider_display_name_unknown() {
588        assert_eq!(provider_display_name("foobar"), "foobar");
589    }
590
591    // =========================================================================
592    // get_provider: all known + unknown
593    // =========================================================================
594
595    #[test]
596    fn test_get_provider_all_known() {
597        for name in PROVIDER_NAMES {
598            assert!(get_provider(name).is_some(), "get_provider({}) should return Some", name);
599        }
600    }
601
602    #[test]
603    fn test_get_provider_case_sensitive_and_unknown() {
604        assert!(get_provider("aws").is_none());
605        assert!(get_provider("DigitalOcean").is_none()); // Case-sensitive
606        assert!(get_provider("VULTR").is_none());
607        assert!(get_provider("").is_none());
608    }
609
610    // =========================================================================
611    // PROVIDER_NAMES constant
612    // =========================================================================
613
614    #[test]
615    fn test_provider_names_has_all_six() {
616        assert_eq!(PROVIDER_NAMES.len(), 6);
617        assert!(PROVIDER_NAMES.contains(&"digitalocean"));
618        assert!(PROVIDER_NAMES.contains(&"proxmox"));
619    }
620
621    // =========================================================================
622    // Provider short_label via get_provider
623    // =========================================================================
624
625    #[test]
626    fn test_provider_short_labels() {
627        let cases = [
628            ("digitalocean", "do"),
629            ("vultr", "vultr"),
630            ("linode", "linode"),
631            ("hetzner", "hetzner"),
632            ("upcloud", "uc"),
633            ("proxmox", "pve"),
634        ];
635        for (name, expected_label) in &cases {
636            let p = get_provider(name).unwrap();
637            assert_eq!(p.short_label(), *expected_label, "short_label for {}", name);
638        }
639    }
640}