1use serde::{Deserialize, Serialize};
4use tracing::debug;
5
6use crate::auth::{ARM_RESOURCE, AzCliAuth};
7use crate::error::ClientError;
8
9const ARM_SUBSCRIPTIONS_API_VERSION: &str = "2024-11-01";
10const COSMOS_DB_API_VERSION: &str = "2025-04-15";
11const ARM_BASE_URL: &str = "https://management.azure.com";
12
13#[derive(Debug, Clone, Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub struct Subscription {
17 pub subscription_id: String,
18 pub display_name: String,
19 pub state: String,
20}
21
22#[derive(Debug, Deserialize)]
23struct SubscriptionListResponse {
24 value: Vec<Subscription>,
25}
26
27#[derive(Debug, Clone)]
29pub struct CosmosAccount {
30 pub name: String,
31 pub location: String,
32 pub kind: Option<String>,
33 pub endpoint: String,
34 pub resource_group: String,
35 pub id: String,
36}
37
38#[derive(Debug, Deserialize)]
39struct CosmosAccountListResponse {
40 value: Vec<CosmosAccountResource>,
41}
42
43#[derive(Debug, Deserialize)]
44#[serde(rename_all = "camelCase")]
45struct CosmosAccountResource {
46 id: String,
47 name: String,
48 location: String,
49 kind: Option<String>,
50 properties: CosmosAccountProperties,
51}
52
53#[derive(Debug, Deserialize)]
54#[serde(rename_all = "camelCase")]
55struct CosmosAccountProperties {
56 document_endpoint: Option<String>,
57}
58
59pub struct ArmClient {
61 http: reqwest::Client,
62 token: String,
63}
64
65impl ArmClient {
66 pub async fn new() -> Result<Self, ClientError> {
68 let token = AzCliAuth::get_token(ARM_RESOURCE).await?;
69 Ok(Self {
70 http: reqwest::Client::new(),
71 token,
72 })
73 }
74
75 pub async fn list_subscriptions(&self) -> Result<Vec<Subscription>, ClientError> {
77 debug!("listing Azure subscriptions");
78
79 let url =
80 format!("{ARM_BASE_URL}/subscriptions?api-version={ARM_SUBSCRIPTIONS_API_VERSION}");
81 let resp = self.http.get(&url).bearer_auth(&self.token).send().await?;
82
83 let status = resp.status();
84 if !status.is_success() {
85 let body = resp.text().await.unwrap_or_default();
86 return Err(ClientError::api(status.as_u16(), body));
87 }
88
89 let list: SubscriptionListResponse = resp.json().await?;
90 let enabled: Vec<Subscription> = list
91 .value
92 .into_iter()
93 .filter(|s| s.state == "Enabled")
94 .collect();
95
96 debug!(count = enabled.len(), "found enabled subscriptions");
97 Ok(enabled)
98 }
99
100 pub async fn list_cosmos_accounts(
102 &self,
103 subscription_id: &str,
104 ) -> Result<Vec<CosmosAccount>, ClientError> {
105 debug!(subscription_id, "listing Cosmos DB accounts");
106
107 let url = format!(
108 "{ARM_BASE_URL}/subscriptions/{subscription_id}/providers/Microsoft.DocumentDB/databaseAccounts?api-version={COSMOS_DB_API_VERSION}"
109 );
110
111 let resp = self.http.get(&url).bearer_auth(&self.token).send().await?;
112
113 let status = resp.status();
114 if !status.is_success() {
115 let body = resp.text().await.unwrap_or_default();
116 if status.as_u16() == 403 {
117 return Err(ClientError::forbidden(
118 body,
119 "You may not have Reader access on this subscription. Check your Azure RBAC roles.",
120 ));
121 }
122 return Err(ClientError::api(status.as_u16(), body));
123 }
124
125 let list: CosmosAccountListResponse = resp.json().await?;
126 let accounts: Vec<CosmosAccount> = list
127 .value
128 .into_iter()
129 .map(|r| {
130 let resource_group =
133 r.id.split('/')
134 .collect::<Vec<_>>()
135 .windows(2)
136 .find(|w| w[0].eq_ignore_ascii_case("resourceGroups"))
137 .map(|w| w[1].to_string())
138 .unwrap_or_default();
139
140 CosmosAccount {
141 name: r.name,
142 location: r.location,
143 kind: r.kind,
144 endpoint: r.properties.document_endpoint.unwrap_or_default(),
145 resource_group,
146 id: r.id,
147 }
148 })
149 .collect();
150
151 debug!(count = accounts.len(), "found Cosmos DB accounts");
152 Ok(accounts)
153 }
154
155 pub async fn has_cosmos_data_role(
157 &self,
158 account_resource_id: &str,
159 principal_id: &str,
160 ) -> Result<bool, ClientError> {
161 debug!(principal_id, "checking Cosmos DB SQL role assignments");
162
163 let url = format!(
164 "{ARM_BASE_URL}{account_resource_id}/sqlRoleAssignments?api-version={COSMOS_DB_API_VERSION}"
165 );
166 let resp = self.http.get(&url).bearer_auth(&self.token).send().await?;
167
168 let status = resp.status();
169 if !status.is_success() {
170 let body = resp.text().await.unwrap_or_default();
171 return Err(ClientError::api(status.as_u16(), body));
172 }
173
174 let list: SqlRoleAssignmentListResponse = resp.json().await?;
175 let has_role = list
176 .value
177 .iter()
178 .any(|a| a.properties.principal_id == principal_id);
179
180 debug!(has_role, "data plane role check complete");
181 Ok(has_role)
182 }
183
184 pub async fn assign_cosmos_data_contributor(
186 &self,
187 account_resource_id: &str,
188 principal_id: &str,
189 ) -> Result<(), ClientError> {
190 debug!(principal_id, "assigning Cosmos DB data contributor role");
191
192 let assignment_id = uuid::Uuid::new_v4().to_string();
193 let url = format!(
194 "{ARM_BASE_URL}{account_resource_id}/sqlRoleAssignments/{assignment_id}?api-version={COSMOS_DB_API_VERSION}"
195 );
196
197 let body = SqlRoleAssignmentCreateBody {
198 properties: SqlRoleAssignmentCreateProperties {
199 role_definition_id: format!(
200 "{account_resource_id}/sqlRoleDefinitions/{COSMOS_DATA_CONTRIBUTOR_ROLE}"
201 ),
202 scope: account_resource_id.to_string(),
203 principal_id: principal_id.to_string(),
204 },
205 };
206
207 let resp = self
208 .http
209 .put(&url)
210 .bearer_auth(&self.token)
211 .json(&body)
212 .send()
213 .await?;
214
215 let status = resp.status();
216 if !status.is_success() {
217 let resp_body = resp.text().await.unwrap_or_default();
218 if status.as_u16() == 403 {
219 return Err(ClientError::forbidden(
220 resp_body,
221 "You need Owner or User Access Administrator role on the Cosmos DB account to assign data plane roles.",
222 ));
223 }
224 return Err(ClientError::api(status.as_u16(), resp_body));
225 }
226
227 debug!("data contributor role assigned successfully");
228 Ok(())
229 }
230}
231
232const COSMOS_DATA_CONTRIBUTOR_ROLE: &str = "00000000-0000-0000-0000-000000000002";
234
235#[derive(Debug, Deserialize)]
236struct SqlRoleAssignmentListResponse {
237 value: Vec<SqlRoleAssignment>,
238}
239
240#[derive(Debug, Deserialize)]
241struct SqlRoleAssignment {
242 properties: SqlRoleAssignmentProperties,
243}
244
245#[derive(Debug, Deserialize)]
246#[serde(rename_all = "camelCase")]
247struct SqlRoleAssignmentProperties {
248 principal_id: String,
249}
250
251#[derive(Debug, Serialize)]
252struct SqlRoleAssignmentCreateBody {
253 properties: SqlRoleAssignmentCreateProperties,
254}
255
256#[derive(Debug, Serialize)]
257#[serde(rename_all = "camelCase")]
258struct SqlRoleAssignmentCreateProperties {
259 role_definition_id: String,
260 scope: String,
261 principal_id: String,
262}