cloud_detect/blocking/mod.rs
1//! Blocking API for cloud provider detection.
2//!
3//! This module provides a blocking API for detecting the host's cloud provider. It is built on top of the asynchronous API
4//! and executes the blocking provider identification within threads.
5//!
6//! This module is intended for use in synchronous applications or in situations where the asynchronous API is not suitable.
7//! While not guaranteed, the performance of this module should be comparable to the asynchronous API.
8//!
9//! ## Optional
10//!
11//! This requires the `blocking` feature to be enabled.
12//!
13//! ## Usage
14//!
15//! Add the following to your `Cargo.toml`:
16//!
17//! ```toml
18//! # ...
19//! cloud_detect = { version = "2", features = ["blocking"] }
20//! tracing-subscriber = { version = "0.3", features = ["env-filter"] } # Optional; for logging
21//! ```
22//!
23//! ## Examples
24//!
25//! Detect the cloud provider and print the result (with default timeout).
26//!
27//! ```rust
28//! use cloud_detect::blocking::detect;
29//!
30//! tracing_subscriber::fmt::init(); // Optional; for logging
31//!
32//! let provider = detect(None).unwrap();
33//! println!("Detected provider: {:?}", provider);
34//! ```
35//!
36//! Detect the cloud provider and print the result (with custom timeout).
37//!
38//! ```rust
39//! use cloud_detect::blocking::detect;
40//!
41//! tracing_subscriber::fmt::init(); // Optional; for logging
42//!
43//! let provider = detect(Some(10)).unwrap();
44//! println!("Detected provider: {:?}", provider);
45//! ```
46
47pub(crate) mod providers;
48
49use std::sync::mpsc::RecvTimeoutError;
50use std::sync::mpsc::SyncSender;
51use std::sync::{mpsc, Arc, LazyLock, Mutex};
52use std::time::Duration;
53
54use anyhow::Result;
55
56use crate::blocking::providers::*;
57use crate::{ProviderId, DEFAULT_DETECTION_TIMEOUT};
58
59/// Represents a cloud service provider.
60#[allow(dead_code)]
61pub(crate) trait Provider: Send + Sync {
62 fn identifier(&self) -> ProviderId;
63 fn identify(&self, tx: SyncSender<ProviderId>, timeout: Duration);
64}
65
66type P = Arc<dyn Provider>;
67
68static PROVIDERS: LazyLock<Mutex<Vec<P>>> = LazyLock::new(|| {
69 Mutex::new(vec![
70 Arc::new(alibaba::Alibaba) as P,
71 Arc::new(aws::Aws) as P,
72 Arc::new(azure::Azure) as P,
73 Arc::new(digitalocean::DigitalOcean) as P,
74 Arc::new(gcp::Gcp) as P,
75 Arc::new(oci::Oci) as P,
76 Arc::new(openstack::OpenStack) as P,
77 Arc::new(vultr::Vultr) as P,
78 ])
79});
80
81/// Returns a list of currently supported providers.
82///
83/// # Examples
84///
85/// Print the list of supported providers.
86///
87/// ```
88/// use cloud_detect::blocking::supported_providers;
89///
90/// let providers = supported_providers().unwrap();
91/// println!("Supported providers: {:?}", providers);
92/// ```
93pub fn supported_providers() -> Result<Vec<String>> {
94 let guard = PROVIDERS
95 .lock()
96 .map_err(|_| anyhow::anyhow!("Error locking providers"))?;
97 let providers: Vec<String> = guard.iter().map(|p| p.identifier().to_string()).collect();
98
99 drop(guard);
100
101 Ok(providers)
102}
103
104/// Detects the host's cloud provider.
105///
106/// Returns [ProviderId::Unknown] if the detection failed or timed out. If the detection was successful, it returns
107/// a value from [ProviderId](enum.ProviderId.html).
108///
109/// # Arguments
110///
111/// * `timeout` - Maximum time (seconds) allowed for detection. Defaults to [DEFAULT_DETECTION_TIMEOUT](constant.DEFAULT_DETECTION_TIMEOUT.html) if `None`.
112///
113/// # Examples
114///
115/// Detect the cloud provider and print the result (with default timeout).
116///
117/// ```
118/// use cloud_detect::blocking::detect;
119///
120/// let provider = detect(None).unwrap();
121/// println!("Detected provider: {:?}", provider);
122/// ```
123///
124/// Detect the cloud provider and print the result (with custom timeout).
125///
126/// ```
127/// use cloud_detect::blocking::detect;
128///
129/// let provider = detect(Some(10)).unwrap();
130/// println!("Detected provider: {:?}", provider);
131/// ```
132pub fn detect(timeout: Option<u64>) -> Result<ProviderId> {
133 let timeout = Duration::from_secs(timeout.unwrap_or(DEFAULT_DETECTION_TIMEOUT));
134 let (tx, rx) = mpsc::sync_channel::<ProviderId>(1);
135 let guard = PROVIDERS
136 .lock()
137 .map_err(|_| anyhow::anyhow!("Error locking providers"))?;
138 let provider_entries: Vec<P> = guard.iter().cloned().collect();
139
140 for provider in provider_entries {
141 let tx = tx.clone();
142 std::thread::spawn(move || provider.identify(tx, timeout));
143 }
144
145 match rx.recv_timeout(timeout) {
146 Ok(provider_id) => Ok(provider_id),
147 Err(err) => match err {
148 RecvTimeoutError::Timeout => Ok(ProviderId::Unknown),
149 RecvTimeoutError::Disconnected => Err(anyhow::anyhow!("Error receiving message")),
150 },
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use anyhow::Result;
157
158 use super::*;
159
160 #[test]
161 fn test_supported_providers() -> Result<()> {
162 let providers = supported_providers()?;
163 assert_eq!(providers.len(), 8);
164 assert!(providers.contains(&alibaba::IDENTIFIER.to_string()));
165 assert!(providers.contains(&aws::IDENTIFIER.to_string()));
166 assert!(providers.contains(&azure::IDENTIFIER.to_string()));
167 assert!(providers.contains(&digitalocean::IDENTIFIER.to_string()));
168 assert!(providers.contains(&gcp::IDENTIFIER.to_string()));
169 assert!(providers.contains(&oci::IDENTIFIER.to_string()));
170 assert!(providers.contains(&openstack::IDENTIFIER.to_string()));
171 assert!(providers.contains(&vultr::IDENTIFIER.to_string()));
172
173 Ok(())
174 }
175}