1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
///! Detect a host's cloud service provider.
mod consts;

#[cfg(feature = "blocking")]
pub mod blocking;

use std::collections::HashMap;
use std::error::Error;
use std::sync::mpsc::{self, TryRecvError};
use std::time::{Duration, Instant};

use lazy_static::lazy_static;

use consts::*;

lazy_static! {
    /// A mapping of supported cloud providers with their metadata URLs.
    pub(crate) static ref PROVIDER_METADATA_MAP: HashMap<&'static str, &'static str> = {
        let mut map = HashMap::new();
        map.insert(AMAZON_WEB_SERVICES, "http://169.254.169.254/latest/");
        map.insert(
            MICROSOFT_AZURE,
            "http://169.254.169.254/metadata/v1/InstanceInfo",
        );
        map.insert(
            GOOGLE_CLOUD_PLATFORM,
            "http://metadata.google.internal/computeMetadata/",
        );
        map
    };
}

/// Makes a GET request to the specified metadata URL and returns true if successful.
///
/// # Arguments
///
/// * `metadata_url` - The metadata URL for the cloud service provider.
async fn ping(metadata_url: &str) -> bool {
    match reqwest::get(metadata_url).await {
        Ok(resp) => resp.status() == reqwest::StatusCode::OK,
        Err(_) => false,
    }
}

// TODO: add test(s)
/// Returns a list of the currently supported cloud service providers.
pub fn supported_providers() -> Vec<&'static str> {
    PROVIDER_METADATA_MAP
        .keys()
        .copied()
        .collect::<Vec<&'static str>>()
}

// TODO: add test(s)
/// Detects the current host's cloud service provider.
/// Returns "unknown" if the detection failed, if the current cloud service provider is unsupported, or if minor errors occurred during detection.
///
/// # Arguments
///
/// * `timeout` - How long to attempt detection for (in seconds). Defaults to 3 seconds.
pub async fn detect(timeout: Option<u64>) -> String {
    // Set default timeout if none specified.
    let timeout_duration = Duration::from_secs(timeout.unwrap_or(DETECTION_TIMEOUT));

    // Concurrently check if the current host belongs to any of the supported providers and write the detected provider
    // to a channel.
    let (tx, rx) = mpsc::sync_channel::<String>(1);
    for (provider, metadata_url) in PROVIDER_METADATA_MAP.iter() {
        let tx = tx.clone();
        tokio::spawn(async move {
            if ping(metadata_url).await {
                tx.send(provider.to_string()).unwrap();
            }
        });
    }

    // Wait for a value from the channel or timeout.
    let start_time = Instant::now();
    let provider = loop {
        match rx.try_recv() {
            Ok(value) => break value,
            Err(TryRecvError::Empty) => {
                if start_time.elapsed() >= timeout_duration {
                    break "unknown".to_string();
                }
            }
            Err(_) => break "unknown".to_string(),
        }
    };

    provider
}

// TODO: add test(s)
/// Attempts to detect the current host's cloud service provider.
/// If we encounter an error, we return it rather than unwrapping or assuming the provider as "unknown".
///
/// **NOTE**: This also means that this function returns an error if the current host's provider is unsupported.
///
/// # Arguments
///
/// * `timeout` - How long to attempt detection for (in seconds). Defaults to 3 seconds.
pub async fn try_detect(timeout: Option<u64>) -> Result<String, Box<dyn Error>> {
    // Set default timeout if none specified.
    let timeout_duration = Duration::from_secs(timeout.unwrap_or(DETECTION_TIMEOUT));

    // Concurrently check if the current host belongs to any of the supported providers and write the detected provider
    // to a channel.
    let (tx, rx) = mpsc::sync_channel::<String>(1);
    for (provider, metadata_url) in PROVIDER_METADATA_MAP.iter() {
        let tx = tx.clone();
        tokio::spawn(async move {
            if ping(metadata_url).await {
                tx.send(provider.to_string())?;
            }

            Ok::<(), Box<dyn Error + Send + Sync>>(())
        });
    }

    // Wait for a value from the channel or timeout.
    let start_time = Instant::now();
    let provider = loop {
        match rx.try_recv() {
            Ok(value) => break Ok(value),
            Err(TryRecvError::Empty) => {
                if start_time.elapsed() >= timeout_duration {
                    break Err("Timed out when attempting to detect provider".to_string())?;
                }
            }
            Err(err) => break Err(err),
        }
    }?;

    Ok(provider)
}