cloud_detect/
lib.rs

1//! # Cloud Detect
2//!
3//! A library to detect the cloud service provider of a host.
4//!
5//! ## Usage
6//!
7//! Add the following to your `Cargo.toml`:
8//!
9//! ```toml
10//! [dependencies]
11//! # ...
12//! cloud_detect = "2"
13//! tokio = { version = "1", features = ["full"] }
14//! tracing-subscriber = { version = "0.3", features = ["env-filter"] } # Optional; for logging
15//! ```
16//!
17//! ## Examples
18//!
19//! Detect the cloud provider and print the result (with default timeout).
20//!
21//! ```rust
22//! use cloud_detect::detect;
23//!
24//! #[tokio::main]
25//! async fn main() {
26//!     tracing_subscriber::fmt::init(); // Optional; for logging
27//!
28//!     let provider = detect(None).await;
29//!     println!("Detected provider: {}", provider);
30//! }
31//! ```
32//!
33//! Detect the cloud provider and print the result (with custom timeout).
34//!
35//! ```rust
36//! use cloud_detect::detect;
37//!
38//! #[tokio::main]
39//! async fn main() {
40//!     tracing_subscriber::fmt::init(); // Optional; for logging
41//!
42//!     let provider = detect(Some(10)).await;
43//!     println!("Detected provider: {}", provider);
44//! }
45//! ```
46
47use std::fmt::Debug;
48use std::sync::atomic::{AtomicUsize, Ordering};
49use std::sync::{Arc, LazyLock};
50use std::time::Duration;
51
52use async_trait::async_trait;
53use strum::Display;
54use tokio::sync::mpsc::Sender;
55use tokio::sync::{mpsc, Mutex, Notify};
56use tracing::{debug, instrument};
57
58use crate::providers::*;
59
60#[cfg(feature = "blocking")]
61pub mod blocking;
62pub(crate) mod providers;
63
64/// Maximum time allowed for detection.
65pub const DEFAULT_DETECTION_TIMEOUT: u64 = 5; // seconds
66
67/// Represents an identifier for a cloud service provider.
68#[non_exhaustive]
69#[derive(Debug, Default, Display, Eq, PartialEq)]
70pub enum ProviderId {
71    /// Unknown cloud service provider.
72    #[default]
73    #[strum(serialize = "unknown")]
74    Unknown,
75    /// Alibaba Cloud.
76    #[strum(serialize = "alibaba")]
77    Alibaba,
78    /// Amazon Web Services (AWS).
79    #[strum(serialize = "aws")]
80    AWS,
81    /// Microsoft Azure.
82    #[strum(serialize = "azure")]
83    Azure,
84    /// DigitalOcean.
85    #[strum(serialize = "digitalocean")]
86    DigitalOcean,
87    /// Google Cloud Platform (GCP).
88    #[strum(serialize = "gcp")]
89    GCP,
90    /// Oracle Cloud Infrastructure (OCI).
91    #[strum(serialize = "oci")]
92    OCI,
93    /// OpenStack.
94    #[strum(serialize = "openstack")]
95    OpenStack,
96    /// Vultr.
97    #[strum(serialize = "vultr")]
98    Vultr,
99}
100
101/// Represents a cloud service provider.
102#[async_trait]
103pub(crate) trait Provider: Send + Sync {
104    fn identifier(&self) -> ProviderId;
105    async fn identify(&self, tx: Sender<ProviderId>, timeout: Duration);
106}
107
108type P = Arc<dyn Provider>;
109
110static PROVIDERS: LazyLock<Mutex<Vec<P>>> = LazyLock::new(|| {
111    Mutex::new(vec![
112        Arc::new(alibaba::Alibaba) as P,
113        Arc::new(aws::Aws) as P,
114        Arc::new(azure::Azure) as P,
115        Arc::new(digitalocean::DigitalOcean) as P,
116        Arc::new(gcp::Gcp) as P,
117        Arc::new(oci::Oci) as P,
118        Arc::new(openstack::OpenStack) as P,
119        Arc::new(vultr::Vultr) as P,
120    ])
121});
122
123/// Returns a list of currently supported providers.
124///
125/// # Examples
126///
127/// Print the list of supported providers.
128///
129/// ```
130/// use cloud_detect::supported_providers;
131///
132/// #[tokio::main]
133/// async fn main() {
134///     let providers = supported_providers().await;
135///     println!("Supported providers: {:?}", providers);
136/// }
137/// ```
138pub async fn supported_providers() -> Vec<String> {
139    let guard = PROVIDERS.lock().await;
140    let providers: Vec<String> = guard.iter().map(|p| p.identifier().to_string()).collect();
141
142    drop(guard);
143
144    providers
145}
146
147/// Detects the host's cloud provider.
148///
149/// Returns [ProviderId::Unknown] if the detection failed or timed out. If the detection was successful, it returns
150/// a value from [ProviderId](enum.ProviderId.html).
151///
152/// # Arguments
153///
154/// * `timeout` - Maximum time (seconds) allowed for detection. Defaults to [DEFAULT_DETECTION_TIMEOUT](constant.DEFAULT_DETECTION_TIMEOUT.html) if `None`.
155///
156/// # Examples
157///
158/// Detect the cloud provider and print the result (with default timeout).
159///
160/// ```
161/// use cloud_detect::detect;
162///
163/// #[tokio::main]
164/// async fn main() {
165///     let provider = detect(None).await;
166///     println!("Detected provider: {}", provider);
167/// }
168/// ```
169///
170/// Detect the cloud provider and print the result (with custom timeout).
171///
172/// ```
173/// use cloud_detect::detect;
174///
175/// #[tokio::main]
176/// async fn main() {
177///     let provider = detect(Some(10)).await;
178///     println!("Detected provider: {}", provider);
179/// }
180/// ```
181#[instrument]
182pub async fn detect(timeout: Option<u64>) -> ProviderId {
183    let timeout = Duration::from_secs(timeout.unwrap_or(DEFAULT_DETECTION_TIMEOUT));
184    let (tx, mut rx) = mpsc::channel::<ProviderId>(1);
185    let guard = PROVIDERS.lock().await;
186    let provider_entries: Vec<P> = guard.iter().cloned().collect();
187    let providers_count = provider_entries.len();
188    let mut handles = Vec::with_capacity(providers_count);
189
190    // Create a counter that will be decremented as tasks complete
191    let counter = Arc::new(AtomicUsize::new(providers_count));
192    let complete = Arc::new(Notify::new());
193
194    for provider in provider_entries {
195        let tx = tx.clone();
196        let counter = counter.clone();
197        let complete = complete.clone();
198
199        handles.push(tokio::spawn(async move {
200            debug!("Spawning task for provider: {}", provider.identifier());
201            provider.identify(tx, timeout).await;
202
203            // Decrement counter and notify if we're the last task
204            if counter.fetch_sub(1, Ordering::SeqCst) == 1 {
205                complete.notify_one();
206            }
207        }));
208    }
209
210    tokio::select! {
211        biased;
212
213        // Priority 1: If we receive an identifier, return it immediately
214        res = rx.recv() => {
215            debug!("Received result from channel: {:?}", res);
216            res.unwrap_or_default()
217        }
218
219        // Priority 2: If all tasks complete without finding an identifier
220        _ = complete.notified() => {
221            debug!("All providers have finished identifying");
222            Default::default()
223        }
224
225        // Priority 3: If we time out
226        _ = tokio::time::sleep(timeout) => {
227            debug!("Detection timed out");
228            Default::default()
229        }
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[tokio::test]
238    async fn test_supported_providers() {
239        let providers = supported_providers().await;
240        assert_eq!(providers.len(), 8);
241        assert!(providers.contains(&alibaba::IDENTIFIER.to_string()));
242        assert!(providers.contains(&aws::IDENTIFIER.to_string()));
243        assert!(providers.contains(&azure::IDENTIFIER.to_string()));
244        assert!(providers.contains(&digitalocean::IDENTIFIER.to_string()));
245        assert!(providers.contains(&gcp::IDENTIFIER.to_string()));
246        assert!(providers.contains(&oci::IDENTIFIER.to_string()));
247        assert!(providers.contains(&openstack::IDENTIFIER.to_string()));
248        assert!(providers.contains(&vultr::IDENTIFIER.to_string()));
249    }
250}