1use std::collections::HashMap;
12
13use serde::{Deserialize, Serialize};
14
15use crate::{
16 AzureHttpClient, Result,
17 types::graph::{GraphBatchRequest, GraphBatchRequestItem, GraphBatchResponse, GraphUser},
18};
19
20#[derive(Debug, Clone, Default, Serialize, Deserialize)]
25#[serde(rename_all = "camelCase")]
26pub struct GraphServicePrincipal {
27 #[serde(skip_serializing_if = "Option::is_none")]
29 pub id: Option<String>,
30
31 #[serde(skip_serializing_if = "Option::is_none")]
33 pub app_id: Option<String>,
34
35 #[serde(skip_serializing_if = "Option::is_none")]
37 pub display_name: Option<String>,
38
39 #[serde(skip_serializing_if = "Option::is_none")]
41 pub service_principal_type: Option<String>,
42
43 #[serde(skip_serializing_if = "Option::is_none")]
45 pub account_enabled: Option<bool>,
46}
47
48#[derive(Debug, Clone, Default, Deserialize)]
50struct GraphListResponse<T> {
51 #[serde(default)]
52 value: Vec<T>,
53}
54
55const GRAPH_BASE: &str = "https://graph.microsoft.com/v1.0";
56
57const USER_SELECT: &str = "$select=id,displayName,userPrincipalName,userType,accountEnabled";
61
62const BATCH_MAX: usize = 20;
64
65pub struct GraphClient<'a> {
70 client: &'a AzureHttpClient,
71}
72
73impl<'a> GraphClient<'a> {
74 pub(crate) fn new(client: &'a AzureHttpClient) -> Self {
76 Self { client }
77 }
78
79 pub async fn get_user(&self, principal_id: &str) -> Result<Option<GraphUser>> {
83 let url = format!("{GRAPH_BASE}/users/{principal_id}?{USER_SELECT}");
84 let response = match self.client.graph_get(&url).await {
85 Ok(r) => r,
86 Err(crate::AzureError::NotFound { .. }) => return Ok(None),
87 Err(e) => return Err(e),
88 };
89 let response = response.error_for_status().await?;
90 let bytes = response.bytes().await?;
91 let user =
92 serde_json::from_slice(&bytes).map_err(|e| crate::AzureError::InvalidResponse {
93 message: format!("Failed to parse GraphUser: {e}"),
94 body: Some(String::from_utf8_lossy(&bytes).to_string()),
95 })?;
96 Ok(Some(user))
97 }
98
99 pub async fn list_service_principals(
108 &self,
109 filter: &str,
110 ) -> Result<Vec<GraphServicePrincipal>> {
111 let encoded_filter = urlencoding::encode(filter);
112 let url = format!(
113 "{GRAPH_BASE}/servicePrincipals?$filter={encoded_filter}&$select=id,appId,displayName,servicePrincipalType,accountEnabled"
114 );
115 let response = self.client.graph_get(&url).await?;
116 let response = response.error_for_status().await?;
117 let bytes = response.bytes().await?;
118 let list: GraphListResponse<GraphServicePrincipal> =
119 serde_json::from_slice(&bytes).map_err(|e| crate::AzureError::InvalidResponse {
120 message: format!("Failed to parse servicePrincipals response: {e}"),
121 body: Some(String::from_utf8_lossy(&bytes).to_string()),
122 })?;
123 Ok(list.value)
124 }
125
126 pub async fn batch_get_users(
134 &self,
135 principal_ids: &[&str],
136 ) -> Result<HashMap<String, GraphUser>> {
137 let mut result: HashMap<String, GraphUser> = HashMap::new();
138
139 for chunk in principal_ids.chunks(BATCH_MAX) {
140 let requests: Vec<GraphBatchRequestItem> = chunk
141 .iter()
142 .enumerate()
143 .map(|(i, id)| GraphBatchRequestItem {
144 id: i.to_string(),
145 method: "GET".to_string(),
146 url: format!("/users/{id}?{USER_SELECT}"),
147 })
148 .collect();
149
150 let index_to_id: Vec<&str> = chunk.to_vec();
152
153 let body = GraphBatchRequest { requests };
154 let body_bytes =
155 serde_json::to_vec(&body).map_err(|e| crate::AzureError::InvalidResponse {
156 message: format!("Failed to serialize GraphBatchRequest: {e}"),
157 body: None,
158 })?;
159
160 let url = format!("{GRAPH_BASE}/$batch");
161 let response = self.client.graph_post(&url, &body_bytes).await?;
162 let response = response.error_for_status().await?;
163 let bytes = response.bytes().await?;
164
165 let batch_resp: GraphBatchResponse =
166 serde_json::from_slice(&bytes).map_err(|e| crate::AzureError::InvalidResponse {
167 message: format!("Failed to parse GraphBatchResponse: {e}"),
168 body: Some(String::from_utf8_lossy(&bytes).to_string()),
169 })?;
170
171 for item in batch_resp.responses {
172 let status = item.status.unwrap_or(0);
173 if status == 404 {
174 continue;
175 }
176 if !(200..300).contains(&status) {
177 continue;
178 }
179 let idx: usize = item
180 .id
181 .as_deref()
182 .and_then(|s| s.parse().ok())
183 .unwrap_or(usize::MAX);
184 if idx >= index_to_id.len() {
185 continue;
186 }
187 let principal_id = index_to_id[idx];
188
189 if let Some(body_val) = item.body
190 && let Ok(user) = serde_json::from_value::<GraphUser>(body_val)
191 {
192 result.insert(principal_id.to_string(), user);
193 }
194 }
195 }
196
197 Ok(result)
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::MockClient;
205
206 fn make_client(mock: MockClient) -> AzureHttpClient {
207 AzureHttpClient::from_mock(mock)
208 }
209
210 fn user_json(id: &str, display_name: &str, user_type: &str) -> serde_json::Value {
211 serde_json::json!({
212 "id": id,
213 "displayName": display_name,
214 "userPrincipalName": format!("{}@contoso.com", display_name.to_lowercase()),
215 "userType": user_type
216 })
217 }
218
219 #[tokio::test]
224 async fn get_user_returns_member_user() {
225 let mut mock = MockClient::new();
226 mock.expect_get("/v1.0/users/abc-123")
228 .returning_json(user_json("abc-123", "Alice", "Member"));
229 let client = make_client(mock);
230 let user = client
231 .graph()
232 .get_user("abc-123")
233 .await
234 .expect("get_user failed")
235 .expect("user should be found");
236 assert_eq!(user.id.as_deref(), Some("abc-123"));
237 assert_eq!(user.display_name.as_deref(), Some("Alice"));
238 assert_eq!(user.user_type.as_deref(), Some("Member"));
239 }
240
241 #[tokio::test]
242 async fn get_user_returns_guest_user() {
243 let mut mock = MockClient::new();
244 mock.expect_get("/v1.0/users/ext-456")
245 .returning_json(serde_json::json!({
246 "id": "ext-456",
247 "displayName": "Bob External",
248 "userPrincipalName": "bob_external#EXT#@contoso.onmicrosoft.com",
249 "userType": "Guest"
250 }));
251 let client = make_client(mock);
252 let user = client
253 .graph()
254 .get_user("ext-456")
255 .await
256 .expect("get_user failed")
257 .expect("user should be found");
258 assert_eq!(user.user_type.as_deref(), Some("Guest"));
259 assert!(
260 user.user_principal_name
261 .as_deref()
262 .unwrap_or("")
263 .contains("#EXT#"),
264 "guest UPN should contain #EXT#"
265 );
266 }
267
268 #[tokio::test]
269 async fn get_user_returns_none_for_404() {
270 let mut mock = MockClient::new();
273 mock.expect_get("/v1.0/users/not-found")
274 .returning_error(crate::AzureError::NotFound {
275 resource: "User not-found".into(),
276 });
277 let client = make_client(mock);
278 let result = client.graph().get_user("not-found").await;
279 assert!(result.is_ok(), "NotFound should become Ok(None), not Err");
280 assert!(result.unwrap().is_none(), "should return None for 404");
281 }
282
283 #[tokio::test]
284 async fn batch_get_users_returns_map() {
285 let mut mock = MockClient::new();
286 mock.expect_post("/v1.0/$batch")
287 .returning_json(serde_json::json!({
288 "responses": [
289 {
290 "id": "0",
291 "status": 200,
292 "body": user_json("user-a", "Alice", "Member")
293 },
294 {
295 "id": "1",
296 "status": 200,
297 "body": user_json("user-b", "Bob", "Guest")
298 }
299 ]
300 }));
301 let client = make_client(mock);
302 let map = client
303 .graph()
304 .batch_get_users(&["user-a", "user-b"])
305 .await
306 .expect("batch_get_users failed");
307 assert_eq!(map.len(), 2);
308 assert_eq!(map["user-a"].user_type.as_deref(), Some("Member"));
309 assert_eq!(map["user-b"].user_type.as_deref(), Some("Guest"));
310 }
311
312 #[tokio::test]
313 async fn batch_get_users_omits_not_found() {
314 let mut mock = MockClient::new();
315 mock.expect_post("/v1.0/$batch")
316 .returning_json(serde_json::json!({
317 "responses": [
318 { "id": "0", "status": 200, "body": user_json("user-a", "Alice", "Member") },
319 { "id": "1", "status": 404, "body": null }
320 ]
321 }));
322 let client = make_client(mock);
323 let map = client
324 .graph()
325 .batch_get_users(&["user-a", "user-b"])
326 .await
327 .expect("batch_get_users failed");
328 assert_eq!(map.len(), 1);
329 assert!(map.contains_key("user-a"));
330 assert!(!map.contains_key("user-b"));
331 }
332
333 #[tokio::test]
334 async fn list_service_principals_returns_results() {
335 let mut mock = MockClient::new();
336 mock.expect_get("/v1.0/servicePrincipals")
337 .returning_json(serde_json::json!({
338 "value": [
339 {
340 "id": "sp-001",
341 "appId": "app-001",
342 "displayName": "Azure Databricks SCIM Provisioning Connector",
343 "servicePrincipalType": "Application",
344 "accountEnabled": true
345 }
346 ]
347 }));
348 let client = make_client(mock);
349 let results = client
350 .graph()
351 .list_service_principals("displayName eq 'Azure Databricks SCIM Provisioning Connector'")
352 .await
353 .expect("list_service_principals failed");
354 assert_eq!(results.len(), 1);
355 assert_eq!(results[0].id.as_deref(), Some("sp-001"));
356 assert_eq!(
357 results[0].display_name.as_deref(),
358 Some("Azure Databricks SCIM Provisioning Connector")
359 );
360 assert_eq!(results[0].account_enabled, Some(true));
361 assert_eq!(
362 results[0].service_principal_type.as_deref(),
363 Some("Application")
364 );
365 }
366
367 #[tokio::test]
368 async fn list_service_principals_returns_empty() {
369 let mut mock = MockClient::new();
370 mock.expect_get("/v1.0/servicePrincipals")
371 .returning_json(serde_json::json!({ "value": [] }));
372 let client = make_client(mock);
373 let results = client
374 .graph()
375 .list_service_principals("displayName eq 'nonexistent'")
376 .await
377 .expect("list_service_principals failed");
378 assert!(results.is_empty());
379 }
380
381 #[tokio::test]
382 async fn batch_get_users_handles_chunking() {
383 let ids: Vec<String> = (0..21).map(|i| format!("user-{i}")).collect();
385 let id_refs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
386
387 let first_batch_response = serde_json::json!({
388 "responses": (0..20_usize).map(|i| serde_json::json!({
389 "id": i.to_string(),
390 "status": 200,
391 "body": user_json(&format!("user-{i}"), &format!("User {i}"), "Member")
392 })).collect::<Vec<_>>()
393 });
394 let second_batch_response = serde_json::json!({
395 "responses": [
396 { "id": "0", "status": 200, "body": user_json("user-20", "User 20", "Member") }
397 ]
398 });
399
400 let mut mock = MockClient::new();
401 mock.expect_post("/v1.0/$batch")
402 .returning_json_sequence(vec![first_batch_response, second_batch_response])
403 .times(2);
404
405 let client = make_client(mock);
406 let map = client
407 .graph()
408 .batch_get_users(&id_refs)
409 .await
410 .expect("batch_get_users failed");
411 assert_eq!(map.len(), 21, "all 21 users should be in result");
412 }
413}