cloud_detect/
lib.rs

1//! # Cloud Detect
2//!
3//! A library to detect the cloud service provider of a host.
4//!
5//! ## Usage
6//!
7//! Add the following to your `Cargo.toml`:
8//!
9//! ```toml
10//! [dependencies]
11//! # ...
12//! cloud_detect = "2"
13//! tokio = { version = "1", features = ["full"] }
14//! tracing-subscriber = { version = "0.3", features = ["env-filter"] } # Optional; for logging
15//! ```
16//!
17//! ## Examples
18//!
19//! Detect the cloud provider and print the result (with default timeout).
20//!
21//! ```rust
22//! use cloud_detect::detect;
23//!
24//! #[tokio::main]
25//! async fn main() {
26//!     tracing_subscriber::fmt::init(); // Optional; for logging
27//!
28//!     let provider = detect(None).await;
29//!     println!("Detected provider: {}", provider);
30//! }
31//! ```
32//!
33//! Detect the cloud provider and print the result (with custom timeout).
34//!
35//! ```rust
36//! use cloud_detect::detect;
37//!
38//! #[tokio::main]
39//! async fn main() {
40//!     tracing_subscriber::fmt::init(); // Optional; for logging
41//!
42//!     let provider = detect(Some(10)).await;
43//!     println!("Detected provider: {}", provider);
44//! }
45//! ```
46
47use std::fmt::Debug;
48use std::sync::atomic::{AtomicUsize, Ordering};
49use std::sync::{Arc, LazyLock};
50use std::time::Duration;
51
52use async_trait::async_trait;
53use strum::Display;
54use tokio::sync::mpsc::Sender;
55use tokio::sync::{mpsc, Mutex, Notify};
56use tracing::{debug, instrument};
57
58use crate::providers::*;
59
60#[cfg(feature = "blocking")]
61pub mod blocking;
62pub(crate) mod providers;
63
64/// Maximum time allowed for detection.
65pub const DEFAULT_DETECTION_TIMEOUT: u64 = 5; // seconds
66
67/// Represents an identifier for a cloud service provider.
68#[non_exhaustive]
69#[derive(Debug, Default, Display, Eq, PartialEq)]
70pub enum ProviderId {
71    /// Unknown cloud service provider.
72    #[default]
73    #[strum(serialize = "unknown")]
74    Unknown,
75    /// Akamai Cloud.
76    #[strum(serialize = "akamai")]
77    Akamai,
78    /// Alibaba Cloud.
79    #[strum(serialize = "alibaba")]
80    Alibaba,
81    /// Amazon Web Services (AWS).
82    #[strum(serialize = "aws")]
83    AWS,
84    /// Microsoft Azure.
85    #[strum(serialize = "azure")]
86    Azure,
87    /// DigitalOcean.
88    #[strum(serialize = "digitalocean")]
89    DigitalOcean,
90    /// Google Cloud Platform (GCP).
91    #[strum(serialize = "gcp")]
92    GCP,
93    /// Oracle Cloud Infrastructure (OCI).
94    #[strum(serialize = "oci")]
95    OCI,
96    /// OpenStack.
97    #[strum(serialize = "openstack")]
98    OpenStack,
99    /// Vultr.
100    #[strum(serialize = "vultr")]
101    Vultr,
102}
103
104/// Represents a cloud service provider.
105#[async_trait]
106pub(crate) trait Provider: Send + Sync {
107    fn identifier(&self) -> ProviderId;
108    async fn identify(&self, tx: Sender<ProviderId>, timeout: Duration);
109}
110
111type P = Arc<dyn Provider>;
112
113static PROVIDERS: LazyLock<Mutex<Vec<P>>> = LazyLock::new(|| {
114    Mutex::new(vec![
115        Arc::new(akamai::Akamai) as P,
116        Arc::new(alibaba::Alibaba) as P,
117        Arc::new(aws::Aws) as P,
118        Arc::new(azure::Azure) as P,
119        Arc::new(digitalocean::DigitalOcean) as P,
120        Arc::new(gcp::Gcp) as P,
121        Arc::new(oci::Oci) as P,
122        Arc::new(openstack::OpenStack) as P,
123        Arc::new(vultr::Vultr) as P,
124    ])
125});
126
127/// Returns a list of currently supported providers.
128///
129/// # Examples
130///
131/// Print the list of supported providers.
132///
133/// ```
134/// use cloud_detect::supported_providers;
135///
136/// #[tokio::main]
137/// async fn main() {
138///     let providers = supported_providers().await;
139///     println!("Supported providers: {:?}", providers);
140/// }
141/// ```
142pub async fn supported_providers() -> Vec<String> {
143    let guard = PROVIDERS.lock().await;
144    let providers: Vec<String> = guard.iter().map(|p| p.identifier().to_string()).collect();
145
146    drop(guard);
147
148    providers
149}
150
151/// Detects the host's cloud provider.
152///
153/// Returns [ProviderId::Unknown] if the detection failed or timed out. If the detection was successful, it returns
154/// a value from [ProviderId](enum.ProviderId.html).
155///
156/// # Arguments
157///
158/// * `timeout` - Maximum time (seconds) allowed for detection. Defaults to [DEFAULT_DETECTION_TIMEOUT](constant.DEFAULT_DETECTION_TIMEOUT.html) if `None`.
159///
160/// # Examples
161///
162/// Detect the cloud provider and print the result (with default timeout).
163///
164/// ```
165/// use cloud_detect::detect;
166///
167/// #[tokio::main]
168/// async fn main() {
169///     let provider = detect(None).await;
170///     println!("Detected provider: {}", provider);
171/// }
172/// ```
173///
174/// Detect the cloud provider and print the result (with custom timeout).
175///
176/// ```
177/// use cloud_detect::detect;
178///
179/// #[tokio::main]
180/// async fn main() {
181///     let provider = detect(Some(10)).await;
182///     println!("Detected provider: {}", provider);
183/// }
184/// ```
185#[instrument]
186pub async fn detect(timeout: Option<u64>) -> ProviderId {
187    let timeout = Duration::from_secs(timeout.unwrap_or(DEFAULT_DETECTION_TIMEOUT));
188    let (tx, mut rx) = mpsc::channel::<ProviderId>(1);
189    let guard = PROVIDERS.lock().await;
190    let provider_entries: Vec<P> = guard.iter().cloned().collect();
191    let providers_count = provider_entries.len();
192    let mut handles = Vec::with_capacity(providers_count);
193
194    // Create a counter that will be decremented as tasks complete
195    let counter = Arc::new(AtomicUsize::new(providers_count));
196    let complete = Arc::new(Notify::new());
197
198    for provider in provider_entries {
199        let tx = tx.clone();
200        let counter = counter.clone();
201        let complete = complete.clone();
202
203        handles.push(tokio::spawn(async move {
204            debug!("Spawning task for provider: {}", provider.identifier());
205            provider.identify(tx, timeout).await;
206
207            // Decrement counter and notify if we're the last task
208            if counter.fetch_sub(1, Ordering::SeqCst) == 1 {
209                complete.notify_one();
210            }
211        }));
212    }
213
214    tokio::select! {
215        biased;
216
217        // Priority 1: If we receive an identifier, return it immediately
218        res = rx.recv() => {
219            debug!("Received result from channel: {:?}", res);
220            res.unwrap_or_default()
221        }
222
223        // Priority 2: If all tasks complete without finding an identifier
224        _ = complete.notified() => {
225            debug!("All providers have finished identifying");
226            Default::default()
227        }
228
229        // Priority 3: If we time out
230        _ = tokio::time::sleep(timeout) => {
231            debug!("Detection timed out");
232            Default::default()
233        }
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[tokio::test]
242    async fn test_supported_providers() {
243        let providers = supported_providers().await;
244        assert_eq!(providers.len(), 9);
245        assert!(providers.contains(&akamai::IDENTIFIER.to_string()));
246        assert!(providers.contains(&alibaba::IDENTIFIER.to_string()));
247        assert!(providers.contains(&aws::IDENTIFIER.to_string()));
248        assert!(providers.contains(&azure::IDENTIFIER.to_string()));
249        assert!(providers.contains(&digitalocean::IDENTIFIER.to_string()));
250        assert!(providers.contains(&gcp::IDENTIFIER.to_string()));
251        assert!(providers.contains(&oci::IDENTIFIER.to_string()));
252        assert!(providers.contains(&openstack::IDENTIFIER.to_string()));
253        assert!(providers.contains(&vultr::IDENTIFIER.to_string()));
254    }
255}