Skip to main content

purple_ssh/providers/
mod.rs

1pub mod aws;
2pub mod azure;
3pub mod config;
4mod digitalocean;
5pub mod gcp;
6mod hetzner;
7mod i3d;
8mod leaseweb;
9mod linode;
10pub mod oracle;
11pub mod ovh;
12mod proxmox;
13pub mod scaleway;
14pub mod sync;
15mod tailscale;
16mod transip;
17mod upcloud;
18mod vultr;
19
20use std::sync::atomic::AtomicBool;
21
22use log::{error, warn};
23use thiserror::Error;
24
25/// A host discovered from a cloud provider API.
26#[derive(Debug, Clone)]
27pub struct ProviderHost {
28    /// Provider-assigned server ID.
29    pub server_id: String,
30    /// Server name/label.
31    pub name: String,
32    /// Public IP address (IPv4 or IPv6).
33    pub ip: String,
34    /// Provider tags/labels.
35    pub tags: Vec<String>,
36    /// Provider metadata (region, plan, etc.) as key-value pairs.
37    pub metadata: Vec<(String, String)>,
38}
39
40impl ProviderHost {
41    /// Create a ProviderHost with no metadata.
42    #[allow(dead_code)]
43    pub fn new(server_id: String, name: String, ip: String, tags: Vec<String>) -> Self {
44        Self {
45            server_id,
46            name,
47            ip,
48            tags,
49            metadata: Vec::new(),
50        }
51    }
52}
53
54/// Errors from provider API calls.
55#[derive(Debug, Error)]
56pub enum ProviderError {
57    #[error("HTTP error: {0}")]
58    Http(String),
59    #[error("Failed to parse response: {0}")]
60    Parse(String),
61    #[error("Authentication failed. Check your API token.")]
62    AuthFailed,
63    #[error("Rate limited. Try again in a moment.")]
64    RateLimited,
65    #[error("{0}")]
66    Execute(String),
67    #[error("Cancelled.")]
68    Cancelled,
69    /// Some hosts were fetched but others failed. The caller should use the
70    /// hosts but suppress destructive operations like --remove.
71    #[error("Partial result: {failures} of {total} failed")]
72    PartialResult {
73        hosts: Vec<ProviderHost>,
74        failures: usize,
75        total: usize,
76    },
77}
78
79/// Trait implemented by each cloud provider.
80pub trait Provider {
81    /// Full provider name (e.g. "digitalocean").
82    fn name(&self) -> &str;
83    /// Short label for aliases (e.g. "do").
84    fn short_label(&self) -> &str;
85    /// Fetch hosts with cancellation support.
86    #[allow(dead_code)]
87    fn fetch_hosts_cancellable(
88        &self,
89        token: &str,
90        cancel: &AtomicBool,
91    ) -> Result<Vec<ProviderHost>, ProviderError>;
92    /// Fetch all servers from the provider API.
93    #[allow(dead_code)]
94    fn fetch_hosts(&self, token: &str) -> Result<Vec<ProviderHost>, ProviderError> {
95        self.fetch_hosts_cancellable(token, &AtomicBool::new(false))
96    }
97    /// Fetch hosts with progress reporting. Default delegates to fetch_hosts_cancellable.
98    #[allow(dead_code)]
99    fn fetch_hosts_with_progress(
100        &self,
101        token: &str,
102        cancel: &AtomicBool,
103        _progress: &dyn Fn(&str),
104    ) -> Result<Vec<ProviderHost>, ProviderError> {
105        self.fetch_hosts_cancellable(token, cancel)
106    }
107}
108
109/// Parse a comma-separated provider config field into a list of trimmed,
110/// non-empty entries. Used for regions/zones/subscriptions.
111fn parse_csv(s: &str) -> Vec<String> {
112    s.split(',')
113        .map(|s| s.trim().to_string())
114        .filter(|s| !s.is_empty())
115        .collect()
116}
117
118/// Factory for a provider implementation from an optional config section.
119/// `None` yields a default-constructed instance; `Some(section)` wires the
120/// section's fields into the provider struct.
121type ProviderBuild = fn(Option<&config::ProviderSection>) -> Box<dyn Provider>;
122
123/// Static registry entry describing one provider. Adding a provider means
124/// adding exactly one `ProviderDescriptor` to `PROVIDERS` below.
125pub struct ProviderDescriptor {
126    /// Slug used in config files and aliases.
127    pub name: &'static str,
128    /// Human-readable name shown in the UI.
129    pub display: &'static str,
130    /// Builder. Must not allocate or fail.
131    pub build: ProviderBuild,
132}
133
134/// Single source of truth for the provider registry. Adding a new provider
135/// means one entry here plus the provider module itself.
136pub const PROVIDERS: &[ProviderDescriptor] = &[
137    ProviderDescriptor {
138        name: "digitalocean",
139        display: "DigitalOcean",
140        build: |_| Box::new(digitalocean::DigitalOcean),
141    },
142    ProviderDescriptor {
143        name: "vultr",
144        display: "Vultr",
145        build: |_| Box::new(vultr::Vultr),
146    },
147    ProviderDescriptor {
148        name: "linode",
149        display: "Linode",
150        build: |_| Box::new(linode::Linode),
151    },
152    ProviderDescriptor {
153        name: "hetzner",
154        display: "Hetzner",
155        build: |_| Box::new(hetzner::Hetzner),
156    },
157    ProviderDescriptor {
158        name: "upcloud",
159        display: "UpCloud",
160        build: |_| Box::new(upcloud::UpCloud),
161    },
162    ProviderDescriptor {
163        name: "proxmox",
164        display: "Proxmox VE",
165        build: |section| {
166            let s = section.cloned().unwrap_or_default();
167            Box::new(proxmox::Proxmox {
168                base_url: s.url,
169                verify_tls: s.verify_tls,
170            })
171        },
172    },
173    ProviderDescriptor {
174        name: "aws",
175        display: "AWS EC2",
176        build: |section| {
177            let s = section.cloned().unwrap_or_default();
178            Box::new(aws::Aws {
179                regions: parse_csv(&s.regions),
180                profile: s.profile,
181            })
182        },
183    },
184    ProviderDescriptor {
185        name: "scaleway",
186        display: "Scaleway",
187        build: |section| {
188            let s = section.cloned().unwrap_or_default();
189            Box::new(scaleway::Scaleway {
190                zones: parse_csv(&s.regions),
191            })
192        },
193    },
194    ProviderDescriptor {
195        name: "gcp",
196        display: "GCP",
197        build: |section| {
198            let s = section.cloned().unwrap_or_default();
199            Box::new(gcp::Gcp {
200                zones: parse_csv(&s.regions),
201                project: s.project,
202            })
203        },
204    },
205    ProviderDescriptor {
206        name: "azure",
207        display: "Azure",
208        build: |section| {
209            let s = section.cloned().unwrap_or_default();
210            Box::new(azure::Azure {
211                subscriptions: parse_csv(&s.regions),
212            })
213        },
214    },
215    ProviderDescriptor {
216        name: "tailscale",
217        display: "Tailscale",
218        build: |_| Box::new(tailscale::Tailscale),
219    },
220    ProviderDescriptor {
221        name: "oracle",
222        display: "Oracle Cloud",
223        build: |section| {
224            let s = section.cloned().unwrap_or_default();
225            Box::new(oracle::Oracle {
226                regions: parse_csv(&s.regions),
227                compartment: s.compartment,
228            })
229        },
230    },
231    ProviderDescriptor {
232        name: "ovh",
233        display: "OVHcloud",
234        // OVH overloads `regions` as the API endpoint (e.g. "ovh-eu").
235        // Known quirk flagged in the architecture review; kept as-is to
236        // avoid schema migration in this refactor.
237        build: |section| {
238            let s = section.cloned().unwrap_or_default();
239            Box::new(ovh::Ovh {
240                project: s.project,
241                endpoint: s.regions,
242            })
243        },
244    },
245    ProviderDescriptor {
246        name: "leaseweb",
247        display: "Leaseweb",
248        build: |_| Box::new(leaseweb::Leaseweb),
249    },
250    ProviderDescriptor {
251        name: "i3d",
252        display: "i3D.net",
253        build: |_| Box::new(i3d::I3d),
254    },
255    ProviderDescriptor {
256        name: "transip",
257        display: "TransIP",
258        build: |_| Box::new(transip::TransIp),
259    },
260];
261
262/// Look up a descriptor by name.
263fn descriptor(name: &str) -> Option<&'static ProviderDescriptor> {
264    PROVIDERS.iter().find(|p| p.name == name)
265}
266
267/// All known provider names, in registration order.
268pub const PROVIDER_NAMES: &[&str] = &[
269    "digitalocean",
270    "vultr",
271    "linode",
272    "hetzner",
273    "upcloud",
274    "proxmox",
275    "aws",
276    "scaleway",
277    "gcp",
278    "azure",
279    "tailscale",
280    "oracle",
281    "ovh",
282    "leaseweb",
283    "i3d",
284    "transip",
285];
286
287// Compile-time guard: PROVIDER_NAMES and PROVIDERS must stay in lockstep.
288const _: () = {
289    assert!(
290        PROVIDER_NAMES.len() == PROVIDERS.len(),
291        "PROVIDER_NAMES and PROVIDERS length must match",
292    );
293};
294
295/// Get a provider implementation by name with default configuration.
296pub fn get_provider(name: &str) -> Option<Box<dyn Provider>> {
297    descriptor(name).map(|d| (d.build)(None))
298}
299
300/// Get a provider implementation configured from a provider section.
301pub fn get_provider_with_config(
302    name: &str,
303    section: &config::ProviderSection,
304) -> Option<Box<dyn Provider>> {
305    descriptor(name).map(|d| (d.build)(Some(section)))
306}
307
308/// Display name for a provider (e.g. "digitalocean" -> "DigitalOcean").
309pub fn provider_display_name(name: &str) -> &str {
310    descriptor(name).map(|d| d.display).unwrap_or(name)
311}
312
313/// Create an HTTP agent with explicit timeouts.
314pub(crate) fn http_agent() -> ureq::Agent {
315    ureq::Agent::config_builder()
316        .timeout_global(Some(std::time::Duration::from_secs(30)))
317        .max_redirects(0)
318        .build()
319        .new_agent()
320}
321
322/// Create an HTTP agent that accepts invalid/self-signed TLS certificates.
323pub(crate) fn http_agent_insecure() -> Result<ureq::Agent, ProviderError> {
324    Ok(ureq::Agent::config_builder()
325        .timeout_global(Some(std::time::Duration::from_secs(30)))
326        .max_redirects(0)
327        .tls_config(
328            ureq::tls::TlsConfig::builder()
329                .provider(ureq::tls::TlsProvider::NativeTls)
330                .disable_verification(true)
331                .build(),
332        )
333        .build()
334        .new_agent())
335}
336
337/// Strip CIDR suffix (/64, /128, etc.) from an IP address.
338/// Some provider APIs return IPv6 addresses with prefix length (e.g. "2600:3c00::1/128").
339/// SSH requires bare addresses without CIDR notation.
340pub(crate) fn strip_cidr(ip: &str) -> &str {
341    // Only strip if it looks like a CIDR suffix (slash followed by digits)
342    if let Some(pos) = ip.rfind('/') {
343        if ip[pos + 1..].bytes().all(|b| b.is_ascii_digit()) && pos + 1 < ip.len() {
344            return &ip[..pos];
345        }
346    }
347    ip
348}
349
350/// RFC 3986 percent-encoding for URL query parameters.
351/// Encodes all characters except unreserved ones (A-Z, a-z, 0-9, '-', '_', '.', '~').
352pub(crate) fn percent_encode(s: &str) -> String {
353    let mut result = String::with_capacity(s.len());
354    for byte in s.bytes() {
355        match byte {
356            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
357                result.push(byte as char);
358            }
359            _ => {
360                result.push_str(&format!("%{:02X}", byte));
361            }
362        }
363    }
364    result
365}
366
367/// Date components from a Unix epoch timestamp (no chrono dependency).
368pub(crate) struct EpochDate {
369    pub year: u64,
370    pub month: u64, // 1-based
371    pub day: u64,   // 1-based
372    pub hours: u64,
373    pub minutes: u64,
374    pub seconds: u64,
375    /// Days since epoch (for weekday calculation)
376    pub epoch_days: u64,
377}
378
379/// Convert Unix epoch seconds to date components.
380pub(crate) fn epoch_to_date(epoch_secs: u64) -> EpochDate {
381    let secs_per_day = 86400u64;
382    let epoch_days = epoch_secs / secs_per_day;
383    let mut remaining_days = epoch_days;
384    let day_secs = epoch_secs % secs_per_day;
385
386    let mut year = 1970u64;
387    loop {
388        let leap = year % 4 == 0 && (year % 100 != 0 || year % 400 == 0);
389        let days_in_year = if leap { 366 } else { 365 };
390        if remaining_days < days_in_year {
391            break;
392        }
393        remaining_days -= days_in_year;
394        year += 1;
395    }
396
397    let leap = year % 4 == 0 && (year % 100 != 0 || year % 400 == 0);
398    let days_per_month: [u64; 12] = [
399        31,
400        if leap { 29 } else { 28 },
401        31,
402        30,
403        31,
404        30,
405        31,
406        31,
407        30,
408        31,
409        30,
410        31,
411    ];
412    let mut month = 0usize;
413    while month < 12 && remaining_days >= days_per_month[month] {
414        remaining_days -= days_per_month[month];
415        month += 1;
416    }
417
418    EpochDate {
419        year,
420        month: (month + 1) as u64,
421        day: remaining_days + 1,
422        hours: day_secs / 3600,
423        minutes: (day_secs % 3600) / 60,
424        seconds: day_secs % 60,
425        epoch_days,
426    }
427}
428
429/// Map a ureq error to a ProviderError.
430fn map_ureq_error(err: ureq::Error) -> ProviderError {
431    match err {
432        ureq::Error::StatusCode(code) => match code {
433            401 | 403 => {
434                error!("[external] HTTP {code}: authentication failed");
435                ProviderError::AuthFailed
436            }
437            429 => {
438                warn!("[external] HTTP 429: rate limited");
439                ProviderError::RateLimited
440            }
441            _ => {
442                error!("[external] HTTP {code}");
443                ProviderError::Http(format!("HTTP {}", code))
444            }
445        },
446        other => {
447            error!("[external] Request failed: {other}");
448            ProviderError::Http(other.to_string())
449        }
450    }
451}
452
453#[cfg(test)]
454#[path = "mod_tests.rs"]
455mod tests;