Skip to main content

cloud_detect/
lib.rs

1//! # Cloud Detect
2//!
3//! A library to detect the cloud service provider of a host.
4//!
5//! ## Usage
6//!
7//! Add the following to your `Cargo.toml`:
8//!
9//! ```toml
10//! [dependencies]
11//! # ...
12//! cloud_detect = "2"
13//! tokio = { version = "1", features = ["full"] }
14//! tracing-subscriber = { version = "0.3", features = ["env-filter"] } # Optional; for logging
15//! ```
16//!
17//! ## Examples
18//!
19//! Detect the cloud provider and print the result (with default timeout).
20//!
21//! ```rust
22//! use cloud_detect::detect;
23//!
24//! #[tokio::main]
25//! async fn main() {
26//!     tracing_subscriber::fmt::init(); // Optional; for logging
27//!
28//!     let provider = detect(None).await;
29//!     println!("Detected provider: {}", provider);
30//! }
31//! ```
32//!
33//! Detect the cloud provider and print the result (with custom timeout).
34//!
35//! ```rust
36//! use std::time::Duration;
37//!
38//! use cloud_detect::detect;
39//!
40//! #[tokio::main]
41//! async fn main() {
42//!     tracing_subscriber::fmt::init(); // Optional; for logging
43//!
44//!     let provider = detect(Some(Duration::from_secs(10))).await;
45//!     println!("Detected provider: {}", provider);
46//! }
47//! ```
48
49use std::fmt::Debug;
50use std::sync::atomic::{AtomicUsize, Ordering};
51use std::sync::{Arc, LazyLock};
52use std::time::Duration;
53
54use async_trait::async_trait;
55use strum::Display;
56use tokio::sync::mpsc::Sender;
57use tokio::sync::{mpsc, Mutex, Notify};
58use tracing::{debug, instrument};
59
60use crate::providers::*;
61
62#[cfg(feature = "blocking")]
63pub mod blocking;
64pub(crate) mod providers;
65
66/// Maximum time allowed for detection.
67pub const DEFAULT_DETECTION_TIMEOUT: Duration = Duration::from_secs(5);
68
69/// Represents an identifier for a cloud service provider.
70#[non_exhaustive]
71#[derive(Debug, Default, Display, Eq, PartialEq)]
72pub enum ProviderId {
73    /// Unknown cloud service provider.
74    #[default]
75    #[strum(serialize = "unknown")]
76    Unknown,
77    /// Akamai Cloud.
78    #[strum(serialize = "akamai")]
79    Akamai,
80    /// Alibaba Cloud.
81    #[strum(serialize = "alibaba")]
82    Alibaba,
83    /// Amazon Web Services (AWS).
84    #[strum(serialize = "aws")]
85    AWS,
86    /// Microsoft Azure.
87    #[strum(serialize = "azure")]
88    Azure,
89    /// DigitalOcean.
90    #[strum(serialize = "digitalocean")]
91    DigitalOcean,
92    /// Google Cloud Platform (GCP).
93    #[strum(serialize = "gcp")]
94    GCP,
95    /// Oracle Cloud Infrastructure (OCI).
96    #[strum(serialize = "oci")]
97    OCI,
98    /// OpenStack.
99    #[strum(serialize = "openstack")]
100    OpenStack,
101    /// Vultr.
102    #[strum(serialize = "vultr")]
103    Vultr,
104}
105
106/// Represents a cloud service provider.
107#[async_trait]
108pub(crate) trait Provider: Send + Sync {
109    fn identifier(&self) -> ProviderId;
110    async fn identify(&self, tx: Sender<ProviderId>, timeout: Duration);
111}
112
113type P = Arc<dyn Provider>;
114
115static PROVIDERS: LazyLock<Mutex<Vec<P>>> = LazyLock::new(|| {
116    Mutex::new(vec![
117        Arc::new(akamai::Akamai) as P,
118        Arc::new(alibaba::Alibaba) as P,
119        Arc::new(aws::Aws) as P,
120        Arc::new(azure::Azure) as P,
121        Arc::new(digitalocean::DigitalOcean) as P,
122        Arc::new(gcp::Gcp) as P,
123        Arc::new(oci::Oci) as P,
124        Arc::new(openstack::OpenStack) as P,
125        Arc::new(vultr::Vultr) as P,
126    ])
127});
128
129/// Returns a list of currently supported providers.
130///
131/// # Examples
132///
133/// Print the list of supported providers.
134///
135/// ```
136/// use cloud_detect::supported_providers;
137///
138/// #[tokio::main]
139/// async fn main() {
140///     let providers = supported_providers().await;
141///     println!("Supported providers: {:?}", providers);
142/// }
143/// ```
144pub async fn supported_providers() -> Vec<String> {
145    let guard = PROVIDERS.lock().await;
146    let providers: Vec<String> = guard.iter().map(|p| p.identifier().to_string()).collect();
147
148    drop(guard);
149
150    providers
151}
152
153/// Detects the host's cloud provider.
154///
155/// Returns [ProviderId::Unknown] if the detection failed or timed out. If the detection was successful, it returns
156/// a value from [ProviderId](enum.ProviderId.html).
157///
158/// # Arguments
159///
160/// * `timeout` - Maximum time (seconds) allowed for detection. Defaults to [DEFAULT_DETECTION_TIMEOUT](constant.DEFAULT_DETECTION_TIMEOUT.html) if `None`.
161///
162/// # Examples
163///
164/// Detect the cloud provider and print the result (with default timeout).
165///
166/// ```
167/// use cloud_detect::detect;
168///
169/// #[tokio::main]
170/// async fn main() {
171///     let provider = detect(None).await;
172///     println!("Detected provider: {}", provider);
173/// }
174/// ```
175///
176/// Detect the cloud provider and print the result (with custom timeout).
177///
178/// ```
179/// use std::time::Duration;
180///
181/// use cloud_detect::detect;
182///
183/// #[tokio::main]
184/// async fn main() {
185///     let provider = detect(Some(Duration::from_secs(10))).await;
186///     println!("Detected provider: {}", provider);
187/// }
188/// ```
189#[instrument]
190pub async fn detect(timeout: Option<Duration>) -> ProviderId {
191    let timeout = timeout.unwrap_or(DEFAULT_DETECTION_TIMEOUT);
192    let (tx, mut rx) = mpsc::channel::<ProviderId>(1);
193    let guard = PROVIDERS.lock().await;
194    let provider_entries: Vec<P> = guard.iter().cloned().collect();
195    let providers_count = provider_entries.len();
196    let mut handles = Vec::with_capacity(providers_count);
197
198    // Create a counter that will be decremented as tasks complete
199    let counter = Arc::new(AtomicUsize::new(providers_count));
200    let complete = Arc::new(Notify::new());
201
202    for provider in provider_entries {
203        let tx = tx.clone();
204        let counter = counter.clone();
205        let complete = complete.clone();
206
207        handles.push(tokio::spawn(async move {
208            debug!("Spawning task for provider: {}", provider.identifier());
209            provider.identify(tx, timeout).await;
210
211            // Decrement counter and notify if we're the last task
212            if counter.fetch_sub(1, Ordering::SeqCst) == 1 {
213                complete.notify_one();
214            }
215        }));
216    }
217
218    tokio::select! {
219        biased;
220
221        // Priority 1: If we receive an identifier, return it immediately
222        res = rx.recv() => {
223            debug!("Received result from channel: {:?}", res);
224            res.unwrap_or_default()
225        }
226
227        // Priority 2: If all tasks complete without finding an identifier
228        _ = complete.notified() => {
229            debug!("All providers have finished identifying");
230            Default::default()
231        }
232
233        // Priority 3: If we time out
234        _ = tokio::time::sleep(timeout) => {
235            debug!("Detection timed out");
236            Default::default()
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[tokio::test]
246    async fn test_supported_providers() {
247        let providers = supported_providers().await;
248        assert_eq!(providers.len(), 9);
249        assert!(providers.contains(&akamai::IDENTIFIER.to_string()));
250        assert!(providers.contains(&alibaba::IDENTIFIER.to_string()));
251        assert!(providers.contains(&aws::IDENTIFIER.to_string()));
252        assert!(providers.contains(&azure::IDENTIFIER.to_string()));
253        assert!(providers.contains(&digitalocean::IDENTIFIER.to_string()));
254        assert!(providers.contains(&gcp::IDENTIFIER.to_string()));
255        assert!(providers.contains(&oci::IDENTIFIER.to_string()));
256        assert!(providers.contains(&openstack::IDENTIFIER.to_string()));
257        assert!(providers.contains(&vultr::IDENTIFIER.to_string()));
258    }
259}