Skip to main content

azure_lite_rs/api/
graph.rs

1//! Microsoft Graph API client.
2//!
3//! Provides access to Entra ID (Azure AD) user objects via the Graph API.
4//! Unlike ARM APIs, Graph uses a different OAuth2 scope and base URL:
5//!   - Base URL: `https://graph.microsoft.com/v1.0`
6//!   - Scope:    `https://graph.microsoft.com/.default`
7//!
8//! HTTP is handled via `AzureHttpClient::graph_get` / `graph_post`, which
9//! acquire Graph-scoped tokens automatically.
10
11use std::collections::HashMap;
12
13use serde::{Deserialize, Serialize};
14
15use crate::{
16    AzureHttpClient, Result,
17    types::graph::{GraphBatchRequest, GraphBatchRequestItem, GraphBatchResponse, GraphUser},
18};
19
20/// An Entra ID service principal (enterprise application) returned by Microsoft Graph.
21///
22/// Used to detect provisioning connectors (e.g. Databricks SCIM) by querying
23/// `GET /v1.0/servicePrincipals?$filter=displayName eq '...'`.
24#[derive(Debug, Clone, Default, Serialize, Deserialize)]
25#[serde(rename_all = "camelCase")]
26pub struct GraphServicePrincipal {
27    /// The unique object ID of the service principal.
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub id: Option<String>,
30
31    /// The application (client) ID associated with this service principal.
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub app_id: Option<String>,
34
35    /// Display name of the service principal.
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub display_name: Option<String>,
38
39    /// The type of service principal: "Application", "ManagedIdentity", or "SocialIdp".
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub service_principal_type: Option<String>,
42
43    /// Whether the service principal is enabled for sign-in.
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub account_enabled: Option<bool>,
46}
47
48/// Response wrapper for Graph API list queries (`{ "value": [...] }`).
49#[derive(Debug, Clone, Default, Deserialize)]
50struct GraphListResponse<T> {
51    #[serde(default)]
52    value: Vec<T>,
53}
54
55const GRAPH_BASE: &str = "https://graph.microsoft.com/v1.0";
56
57/// Fields to request from Graph user objects.
58///
59/// `userType` and `accountEnabled` are not returned by default — they must be explicitly selected.
60const USER_SELECT: &str = "$select=id,displayName,userPrincipalName,userType,accountEnabled";
61
62/// Maximum requests per `$batch` call (Graph API limit).
63const BATCH_MAX: usize = 20;
64
65/// Client for the Microsoft Graph API.
66///
67/// Provides Entra ID user lookups needed to distinguish internal users
68/// ("Member") from external/guest users ("Guest") in RBAC policy evaluation.
69pub struct GraphClient<'a> {
70    client: &'a AzureHttpClient,
71}
72
73impl<'a> GraphClient<'a> {
74    /// Create a new Microsoft Graph API client.
75    pub(crate) fn new(client: &'a AzureHttpClient) -> Self {
76        Self { client }
77    }
78
79    /// Get a single user by their object ID (principal ID).
80    ///
81    /// Returns `None` if the user is not found (404).
82    pub async fn get_user(&self, principal_id: &str) -> Result<Option<GraphUser>> {
83        let url = format!("{GRAPH_BASE}/users/{principal_id}?{USER_SELECT}");
84        let response = match self.client.graph_get(&url).await {
85            Ok(r) => r,
86            Err(crate::AzureError::NotFound { .. }) => return Ok(None),
87            Err(e) => return Err(e),
88        };
89        let response = response.error_for_status().await?;
90        let bytes = response.bytes().await?;
91        let user =
92            serde_json::from_slice(&bytes).map_err(|e| crate::AzureError::InvalidResponse {
93                message: format!("Failed to parse GraphUser: {e}"),
94                body: Some(String::from_utf8_lossy(&bytes).to_string()),
95            })?;
96        Ok(Some(user))
97    }
98
99    /// List service principals matching an OData `$filter` expression.
100    ///
101    /// Example: find the Databricks SCIM provisioning connector:
102    /// ```ignore
103    /// let results = graph
104    ///     .list_service_principals("displayName eq 'Azure Databricks SCIM Provisioning Connector'")
105    ///     .await?;
106    /// ```
107    pub async fn list_service_principals(
108        &self,
109        filter: &str,
110    ) -> Result<Vec<GraphServicePrincipal>> {
111        let encoded_filter = urlencoding::encode(filter);
112        let url = format!(
113            "{GRAPH_BASE}/servicePrincipals?$filter={encoded_filter}&$select=id,appId,displayName,servicePrincipalType,accountEnabled"
114        );
115        let response = self.client.graph_get(&url).await?;
116        let response = response.error_for_status().await?;
117        let bytes = response.bytes().await?;
118        let list: GraphListResponse<GraphServicePrincipal> =
119            serde_json::from_slice(&bytes).map_err(|e| crate::AzureError::InvalidResponse {
120                message: format!("Failed to parse servicePrincipals response: {e}"),
121                body: Some(String::from_utf8_lossy(&bytes).to_string()),
122            })?;
123        Ok(list.value)
124    }
125
126    /// Batch-fetch multiple users by their object IDs.
127    ///
128    /// Returns a map of `principal_id → GraphUser` for users that were found.
129    /// Not-found (404) entries are silently omitted from the result.
130    ///
131    /// Internally uses `$batch` to resolve up to 20 IDs per HTTP call,
132    /// handling chunking automatically for larger sets.
133    pub async fn batch_get_users(
134        &self,
135        principal_ids: &[&str],
136    ) -> Result<HashMap<String, GraphUser>> {
137        let mut result: HashMap<String, GraphUser> = HashMap::new();
138
139        for chunk in principal_ids.chunks(BATCH_MAX) {
140            let requests: Vec<GraphBatchRequestItem> = chunk
141                .iter()
142                .enumerate()
143                .map(|(i, id)| GraphBatchRequestItem {
144                    id: i.to_string(),
145                    method: "GET".to_string(),
146                    url: format!("/users/{id}?{USER_SELECT}"),
147                })
148                .collect();
149
150            // Map batch request index → principalId for result correlation
151            let index_to_id: Vec<&str> = chunk.to_vec();
152
153            let body = GraphBatchRequest { requests };
154            let body_bytes =
155                serde_json::to_vec(&body).map_err(|e| crate::AzureError::InvalidResponse {
156                    message: format!("Failed to serialize GraphBatchRequest: {e}"),
157                    body: None,
158                })?;
159
160            let url = format!("{GRAPH_BASE}/$batch");
161            let response = self.client.graph_post(&url, &body_bytes).await?;
162            let response = response.error_for_status().await?;
163            let bytes = response.bytes().await?;
164
165            let batch_resp: GraphBatchResponse =
166                serde_json::from_slice(&bytes).map_err(|e| crate::AzureError::InvalidResponse {
167                    message: format!("Failed to parse GraphBatchResponse: {e}"),
168                    body: Some(String::from_utf8_lossy(&bytes).to_string()),
169                })?;
170
171            for item in batch_resp.responses {
172                let status = item.status.unwrap_or(0);
173                if status == 404 {
174                    continue;
175                }
176                if !(200..300).contains(&status) {
177                    continue;
178                }
179                let idx: usize = item
180                    .id
181                    .as_deref()
182                    .and_then(|s| s.parse().ok())
183                    .unwrap_or(usize::MAX);
184                if idx >= index_to_id.len() {
185                    continue;
186                }
187                let principal_id = index_to_id[idx];
188
189                if let Some(body_val) = item.body
190                    && let Ok(user) = serde_json::from_value::<GraphUser>(body_val)
191                {
192                    result.insert(principal_id.to_string(), user);
193                }
194            }
195        }
196
197        Ok(result)
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::MockClient;
205
206    fn make_client(mock: MockClient) -> AzureHttpClient {
207        AzureHttpClient::from_mock(mock)
208    }
209
210    fn user_json(id: &str, display_name: &str, user_type: &str) -> serde_json::Value {
211        serde_json::json!({
212            "id": id,
213            "displayName": display_name,
214            "userPrincipalName": format!("{}@contoso.com", display_name.to_lowercase()),
215            "userType": user_type
216        })
217    }
218
219    // NOTE: The Graph API requires `$select=id,displayName,userPrincipalName,userType` to return
220    // userType — it is NOT included in default responses. This is proven by integration tests.
221    // Mock expectations use URL path matching, so the ?$select= suffix is matched by StartsWith.
222
223    #[tokio::test]
224    async fn get_user_returns_member_user() {
225        let mut mock = MockClient::new();
226        // Mock matches on path prefix — includes $select suffix in the real URL
227        mock.expect_get("/v1.0/users/abc-123")
228            .returning_json(user_json("abc-123", "Alice", "Member"));
229        let client = make_client(mock);
230        let user = client
231            .graph()
232            .get_user("abc-123")
233            .await
234            .expect("get_user failed")
235            .expect("user should be found");
236        assert_eq!(user.id.as_deref(), Some("abc-123"));
237        assert_eq!(user.display_name.as_deref(), Some("Alice"));
238        assert_eq!(user.user_type.as_deref(), Some("Member"));
239    }
240
241    #[tokio::test]
242    async fn get_user_returns_guest_user() {
243        let mut mock = MockClient::new();
244        mock.expect_get("/v1.0/users/ext-456")
245            .returning_json(serde_json::json!({
246                "id": "ext-456",
247                "displayName": "Bob External",
248                "userPrincipalName": "bob_external#EXT#@contoso.onmicrosoft.com",
249                "userType": "Guest"
250            }));
251        let client = make_client(mock);
252        let user = client
253            .graph()
254            .get_user("ext-456")
255            .await
256            .expect("get_user failed")
257            .expect("user should be found");
258        assert_eq!(user.user_type.as_deref(), Some("Guest"));
259        assert!(
260            user.user_principal_name
261                .as_deref()
262                .unwrap_or("")
263                .contains("#EXT#"),
264            "guest UPN should contain #EXT#"
265        );
266    }
267
268    #[tokio::test]
269    async fn get_user_returns_none_for_404() {
270        // Integration-proven: Graph 404 is returned as AzureError::NotFound (not a 200 with status).
271        // get_user must match on NotFound and return Ok(None).
272        let mut mock = MockClient::new();
273        mock.expect_get("/v1.0/users/not-found")
274            .returning_error(crate::AzureError::NotFound {
275                resource: "User not-found".into(),
276            });
277        let client = make_client(mock);
278        let result = client.graph().get_user("not-found").await;
279        assert!(result.is_ok(), "NotFound should become Ok(None), not Err");
280        assert!(result.unwrap().is_none(), "should return None for 404");
281    }
282
283    #[tokio::test]
284    async fn batch_get_users_returns_map() {
285        let mut mock = MockClient::new();
286        mock.expect_post("/v1.0/$batch")
287            .returning_json(serde_json::json!({
288                "responses": [
289                    {
290                        "id": "0",
291                        "status": 200,
292                        "body": user_json("user-a", "Alice", "Member")
293                    },
294                    {
295                        "id": "1",
296                        "status": 200,
297                        "body": user_json("user-b", "Bob", "Guest")
298                    }
299                ]
300            }));
301        let client = make_client(mock);
302        let map = client
303            .graph()
304            .batch_get_users(&["user-a", "user-b"])
305            .await
306            .expect("batch_get_users failed");
307        assert_eq!(map.len(), 2);
308        assert_eq!(map["user-a"].user_type.as_deref(), Some("Member"));
309        assert_eq!(map["user-b"].user_type.as_deref(), Some("Guest"));
310    }
311
312    #[tokio::test]
313    async fn batch_get_users_omits_not_found() {
314        let mut mock = MockClient::new();
315        mock.expect_post("/v1.0/$batch")
316            .returning_json(serde_json::json!({
317                "responses": [
318                    { "id": "0", "status": 200, "body": user_json("user-a", "Alice", "Member") },
319                    { "id": "1", "status": 404, "body": null }
320                ]
321            }));
322        let client = make_client(mock);
323        let map = client
324            .graph()
325            .batch_get_users(&["user-a", "user-b"])
326            .await
327            .expect("batch_get_users failed");
328        assert_eq!(map.len(), 1);
329        assert!(map.contains_key("user-a"));
330        assert!(!map.contains_key("user-b"));
331    }
332
333    #[tokio::test]
334    async fn list_service_principals_returns_results() {
335        let mut mock = MockClient::new();
336        mock.expect_get("/v1.0/servicePrincipals")
337            .returning_json(serde_json::json!({
338                "value": [
339                    {
340                        "id": "sp-001",
341                        "appId": "app-001",
342                        "displayName": "Azure Databricks SCIM Provisioning Connector",
343                        "servicePrincipalType": "Application",
344                        "accountEnabled": true
345                    }
346                ]
347            }));
348        let client = make_client(mock);
349        let results = client
350            .graph()
351            .list_service_principals("displayName eq 'Azure Databricks SCIM Provisioning Connector'")
352            .await
353            .expect("list_service_principals failed");
354        assert_eq!(results.len(), 1);
355        assert_eq!(results[0].id.as_deref(), Some("sp-001"));
356        assert_eq!(
357            results[0].display_name.as_deref(),
358            Some("Azure Databricks SCIM Provisioning Connector")
359        );
360        assert_eq!(results[0].account_enabled, Some(true));
361        assert_eq!(
362            results[0].service_principal_type.as_deref(),
363            Some("Application")
364        );
365    }
366
367    #[tokio::test]
368    async fn list_service_principals_returns_empty() {
369        let mut mock = MockClient::new();
370        mock.expect_get("/v1.0/servicePrincipals")
371            .returning_json(serde_json::json!({ "value": [] }));
372        let client = make_client(mock);
373        let results = client
374            .graph()
375            .list_service_principals("displayName eq 'nonexistent'")
376            .await
377            .expect("list_service_principals failed");
378        assert!(results.is_empty());
379    }
380
381    #[tokio::test]
382    async fn batch_get_users_handles_chunking() {
383        // 21 users → 2 batch calls (20 + 1)
384        let ids: Vec<String> = (0..21).map(|i| format!("user-{i}")).collect();
385        let id_refs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
386
387        let first_batch_response = serde_json::json!({
388            "responses": (0..20_usize).map(|i| serde_json::json!({
389                "id": i.to_string(),
390                "status": 200,
391                "body": user_json(&format!("user-{i}"), &format!("User {i}"), "Member")
392            })).collect::<Vec<_>>()
393        });
394        let second_batch_response = serde_json::json!({
395            "responses": [
396                { "id": "0", "status": 200, "body": user_json("user-20", "User 20", "Member") }
397            ]
398        });
399
400        let mut mock = MockClient::new();
401        mock.expect_post("/v1.0/$batch")
402            .returning_json_sequence(vec![first_batch_response, second_batch_response])
403            .times(2);
404
405        let client = make_client(mock);
406        let map = client
407            .graph()
408            .batch_get_users(&id_refs)
409            .await
410            .expect("batch_get_users failed");
411        assert_eq!(map.len(), 21, "all 21 users should be in result");
412    }
413}