Skip to main content

openauth_plugins/jwt/
options.rs

1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use openauth_core::context::AuthContext;
7use openauth_core::db::{Session, User};
8use openauth_core::error::OpenAuthError;
9
10use super::{Jwk, JwkAlgorithm, JwtClaims};
11
12pub type JwtClaimsFuture<'a> =
13    Pin<Box<dyn Future<Output = Result<JwtClaims, OpenAuthError>> + Send + 'a>>;
14pub type JwtStringFuture<'a> =
15    Pin<Box<dyn Future<Output = Result<String, OpenAuthError>> + Send + 'a>>;
16pub type JwtJwksFuture<'a> =
17    Pin<Box<dyn Future<Output = Result<Vec<Jwk>, OpenAuthError>> + Send + 'a>>;
18pub type JwtJwkFuture<'a> = Pin<Box<dyn Future<Output = Result<Jwk, OpenAuthError>> + Send + 'a>>;
19
20pub type JwtDefinePayloadHandler =
21    Arc<dyn for<'a> Fn(&'a JwtSessionContext) -> JwtClaimsFuture<'a> + Send + Sync>;
22pub type JwtGetSubjectHandler =
23    Arc<dyn for<'a> Fn(&'a JwtSessionContext) -> JwtStringFuture<'a> + Send + Sync>;
24pub type JwtSignHandler = Arc<dyn Fn(JwtClaims) -> JwtStringFuture<'static> + Send + Sync>;
25pub type JwtGetJwksHandler =
26    Arc<dyn for<'a> Fn(&'a AuthContext) -> JwtJwksFuture<'a> + Send + Sync>;
27pub type JwtCreateJwkHandler =
28    Arc<dyn for<'a> Fn(&'a AuthContext, Jwk) -> JwtJwkFuture<'a> + Send + Sync>;
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct JwtSessionContext {
32    pub session: Session,
33    pub user: User,
34}
35
36#[derive(Clone, Default)]
37pub struct JwtOptions {
38    pub jwks: JwtJwksOptions,
39    pub jwt: JwtSigningOptions,
40    pub adapter: JwtAdapterOptions,
41    pub disable_setting_jwt_header: bool,
42}
43
44#[derive(Debug, Clone)]
45pub struct JwtJwksOptions {
46    pub remote_url: Option<String>,
47    pub key_pair_algorithm: Option<JwkAlgorithm>,
48    pub rsa_modulus_length: Option<u32>,
49    pub disable_private_key_encryption: bool,
50    pub rotation_interval: Option<i64>,
51    pub grace_period: i64,
52    pub jwks_path: String,
53}
54
55impl Default for JwtJwksOptions {
56    fn default() -> Self {
57        Self {
58            remote_url: None,
59            key_pair_algorithm: Some(JwkAlgorithm::EdDsa),
60            rsa_modulus_length: None,
61            disable_private_key_encryption: false,
62            rotation_interval: None,
63            grace_period: 60 * 60 * 24 * 30,
64            jwks_path: "/jwks".to_owned(),
65        }
66    }
67}
68
69#[derive(Clone, Default)]
70pub struct JwtSigningOptions {
71    pub issuer: Option<String>,
72    pub audience: Option<Vec<String>>,
73    pub expiration_time: Option<super::TimeInput>,
74    pub define_payload: Option<JwtDefinePayloadHandler>,
75    pub get_subject: Option<JwtGetSubjectHandler>,
76    pub sign: Option<JwtSignHandler>,
77}
78
79impl fmt::Debug for JwtSigningOptions {
80    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
81        formatter
82            .debug_struct("JwtSigningOptions")
83            .field("issuer", &self.issuer)
84            .field("audience", &self.audience)
85            .field("expiration_time", &self.expiration_time)
86            .field(
87                "define_payload",
88                &self.define_payload.as_ref().map(|_| "<define-payload>"),
89            )
90            .field(
91                "get_subject",
92                &self.get_subject.as_ref().map(|_| "<get-subject>"),
93            )
94            .field("sign", &self.sign.as_ref().map(|_| "<sign-handler>"))
95            .finish()
96    }
97}
98
99#[derive(Clone, Default)]
100pub struct JwtAdapterOptions {
101    pub get_jwks: Option<JwtGetJwksHandler>,
102    pub create_jwk: Option<JwtCreateJwkHandler>,
103}
104
105impl fmt::Debug for JwtAdapterOptions {
106    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
107        formatter
108            .debug_struct("JwtAdapterOptions")
109            .field("get_jwks", &self.get_jwks.as_ref().map(|_| "<get-jwks>"))
110            .field(
111                "create_jwk",
112                &self.create_jwk.as_ref().map(|_| "<create-jwk>"),
113            )
114            .finish()
115    }
116}
117
118impl JwtOptions {
119    pub fn validate(&self) -> Result<(), OpenAuthError> {
120        if self.jwt.sign.is_some() && self.jwks.remote_url.is_none() {
121            return Err(OpenAuthError::InvalidConfig(
122                "options.jwks.remoteUrl must be set when using options.jwt.sign".to_owned(),
123            ));
124        }
125        if self.jwks.remote_url.is_some() && self.jwks.key_pair_algorithm.is_none() {
126            return Err(OpenAuthError::InvalidConfig(
127                "options.jwks.keyPairConfig.alg must be specified when using remoteUrl".to_owned(),
128            ));
129        }
130        if let Some(modulus_length) = self.jwks.rsa_modulus_length {
131            if modulus_length < 2048 {
132                return Err(OpenAuthError::InvalidConfig(
133                    "options.jwks.keyPairConfig.modulusLength must be at least 2048".to_owned(),
134                ));
135            }
136        }
137        let path = &self.jwks.jwks_path;
138        if path.is_empty() || !path.starts_with('/') || path.contains("..") {
139            return Err(OpenAuthError::InvalidConfig(
140                "options.jwks.jwksPath must be a non-empty string starting with '/' and not contain '..'"
141                    .to_owned(),
142            ));
143        }
144        Ok(())
145    }
146
147    pub(crate) fn algorithm(&self) -> JwkAlgorithm {
148        self.jwks.key_pair_algorithm.unwrap_or(JwkAlgorithm::EdDsa)
149    }
150}