Skip to main content

dispatcher_usage/
dispatcher_usage.rs

1use async_trait::async_trait;
2use jsonwebtoken::Header;
3use modkit_auth::claims::Permission;
4use modkit_auth::plugin_traits::{ClaimsPlugin, KeyProvider};
5use modkit_auth::validation::{
6    extract_audiences, extract_string, parse_timestamp, parse_uuid_from_value,
7};
8use modkit_auth::{
9    AuthConfig, AuthDispatcher, AuthModeConfig, Claims, ClaimsError, PluginConfig, PluginRegistry,
10    ValidationConfig,
11};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::Arc;
15use time::{Duration, OffsetDateTime};
16use uuid::Uuid;
17
18/// Minimal claims plugin that converts raw JSON into strongly typed `Claims`.
19struct DemoClaimsPlugin;
20
21impl ClaimsPlugin for DemoClaimsPlugin {
22    fn name(&self) -> &'static str {
23        "demo"
24    }
25
26    fn normalize(&self, raw: &Value) -> Result<Claims, ClaimsError> {
27        let issuer_value = raw
28            .get("iss")
29            .ok_or_else(|| ClaimsError::MissingClaim("iss".to_owned()))?;
30        let issuer = extract_string(issuer_value, "iss")?;
31
32        let sub_value = raw
33            .get("sub")
34            .ok_or_else(|| ClaimsError::MissingClaim("sub".to_owned()))?;
35        let subject = parse_uuid_from_value(sub_value, "sub")?;
36
37        let audiences = raw.get("aud").map(extract_audiences).unwrap_or_default();
38
39        let expires_at = raw
40            .get("exp")
41            .map(|value| parse_timestamp(value, "exp"))
42            .transpose()?;
43
44        let not_before = raw
45            .get("nbf")
46            .map(|value| parse_timestamp(value, "nbf"))
47            .transpose()?;
48
49        let issued_at = raw
50            .get("iat")
51            .map(|value| parse_timestamp(value, "iat"))
52            .transpose()?;
53
54        let jwt_id = raw
55            .get("jti")
56            .and_then(|v| v.as_str())
57            .map(ToString::to_string);
58
59        let tenant_id_value = raw
60            .get("tenant_id")
61            .ok_or_else(|| ClaimsError::MissingClaim("tenant_id".to_owned()))?;
62        let tenant_id = parse_uuid_from_value(tenant_id_value, "tenant_id")?;
63
64        let permissions: Vec<Permission> = raw
65            .get("roles")
66            .and_then(Value::as_array)
67            .map(|arr| {
68                arr.iter()
69                    .filter_map(|value| value.as_str())
70                    .filter_map(|role| {
71                        if let Some(pos) = role.rfind(':') {
72                            Permission::builder()
73                                .resource_pattern(&role[..pos])
74                                .action(&role[pos + 1..])
75                                .build()
76                                .ok()
77                        } else {
78                            Permission::builder()
79                                .resource_pattern(role)
80                                .action("*")
81                                .build()
82                                .ok()
83                        }
84                    })
85                    .collect()
86            })
87            .unwrap_or_default();
88
89        Ok(Claims {
90            issuer,
91            subject,
92            audiences,
93            expires_at,
94            not_before,
95            issued_at,
96            jwt_id,
97            tenant_id,
98            permissions,
99            extras: serde_json::Map::new(),
100        })
101    }
102}
103
104/// Static key provider that skips signature validation for demonstration purposes.
105struct StaticKeyProvider {
106    claims: Value,
107}
108
109impl StaticKeyProvider {
110    fn new(claims: Value) -> Self {
111        Self { claims }
112    }
113}
114
115#[async_trait]
116impl KeyProvider for StaticKeyProvider {
117    fn name(&self) -> &'static str {
118        "static"
119    }
120
121    async fn validate_and_decode(&self, _token: &str) -> Result<(Header, Value), ClaimsError> {
122        Ok((Header::default(), self.claims.clone()))
123    }
124}
125
126#[tokio::main]
127async fn main() -> Result<(), Box<dyn std::error::Error>> {
128    let mut plugins = PluginRegistry::default();
129    plugins.register("demo", Arc::new(DemoClaimsPlugin));
130
131    let mut plugin_configs = HashMap::new();
132    plugin_configs.insert(
133        "demo".to_owned(),
134        PluginConfig::Oidc {
135            tenant_claim: "tenants".to_owned(),
136            roles_claim: "roles".to_owned(),
137        },
138    );
139
140    let config = AuthConfig {
141        mode: AuthModeConfig {
142            provider: "demo".to_owned(),
143        },
144        issuers: vec!["https://issuer.local".to_owned()],
145        audiences: vec!["demo-api".to_owned()],
146        plugins: plugin_configs,
147        ..AuthConfig::default()
148    };
149
150    let validation = ValidationConfig {
151        allowed_issuers: config.issuers.clone(),
152        allowed_audiences: config.audiences.clone(),
153        leeway_seconds: config.leeway_seconds,
154        require_uuid_subject: true,
155        require_uuid_tenants: true,
156    };
157
158    let subject = Uuid::new_v4();
159    let tenant = Uuid::new_v4();
160    let expires_at = OffsetDateTime::now_utc() + Duration::minutes(15);
161
162    let raw_claims = serde_json::json!({
163        "iss": "https://issuer.local",
164        "sub": subject.to_string(),
165        "aud": ["demo-api"],
166        "exp": expires_at.unix_timestamp(),
167        "tenant_id": tenant.to_string(),
168        "roles": ["viewer:read"]
169    });
170
171    let dispatcher = AuthDispatcher::new(validation, &config, &plugins)?
172        .with_key_provider(Arc::new(StaticKeyProvider::new(raw_claims)));
173
174    let claims = dispatcher.validate_jwt("demo-token").await?;
175    let perm_list = if claims.permissions.is_empty() {
176        "none".to_owned()
177    } else {
178        claims
179            .permissions
180            .iter()
181            .map(|p| format!("{}:{}", p.resource_pattern(), p.action()))
182            .collect::<Vec<_>>()
183            .join(", ")
184    };
185    println!(
186        "Validated token for subject {} with permissions {}",
187        claims.subject, perm_list
188    );
189
190    Ok(())
191}