1use reqwest::Client;
4use serde::Deserialize;
5use tracing::debug;
6
7use crate::auth::AzCliAuth;
8use crate::error::ClientError;
9
10const ARM_BASE_URL: &str = "https://management.azure.com";
11
12pub struct ArmClient {
14 http: Client,
15 token: String,
16}
17
18#[derive(Debug, Clone, Deserialize)]
20#[serde(rename_all = "camelCase")]
21pub struct Subscription {
22 pub subscription_id: String,
23 pub display_name: String,
24 pub state: String,
25}
26
27impl std::fmt::Display for Subscription {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(f, "{} ({})", self.display_name, self.subscription_id)
30 }
31}
32
33#[derive(Debug, Clone, Deserialize)]
35pub struct SearchService {
36 pub name: String,
37 pub location: String,
38 pub sku: SearchServiceSku,
39 #[serde(default)]
40 pub id: String,
41}
42
43#[derive(Debug, Clone, Deserialize)]
44pub struct SearchServiceSku {
45 pub name: String,
46}
47
48impl std::fmt::Display for SearchService {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 write!(
51 f,
52 "{} ({}, {})",
53 self.name,
54 self.location,
55 self.sku.name.to_uppercase()
56 )
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct DiscoveredService {
63 pub name: String,
64 pub subscription_id: String,
65 pub location: String,
66}
67
68#[derive(Debug, Clone, Deserialize)]
70pub struct StorageAccount {
71 pub name: String,
72 pub location: String,
73 #[serde(default)]
74 pub id: String,
75}
76
77impl std::fmt::Display for StorageAccount {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 write!(f, "{} ({})", self.name, self.location)
80 }
81}
82
83#[derive(Debug, Clone, Deserialize)]
85struct StorageKey {
86 value: String,
87}
88
89#[derive(Debug, Deserialize)]
91struct StorageKeyList {
92 keys: Vec<StorageKey>,
93}
94
95#[derive(Debug, Deserialize)]
97struct ArmListResponse<T> {
98 value: Vec<T>,
99}
100
101impl ArmClient {
102 pub fn new() -> Result<Self, ClientError> {
104 let token = AzCliAuth::get_arm_token()?;
105 let http = Client::builder()
106 .timeout(std::time::Duration::from_secs(30))
107 .build()?;
108
109 Ok(Self { http, token })
110 }
111
112 pub async fn list_subscriptions(&self) -> Result<Vec<Subscription>, ClientError> {
114 let url = format!("{}/subscriptions?api-version=2022-12-01", ARM_BASE_URL);
115 debug!("Listing subscriptions: {}", url);
116
117 let response = self
118 .http
119 .get(&url)
120 .header("Authorization", format!("Bearer {}", self.token))
121 .send()
122 .await?;
123
124 let status = response.status();
125 if !status.is_success() {
126 let body = response.text().await?;
127 return Err(ClientError::from_response(status.as_u16(), &body));
128 }
129
130 let result: ArmListResponse<Subscription> = response.json().await?;
131 Ok(result
133 .value
134 .into_iter()
135 .filter(|s| s.state == "Enabled")
136 .collect())
137 }
138
139 pub async fn list_search_services(
141 &self,
142 subscription_id: &str,
143 ) -> Result<Vec<SearchService>, ClientError> {
144 let url = format!(
145 "{}/subscriptions/{}/providers/Microsoft.Search/searchServices?api-version=2023-11-01",
146 ARM_BASE_URL, subscription_id
147 );
148 debug!("Listing search services: {}", url);
149
150 let response = self
151 .http
152 .get(&url)
153 .header("Authorization", format!("Bearer {}", self.token))
154 .send()
155 .await?;
156
157 let status = response.status();
158 if !status.is_success() {
159 let body = response.text().await?;
160 return Err(ClientError::from_response(status.as_u16(), &body));
161 }
162
163 let result: ArmListResponse<SearchService> = response.json().await?;
164 Ok(result.value)
165 }
166
167 pub async fn find_resource_group(
171 &self,
172 subscription_id: &str,
173 service_name: &str,
174 ) -> Result<String, ClientError> {
175 let services = self.list_search_services(subscription_id).await?;
176
177 for svc in &services {
178 if svc.name.eq_ignore_ascii_case(service_name) {
179 return parse_resource_group(&svc.id).ok_or_else(|| ClientError::Api {
182 status: 0,
183 message: format!("Could not parse resource group from ARM ID: {}", svc.id),
184 });
185 }
186 }
187
188 Err(ClientError::NotFound {
189 kind: "Search service".to_string(),
190 name: service_name.to_string(),
191 })
192 }
193
194 pub async fn list_storage_accounts(
196 &self,
197 subscription_id: &str,
198 resource_group: &str,
199 ) -> Result<Vec<StorageAccount>, ClientError> {
200 let url = format!(
201 "{}/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Storage/storageAccounts?api-version=2023-05-01",
202 ARM_BASE_URL, subscription_id, resource_group
203 );
204 debug!("Listing storage accounts: {}", url);
205
206 let response = self
207 .http
208 .get(&url)
209 .header("Authorization", format!("Bearer {}", self.token))
210 .send()
211 .await?;
212
213 let status = response.status();
214 if !status.is_success() {
215 let body = response.text().await?;
216 return Err(ClientError::from_response(status.as_u16(), &body));
217 }
218
219 let result: ArmListResponse<StorageAccount> = response.json().await?;
220 Ok(result.value)
221 }
222
223 pub async fn get_storage_account_key(
225 &self,
226 subscription_id: &str,
227 resource_group: &str,
228 account_name: &str,
229 ) -> Result<String, ClientError> {
230 let url = format!(
231 "{}/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Storage/storageAccounts/{}/listKeys?api-version=2023-05-01",
232 ARM_BASE_URL, subscription_id, resource_group, account_name
233 );
234 debug!("Getting storage account keys: {}", url);
235
236 let response = self
237 .http
238 .post(&url)
239 .header("Authorization", format!("Bearer {}", self.token))
240 .header("Content-Length", "0")
241 .send()
242 .await?;
243
244 let status = response.status();
245 if !status.is_success() {
246 let body = response.text().await?;
247 return Err(ClientError::from_response(status.as_u16(), &body));
248 }
249
250 let key_list: StorageKeyList = response.json().await?;
251 key_list
252 .keys
253 .into_iter()
254 .next()
255 .map(|k| k.value)
256 .ok_or_else(|| ClientError::Api {
257 status: 0,
258 message: "No keys found for storage account".to_string(),
259 })
260 }
261
262 pub async fn get_storage_connection_string(
264 &self,
265 subscription_id: &str,
266 resource_group: &str,
267 account_name: &str,
268 ) -> Result<String, ClientError> {
269 let key = self
270 .get_storage_account_key(subscription_id, resource_group, account_name)
271 .await?;
272
273 Ok(format!(
274 "DefaultEndpointsProtocol=https;AccountName={};AccountKey={};EndpointSuffix=core.windows.net",
275 account_name, key
276 ))
277 }
278}
279
280fn parse_resource_group(arm_id: &str) -> Option<String> {
284 let parts: Vec<&str> = arm_id.split('/').collect();
285 for (i, part) in parts.iter().enumerate() {
286 if part.eq_ignore_ascii_case("resourceGroups")
287 || part.eq_ignore_ascii_case("resourcegroups")
288 {
289 return parts.get(i + 1).map(|s| s.to_string());
290 }
291 }
292 None
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_parse_resource_group() {
301 let id = "/subscriptions/abc-123/resourceGroups/my-rg/providers/Microsoft.Search/searchServices/my-svc";
302 assert_eq!(parse_resource_group(id), Some("my-rg".to_string()));
303 }
304
305 #[test]
306 fn test_parse_resource_group_case_insensitive() {
307 let id = "/subscriptions/abc/resourcegroups/MyRG/providers/Something";
308 assert_eq!(parse_resource_group(id), Some("MyRG".to_string()));
309 }
310
311 #[test]
312 fn test_parse_resource_group_missing() {
313 let id = "/subscriptions/abc/providers/Something";
314 assert_eq!(parse_resource_group(id), None);
315 }
316}