cloud_detect/providers/
azure.rs

1//! Microsoft Azure.
2
3use std::path::Path;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use tokio::fs;
9use tokio::sync::mpsc::Sender;
10use tracing::{debug, error, info, instrument};
11
12use crate::{Provider, ProviderId};
13
14const METADATA_URI: &str = "http://169.254.169.254";
15const METADATA_PATH: &str = "/metadata/instance?api-version=2017-12-01";
16const VENDOR_FILE: &str = "/sys/class/dmi/id/sys_vendor";
17pub(crate) const IDENTIFIER: ProviderId = ProviderId::Azure;
18
19#[derive(Serialize, Deserialize)]
20struct Compute {
21    #[serde(rename = "vmId")]
22    vm_id: String,
23}
24
25#[derive(Serialize, Deserialize)]
26struct MetadataResponse {
27    compute: Compute,
28}
29
30pub(crate) struct Azure;
31
32#[async_trait]
33impl Provider for Azure {
34    fn identifier(&self) -> ProviderId {
35        IDENTIFIER
36    }
37
38    /// Tries to identify Azure using all the implemented options.
39    #[instrument(skip_all)]
40    async fn identify(&self, tx: Sender<ProviderId>, timeout: Duration) {
41        info!("Checking Microsoft Azure");
42        if self.check_vendor_file(VENDOR_FILE).await
43            || self.check_metadata_server(METADATA_URI, timeout).await
44        {
45            info!("Identified Microsoft Azure");
46            let res = tx.send(IDENTIFIER).await;
47
48            if let Err(err) = res {
49                error!("Error sending message: {:?}", err);
50            }
51        }
52    }
53}
54
55impl Azure {
56    /// Tries to identify Azure via metadata server.
57    #[instrument(skip_all)]
58    async fn check_metadata_server(&self, metadata_uri: &str, timeout: Duration) -> bool {
59        let url = format!("{metadata_uri}{METADATA_PATH}");
60        debug!("Checking {} metadata using url: {}", IDENTIFIER, url);
61
62        let client = if let Ok(client) = reqwest::Client::builder().timeout(timeout).build() {
63            client
64        } else {
65            error!("Error creating client");
66            return false;
67        };
68        let req = client.get(url).header("Metadata", "true");
69
70        match req.send().await {
71            Ok(resp) => match resp.json::<MetadataResponse>().await {
72                Ok(resp) => !resp.compute.vm_id.is_empty(),
73                Err(err) => {
74                    error!("Error reading response: {:?}", err);
75                    false
76                }
77            },
78            Err(err) => {
79                error!("Error making request: {:?}", err);
80                false
81            }
82        }
83    }
84
85    /// Tries to identify Azure using vendor file(s).
86    #[instrument(skip_all)]
87    async fn check_vendor_file<P: AsRef<Path>>(&self, vendor_file: P) -> bool {
88        debug!(
89            "Checking {} vendor file: {}",
90            IDENTIFIER,
91            vendor_file.as_ref().display()
92        );
93
94        if vendor_file.as_ref().is_file() {
95            return match fs::read_to_string(vendor_file).await {
96                Ok(content) => content.contains("Microsoft Corporation"),
97                Err(err) => {
98                    error!("Error reading file: {:?}", err);
99                    false
100                }
101            };
102        }
103
104        false
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use std::io::Write;
111
112    use anyhow::Result;
113    use tempfile::NamedTempFile;
114    use wiremock::matchers::query_param;
115    use wiremock::{Mock, MockServer, ResponseTemplate};
116
117    use super::*;
118
119    #[tokio::test]
120    async fn test_check_metadata_server_success() {
121        let mock_server = MockServer::start().await;
122        Mock::given(query_param("api-version", "2017-12-01"))
123            .respond_with(ResponseTemplate::new(200).set_body_json(MetadataResponse {
124                compute: Compute {
125                    vm_id: "vm-123abc".to_string(),
126                },
127            }))
128            .expect(1)
129            .mount(&mock_server)
130            .await;
131
132        let provider = Azure;
133        let metadata_uri = mock_server.uri();
134        let result = provider
135            .check_metadata_server(&metadata_uri, Duration::from_secs(1))
136            .await;
137
138        assert!(result);
139    }
140
141    #[tokio::test]
142    async fn test_check_metadata_server_failure() {
143        let mock_server = MockServer::start().await;
144        Mock::given(query_param("api-version", "2017-12-01"))
145            .respond_with(ResponseTemplate::new(200).set_body_json(MetadataResponse {
146                compute: Compute {
147                    vm_id: "".to_string(),
148                },
149            }))
150            .expect(1)
151            .mount(&mock_server)
152            .await;
153
154        let provider = Azure;
155        let metadata_uri = mock_server.uri();
156        let result = provider
157            .check_metadata_server(&metadata_uri, Duration::from_secs(1))
158            .await;
159
160        assert!(!result);
161    }
162
163    #[tokio::test]
164    async fn test_check_vendor_file_success() -> Result<()> {
165        let mut vendor_file = NamedTempFile::new()?;
166        vendor_file.write_all(b"Microsoft Corporation")?;
167
168        let provider = Azure;
169        let result = provider.check_vendor_file(vendor_file.path()).await;
170
171        assert!(result);
172
173        Ok(())
174    }
175
176    #[tokio::test]
177    async fn test_check_vendor_file_failure() -> Result<()> {
178        let vendor_file = NamedTempFile::new()?;
179
180        let provider = Azure;
181        let result = provider.check_vendor_file(vendor_file.path()).await;
182
183        assert!(!result);
184
185        Ok(())
186    }
187}