1use std::fmt::Debug;
48use std::sync::atomic::{AtomicUsize, Ordering};
49use std::sync::{Arc, LazyLock};
50use std::time::Duration;
51
52use async_trait::async_trait;
53use strum::Display;
54use tokio::sync::mpsc::Sender;
55use tokio::sync::{mpsc, Mutex, Notify};
56use tracing::{debug, instrument};
57
58use crate::providers::*;
59
60#[cfg(feature = "blocking")]
61pub mod blocking;
62pub(crate) mod providers;
63
64pub const DEFAULT_DETECTION_TIMEOUT: u64 = 5; #[non_exhaustive]
69#[derive(Debug, Default, Display, Eq, PartialEq)]
70pub enum ProviderId {
71 #[default]
73 #[strum(serialize = "unknown")]
74 Unknown,
75 #[strum(serialize = "akamai")]
77 Akamai,
78 #[strum(serialize = "alibaba")]
80 Alibaba,
81 #[strum(serialize = "aws")]
83 AWS,
84 #[strum(serialize = "azure")]
86 Azure,
87 #[strum(serialize = "digitalocean")]
89 DigitalOcean,
90 #[strum(serialize = "gcp")]
92 GCP,
93 #[strum(serialize = "oci")]
95 OCI,
96 #[strum(serialize = "openstack")]
98 OpenStack,
99 #[strum(serialize = "vultr")]
101 Vultr,
102}
103
104#[async_trait]
106pub(crate) trait Provider: Send + Sync {
107 fn identifier(&self) -> ProviderId;
108 async fn identify(&self, tx: Sender<ProviderId>, timeout: Duration);
109}
110
111type P = Arc<dyn Provider>;
112
113static PROVIDERS: LazyLock<Mutex<Vec<P>>> = LazyLock::new(|| {
114 Mutex::new(vec![
115 Arc::new(akamai::Akamai) as P,
116 Arc::new(alibaba::Alibaba) as P,
117 Arc::new(aws::Aws) as P,
118 Arc::new(azure::Azure) as P,
119 Arc::new(digitalocean::DigitalOcean) as P,
120 Arc::new(gcp::Gcp) as P,
121 Arc::new(oci::Oci) as P,
122 Arc::new(openstack::OpenStack) as P,
123 Arc::new(vultr::Vultr) as P,
124 ])
125});
126
127pub async fn supported_providers() -> Vec<String> {
143 let guard = PROVIDERS.lock().await;
144 let providers: Vec<String> = guard.iter().map(|p| p.identifier().to_string()).collect();
145
146 drop(guard);
147
148 providers
149}
150
151#[instrument]
186pub async fn detect(timeout: Option<u64>) -> ProviderId {
187 let timeout = Duration::from_secs(timeout.unwrap_or(DEFAULT_DETECTION_TIMEOUT));
188 let (tx, mut rx) = mpsc::channel::<ProviderId>(1);
189 let guard = PROVIDERS.lock().await;
190 let provider_entries: Vec<P> = guard.iter().cloned().collect();
191 let providers_count = provider_entries.len();
192 let mut handles = Vec::with_capacity(providers_count);
193
194 let counter = Arc::new(AtomicUsize::new(providers_count));
196 let complete = Arc::new(Notify::new());
197
198 for provider in provider_entries {
199 let tx = tx.clone();
200 let counter = counter.clone();
201 let complete = complete.clone();
202
203 handles.push(tokio::spawn(async move {
204 debug!("Spawning task for provider: {}", provider.identifier());
205 provider.identify(tx, timeout).await;
206
207 if counter.fetch_sub(1, Ordering::SeqCst) == 1 {
209 complete.notify_one();
210 }
211 }));
212 }
213
214 tokio::select! {
215 biased;
216
217 res = rx.recv() => {
219 debug!("Received result from channel: {:?}", res);
220 res.unwrap_or_default()
221 }
222
223 _ = complete.notified() => {
225 debug!("All providers have finished identifying");
226 Default::default()
227 }
228
229 _ = tokio::time::sleep(timeout) => {
231 debug!("Detection timed out");
232 Default::default()
233 }
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[tokio::test]
242 async fn test_supported_providers() {
243 let providers = supported_providers().await;
244 assert_eq!(providers.len(), 9);
245 assert!(providers.contains(&akamai::IDENTIFIER.to_string()));
246 assert!(providers.contains(&alibaba::IDENTIFIER.to_string()));
247 assert!(providers.contains(&aws::IDENTIFIER.to_string()));
248 assert!(providers.contains(&azure::IDENTIFIER.to_string()));
249 assert!(providers.contains(&digitalocean::IDENTIFIER.to_string()));
250 assert!(providers.contains(&gcp::IDENTIFIER.to_string()));
251 assert!(providers.contains(&oci::IDENTIFIER.to_string()));
252 assert!(providers.contains(&openstack::IDENTIFIER.to_string()));
253 assert!(providers.contains(&vultr::IDENTIFIER.to_string()));
254 }
255}