use std::fmt::Debug;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, LazyLock};
use std::time::Duration;
use async_trait::async_trait;
use strum::Display;
use tokio::sync::mpsc::Sender;
use tokio::sync::{mpsc, Mutex, Notify};
use tracing::{debug, instrument};
use crate::providers::*;
#[cfg(feature = "blocking")]
pub mod blocking;
pub(crate) mod providers;
pub const DEFAULT_DETECTION_TIMEOUT: u64 = 5;
#[non_exhaustive]
#[derive(Debug, Default, Display, Eq, PartialEq)]
pub enum ProviderId {
#[default]
#[strum(serialize = "unknown")]
Unknown,
#[strum(serialize = "akamai")]
Akamai,
#[strum(serialize = "alibaba")]
Alibaba,
#[strum(serialize = "aws")]
AWS,
#[strum(serialize = "azure")]
Azure,
#[strum(serialize = "digitalocean")]
DigitalOcean,
#[strum(serialize = "gcp")]
GCP,
#[strum(serialize = "oci")]
OCI,
#[strum(serialize = "openstack")]
OpenStack,
#[strum(serialize = "vultr")]
Vultr,
}
#[async_trait]
pub(crate) trait Provider: Send + Sync {
fn identifier(&self) -> ProviderId;
async fn identify(&self, tx: Sender<ProviderId>, timeout: Duration);
}
type P = Arc<dyn Provider>;
static PROVIDERS: LazyLock<Mutex<Vec<P>>> = LazyLock::new(|| {
Mutex::new(vec![
Arc::new(akamai::Akamai) as P,
Arc::new(alibaba::Alibaba) as P,
Arc::new(aws::Aws) as P,
Arc::new(azure::Azure) as P,
Arc::new(digitalocean::DigitalOcean) as P,
Arc::new(gcp::Gcp) as P,
Arc::new(oci::Oci) as P,
Arc::new(openstack::OpenStack) as P,
Arc::new(vultr::Vultr) as P,
])
});
pub async fn supported_providers() -> Vec<String> {
let guard = PROVIDERS.lock().await;
let providers: Vec<String> = guard.iter().map(|p| p.identifier().to_string()).collect();
drop(guard);
providers
}
#[instrument]
pub async fn detect(timeout: Option<u64>) -> ProviderId {
let timeout = Duration::from_secs(timeout.unwrap_or(DEFAULT_DETECTION_TIMEOUT));
let (tx, mut rx) = mpsc::channel::<ProviderId>(1);
let guard = PROVIDERS.lock().await;
let provider_entries: Vec<P> = guard.iter().cloned().collect();
let providers_count = provider_entries.len();
let mut handles = Vec::with_capacity(providers_count);
let counter = Arc::new(AtomicUsize::new(providers_count));
let complete = Arc::new(Notify::new());
for provider in provider_entries {
let tx = tx.clone();
let counter = counter.clone();
let complete = complete.clone();
handles.push(tokio::spawn(async move {
debug!("Spawning task for provider: {}", provider.identifier());
provider.identify(tx, timeout).await;
if counter.fetch_sub(1, Ordering::SeqCst) == 1 {
complete.notify_one();
}
}));
}
tokio::select! {
biased;
res = rx.recv() => {
debug!("Received result from channel: {:?}", res);
res.unwrap_or_default()
}
_ = complete.notified() => {
debug!("All providers have finished identifying");
Default::default()
}
_ = tokio::time::sleep(timeout) => {
debug!("Detection timed out");
Default::default()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_supported_providers() {
let providers = supported_providers().await;
assert_eq!(providers.len(), 9);
assert!(providers.contains(&akamai::IDENTIFIER.to_string()));
assert!(providers.contains(&alibaba::IDENTIFIER.to_string()));
assert!(providers.contains(&aws::IDENTIFIER.to_string()));
assert!(providers.contains(&azure::IDENTIFIER.to_string()));
assert!(providers.contains(&digitalocean::IDENTIFIER.to_string()));
assert!(providers.contains(&gcp::IDENTIFIER.to_string()));
assert!(providers.contains(&oci::IDENTIFIER.to_string()));
assert!(providers.contains(&openstack::IDENTIFIER.to_string()));
assert!(providers.contains(&vultr::IDENTIFIER.to_string()));
}
}