cloud_detect/lib.rs
1//! # Cloud Detect
2//!
3//! A library to detect the cloud service provider of a host.
4//!
5//! ## Usage
6//!
7//! Add the following to your `Cargo.toml`:
8//!
9//! ```toml
10//! [dependencies]
11//! # ...
12//! cloud_detect = "2"
13//! tokio = { version = "1", features = ["full"] }
14//! tracing-subscriber = { version = "0.3", features = ["env-filter"] } # Optional; for logging
15//! ```
16//!
17//! ## Examples
18//!
19//! Detect the cloud provider and print the result (with default timeout).
20//!
21//! ```rust
22//! use cloud_detect::detect;
23//!
24//! #[tokio::main]
25//! async fn main() {
26//! tracing_subscriber::fmt::init(); // Optional; for logging
27//!
28//! let provider = detect(None).await;
29//! println!("Detected provider: {}", provider);
30//! }
31//! ```
32//!
33//! Detect the cloud provider and print the result (with custom timeout).
34//!
35//! ```rust
36//! use cloud_detect::detect;
37//!
38//! #[tokio::main]
39//! async fn main() {
40//! tracing_subscriber::fmt::init(); // Optional; for logging
41//!
42//! let provider = detect(Some(10)).await;
43//! println!("Detected provider: {}", provider);
44//! }
45//! ```
46
47use std::fmt::Debug;
48use std::sync::atomic::{AtomicUsize, Ordering};
49use std::sync::{Arc, LazyLock};
50use std::time::Duration;
51
52use async_trait::async_trait;
53use strum::Display;
54use tokio::sync::mpsc::Sender;
55use tokio::sync::{mpsc, Mutex, Notify};
56use tracing::{debug, instrument};
57
58use crate::providers::*;
59
60#[cfg(feature = "blocking")]
61pub mod blocking;
62pub(crate) mod providers;
63
64/// Maximum time allowed for detection.
65pub const DEFAULT_DETECTION_TIMEOUT: u64 = 5; // seconds
66
67/// Represents an identifier for a cloud service provider.
68#[non_exhaustive]
69#[derive(Debug, Default, Display, Eq, PartialEq)]
70pub enum ProviderId {
71 /// Unknown cloud service provider.
72 #[default]
73 #[strum(serialize = "unknown")]
74 Unknown,
75 /// Alibaba Cloud.
76 #[strum(serialize = "alibaba")]
77 Alibaba,
78 /// Amazon Web Services (AWS).
79 #[strum(serialize = "aws")]
80 AWS,
81 /// Microsoft Azure.
82 #[strum(serialize = "azure")]
83 Azure,
84 /// DigitalOcean.
85 #[strum(serialize = "digitalocean")]
86 DigitalOcean,
87 /// Google Cloud Platform (GCP).
88 #[strum(serialize = "gcp")]
89 GCP,
90 /// Oracle Cloud Infrastructure (OCI).
91 #[strum(serialize = "oci")]
92 OCI,
93 /// OpenStack.
94 #[strum(serialize = "openstack")]
95 OpenStack,
96 /// Vultr.
97 #[strum(serialize = "vultr")]
98 Vultr,
99}
100
101/// Represents a cloud service provider.
102#[async_trait]
103pub(crate) trait Provider: Send + Sync {
104 fn identifier(&self) -> ProviderId;
105 async fn identify(&self, tx: Sender<ProviderId>, timeout: Duration);
106}
107
108type P = Arc<dyn Provider>;
109
110static PROVIDERS: LazyLock<Mutex<Vec<P>>> = LazyLock::new(|| {
111 Mutex::new(vec![
112 Arc::new(alibaba::Alibaba) as P,
113 Arc::new(aws::Aws) as P,
114 Arc::new(azure::Azure) as P,
115 Arc::new(digitalocean::DigitalOcean) as P,
116 Arc::new(gcp::Gcp) as P,
117 Arc::new(oci::Oci) as P,
118 Arc::new(openstack::OpenStack) as P,
119 Arc::new(vultr::Vultr) as P,
120 ])
121});
122
123/// Returns a list of currently supported providers.
124///
125/// # Examples
126///
127/// Print the list of supported providers.
128///
129/// ```
130/// use cloud_detect::supported_providers;
131///
132/// #[tokio::main]
133/// async fn main() {
134/// let providers = supported_providers().await;
135/// println!("Supported providers: {:?}", providers);
136/// }
137/// ```
138pub async fn supported_providers() -> Vec<String> {
139 let guard = PROVIDERS.lock().await;
140 let providers: Vec<String> = guard.iter().map(|p| p.identifier().to_string()).collect();
141
142 drop(guard);
143
144 providers
145}
146
147/// Detects the host's cloud provider.
148///
149/// Returns [ProviderId::Unknown] if the detection failed or timed out. If the detection was successful, it returns
150/// a value from [ProviderId](enum.ProviderId.html).
151///
152/// # Arguments
153///
154/// * `timeout` - Maximum time (seconds) allowed for detection. Defaults to [DEFAULT_DETECTION_TIMEOUT](constant.DEFAULT_DETECTION_TIMEOUT.html) if `None`.
155///
156/// # Examples
157///
158/// Detect the cloud provider and print the result (with default timeout).
159///
160/// ```
161/// use cloud_detect::detect;
162///
163/// #[tokio::main]
164/// async fn main() {
165/// let provider = detect(None).await;
166/// println!("Detected provider: {}", provider);
167/// }
168/// ```
169///
170/// Detect the cloud provider and print the result (with custom timeout).
171///
172/// ```
173/// use cloud_detect::detect;
174///
175/// #[tokio::main]
176/// async fn main() {
177/// let provider = detect(Some(10)).await;
178/// println!("Detected provider: {}", provider);
179/// }
180/// ```
181#[instrument]
182pub async fn detect(timeout: Option<u64>) -> ProviderId {
183 let timeout = Duration::from_secs(timeout.unwrap_or(DEFAULT_DETECTION_TIMEOUT));
184 let (tx, mut rx) = mpsc::channel::<ProviderId>(1);
185 let guard = PROVIDERS.lock().await;
186 let provider_entries: Vec<P> = guard.iter().cloned().collect();
187 let providers_count = provider_entries.len();
188 let mut handles = Vec::with_capacity(providers_count);
189
190 // Create a counter that will be decremented as tasks complete
191 let counter = Arc::new(AtomicUsize::new(providers_count));
192 let complete = Arc::new(Notify::new());
193
194 for provider in provider_entries {
195 let tx = tx.clone();
196 let counter = counter.clone();
197 let complete = complete.clone();
198
199 handles.push(tokio::spawn(async move {
200 debug!("Spawning task for provider: {}", provider.identifier());
201 provider.identify(tx, timeout).await;
202
203 // Decrement counter and notify if we're the last task
204 if counter.fetch_sub(1, Ordering::SeqCst) == 1 {
205 complete.notify_one();
206 }
207 }));
208 }
209
210 tokio::select! {
211 biased;
212
213 // Priority 1: If we receive an identifier, return it immediately
214 res = rx.recv() => {
215 debug!("Received result from channel: {:?}", res);
216 res.unwrap_or_default()
217 }
218
219 // Priority 2: If all tasks complete without finding an identifier
220 _ = complete.notified() => {
221 debug!("All providers have finished identifying");
222 Default::default()
223 }
224
225 // Priority 3: If we time out
226 _ = tokio::time::sleep(timeout) => {
227 debug!("Detection timed out");
228 Default::default()
229 }
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[tokio::test]
238 async fn test_supported_providers() {
239 let providers = supported_providers().await;
240 assert_eq!(providers.len(), 8);
241 assert!(providers.contains(&alibaba::IDENTIFIER.to_string()));
242 assert!(providers.contains(&aws::IDENTIFIER.to_string()));
243 assert!(providers.contains(&azure::IDENTIFIER.to_string()));
244 assert!(providers.contains(&digitalocean::IDENTIFIER.to_string()));
245 assert!(providers.contains(&gcp::IDENTIFIER.to_string()));
246 assert!(providers.contains(&oci::IDENTIFIER.to_string()));
247 assert!(providers.contains(&openstack::IDENTIFIER.to_string()));
248 assert!(providers.contains(&vultr::IDENTIFIER.to_string()));
249 }
250}