pub(crate) mod providers;
use std::sync::mpsc::RecvTimeoutError;
use std::sync::mpsc::SyncSender;
use std::sync::{mpsc, Arc, LazyLock, Mutex};
use std::time::Duration;
use anyhow::Result;
use crate::blocking::providers::*;
use crate::{ProviderId, DEFAULT_DETECTION_TIMEOUT};
#[allow(dead_code)]
pub(crate) trait Provider: Send + Sync {
fn identifier(&self) -> ProviderId;
fn identify(&self, tx: SyncSender<ProviderId>, timeout: Duration);
}
type P = Arc<dyn Provider>;
static PROVIDERS: LazyLock<Mutex<Vec<P>>> = LazyLock::new(|| {
Mutex::new(vec![
Arc::new(akamai::Akamai) as P,
Arc::new(alibaba::Alibaba) as P,
Arc::new(aws::Aws) as P,
Arc::new(azure::Azure) as P,
Arc::new(digitalocean::DigitalOcean) as P,
Arc::new(gcp::Gcp) as P,
Arc::new(oci::Oci) as P,
Arc::new(openstack::OpenStack) as P,
Arc::new(vultr::Vultr) as P,
])
});
pub fn supported_providers() -> Result<Vec<String>> {
let guard = PROVIDERS
.lock()
.map_err(|_| anyhow::anyhow!("Error locking providers"))?;
let providers: Vec<String> = guard.iter().map(|p| p.identifier().to_string()).collect();
drop(guard);
Ok(providers)
}
pub fn detect(timeout: Option<u64>) -> Result<ProviderId> {
let timeout = Duration::from_secs(timeout.unwrap_or(DEFAULT_DETECTION_TIMEOUT));
let (tx, rx) = mpsc::sync_channel::<ProviderId>(1);
let guard = PROVIDERS
.lock()
.map_err(|_| anyhow::anyhow!("Error locking providers"))?;
let provider_entries: Vec<P> = guard.iter().cloned().collect();
for provider in provider_entries {
let tx = tx.clone();
std::thread::spawn(move || provider.identify(tx, timeout));
}
match rx.recv_timeout(timeout) {
Ok(provider_id) => Ok(provider_id),
Err(err) => match err {
RecvTimeoutError::Timeout => Ok(ProviderId::Unknown),
RecvTimeoutError::Disconnected => Err(anyhow::anyhow!("Error receiving message")),
},
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use super::*;
#[test]
fn test_supported_providers() -> Result<()> {
let providers = supported_providers()?;
assert_eq!(providers.len(), 9);
assert!(providers.contains(&akamai::IDENTIFIER.to_string()));
assert!(providers.contains(&alibaba::IDENTIFIER.to_string()));
assert!(providers.contains(&aws::IDENTIFIER.to_string()));
assert!(providers.contains(&azure::IDENTIFIER.to_string()));
assert!(providers.contains(&digitalocean::IDENTIFIER.to_string()));
assert!(providers.contains(&gcp::IDENTIFIER.to_string()));
assert!(providers.contains(&oci::IDENTIFIER.to_string()));
assert!(providers.contains(&openstack::IDENTIFIER.to_string()));
assert!(providers.contains(&vultr::IDENTIFIER.to_string()));
Ok(())
}
}