use anyhow::{Context, Result};
use serde::Deserialize;
use sqlx::PgPool;
use std::collections::HashSet;
use std::time::Duration;
use uuid::Uuid;
use crate::db::{models::TeamRole, teams, users};
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
#[allow(dead_code)]
expires_in: u64,
}
#[derive(Debug, Deserialize)]
struct GraphListResponse<T> {
value: Vec<T>,
#[serde(rename = "@odata.nextLink")]
next_link: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ServicePrincipal {
id: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct AppRoleAssignment {
principal_id: String,
principal_type: String,
principal_display_name: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GraphUser {
#[allow(dead_code)]
id: String,
mail: Option<String>,
user_principal_name: Option<String>,
}
struct GraphClient {
http: reqwest::Client,
tenant_id: String,
client_id: String,
client_secret: String,
access_token: Option<String>,
}
impl GraphClient {
fn new(tenant_id: &str, client_id: &str, client_secret: &str) -> Self {
Self {
http: reqwest::Client::new(),
tenant_id: tenant_id.to_string(),
client_id: client_id.to_string(),
client_secret: client_secret.to_string(),
access_token: None,
}
}
async fn ensure_token(&mut self) -> Result<()> {
let token_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
self.tenant_id
);
let resp = self
.http
.post(&token_url)
.form(&[
("grant_type", "client_credentials"),
("client_id", &self.client_id),
("client_secret", &self.client_secret),
("scope", "https://graph.microsoft.com/.default"),
])
.send()
.await
.context("Failed to request Graph API token")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
anyhow::bail!(
"Token request failed (HTTP {}): {}",
status,
truncate_string(&body, 500)
);
}
let token: TokenResponse = resp
.json()
.await
.context("Failed to parse token response")?;
self.access_token = Some(token.access_token);
Ok(())
}
fn token(&self) -> Result<&str> {
self.access_token
.as_deref()
.ok_or_else(|| anyhow::anyhow!("No access token available"))
}
async fn get_paginated<T: serde::de::DeserializeOwned>(
&self,
initial_url: &str,
) -> Result<Vec<T>> {
let mut results = Vec::new();
let mut url = initial_url.to_string();
loop {
let resp = self
.http
.get(&url)
.bearer_auth(self.token()?)
.header("ConsistencyLevel", "eventual")
.send()
.await
.with_context(|| format!("Graph API request failed for {}", url))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
anyhow::bail!(
"Graph API error (HTTP {}) for {}: {}",
status,
url,
truncate_string(&body, 500)
);
}
let page: GraphListResponse<T> = resp
.json()
.await
.with_context(|| format!("Failed to parse Graph API response from {}", url))?;
results.extend(page.value);
match page.next_link {
Some(next) => url = next,
None => break,
}
}
Ok(results)
}
async fn get_service_principal_id(&self) -> Result<String> {
let encoded_client_id = urlencoding::encode(&self.client_id);
let url = format!(
"https://graph.microsoft.com/v1.0/servicePrincipals?$filter=appId eq '{}'&$select=id",
encoded_client_id
);
let sps: Vec<ServicePrincipal> = self.get_paginated(&url).await?;
sps.into_iter()
.next()
.map(|sp| sp.id)
.ok_or_else(|| {
anyhow::anyhow!(
"No service principal found for appId '{}'. Ensure the app is registered as an Enterprise Application in Entra.",
self.client_id
)
})
}
async fn get_app_role_assignments(&self, sp_id: &str) -> Result<Vec<AppRoleAssignment>> {
let encoded_sp_id = urlencoding::encode(sp_id);
let url = format!(
"https://graph.microsoft.com/v1.0/servicePrincipals/{}/appRoleAssignedTo?$select=principalId,principalType,principalDisplayName",
encoded_sp_id
);
self.get_paginated(&url).await
}
async fn get_group_user_members(&self, group_id: &str) -> Result<Vec<GraphUser>> {
let encoded_group_id = urlencoding::encode(group_id);
let url = format!(
"https://graph.microsoft.com/v1.0/groups/{}/transitiveMembers/microsoft.graph.user?$select=id,mail,userPrincipalName",
encoded_group_id
);
self.get_paginated(&url).await
}
}
struct GroupToSync {
display_name: String,
team_name: String,
member_emails: Vec<String>,
}
pub fn extract_tenant_id(issuer: &str) -> Result<String> {
let normalized = issuer
.trim_end_matches('/')
.trim_end_matches("/v2.0")
.trim_end_matches('/');
let tenant = normalized
.rsplit('/')
.next()
.filter(|s| !s.is_empty())
.ok_or_else(|| {
anyhow::anyhow!(
"Cannot extract tenant ID from issuer URL '{}'. \
Expected format: https://login.microsoftonline.com/{{tenant}}/v2.0",
issuer
)
})?;
Ok(tenant.to_string())
}
pub fn sanitize_team_name(display_name: &str) -> String {
let mut result = String::with_capacity(display_name.len());
for ch in display_name.chars() {
match ch {
'A'..='Z' => result.push(ch.to_ascii_lowercase()),
'a'..='z' | '0'..='9' => result.push(ch),
' ' | '_' | '.' => result.push('-'),
'-' => result.push('-'),
_ => {} }
}
let mut collapsed = String::with_capacity(result.len());
let mut prev_hyphen = false;
for ch in result.chars() {
if ch == '-' {
if !prev_hyphen {
collapsed.push('-');
}
prev_hyphen = true;
} else {
collapsed.push(ch);
prev_hyphen = false;
}
}
collapsed.trim_matches('-').to_string()
}
async fn sync_once(pool: &PgPool, client: &mut GraphClient) -> Result<()> {
client.ensure_token().await?;
let sp_id = client.get_service_principal_id().await?;
tracing::debug!("Found service principal: {}", sp_id);
let assignments = client.get_app_role_assignments(&sp_id).await?;
tracing::info!(
"Fetched {} app role assignments from Entra",
assignments.len()
);
let mut groups_to_sync: Vec<GroupToSync> = Vec::new();
let mut seen_team_names: HashSet<String> = HashSet::new();
for assignment in &assignments {
if assignment.principal_type != "Group" {
continue;
}
let display_name = assignment
.principal_display_name
.clone()
.unwrap_or_else(|| assignment.principal_id.clone());
let team_name = sanitize_team_name(&display_name);
if team_name.is_empty() {
tracing::warn!(
"Skipping group '{}' — sanitized name is empty",
display_name
);
continue;
}
if !seen_team_names.insert(team_name.clone()) {
tracing::warn!(
"Skipping duplicate group '{}' (sanitized to '{}' which was already seen)",
display_name,
team_name
);
continue;
}
let members = match client
.get_group_user_members(&assignment.principal_id)
.await
{
Ok(m) => m,
Err(e) => {
tracing::error!(
"Failed to fetch members for group '{}' ({}): {:?}",
display_name,
assignment.principal_id,
e
);
continue;
}
};
let member_emails: Vec<String> = members
.into_iter()
.filter_map(|u| u.mail.or(u.user_principal_name).filter(|e| !e.is_empty()))
.collect();
tracing::debug!(
"Group '{}' → team '{}' with {} members",
display_name,
team_name,
member_emails.len()
);
groups_to_sync.push(GroupToSync {
display_name,
team_name,
member_emails,
});
}
let synced_team_names: HashSet<String> =
groups_to_sync.iter().map(|g| g.team_name.clone()).collect();
let mut tx = pool
.begin()
.await
.context("Failed to begin transaction for Entra sync")?;
for group in &groups_to_sync {
if let Err(e) = sync_group(&mut tx, group).await {
tracing::error!(
"Failed to sync group '{}' (team '{}'): {:?}",
group.display_name,
group.team_name,
e
);
return Err(e);
}
}
let all_idp_teams = teams::list_idp_managed(&mut *tx)
.await
.context("Failed to list IdP-managed teams")?;
for team in all_idp_teams {
if !synced_team_names.contains(&team.name) {
let removed = teams::remove_all_team_members(&mut *tx, team.id)
.await
.context("Failed to remove members during cleanup")?;
if removed > 0 {
tracing::info!(
"Removed {} members from IdP-managed team '{}' (no longer assigned in Entra)",
removed,
team.name
);
}
}
}
tx.commit()
.await
.context("Failed to commit Entra sync transaction")?;
tracing::info!(
"Entra active sync completed: {} groups synced",
groups_to_sync.len()
);
Ok(())
}
async fn sync_group(
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
group: &GroupToSync,
) -> Result<()> {
let existing = teams::find_by_name(&mut **tx, &group.team_name)
.await
.context("Failed to look up team")?;
let team_id = if let Some(team) = existing {
if team.name != group.team_name {
teams::update_name(&mut **tx, team.id, &group.team_name)
.await
.context("Failed to update team name")?;
}
if !team.idp_managed {
tracing::info!(
"Converting team '{}' to IdP-managed (Entra sync)",
group.team_name
);
teams::set_idp_managed(&mut **tx, team.id, true)
.await
.context("Failed to set IdP-managed flag")?;
teams::remove_all_owners(&mut **tx, team.id)
.await
.context("Failed to remove owners on IdP takeover")?;
}
team.id
} else {
tracing::info!(
"Creating new IdP-managed team '{}' (from Entra group '{}')",
group.team_name,
group.display_name
);
let new_team = teams::create(&mut **tx, &group.team_name)
.await
.context("Failed to create team")?;
teams::set_idp_managed(&mut **tx, new_team.id, true)
.await
.context("Failed to set IdP-managed flag on new team")?;
new_team.id
};
let mut expected_user_ids: HashSet<Uuid> = HashSet::new();
for email in &group.member_emails {
let user = users::find_or_create_with_executor(tx, email)
.await
.with_context(|| format!("Failed to find/create user '{}'", email))?;
expected_user_ids.insert(user.id);
}
let current_member_ids: HashSet<Uuid> = teams::get_all_member_user_ids(&mut **tx, team_id)
.await
.context("Failed to get current team members")?
.into_iter()
.collect();
for uid in &expected_user_ids {
if !current_member_ids.contains(uid) {
teams::add_member(&mut **tx, team_id, *uid, TeamRole::Member)
.await
.with_context(|| format!("Failed to add member {} to team", uid))?;
}
}
for uid in ¤t_member_ids {
if !expected_user_ids.contains(uid) {
teams::remove_all_user_roles(&mut **tx, team_id, *uid)
.await
.with_context(|| format!("Failed to remove member {} from team", uid))?;
}
}
Ok(())
}
pub async fn run_entra_sync_loop(
pool: PgPool,
auth_settings: crate::server::settings::AuthSettings,
) {
let tenant_id = match extract_tenant_id(&auth_settings.issuer) {
Ok(t) => t,
Err(e) => {
tracing::error!("Entra active sync cannot start: {:?}", e);
return;
}
};
let interval_secs = auth_settings.active_sync_interval_secs;
let mut client = GraphClient::new(
&tenant_id,
&auth_settings.client_id,
&auth_settings.client_secret,
);
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
tracing::info!(
"Entra active sync started (tenant={}, interval={}s)",
tenant_id,
interval_secs
);
let mut shutdown = std::pin::pin!(async {
let ctrl_c = tokio::signal::ctrl_c();
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {}
_ = terminate => {}
}
});
loop {
tokio::select! {
_ = interval.tick() => {}
_ = &mut shutdown => {
tracing::info!("Entra active sync shutting down");
break;
}
}
tracing::debug!("Running Entra active sync cycle");
if let Err(e) = sync_once(&pool, &mut client).await {
tracing::error!("Entra active sync failed: {:?}", e);
}
tracing::info!("Next Entra active sync in {}s", interval_secs);
}
}
fn truncate_string(s: &str, max_len: usize) -> &str {
if s.len() <= max_len {
return s;
}
let mut end = max_len;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_tenant_id_v2() {
let result =
extract_tenant_id("https://login.microsoftonline.com/my-tenant-id/v2.0").unwrap();
assert_eq!(result, "my-tenant-id");
}
#[test]
fn test_extract_tenant_id_no_v2() {
let result = extract_tenant_id("https://login.microsoftonline.com/my-tenant-id").unwrap();
assert_eq!(result, "my-tenant-id");
}
#[test]
fn test_extract_tenant_id_sts() {
let result = extract_tenant_id("https://sts.windows.net/my-tenant-id/").unwrap();
assert_eq!(result, "my-tenant-id");
}
#[test]
fn test_extract_tenant_id_with_trailing_slashes() {
let result = extract_tenant_id("https://login.microsoftonline.com/abc-123/v2.0/").unwrap();
assert_eq!(result, "abc-123");
}
#[test]
fn test_extract_tenant_id_guid() {
let result = extract_tenant_id(
"https://login.microsoftonline.com/550e8400-e29b-41d4-a716-446655440000/v2.0",
)
.unwrap();
assert_eq!(result, "550e8400-e29b-41d4-a716-446655440000");
}
#[test]
fn test_extract_tenant_id_custom_domain() {
let result =
extract_tenant_id("https://login.microsoftonline.com/contoso.onmicrosoft.com/v2.0")
.unwrap();
assert_eq!(result, "contoso.onmicrosoft.com");
}
#[test]
fn test_sanitize_simple() {
assert_eq!(sanitize_team_name("engineering"), "engineering");
}
#[test]
fn test_sanitize_uppercase() {
assert_eq!(sanitize_team_name("Engineering"), "engineering");
}
#[test]
fn test_sanitize_spaces() {
assert_eq!(sanitize_team_name("Engineering Team"), "engineering-team");
}
#[test]
fn test_sanitize_underscores() {
assert_eq!(sanitize_team_name("dev_ops"), "dev-ops");
}
#[test]
fn test_sanitize_dots() {
assert_eq!(sanitize_team_name("team.name"), "team-name");
}
#[test]
fn test_sanitize_special_chars() {
assert_eq!(sanitize_team_name("Team (Special) #1!"), "team-special-1");
}
#[test]
fn test_sanitize_multiple_hyphens() {
assert_eq!(sanitize_team_name("a---b"), "a-b");
}
#[test]
fn test_sanitize_leading_trailing_hyphens() {
assert_eq!(sanitize_team_name("-team-"), "team");
}
#[test]
fn test_sanitize_mixed() {
assert_eq!(
sanitize_team_name("Rise DevOps - Production"),
"rise-devops-production"
);
}
#[test]
fn test_sanitize_empty_result() {
assert_eq!(sanitize_team_name("!!!"), "");
}
#[test]
fn test_sanitize_numbers() {
assert_eq!(sanitize_team_name("Team 42"), "team-42");
}
}