Skip to main content

hoist_client/
arm.rs

1//! Azure Resource Manager client for discovering Search services
2
3use reqwest::Client;
4use serde::Deserialize;
5use tracing::debug;
6
7use crate::auth::AzCliAuth;
8use crate::error::ClientError;
9
10const ARM_BASE_URL: &str = "https://management.azure.com";
11
12/// Azure Resource Manager client for subscription/service discovery
13pub struct ArmClient {
14    http: Client,
15    token: String,
16}
17
18/// Azure subscription
19#[derive(Debug, Clone, Deserialize)]
20#[serde(rename_all = "camelCase")]
21pub struct Subscription {
22    pub subscription_id: String,
23    pub display_name: String,
24    pub state: String,
25}
26
27impl std::fmt::Display for Subscription {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        write!(f, "{} ({})", self.display_name, self.subscription_id)
30    }
31}
32
33/// Azure AI Search service
34#[derive(Debug, Clone, Deserialize)]
35pub struct SearchService {
36    pub name: String,
37    pub location: String,
38    pub sku: SearchServiceSku,
39    #[serde(default)]
40    pub id: String,
41}
42
43#[derive(Debug, Clone, Deserialize)]
44pub struct SearchServiceSku {
45    pub name: String,
46}
47
48impl std::fmt::Display for SearchService {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        write!(
51            f,
52            "{} ({}, {})",
53            self.name,
54            self.location,
55            self.sku.name.to_uppercase()
56        )
57    }
58}
59
60/// Result of the discovery flow
61#[derive(Debug, Clone)]
62pub struct DiscoveredService {
63    pub name: String,
64    pub subscription_id: String,
65    pub location: String,
66}
67
68/// Azure Storage account
69#[derive(Debug, Clone, Deserialize)]
70pub struct StorageAccount {
71    pub name: String,
72    pub location: String,
73    #[serde(default)]
74    pub id: String,
75}
76
77impl std::fmt::Display for StorageAccount {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        write!(f, "{} ({})", self.name, self.location)
80    }
81}
82
83/// Storage account key
84#[derive(Debug, Clone, Deserialize)]
85struct StorageKey {
86    value: String,
87}
88
89/// Storage account key list response
90#[derive(Debug, Deserialize)]
91struct StorageKeyList {
92    keys: Vec<StorageKey>,
93}
94
95/// ARM list response envelope
96#[derive(Debug, Deserialize)]
97struct ArmListResponse<T> {
98    value: Vec<T>,
99}
100
101impl ArmClient {
102    /// Create a new ARM client using Azure CLI credentials
103    pub fn new() -> Result<Self, ClientError> {
104        let token = AzCliAuth::get_arm_token()?;
105        let http = Client::builder()
106            .timeout(std::time::Duration::from_secs(30))
107            .build()?;
108
109        Ok(Self { http, token })
110    }
111
112    /// List subscriptions the user has access to
113    pub async fn list_subscriptions(&self) -> Result<Vec<Subscription>, ClientError> {
114        let url = format!("{}/subscriptions?api-version=2022-12-01", ARM_BASE_URL);
115        debug!("Listing subscriptions: {}", url);
116
117        let response = self
118            .http
119            .get(&url)
120            .header("Authorization", format!("Bearer {}", self.token))
121            .send()
122            .await?;
123
124        let status = response.status();
125        if !status.is_success() {
126            let body = response.text().await?;
127            return Err(ClientError::from_response(status.as_u16(), &body));
128        }
129
130        let result: ArmListResponse<Subscription> = response.json().await?;
131        // Only return enabled subscriptions
132        Ok(result
133            .value
134            .into_iter()
135            .filter(|s| s.state == "Enabled")
136            .collect())
137    }
138
139    /// List Azure AI Search services in a subscription
140    pub async fn list_search_services(
141        &self,
142        subscription_id: &str,
143    ) -> Result<Vec<SearchService>, ClientError> {
144        let url = format!(
145            "{}/subscriptions/{}/providers/Microsoft.Search/searchServices?api-version=2023-11-01",
146            ARM_BASE_URL, subscription_id
147        );
148        debug!("Listing search services: {}", url);
149
150        let response = self
151            .http
152            .get(&url)
153            .header("Authorization", format!("Bearer {}", self.token))
154            .send()
155            .await?;
156
157        let status = response.status();
158        if !status.is_success() {
159            let body = response.text().await?;
160            return Err(ClientError::from_response(status.as_u16(), &body));
161        }
162
163        let result: ArmListResponse<SearchService> = response.json().await?;
164        Ok(result.value)
165    }
166
167    /// Find the resource group of a search service by scanning the subscription.
168    ///
169    /// Returns the resource group name extracted from the service's ARM resource ID.
170    pub async fn find_resource_group(
171        &self,
172        subscription_id: &str,
173        service_name: &str,
174    ) -> Result<String, ClientError> {
175        let services = self.list_search_services(subscription_id).await?;
176
177        for svc in &services {
178            if svc.name.eq_ignore_ascii_case(service_name) {
179                // Parse resource group from ARM ID:
180                // /subscriptions/{sub}/resourceGroups/{rg}/providers/...
181                return parse_resource_group(&svc.id).ok_or_else(|| ClientError::Api {
182                    status: 0,
183                    message: format!("Could not parse resource group from ARM ID: {}", svc.id),
184                });
185            }
186        }
187
188        Err(ClientError::NotFound {
189            kind: "Search service".to_string(),
190            name: service_name.to_string(),
191        })
192    }
193
194    /// List storage accounts in a resource group.
195    pub async fn list_storage_accounts(
196        &self,
197        subscription_id: &str,
198        resource_group: &str,
199    ) -> Result<Vec<StorageAccount>, ClientError> {
200        let url = format!(
201            "{}/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Storage/storageAccounts?api-version=2023-05-01",
202            ARM_BASE_URL, subscription_id, resource_group
203        );
204        debug!("Listing storage accounts: {}", url);
205
206        let response = self
207            .http
208            .get(&url)
209            .header("Authorization", format!("Bearer {}", self.token))
210            .send()
211            .await?;
212
213        let status = response.status();
214        if !status.is_success() {
215            let body = response.text().await?;
216            return Err(ClientError::from_response(status.as_u16(), &body));
217        }
218
219        let result: ArmListResponse<StorageAccount> = response.json().await?;
220        Ok(result.value)
221    }
222
223    /// Get the primary access key for a storage account.
224    pub async fn get_storage_account_key(
225        &self,
226        subscription_id: &str,
227        resource_group: &str,
228        account_name: &str,
229    ) -> Result<String, ClientError> {
230        let url = format!(
231            "{}/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Storage/storageAccounts/{}/listKeys?api-version=2023-05-01",
232            ARM_BASE_URL, subscription_id, resource_group, account_name
233        );
234        debug!("Getting storage account keys: {}", url);
235
236        let response = self
237            .http
238            .post(&url)
239            .header("Authorization", format!("Bearer {}", self.token))
240            .header("Content-Length", "0")
241            .send()
242            .await?;
243
244        let status = response.status();
245        if !status.is_success() {
246            let body = response.text().await?;
247            return Err(ClientError::from_response(status.as_u16(), &body));
248        }
249
250        let key_list: StorageKeyList = response.json().await?;
251        key_list
252            .keys
253            .into_iter()
254            .next()
255            .map(|k| k.value)
256            .ok_or_else(|| ClientError::Api {
257                status: 0,
258                message: "No keys found for storage account".to_string(),
259            })
260    }
261
262    /// Build a full connection string for a storage account.
263    pub async fn get_storage_connection_string(
264        &self,
265        subscription_id: &str,
266        resource_group: &str,
267        account_name: &str,
268    ) -> Result<String, ClientError> {
269        let key = self
270            .get_storage_account_key(subscription_id, resource_group, account_name)
271            .await?;
272
273        Ok(format!(
274            "DefaultEndpointsProtocol=https;AccountName={};AccountKey={};EndpointSuffix=core.windows.net",
275            account_name, key
276        ))
277    }
278}
279
280/// Parse resource group from an ARM resource ID.
281///
282/// ARM IDs look like: `/subscriptions/{sub}/resourceGroups/{rg}/providers/...`
283fn parse_resource_group(arm_id: &str) -> Option<String> {
284    let parts: Vec<&str> = arm_id.split('/').collect();
285    for (i, part) in parts.iter().enumerate() {
286        if part.eq_ignore_ascii_case("resourceGroups")
287            || part.eq_ignore_ascii_case("resourcegroups")
288        {
289            return parts.get(i + 1).map(|s| s.to_string());
290        }
291    }
292    None
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_parse_resource_group() {
301        let id = "/subscriptions/abc-123/resourceGroups/my-rg/providers/Microsoft.Search/searchServices/my-svc";
302        assert_eq!(parse_resource_group(id), Some("my-rg".to_string()));
303    }
304
305    #[test]
306    fn test_parse_resource_group_case_insensitive() {
307        let id = "/subscriptions/abc/resourcegroups/MyRG/providers/Something";
308        assert_eq!(parse_resource_group(id), Some("MyRG".to_string()));
309    }
310
311    #[test]
312    fn test_parse_resource_group_missing() {
313        let id = "/subscriptions/abc/providers/Something";
314        assert_eq!(parse_resource_group(id), None);
315    }
316}