use async_trait::async_trait;
use serde_json::json;
use crate::{
error::Result,
oidc_provider::OidcProvider,
provider::{OAuthProvider, TokenResponse, UserInfo},
};
#[derive(Debug)]
pub struct OryOAuth {
oidc: OidcProvider,
issuer_url: String,
}
impl OryOAuth {
pub async fn new(
client_id: String,
client_secret: String,
ory_issuer_url: String,
redirect_uri: String,
) -> Result<Self> {
let issuer_url = ory_issuer_url.clone();
let oidc = OidcProvider::new("ory", &issuer_url, &client_id, &client_secret, &redirect_uri)
.await?;
Ok(Self { oidc, issuer_url })
}
fn extract_groups(raw_claims: &serde_json::Value) -> Vec<String> {
raw_claims
.get("groups")
.and_then(|groups| {
if groups.is_array() {
Some(
groups
.as_array()
.unwrap_or(&vec![])
.iter()
.filter_map(|g| g.as_str().map(|s| s.to_string()))
.collect(),
)
} else {
groups.as_str().map(|s| vec![s.to_string()])
}
})
.unwrap_or_default()
}
pub fn map_ory_groups_to_fraiseql(ory_groups: Vec<String>) -> Vec<String> {
ory_groups
.into_iter()
.filter_map(|group| {
let group_lower = group.to_lowercase();
match group_lower.as_str() {
"admin" | "ory-admin" | "administrators" => Some("admin".to_string()),
"operator" | "ory-operator" | "operators" => Some("operator".to_string()),
"viewer" | "ory-viewer" | "viewers" | "user" => Some("viewer".to_string()),
_ => {
if group_lower.contains("fraiseql") {
if group_lower.contains("admin") {
Some("admin".to_string())
} else if group_lower.contains("operator") {
Some("operator".to_string())
} else if group_lower.contains("viewer") {
Some("viewer".to_string())
} else {
None
}
} else {
None
}
},
}
})
.collect()
}
fn extract_org_id(raw_claims: &serde_json::Value, email: &str) -> Option<String> {
if let Some(org_id) = raw_claims.get("org_id") {
if let Some(org_id_str) = org_id.as_str() {
return Some(org_id_str.to_string());
}
}
if !email.is_empty() {
if let Some(domain) = email.split('@').nth(1) {
return Some(domain.to_string());
}
}
None
}
}
#[async_trait]
impl OAuthProvider for OryOAuth {
fn name(&self) -> &'static str {
"ory"
}
fn authorization_url(&self, state: &str) -> String {
self.oidc.authorization_url(state)
}
async fn exchange_code(&self, code: &str) -> Result<TokenResponse> {
self.oidc.exchange_code(code).await
}
async fn user_info(&self, access_token: &str) -> Result<UserInfo> {
let mut user_info = self.oidc.user_info(access_token).await?;
let groups = Self::extract_groups(&user_info.raw_claims);
let mapped_roles = Self::map_ory_groups_to_fraiseql(groups.clone());
user_info.raw_claims["ory_groups"] = json!(groups);
user_info.raw_claims["ory_roles"] = json!(mapped_roles);
user_info.raw_claims["ory_issuer"] = json!(&self.issuer_url);
if let Some(org_id) = Self::extract_org_id(&user_info.raw_claims, &user_info.email) {
user_info.raw_claims["org_id"] = json!(org_id);
}
Ok(user_info)
}
async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse> {
self.oidc.refresh_token(refresh_token).await
}
async fn revoke_token(&self, token: &str) -> Result<()> {
self.oidc.revoke_token(token).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_groups_from_array() {
let claims = json!({
"groups": ["admin", "operators", "viewers"]
});
let groups = OryOAuth::extract_groups(&claims);
assert_eq!(groups.len(), 3);
assert!(groups.contains(&"admin".to_string()));
assert!(groups.contains(&"operators".to_string()));
}
#[test]
fn test_extract_groups_from_string() {
let claims = json!({
"groups": "admin"
});
let groups = OryOAuth::extract_groups(&claims);
assert_eq!(groups.len(), 1);
assert_eq!(groups[0], "admin");
}
#[test]
fn test_extract_groups_missing() {
let claims = json!({});
let groups = OryOAuth::extract_groups(&claims);
assert!(groups.is_empty());
}
#[test]
fn test_map_ory_groups_to_fraiseql() {
let groups = vec![
"admin".to_string(),
"ory-operator".to_string(),
"user".to_string(),
"unknown".to_string(),
];
let fraiseql_roles = OryOAuth::map_ory_groups_to_fraiseql(groups);
assert_eq!(fraiseql_roles.len(), 3);
assert!(fraiseql_roles.contains(&"admin".to_string()));
assert!(fraiseql_roles.contains(&"operator".to_string()));
assert!(fraiseql_roles.contains(&"viewer".to_string()));
}
#[test]
fn test_map_ory_groups_case_insensitive() {
let groups = vec![
"ADMIN".to_string(),
"Operator".to_string(),
"VIEWER".to_string(),
];
let fraiseql_roles = OryOAuth::map_ory_groups_to_fraiseql(groups);
assert_eq!(fraiseql_roles.len(), 3);
assert!(fraiseql_roles.contains(&"admin".to_string()));
assert!(fraiseql_roles.contains(&"operator".to_string()));
assert!(fraiseql_roles.contains(&"viewer".to_string()));
}
#[test]
fn test_map_ory_groups_keto_patterns() {
let groups = vec![
"fraiseql:admin".to_string(),
"fraiseql:operator".to_string(),
"fraiseql:viewer".to_string(),
"other:role".to_string(),
];
let fraiseql_roles = OryOAuth::map_ory_groups_to_fraiseql(groups);
assert_eq!(fraiseql_roles.len(), 3);
assert!(fraiseql_roles.contains(&"admin".to_string()));
assert!(fraiseql_roles.contains(&"operator".to_string()));
assert!(fraiseql_roles.contains(&"viewer".to_string()));
}
#[test]
fn test_extract_org_id_from_claim() {
let claims = json!({
"org_id": "acme-corp"
});
let org_id = OryOAuth::extract_org_id(&claims, "user@example.com");
assert_eq!(org_id, Some("acme-corp".to_string()));
}
#[test]
fn test_extract_org_id_from_email_domain() {
let claims = json!({});
let org_id = OryOAuth::extract_org_id(&claims, "user@example.com");
assert_eq!(org_id, Some("example.com".to_string()));
}
#[test]
fn test_extract_org_id_missing() {
let claims = json!({});
let org_id = OryOAuth::extract_org_id(&claims, "");
assert!(org_id.is_none());
}
#[test]
fn test_extract_all_roles_and_org() {
let claims = json!({
"groups": ["admin", "operators"],
"org_id": "my-org"
});
let groups = OryOAuth::extract_groups(&claims);
let roles = OryOAuth::map_ory_groups_to_fraiseql(groups);
let org_id = OryOAuth::extract_org_id(&claims, "user@example.com");
assert_eq!(roles.len(), 2);
assert!(roles.contains(&"admin".to_string()));
assert!(roles.contains(&"operator".to_string()));
assert_eq!(org_id, Some("my-org".to_string()));
}
}