cloud_detect/blocking/
mod.rs

1//! Blocking API for cloud provider detection.
2//!
3//! This module provides a blocking API for detecting the host's cloud provider. It is built on top of the asynchronous API
4//! and executes the blocking provider identification within threads.
5//!
6//! This module is intended for use in synchronous applications or in situations where the asynchronous API is not suitable.
7//! While not guaranteed, the performance of this module should be comparable to the asynchronous API.
8//!
9//! ## Optional
10//!
11//! This requires the `blocking` feature to be enabled.
12//!
13//! ## Usage
14//!
15//! Add the following to your `Cargo.toml`:
16//!
17//! ```toml
18//! # ...
19//! cloud_detect = { version = "2", features = ["blocking"] }
20//! tracing-subscriber = { version = "0.3", features = ["env-filter"] } # Optional; for logging
21//! ```
22//!
23//! ## Examples
24//!
25//! Detect the cloud provider and print the result (with default timeout).
26//!
27//! ```rust
28//! use cloud_detect::blocking::detect;
29//!
30//! tracing_subscriber::fmt::init(); // Optional; for logging
31//!
32//! let provider = detect(None).unwrap();
33//! println!("Detected provider: {:?}", provider);
34//! ```
35//!
36//! Detect the cloud provider and print the result (with custom timeout).
37//!
38//! ```rust
39//! use cloud_detect::blocking::detect;
40//!
41//! tracing_subscriber::fmt::init(); // Optional; for logging
42//!
43//! let provider = detect(Some(10)).unwrap();
44//! println!("Detected provider: {:?}", provider);
45//! ```
46
47pub(crate) mod providers;
48
49use std::sync::mpsc::RecvTimeoutError;
50use std::sync::mpsc::SyncSender;
51use std::sync::{mpsc, Arc, LazyLock, Mutex};
52use std::time::Duration;
53
54use anyhow::Result;
55
56use crate::blocking::providers::*;
57use crate::{ProviderId, DEFAULT_DETECTION_TIMEOUT};
58
59/// Represents a cloud service provider.
60#[allow(dead_code)]
61pub(crate) trait Provider: Send + Sync {
62    fn identifier(&self) -> ProviderId;
63    fn identify(&self, tx: SyncSender<ProviderId>, timeout: Duration);
64}
65
66type P = Arc<dyn Provider>;
67
68static PROVIDERS: LazyLock<Mutex<Vec<P>>> = LazyLock::new(|| {
69    Mutex::new(vec![
70        Arc::new(akamai::Akamai) as P,
71        Arc::new(alibaba::Alibaba) as P,
72        Arc::new(aws::Aws) as P,
73        Arc::new(azure::Azure) as P,
74        Arc::new(digitalocean::DigitalOcean) as P,
75        Arc::new(gcp::Gcp) as P,
76        Arc::new(oci::Oci) as P,
77        Arc::new(openstack::OpenStack) as P,
78        Arc::new(vultr::Vultr) as P,
79    ])
80});
81
82/// Returns a list of currently supported providers.
83///
84/// # Examples
85///
86/// Print the list of supported providers.
87///
88/// ```
89/// use cloud_detect::blocking::supported_providers;
90///
91/// let providers = supported_providers().unwrap();
92/// println!("Supported providers: {:?}", providers);
93/// ```
94pub fn supported_providers() -> Result<Vec<String>> {
95    let guard = PROVIDERS
96        .lock()
97        .map_err(|_| anyhow::anyhow!("Error locking providers"))?;
98    let providers: Vec<String> = guard.iter().map(|p| p.identifier().to_string()).collect();
99
100    drop(guard);
101
102    Ok(providers)
103}
104
105/// Detects the host's cloud provider.
106///
107/// Returns [ProviderId::Unknown] if the detection failed or timed out. If the detection was successful, it returns
108/// a value from [ProviderId](enum.ProviderId.html).
109///
110/// # Arguments
111///
112/// * `timeout` - Maximum time (seconds) allowed for detection. Defaults to [DEFAULT_DETECTION_TIMEOUT](constant.DEFAULT_DETECTION_TIMEOUT.html) if `None`.
113///
114/// # Examples
115///
116/// Detect the cloud provider and print the result (with default timeout).
117///
118/// ```
119/// use cloud_detect::blocking::detect;
120///
121/// let provider = detect(None).unwrap();
122/// println!("Detected provider: {:?}", provider);
123/// ```
124///
125/// Detect the cloud provider and print the result (with custom timeout).
126///
127/// ```
128/// use cloud_detect::blocking::detect;
129///
130/// let provider = detect(Some(10)).unwrap();
131/// println!("Detected provider: {:?}", provider);
132/// ```
133pub fn detect(timeout: Option<u64>) -> Result<ProviderId> {
134    let timeout = Duration::from_secs(timeout.unwrap_or(DEFAULT_DETECTION_TIMEOUT));
135    let (tx, rx) = mpsc::sync_channel::<ProviderId>(1);
136    let guard = PROVIDERS
137        .lock()
138        .map_err(|_| anyhow::anyhow!("Error locking providers"))?;
139    let provider_entries: Vec<P> = guard.iter().cloned().collect();
140
141    for provider in provider_entries {
142        let tx = tx.clone();
143        std::thread::spawn(move || provider.identify(tx, timeout));
144    }
145
146    match rx.recv_timeout(timeout) {
147        Ok(provider_id) => Ok(provider_id),
148        Err(err) => match err {
149            RecvTimeoutError::Timeout => Ok(ProviderId::Unknown),
150            RecvTimeoutError::Disconnected => Err(anyhow::anyhow!("Error receiving message")),
151        },
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use anyhow::Result;
158
159    use super::*;
160
161    #[test]
162    fn test_supported_providers() -> Result<()> {
163        let providers = supported_providers()?;
164        assert_eq!(providers.len(), 9);
165        assert!(providers.contains(&akamai::IDENTIFIER.to_string()));
166        assert!(providers.contains(&alibaba::IDENTIFIER.to_string()));
167        assert!(providers.contains(&aws::IDENTIFIER.to_string()));
168        assert!(providers.contains(&azure::IDENTIFIER.to_string()));
169        assert!(providers.contains(&digitalocean::IDENTIFIER.to_string()));
170        assert!(providers.contains(&gcp::IDENTIFIER.to_string()));
171        assert!(providers.contains(&oci::IDENTIFIER.to_string()));
172        assert!(providers.contains(&openstack::IDENTIFIER.to_string()));
173        assert!(providers.contains(&vultr::IDENTIFIER.to_string()));
174
175        Ok(())
176    }
177}