use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::{
AzureHttpClient, Result,
types::graph::{GraphBatchRequest, GraphBatchRequestItem, GraphBatchResponse, GraphUser},
};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GraphServicePrincipal {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub app_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_principal_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub account_enabled: Option<bool>,
}
#[derive(Debug, Clone, Default, Deserialize)]
struct GraphListResponse<T> {
#[serde(default)]
value: Vec<T>,
}
const GRAPH_BASE: &str = "https://graph.microsoft.com/v1.0";
const USER_SELECT: &str = "$select=id,displayName,userPrincipalName,userType,accountEnabled";
const BATCH_MAX: usize = 20;
pub struct GraphClient<'a> {
client: &'a AzureHttpClient,
}
impl<'a> GraphClient<'a> {
pub(crate) fn new(client: &'a AzureHttpClient) -> Self {
Self { client }
}
pub async fn get_user(&self, principal_id: &str) -> Result<Option<GraphUser>> {
let url = format!("{GRAPH_BASE}/users/{principal_id}?{USER_SELECT}");
let response = match self.client.graph_get(&url).await {
Ok(r) => r,
Err(crate::AzureError::NotFound { .. }) => return Ok(None),
Err(e) => return Err(e),
};
let response = response.error_for_status().await?;
let bytes = response.bytes().await?;
let user =
serde_json::from_slice(&bytes).map_err(|e| crate::AzureError::InvalidResponse {
message: format!("Failed to parse GraphUser: {e}"),
body: Some(String::from_utf8_lossy(&bytes).to_string()),
})?;
Ok(Some(user))
}
pub async fn list_service_principals(
&self,
filter: &str,
) -> Result<Vec<GraphServicePrincipal>> {
let encoded_filter = urlencoding::encode(filter);
let url = format!(
"{GRAPH_BASE}/servicePrincipals?$filter={encoded_filter}&$select=id,appId,displayName,servicePrincipalType,accountEnabled"
);
let response = self.client.graph_get(&url).await?;
let response = response.error_for_status().await?;
let bytes = response.bytes().await?;
let list: GraphListResponse<GraphServicePrincipal> =
serde_json::from_slice(&bytes).map_err(|e| crate::AzureError::InvalidResponse {
message: format!("Failed to parse servicePrincipals response: {e}"),
body: Some(String::from_utf8_lossy(&bytes).to_string()),
})?;
Ok(list.value)
}
pub async fn batch_get_users(
&self,
principal_ids: &[&str],
) -> Result<HashMap<String, GraphUser>> {
let mut result: HashMap<String, GraphUser> = HashMap::new();
for chunk in principal_ids.chunks(BATCH_MAX) {
let requests: Vec<GraphBatchRequestItem> = chunk
.iter()
.enumerate()
.map(|(i, id)| GraphBatchRequestItem {
id: i.to_string(),
method: "GET".to_string(),
url: format!("/users/{id}?{USER_SELECT}"),
})
.collect();
let index_to_id: Vec<&str> = chunk.to_vec();
let body = GraphBatchRequest { requests };
let body_bytes =
serde_json::to_vec(&body).map_err(|e| crate::AzureError::InvalidResponse {
message: format!("Failed to serialize GraphBatchRequest: {e}"),
body: None,
})?;
let url = format!("{GRAPH_BASE}/$batch");
let response = self.client.graph_post(&url, &body_bytes).await?;
let response = response.error_for_status().await?;
let bytes = response.bytes().await?;
let batch_resp: GraphBatchResponse =
serde_json::from_slice(&bytes).map_err(|e| crate::AzureError::InvalidResponse {
message: format!("Failed to parse GraphBatchResponse: {e}"),
body: Some(String::from_utf8_lossy(&bytes).to_string()),
})?;
for item in batch_resp.responses {
let status = item.status.unwrap_or(0);
if status == 404 {
continue;
}
if !(200..300).contains(&status) {
continue;
}
let idx: usize = item
.id
.as_deref()
.and_then(|s| s.parse().ok())
.unwrap_or(usize::MAX);
if idx >= index_to_id.len() {
continue;
}
let principal_id = index_to_id[idx];
if let Some(body_val) = item.body
&& let Ok(user) = serde_json::from_value::<GraphUser>(body_val)
{
result.insert(principal_id.to_string(), user);
}
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::MockClient;
fn make_client(mock: MockClient) -> AzureHttpClient {
AzureHttpClient::from_mock(mock)
}
fn user_json(id: &str, display_name: &str, user_type: &str) -> serde_json::Value {
serde_json::json!({
"id": id,
"displayName": display_name,
"userPrincipalName": format!("{}@contoso.com", display_name.to_lowercase()),
"userType": user_type
})
}
#[tokio::test]
async fn get_user_returns_member_user() {
let mut mock = MockClient::new();
mock.expect_get("/v1.0/users/abc-123")
.returning_json(user_json("abc-123", "Alice", "Member"));
let client = make_client(mock);
let user = client
.graph()
.get_user("abc-123")
.await
.expect("get_user failed")
.expect("user should be found");
assert_eq!(user.id.as_deref(), Some("abc-123"));
assert_eq!(user.display_name.as_deref(), Some("Alice"));
assert_eq!(user.user_type.as_deref(), Some("Member"));
}
#[tokio::test]
async fn get_user_returns_guest_user() {
let mut mock = MockClient::new();
mock.expect_get("/v1.0/users/ext-456")
.returning_json(serde_json::json!({
"id": "ext-456",
"displayName": "Bob External",
"userPrincipalName": "bob_external#EXT#@contoso.onmicrosoft.com",
"userType": "Guest"
}));
let client = make_client(mock);
let user = client
.graph()
.get_user("ext-456")
.await
.expect("get_user failed")
.expect("user should be found");
assert_eq!(user.user_type.as_deref(), Some("Guest"));
assert!(
user.user_principal_name
.as_deref()
.unwrap_or("")
.contains("#EXT#"),
"guest UPN should contain #EXT#"
);
}
#[tokio::test]
async fn get_user_returns_none_for_404() {
let mut mock = MockClient::new();
mock.expect_get("/v1.0/users/not-found")
.returning_error(crate::AzureError::NotFound {
resource: "User not-found".into(),
});
let client = make_client(mock);
let result = client.graph().get_user("not-found").await;
assert!(result.is_ok(), "NotFound should become Ok(None), not Err");
assert!(result.unwrap().is_none(), "should return None for 404");
}
#[tokio::test]
async fn batch_get_users_returns_map() {
let mut mock = MockClient::new();
mock.expect_post("/v1.0/$batch")
.returning_json(serde_json::json!({
"responses": [
{
"id": "0",
"status": 200,
"body": user_json("user-a", "Alice", "Member")
},
{
"id": "1",
"status": 200,
"body": user_json("user-b", "Bob", "Guest")
}
]
}));
let client = make_client(mock);
let map = client
.graph()
.batch_get_users(&["user-a", "user-b"])
.await
.expect("batch_get_users failed");
assert_eq!(map.len(), 2);
assert_eq!(map["user-a"].user_type.as_deref(), Some("Member"));
assert_eq!(map["user-b"].user_type.as_deref(), Some("Guest"));
}
#[tokio::test]
async fn batch_get_users_omits_not_found() {
let mut mock = MockClient::new();
mock.expect_post("/v1.0/$batch")
.returning_json(serde_json::json!({
"responses": [
{ "id": "0", "status": 200, "body": user_json("user-a", "Alice", "Member") },
{ "id": "1", "status": 404, "body": null }
]
}));
let client = make_client(mock);
let map = client
.graph()
.batch_get_users(&["user-a", "user-b"])
.await
.expect("batch_get_users failed");
assert_eq!(map.len(), 1);
assert!(map.contains_key("user-a"));
assert!(!map.contains_key("user-b"));
}
#[tokio::test]
async fn list_service_principals_returns_results() {
let mut mock = MockClient::new();
mock.expect_get("/v1.0/servicePrincipals")
.returning_json(serde_json::json!({
"value": [
{
"id": "sp-001",
"appId": "app-001",
"displayName": "Azure Databricks SCIM Provisioning Connector",
"servicePrincipalType": "Application",
"accountEnabled": true
}
]
}));
let client = make_client(mock);
let results = client
.graph()
.list_service_principals("displayName eq 'Azure Databricks SCIM Provisioning Connector'")
.await
.expect("list_service_principals failed");
assert_eq!(results.len(), 1);
assert_eq!(results[0].id.as_deref(), Some("sp-001"));
assert_eq!(
results[0].display_name.as_deref(),
Some("Azure Databricks SCIM Provisioning Connector")
);
assert_eq!(results[0].account_enabled, Some(true));
assert_eq!(
results[0].service_principal_type.as_deref(),
Some("Application")
);
}
#[tokio::test]
async fn list_service_principals_returns_empty() {
let mut mock = MockClient::new();
mock.expect_get("/v1.0/servicePrincipals")
.returning_json(serde_json::json!({ "value": [] }));
let client = make_client(mock);
let results = client
.graph()
.list_service_principals("displayName eq 'nonexistent'")
.await
.expect("list_service_principals failed");
assert!(results.is_empty());
}
#[tokio::test]
async fn batch_get_users_handles_chunking() {
let ids: Vec<String> = (0..21).map(|i| format!("user-{i}")).collect();
let id_refs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
let first_batch_response = serde_json::json!({
"responses": (0..20_usize).map(|i| serde_json::json!({
"id": i.to_string(),
"status": 200,
"body": user_json(&format!("user-{i}"), &format!("User {i}"), "Member")
})).collect::<Vec<_>>()
});
let second_batch_response = serde_json::json!({
"responses": [
{ "id": "0", "status": 200, "body": user_json("user-20", "User 20", "Member") }
]
});
let mut mock = MockClient::new();
mock.expect_post("/v1.0/$batch")
.returning_json_sequence(vec![first_batch_response, second_batch_response])
.times(2);
let client = make_client(mock);
let map = client
.graph()
.batch_get_users(&id_refs)
.await
.expect("batch_get_users failed");
assert_eq!(map.len(), 21, "all 21 users should be in result");
}
}