1use std::fmt::Debug;
50use std::sync::atomic::{AtomicUsize, Ordering};
51use std::sync::{Arc, LazyLock};
52use std::time::Duration;
53
54use async_trait::async_trait;
55use strum::Display;
56use tokio::sync::mpsc::Sender;
57use tokio::sync::{mpsc, Mutex, Notify};
58use tracing::{debug, instrument};
59
60use crate::providers::*;
61
62#[cfg(feature = "blocking")]
63pub mod blocking;
64pub(crate) mod providers;
65
66pub const DEFAULT_DETECTION_TIMEOUT: Duration = Duration::from_secs(5);
68
69#[non_exhaustive]
71#[derive(Debug, Default, Display, Eq, PartialEq)]
72pub enum ProviderId {
73 #[default]
75 #[strum(serialize = "unknown")]
76 Unknown,
77 #[strum(serialize = "akamai")]
79 Akamai,
80 #[strum(serialize = "alibaba")]
82 Alibaba,
83 #[strum(serialize = "aws")]
85 AWS,
86 #[strum(serialize = "azure")]
88 Azure,
89 #[strum(serialize = "digitalocean")]
91 DigitalOcean,
92 #[strum(serialize = "gcp")]
94 GCP,
95 #[strum(serialize = "oci")]
97 OCI,
98 #[strum(serialize = "openstack")]
100 OpenStack,
101 #[strum(serialize = "vultr")]
103 Vultr,
104}
105
106#[async_trait]
108pub(crate) trait Provider: Send + Sync {
109 fn identifier(&self) -> ProviderId;
110 async fn identify(&self, tx: Sender<ProviderId>, timeout: Duration);
111}
112
113type P = Arc<dyn Provider>;
114
115static PROVIDERS: LazyLock<Mutex<Vec<P>>> = LazyLock::new(|| {
116 Mutex::new(vec![
117 Arc::new(akamai::Akamai) as P,
118 Arc::new(alibaba::Alibaba) as P,
119 Arc::new(aws::Aws) as P,
120 Arc::new(azure::Azure) as P,
121 Arc::new(digitalocean::DigitalOcean) as P,
122 Arc::new(gcp::Gcp) as P,
123 Arc::new(oci::Oci) as P,
124 Arc::new(openstack::OpenStack) as P,
125 Arc::new(vultr::Vultr) as P,
126 ])
127});
128
129pub async fn supported_providers() -> Vec<String> {
145 let guard = PROVIDERS.lock().await;
146 let providers: Vec<String> = guard.iter().map(|p| p.identifier().to_string()).collect();
147
148 drop(guard);
149
150 providers
151}
152
153#[instrument]
190pub async fn detect(timeout: Option<Duration>) -> ProviderId {
191 let timeout = timeout.unwrap_or(DEFAULT_DETECTION_TIMEOUT);
192 let (tx, mut rx) = mpsc::channel::<ProviderId>(1);
193 let guard = PROVIDERS.lock().await;
194 let provider_entries: Vec<P> = guard.iter().cloned().collect();
195 let providers_count = provider_entries.len();
196 let mut handles = Vec::with_capacity(providers_count);
197
198 let counter = Arc::new(AtomicUsize::new(providers_count));
200 let complete = Arc::new(Notify::new());
201
202 for provider in provider_entries {
203 let tx = tx.clone();
204 let counter = counter.clone();
205 let complete = complete.clone();
206
207 handles.push(tokio::spawn(async move {
208 debug!("Spawning task for provider: {}", provider.identifier());
209 provider.identify(tx, timeout).await;
210
211 if counter.fetch_sub(1, Ordering::SeqCst) == 1 {
213 complete.notify_one();
214 }
215 }));
216 }
217
218 tokio::select! {
219 biased;
220
221 res = rx.recv() => {
223 debug!("Received result from channel: {:?}", res);
224 res.unwrap_or_default()
225 }
226
227 _ = complete.notified() => {
229 debug!("All providers have finished identifying");
230 Default::default()
231 }
232
233 _ = tokio::time::sleep(timeout) => {
235 debug!("Detection timed out");
236 Default::default()
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[tokio::test]
246 async fn test_supported_providers() {
247 let providers = supported_providers().await;
248 assert_eq!(providers.len(), 9);
249 assert!(providers.contains(&akamai::IDENTIFIER.to_string()));
250 assert!(providers.contains(&alibaba::IDENTIFIER.to_string()));
251 assert!(providers.contains(&aws::IDENTIFIER.to_string()));
252 assert!(providers.contains(&azure::IDENTIFIER.to_string()));
253 assert!(providers.contains(&digitalocean::IDENTIFIER.to_string()));
254 assert!(providers.contains(&gcp::IDENTIFIER.to_string()));
255 assert!(providers.contains(&oci::IDENTIFIER.to_string()));
256 assert!(providers.contains(&openstack::IDENTIFIER.to_string()));
257 assert!(providers.contains(&vultr::IDENTIFIER.to_string()));
258 }
259}