cloud_detect/blocking/
mod.rs

1//! Blocking API for cloud provider detection.
2//!
3//! This module provides a blocking API for detecting the host's cloud provider. It is built on top of the asynchronous API
4//! and executes the blocking provider identification within threads.
5//!
6//! This module is intended for use in synchronous applications or in situations where the asynchronous API is not suitable.
7//! While not guaranteed, the performance of this module should be comparable to the asynchronous API.
8//!
9//! ## Optional
10//!
11//! This requires the `blocking` feature to be enabled.
12//!
13//! ## Usage
14//!
15//! Add the following to your `Cargo.toml`:
16//!
17//! ```toml
18//! # ...
19//! cloud_detect = { version = "2", features = ["blocking"] }
20//! tracing-subscriber = { version = "0.3", features = ["env-filter"] } # Optional; for logging
21//! ```
22//!
23//! ## Examples
24//!
25//! Detect the cloud provider and print the result (with default timeout).
26//!
27//! ```rust
28//! use cloud_detect::blocking::detect;
29//!
30//! tracing_subscriber::fmt::init(); // Optional; for logging
31//!
32//! let provider = detect(None).unwrap();
33//! println!("Detected provider: {:?}", provider);
34//! ```
35//!
36//! Detect the cloud provider and print the result (with custom timeout).
37//!
38//! ```rust
39//! use cloud_detect::blocking::detect;
40//!
41//! tracing_subscriber::fmt::init(); // Optional; for logging
42//!
43//! let provider = detect(Some(10)).unwrap();
44//! println!("Detected provider: {:?}", provider);
45//! ```
46
47pub(crate) mod providers;
48
49use std::sync::mpsc::RecvTimeoutError;
50use std::sync::mpsc::SyncSender;
51use std::sync::{mpsc, Arc, LazyLock, Mutex};
52use std::time::Duration;
53
54use anyhow::Result;
55
56use crate::blocking::providers::*;
57use crate::{ProviderId, DEFAULT_DETECTION_TIMEOUT};
58
59/// Represents a cloud service provider.
60#[allow(dead_code)]
61pub(crate) trait Provider: Send + Sync {
62    fn identifier(&self) -> ProviderId;
63    fn identify(&self, tx: SyncSender<ProviderId>, timeout: Duration);
64}
65
66type P = Arc<dyn Provider>;
67
68static PROVIDERS: LazyLock<Mutex<Vec<P>>> = LazyLock::new(|| {
69    Mutex::new(vec![
70        Arc::new(alibaba::Alibaba) as P,
71        Arc::new(aws::Aws) as P,
72        Arc::new(azure::Azure) as P,
73        Arc::new(digitalocean::DigitalOcean) as P,
74        Arc::new(gcp::Gcp) as P,
75        Arc::new(oci::Oci) as P,
76        Arc::new(openstack::OpenStack) as P,
77        Arc::new(vultr::Vultr) as P,
78    ])
79});
80
81/// Returns a list of currently supported providers.
82///
83/// # Examples
84///
85/// Print the list of supported providers.
86///
87/// ```
88/// use cloud_detect::blocking::supported_providers;
89///
90/// let providers = supported_providers().unwrap();
91/// println!("Supported providers: {:?}", providers);
92/// ```
93pub fn supported_providers() -> Result<Vec<String>> {
94    let guard = PROVIDERS
95        .lock()
96        .map_err(|_| anyhow::anyhow!("Error locking providers"))?;
97    let providers: Vec<String> = guard.iter().map(|p| p.identifier().to_string()).collect();
98
99    drop(guard);
100
101    Ok(providers)
102}
103
104/// Detects the host's cloud provider.
105///
106/// Returns [ProviderId::Unknown] if the detection failed or timed out. If the detection was successful, it returns
107/// a value from [ProviderId](enum.ProviderId.html).
108///
109/// # Arguments
110///
111/// * `timeout` - Maximum time (seconds) allowed for detection. Defaults to [DEFAULT_DETECTION_TIMEOUT](constant.DEFAULT_DETECTION_TIMEOUT.html) if `None`.
112///
113/// # Examples
114///
115/// Detect the cloud provider and print the result (with default timeout).
116///
117/// ```
118/// use cloud_detect::blocking::detect;
119///
120/// let provider = detect(None).unwrap();
121/// println!("Detected provider: {:?}", provider);
122/// ```
123///
124/// Detect the cloud provider and print the result (with custom timeout).
125///
126/// ```
127/// use cloud_detect::blocking::detect;
128///
129/// let provider = detect(Some(10)).unwrap();
130/// println!("Detected provider: {:?}", provider);
131/// ```
132pub fn detect(timeout: Option<u64>) -> Result<ProviderId> {
133    let timeout = Duration::from_secs(timeout.unwrap_or(DEFAULT_DETECTION_TIMEOUT));
134    let (tx, rx) = mpsc::sync_channel::<ProviderId>(1);
135    let guard = PROVIDERS
136        .lock()
137        .map_err(|_| anyhow::anyhow!("Error locking providers"))?;
138    let provider_entries: Vec<P> = guard.iter().cloned().collect();
139
140    for provider in provider_entries {
141        let tx = tx.clone();
142        std::thread::spawn(move || provider.identify(tx, timeout));
143    }
144
145    match rx.recv_timeout(timeout) {
146        Ok(provider_id) => Ok(provider_id),
147        Err(err) => match err {
148            RecvTimeoutError::Timeout => Ok(ProviderId::Unknown),
149            RecvTimeoutError::Disconnected => Err(anyhow::anyhow!("Error receiving message")),
150        },
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use anyhow::Result;
157
158    use super::*;
159
160    #[test]
161    fn test_supported_providers() -> Result<()> {
162        let providers = supported_providers()?;
163        assert_eq!(providers.len(), 8);
164        assert!(providers.contains(&alibaba::IDENTIFIER.to_string()));
165        assert!(providers.contains(&aws::IDENTIFIER.to_string()));
166        assert!(providers.contains(&azure::IDENTIFIER.to_string()));
167        assert!(providers.contains(&digitalocean::IDENTIFIER.to_string()));
168        assert!(providers.contains(&gcp::IDENTIFIER.to_string()));
169        assert!(providers.contains(&oci::IDENTIFIER.to_string()));
170        assert!(providers.contains(&openstack::IDENTIFIER.to_string()));
171        assert!(providers.contains(&vultr::IDENTIFIER.to_string()));
172
173        Ok(())
174    }
175}