cloud_detect/providers/
digitalocean.rs1use std::path::Path;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use tokio::fs;
9use tokio::sync::mpsc::Sender;
10use tracing::{debug, error, info, instrument};
11
12use crate::{Provider, ProviderId};
13
14const METADATA_URI: &str = "http://169.254.169.254";
15const METADATA_PATH: &str = "/metadata/v1.json";
16const VENDOR_FILE: &str = "/sys/class/dmi/id/sys_vendor";
17pub(crate) const IDENTIFIER: ProviderId = ProviderId::DigitalOcean;
18
19pub(crate) struct DigitalOcean;
20
21#[derive(Serialize, Deserialize)]
22struct MetadataResponse {
23 droplet_id: usize,
24}
25
26#[async_trait]
27impl Provider for DigitalOcean {
28 fn identifier(&self) -> ProviderId {
29 IDENTIFIER
30 }
31
32 #[instrument(skip_all)]
34 async fn identify(&self, tx: Sender<ProviderId>, timeout: Duration) {
35 info!("Checking DigitalOcean");
36 if self.check_vendor_file(VENDOR_FILE).await
37 || self.check_metadata_server(METADATA_URI, timeout).await
38 {
39 info!("Identified DigitalOcean");
40 let res = tx.send(IDENTIFIER).await;
41
42 if let Err(err) = res {
43 error!("Error sending message: {:?}", err);
44 }
45 }
46 }
47}
48
49impl DigitalOcean {
50 #[instrument(skip_all)]
52 async fn check_metadata_server(&self, metadata_uri: &str, timeout: Duration) -> bool {
53 let url = format!("{metadata_uri}{METADATA_PATH}");
54 debug!("Checking {} metadata using url: {}", IDENTIFIER, url);
55
56 let client = if let Ok(client) = reqwest::Client::builder().timeout(timeout).build() {
57 client
58 } else {
59 error!("Error creating client");
60 return false;
61 };
62
63 match client.get(url).send().await {
64 Ok(resp) => match resp.json::<MetadataResponse>().await {
65 Ok(resp) => resp.droplet_id > 0,
66 Err(err) => {
67 error!("Error reading response: {:?}", err);
68 false
69 }
70 },
71 Err(err) => {
72 error!("Error making request: {:?}", err);
73 false
74 }
75 }
76 }
77
78 #[instrument(skip_all)]
80 async fn check_vendor_file<P: AsRef<Path>>(&self, vendor_file: P) -> bool {
81 debug!(
82 "Checking {} vendor file: {}",
83 IDENTIFIER,
84 vendor_file.as_ref().display()
85 );
86
87 if vendor_file.as_ref().is_file() {
88 return match fs::read_to_string(vendor_file).await {
89 Ok(content) => content.contains("DigitalOcean"),
90 Err(err) => {
91 error!("Error reading file: {:?}", err);
92 false
93 }
94 };
95 }
96
97 false
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use std::io::Write;
104
105 use anyhow::Result;
106 use tempfile::NamedTempFile;
107 use wiremock::matchers::path;
108 use wiremock::{Mock, MockServer, ResponseTemplate};
109
110 use super::*;
111
112 #[tokio::test]
113 async fn test_check_metadata_server_success() {
114 let mock_server = MockServer::start().await;
115 Mock::given(path(METADATA_PATH))
116 .respond_with(
117 ResponseTemplate::new(200).set_body_json(MetadataResponse { droplet_id: 123 }),
118 )
119 .expect(1)
120 .mount(&mock_server)
121 .await;
122
123 let provider = DigitalOcean;
124 let metadata_uri = mock_server.uri();
125 let result = provider
126 .check_metadata_server(&metadata_uri, Duration::from_secs(1))
127 .await;
128
129 assert!(result);
130 }
131
132 #[tokio::test]
133 async fn test_check_metadata_server_failure() {
134 let mock_server = MockServer::start().await;
135 Mock::given(path(METADATA_PATH))
136 .respond_with(
137 ResponseTemplate::new(200).set_body_json(MetadataResponse { droplet_id: 0 }),
138 )
139 .expect(1)
140 .mount(&mock_server)
141 .await;
142
143 let provider = DigitalOcean;
144 let metadata_uri = mock_server.uri();
145 let result = provider
146 .check_metadata_server(&metadata_uri, Duration::from_secs(1))
147 .await;
148
149 assert!(!result);
150 }
151
152 #[tokio::test]
153 async fn test_check_vendor_file_success() -> Result<()> {
154 let mut vendor_file = NamedTempFile::new()?;
155 vendor_file.write_all(b"DigitalOcean")?;
156
157 let provider = DigitalOcean;
158 let result = provider.check_vendor_file(vendor_file.path()).await;
159
160 assert!(result);
161
162 Ok(())
163 }
164
165 #[tokio::test]
166 async fn test_check_vendor_file_failure() -> Result<()> {
167 let vendor_file = NamedTempFile::new()?;
168
169 let provider = DigitalOcean;
170 let result = provider.check_vendor_file(vendor_file.path()).await;
171
172 assert!(!result);
173
174 Ok(())
175 }
176}