cloud_detect/providers/
azure.rs1use std::path::Path;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use tokio::fs;
9use tokio::sync::mpsc::Sender;
10use tracing::{debug, error, info, instrument};
11
12use crate::{Provider, ProviderId};
13
14const METADATA_URI: &str = "http://169.254.169.254";
15const METADATA_PATH: &str = "/metadata/instance?api-version=2017-12-01";
16const VENDOR_FILE: &str = "/sys/class/dmi/id/sys_vendor";
17pub(crate) const IDENTIFIER: ProviderId = ProviderId::Azure;
18
19#[derive(Serialize, Deserialize)]
20struct Compute {
21 #[serde(rename = "vmId")]
22 vm_id: String,
23}
24
25#[derive(Serialize, Deserialize)]
26struct MetadataResponse {
27 compute: Compute,
28}
29
30pub(crate) struct Azure;
31
32#[async_trait]
33impl Provider for Azure {
34 fn identifier(&self) -> ProviderId {
35 IDENTIFIER
36 }
37
38 #[instrument(skip_all)]
40 async fn identify(&self, tx: Sender<ProviderId>, timeout: Duration) {
41 info!("Checking Microsoft Azure");
42 if self.check_vendor_file(VENDOR_FILE).await
43 || self.check_metadata_server(METADATA_URI, timeout).await
44 {
45 info!("Identified Microsoft Azure");
46 let res = tx.send(IDENTIFIER).await;
47
48 if let Err(err) = res {
49 error!("Error sending message: {:?}", err);
50 }
51 }
52 }
53}
54
55impl Azure {
56 #[instrument(skip_all)]
58 async fn check_metadata_server(&self, metadata_uri: &str, timeout: Duration) -> bool {
59 let url = format!("{metadata_uri}{METADATA_PATH}");
60 debug!("Checking {} metadata using url: {}", IDENTIFIER, url);
61
62 let client = if let Ok(client) = reqwest::Client::builder().timeout(timeout).build() {
63 client
64 } else {
65 error!("Error creating client");
66 return false;
67 };
68 let req = client.get(url).header("Metadata", "true");
69
70 match req.send().await {
71 Ok(resp) => match resp.json::<MetadataResponse>().await {
72 Ok(resp) => !resp.compute.vm_id.is_empty(),
73 Err(err) => {
74 error!("Error reading response: {:?}", err);
75 false
76 }
77 },
78 Err(err) => {
79 error!("Error making request: {:?}", err);
80 false
81 }
82 }
83 }
84
85 #[instrument(skip_all)]
87 async fn check_vendor_file<P: AsRef<Path>>(&self, vendor_file: P) -> bool {
88 debug!(
89 "Checking {} vendor file: {}",
90 IDENTIFIER,
91 vendor_file.as_ref().display()
92 );
93
94 if vendor_file.as_ref().is_file() {
95 return match fs::read_to_string(vendor_file).await {
96 Ok(content) => content.contains("Microsoft Corporation"),
97 Err(err) => {
98 error!("Error reading file: {:?}", err);
99 false
100 }
101 };
102 }
103
104 false
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use std::io::Write;
111
112 use anyhow::Result;
113 use tempfile::NamedTempFile;
114 use wiremock::matchers::query_param;
115 use wiremock::{Mock, MockServer, ResponseTemplate};
116
117 use super::*;
118
119 #[tokio::test]
120 async fn test_check_metadata_server_success() {
121 let mock_server = MockServer::start().await;
122 Mock::given(query_param("api-version", "2017-12-01"))
123 .respond_with(ResponseTemplate::new(200).set_body_json(MetadataResponse {
124 compute: Compute {
125 vm_id: "vm-123abc".to_string(),
126 },
127 }))
128 .expect(1)
129 .mount(&mock_server)
130 .await;
131
132 let provider = Azure;
133 let metadata_uri = mock_server.uri();
134 let result = provider
135 .check_metadata_server(&metadata_uri, Duration::from_secs(1))
136 .await;
137
138 assert!(result);
139 }
140
141 #[tokio::test]
142 async fn test_check_metadata_server_failure() {
143 let mock_server = MockServer::start().await;
144 Mock::given(query_param("api-version", "2017-12-01"))
145 .respond_with(ResponseTemplate::new(200).set_body_json(MetadataResponse {
146 compute: Compute {
147 vm_id: "".to_string(),
148 },
149 }))
150 .expect(1)
151 .mount(&mock_server)
152 .await;
153
154 let provider = Azure;
155 let metadata_uri = mock_server.uri();
156 let result = provider
157 .check_metadata_server(&metadata_uri, Duration::from_secs(1))
158 .await;
159
160 assert!(!result);
161 }
162
163 #[tokio::test]
164 async fn test_check_vendor_file_success() -> Result<()> {
165 let mut vendor_file = NamedTempFile::new()?;
166 vendor_file.write_all(b"Microsoft Corporation")?;
167
168 let provider = Azure;
169 let result = provider.check_vendor_file(vendor_file.path()).await;
170
171 assert!(result);
172
173 Ok(())
174 }
175
176 #[tokio::test]
177 async fn test_check_vendor_file_failure() -> Result<()> {
178 let vendor_file = NamedTempFile::new()?;
179
180 let provider = Azure;
181 let result = provider.check_vendor_file(vendor_file.path()).await;
182
183 assert!(!result);
184
185 Ok(())
186 }
187}