Skip to main content

purple_ssh/providers/
mod.rs

1pub mod aws;
2pub mod azure;
3pub mod config;
4mod digitalocean;
5pub mod gcp;
6mod hetzner;
7mod i3d;
8pub mod kind;
9mod leaseweb;
10mod linode;
11pub mod oracle;
12pub mod ovh;
13mod proxmox;
14pub mod scaleway;
15pub mod sync;
16mod tailscale;
17mod transip;
18mod upcloud;
19mod vultr;
20
21pub use kind::ProviderKind;
22
23use std::sync::atomic::{AtomicBool, Ordering};
24
25use log::{debug, error, warn};
26use thiserror::Error;
27
28/// A host discovered from a cloud provider API.
29#[derive(Debug, Clone)]
30pub struct ProviderHost {
31    /// Provider-assigned server ID.
32    pub server_id: String,
33    /// Server name/label.
34    pub name: String,
35    /// Public IP address (IPv4 or IPv6).
36    pub ip: String,
37    /// Provider tags/labels.
38    pub tags: Vec<String>,
39    /// Provider metadata (region, plan, etc.) as key-value pairs.
40    pub metadata: Vec<(String, String)>,
41}
42
43impl ProviderHost {
44    /// Create a ProviderHost with no metadata.
45    #[allow(dead_code)]
46    pub fn new(server_id: String, name: String, ip: String, tags: Vec<String>) -> Self {
47        Self {
48            server_id,
49            name,
50            ip,
51            tags,
52            metadata: Vec::new(),
53        }
54    }
55}
56
57/// Errors from provider API calls.
58#[derive(Debug, Error)]
59pub enum ProviderError {
60    #[error("HTTP error: {0}")]
61    Http(String),
62    #[error("Failed to parse response: {0}")]
63    Parse(String),
64    #[error("Authentication failed. Check your API token.")]
65    AuthFailed,
66    #[error("Rate limited. Try again in a moment.")]
67    RateLimited,
68    #[error("{0}")]
69    Execute(String),
70    #[error("Cancelled.")]
71    Cancelled,
72    /// Some hosts were fetched but others failed. The caller should use the
73    /// hosts but suppress destructive operations like --remove.
74    #[error("Partial result: {failures} of {total} failed")]
75    PartialResult {
76        hosts: Vec<ProviderHost>,
77        failures: usize,
78        total: usize,
79    },
80}
81
82/// Trait implemented by each cloud provider.
83pub trait Provider {
84    /// Full provider name (e.g. "digitalocean").
85    fn name(&self) -> &str;
86    /// Short label for aliases (e.g. "do").
87    fn short_label(&self) -> &str;
88    /// Fetch hosts with cancellation support. `env` carries the resolved
89    /// process environment (home directory, credential env vars) so the few
90    /// providers that read AWS credentials or expand `~` in key paths take
91    /// them from the injected snapshot instead of ambient `std::env` /
92    /// `dirs::home_dir`. Most providers ignore it.
93    #[allow(dead_code)]
94    fn fetch_hosts_cancellable(
95        &self,
96        token: &str,
97        cancel: &AtomicBool,
98        env: &crate::runtime::env::Env,
99    ) -> Result<Vec<ProviderHost>, ProviderError>;
100    /// Fetch all servers from the provider API.
101    #[allow(dead_code)]
102    fn fetch_hosts(
103        &self,
104        token: &str,
105        env: &crate::runtime::env::Env,
106    ) -> Result<Vec<ProviderHost>, ProviderError> {
107        self.fetch_hosts_cancellable(token, &AtomicBool::new(false), env)
108    }
109    /// Fetch hosts with progress reporting. Default delegates to fetch_hosts_cancellable.
110    #[allow(dead_code)]
111    fn fetch_hosts_with_progress(
112        &self,
113        token: &str,
114        cancel: &AtomicBool,
115        env: &crate::runtime::env::Env,
116        _progress: &dyn Fn(&str),
117    ) -> Result<Vec<ProviderHost>, ProviderError> {
118        self.fetch_hosts_cancellable(token, cancel, env)
119    }
120}
121
122/// Parse a comma-separated provider config field into a list of trimmed,
123/// non-empty entries. Used for regions/zones/subscriptions.
124fn parse_csv(s: &str) -> Vec<String> {
125    s.split(',')
126        .map(|s| s.trim().to_string())
127        .filter(|s| !s.is_empty())
128        .collect()
129}
130
131/// Factory for a provider implementation from an optional config section.
132/// `None` yields a default-constructed instance; `Some(section)` wires the
133/// section's fields into the provider struct.
134type ProviderBuild = fn(Option<&config::ProviderSection>) -> Box<dyn Provider>;
135
136/// Static registry entry describing one provider. Adding a provider means
137/// adding exactly one `ProviderDescriptor` to `PROVIDERS` below.
138pub struct ProviderDescriptor {
139    /// Slug used in config files and aliases.
140    pub name: &'static str,
141    /// Human-readable name shown in the UI.
142    pub display: &'static str,
143    /// Builder. Must not allocate or fail.
144    pub build: ProviderBuild,
145}
146
147/// Single source of truth for the provider registry. Adding a new provider
148/// means one entry here plus the provider module itself.
149pub const PROVIDERS: &[ProviderDescriptor] = &[
150    ProviderDescriptor {
151        name: "digitalocean",
152        display: "DigitalOcean",
153        build: |_| Box::new(digitalocean::DigitalOcean),
154    },
155    ProviderDescriptor {
156        name: "vultr",
157        display: "Vultr",
158        build: |_| Box::new(vultr::Vultr),
159    },
160    ProviderDescriptor {
161        name: "linode",
162        display: "Linode",
163        build: |_| Box::new(linode::Linode),
164    },
165    ProviderDescriptor {
166        name: "hetzner",
167        display: "Hetzner",
168        build: |_| Box::new(hetzner::Hetzner),
169    },
170    ProviderDescriptor {
171        name: "upcloud",
172        display: "UpCloud",
173        build: |_| Box::new(upcloud::UpCloud),
174    },
175    ProviderDescriptor {
176        name: "proxmox",
177        display: "Proxmox VE",
178        build: |section| {
179            let s = section.cloned().unwrap_or_default();
180            Box::new(proxmox::Proxmox {
181                base_url: s.url,
182                verify_tls: s.verify_tls,
183            })
184        },
185    },
186    ProviderDescriptor {
187        name: "aws",
188        display: "AWS EC2",
189        build: |section| {
190            let s = section.cloned().unwrap_or_default();
191            Box::new(aws::Aws {
192                regions: parse_csv(&s.regions),
193                profile: s.profile,
194            })
195        },
196    },
197    ProviderDescriptor {
198        name: "scaleway",
199        display: "Scaleway",
200        build: |section| {
201            let s = section.cloned().unwrap_or_default();
202            Box::new(scaleway::Scaleway {
203                zones: parse_csv(&s.regions),
204            })
205        },
206    },
207    ProviderDescriptor {
208        name: "gcp",
209        display: "GCP",
210        build: |section| {
211            let s = section.cloned().unwrap_or_default();
212            Box::new(gcp::Gcp {
213                zones: parse_csv(&s.regions),
214                project: s.project,
215            })
216        },
217    },
218    ProviderDescriptor {
219        name: "azure",
220        display: "Azure",
221        build: |section| {
222            let s = section.cloned().unwrap_or_default();
223            Box::new(azure::Azure {
224                subscriptions: parse_csv(&s.regions),
225            })
226        },
227    },
228    ProviderDescriptor {
229        name: "tailscale",
230        display: "Tailscale",
231        build: |_| Box::new(tailscale::Tailscale),
232    },
233    ProviderDescriptor {
234        name: "oracle",
235        display: "Oracle Cloud",
236        build: |section| {
237            let s = section.cloned().unwrap_or_default();
238            Box::new(oracle::Oracle {
239                regions: parse_csv(&s.regions),
240                compartment: s.compartment,
241            })
242        },
243    },
244    ProviderDescriptor {
245        name: "ovh",
246        display: "OVHcloud",
247        // OVH overloads `regions` as the API endpoint (e.g. "ovh-eu").
248        // Known quirk flagged in the architecture review; kept as-is to
249        // avoid schema migration in this refactor.
250        build: |section| {
251            let s = section.cloned().unwrap_or_default();
252            Box::new(ovh::Ovh {
253                project: s.project,
254                endpoint: s.regions,
255            })
256        },
257    },
258    ProviderDescriptor {
259        name: "leaseweb",
260        display: "Leaseweb",
261        build: |_| Box::new(leaseweb::Leaseweb),
262    },
263    ProviderDescriptor {
264        name: "i3d",
265        display: "i3D.net",
266        build: |_| Box::new(i3d::I3d),
267    },
268    ProviderDescriptor {
269        name: "transip",
270        display: "TransIP",
271        build: |_| Box::new(transip::TransIp),
272    },
273];
274
275/// Look up a descriptor by bare provider name. Internal helper; public wrappers
276/// below strip any `:label` suffix before calling this, so callers cannot
277/// accidentally pass a labeled id (`proxmox:server1`) and silently miss.
278fn descriptor(name: &str) -> Option<&'static ProviderDescriptor> {
279    PROVIDERS.iter().find(|p| p.name == name)
280}
281
282/// Return the bare provider name, stripping an optional `:label` suffix.
283/// `ProviderConfigId` is the canonical home for this split; this helper keeps
284/// `&str`-only public APIs label-tolerant without forcing every caller to
285/// parse first.
286fn bare_provider_name(name: &str) -> &str {
287    name.split_once(':').map(|(p, _)| p).unwrap_or(name)
288}
289
290/// All known provider names, in registration order.
291pub const PROVIDER_NAMES: &[&str] = &[
292    "digitalocean",
293    "vultr",
294    "linode",
295    "hetzner",
296    "upcloud",
297    "proxmox",
298    "aws",
299    "scaleway",
300    "gcp",
301    "azure",
302    "tailscale",
303    "oracle",
304    "ovh",
305    "leaseweb",
306    "i3d",
307    "transip",
308];
309
310// Compile-time guard: PROVIDER_NAMES and PROVIDERS must stay in lockstep.
311const _: () = {
312    assert!(
313        PROVIDER_NAMES.len() == PROVIDERS.len(),
314        "PROVIDER_NAMES and PROVIDERS length must match",
315    );
316};
317
318/// Get a provider implementation by name with default configuration. Accepts
319/// either a bare provider name (`"proxmox"`) or a labeled id (`"proxmox:server1"`).
320pub fn get_provider(name: &str) -> Option<Box<dyn Provider>> {
321    descriptor(bare_provider_name(name)).map(|d| (d.build)(None))
322}
323
324/// Get a provider implementation configured from a provider section. The bare
325/// provider name comes from `section.id.provider`, so labeled configs resolve
326/// the right descriptor by construction; passing a separate `name` was
327/// historically a foot-gun (issue #51) where callers handed in the labeled id
328/// string and the lookup missed the registry.
329pub fn get_provider_with_config(section: &config::ProviderSection) -> Option<Box<dyn Provider>> {
330    descriptor(section.provider()).map(|d| (d.build)(Some(section)))
331}
332
333/// Display name for a provider (e.g. "digitalocean" -> "DigitalOcean"). Accepts
334/// either a bare name or a labeled id; unknown names fall back to the input.
335pub fn provider_display_name(name: &str) -> &str {
336    descriptor(bare_provider_name(name))
337        .map(|d| d.display)
338        .unwrap_or(name)
339}
340
341/// Create an HTTP agent with explicit timeouts.
342pub(crate) fn http_agent() -> ureq::Agent {
343    ureq::Agent::config_builder()
344        .timeout_global(Some(std::time::Duration::from_secs(30)))
345        .max_redirects(0)
346        .build()
347        .new_agent()
348}
349
350/// Create an HTTP agent that accepts invalid/self-signed TLS certificates.
351pub(crate) fn http_agent_insecure() -> Result<ureq::Agent, ProviderError> {
352    Ok(ureq::Agent::config_builder()
353        .timeout_global(Some(std::time::Duration::from_secs(30)))
354        .max_redirects(0)
355        .tls_config(
356            ureq::tls::TlsConfig::builder()
357                .provider(ureq::tls::TlsProvider::NativeTls)
358                .disable_verification(true)
359                .build(),
360        )
361        .build()
362        .new_agent())
363}
364
365/// Strip CIDR suffix (/64, /128, etc.) from an IP address.
366/// Some provider APIs return IPv6 addresses with prefix length (e.g. "2600:3c00::1/128").
367/// SSH requires bare addresses without CIDR notation.
368pub(crate) fn strip_cidr(ip: &str) -> &str {
369    // Only strip if it looks like a CIDR suffix (slash followed by digits)
370    if let Some(pos) = ip.rfind('/') {
371        if ip[pos + 1..].bytes().all(|b| b.is_ascii_digit()) && pos + 1 < ip.len() {
372            return &ip[..pos];
373        }
374    }
375    ip
376}
377
378/// RFC 3986 percent-encoding for URL query parameters.
379/// Encodes all characters except unreserved ones (A-Z, a-z, 0-9, '-', '_', '.', '~').
380pub(crate) fn percent_encode(s: &str) -> String {
381    let mut result = String::with_capacity(s.len());
382    for byte in s.bytes() {
383        match byte {
384            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
385                result.push(byte as char);
386            }
387            _ => {
388                result.push_str(&format!("%{:02X}", byte));
389            }
390        }
391    }
392    result
393}
394
395/// Date components from a Unix epoch timestamp (no chrono dependency).
396pub(crate) struct EpochDate {
397    pub year: u64,
398    pub month: u64, // 1-based
399    pub day: u64,   // 1-based
400    pub hours: u64,
401    pub minutes: u64,
402    pub seconds: u64,
403    /// Days since epoch (for weekday calculation)
404    pub epoch_days: u64,
405}
406
407/// Convert Unix epoch seconds to date components.
408pub(crate) fn epoch_to_date(epoch_secs: u64) -> EpochDate {
409    let secs_per_day = 86400u64;
410    let epoch_days = epoch_secs / secs_per_day;
411    let mut remaining_days = epoch_days;
412    let day_secs = epoch_secs % secs_per_day;
413
414    let mut year = 1970u64;
415    loop {
416        let leap = year % 4 == 0 && (year % 100 != 0 || year % 400 == 0);
417        let days_in_year = if leap { 366 } else { 365 };
418        if remaining_days < days_in_year {
419            break;
420        }
421        remaining_days -= days_in_year;
422        year += 1;
423    }
424
425    let leap = year % 4 == 0 && (year % 100 != 0 || year % 400 == 0);
426    let days_per_month: [u64; 12] = [
427        31,
428        if leap { 29 } else { 28 },
429        31,
430        30,
431        31,
432        30,
433        31,
434        31,
435        30,
436        31,
437        30,
438        31,
439    ];
440    let mut month = 0usize;
441    while month < 12 && remaining_days >= days_per_month[month] {
442        remaining_days -= days_per_month[month];
443        month += 1;
444    }
445
446    EpochDate {
447        year,
448        month: (month + 1) as u64,
449        day: remaining_days + 1,
450        hours: day_secs / 3600,
451        minutes: (day_secs % 3600) / 60,
452        seconds: day_secs % 60,
453        epoch_days,
454    }
455}
456
457/// Map a ureq error to a ProviderError.
458fn map_ureq_error(err: ureq::Error) -> ProviderError {
459    match err {
460        ureq::Error::StatusCode(code) => match code {
461            401 | 403 => {
462                error!("[external] HTTP {code}: authentication failed");
463                ProviderError::AuthFailed
464            }
465            429 => {
466                warn!("[external] HTTP 429: rate limited");
467                ProviderError::RateLimited
468            }
469            _ => {
470                error!("[external] HTTP {code}");
471                ProviderError::Http(format!("HTTP {}", code))
472            }
473        },
474        other => {
475            error!("[external] Request failed: {other}");
476            ProviderError::Http(other.to_string())
477        }
478    }
479}
480
481/// Upper bound on pages fetched from one paginated list endpoint. A safety
482/// valve for a provider that never signals its last page; 500 pages covers any
483/// realistic account.
484pub(crate) const MAX_PAGES: u64 = 500;
485
486/// One mapped page from a paginated list endpoint.
487pub(crate) struct PageResult {
488    /// Hosts mapped from this page, appended to the running total.
489    pub hosts: Vec<ProviderHost>,
490    /// Whether another page should be fetched.
491    pub more: bool,
492}
493
494/// Drive a paginated list endpoint under one shared cancellation, runaway-guard
495/// and partial-failure contract, so every JSON list provider behaves the same.
496///
497/// `fetch_page(index)` performs one request (0-based page index), parses it and
498/// maps entries into a `PageResult`. The closure owns its own cursor or
499/// page-number state across calls.
500///
501/// Failure policy, matching what `sync` relies on: a failure on the first page
502/// (nothing collected) propagates as a hard error so the provider is skipped
503/// and the config left untouched. A failure on a later page returns the hosts
504/// gathered so far as `PartialResult`, so add and update still run while remove
505/// and stale marking are suppressed upstream. `AuthFailed`, `RateLimited` and
506/// `Cancelled` always propagate immediately, even mid-run, because they
507/// invalidate the whole sync.
508pub(crate) fn paginate<F>(
509    cancel: &AtomicBool,
510    mut fetch_page: F,
511) -> Result<Vec<ProviderHost>, ProviderError>
512where
513    F: FnMut(u64) -> Result<PageResult, ProviderError>,
514{
515    let mut hosts = Vec::new();
516    let mut index = 0u64;
517    loop {
518        if cancel.load(Ordering::Relaxed) {
519            return Err(ProviderError::Cancelled);
520        }
521        match fetch_page(index) {
522            Ok(page) => {
523                hosts.extend(page.hosts);
524                if !page.more {
525                    return Ok(hosts);
526                }
527            }
528            Err(
529                e @ (ProviderError::Cancelled
530                | ProviderError::AuthFailed
531                | ProviderError::RateLimited),
532            ) => return Err(e),
533            Err(e) => {
534                if hosts.is_empty() {
535                    return Err(e);
536                }
537                debug!(
538                    "[external] paginate: page {} failed after {} hosts collected, returning partial result ({e})",
539                    index + 1,
540                    hosts.len()
541                );
542                return Err(ProviderError::PartialResult {
543                    hosts,
544                    failures: 1,
545                    total: (index + 1) as usize,
546                });
547            }
548        }
549        index += 1;
550        if index >= MAX_PAGES {
551            debug!(
552                "[purple] paginate: reached MAX_PAGES ({MAX_PAGES}) guard, stopping with {} hosts",
553                hosts.len()
554            );
555            return Ok(hosts);
556        }
557    }
558}
559
560#[cfg(test)]
561#[path = "mod_tests.rs"]
562mod tests;