Skip to main content

oauth2_test_server/
config.rs

1use serde_json::{json, Value};
2use std::collections::HashSet;
3
4/// Server-level configuration for the OAuth2 / OIDC issuer.
5///
6/// Construct via [`IssuerConfig::default()`] and override individual fields,
7/// or build one from scratch for full control.
8#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9pub struct IssuerConfig {
10    #[serde(default = "default_scheme")]
11    pub scheme: String,
12    #[serde(default = "default_host")]
13    pub host: String,
14    #[serde(default)]
15    pub port: u16,
16
17    // OIDC / OAuth capabilities
18    #[serde(default = "default_scopes_supported")]
19    pub scopes_supported: HashSet<String>,
20    #[serde(default = "default_claims_supported")]
21    pub claims_supported: Vec<String>,
22    #[serde(default = "default_grant_types_supported")]
23    pub grant_types_supported: HashSet<String>,
24    #[serde(default = "default_response_types_supported")]
25    pub response_types_supported: HashSet<String>,
26    #[serde(default = "default_token_endpoint_auth_methods_supported")]
27    pub token_endpoint_auth_methods_supported: HashSet<String>,
28    #[serde(default = "default_code_challenge_methods_supported")]
29    pub code_challenge_methods_supported: HashSet<String>,
30    #[serde(default = "default_subject_types_supported")]
31    pub subject_types_supported: Vec<String>,
32    #[serde(default = "default_id_token_signing_alg_values_supported")]
33    pub id_token_signing_alg_values_supported: Vec<String>,
34
35    #[serde(default = "default_generate_client_secret")]
36    pub generate_client_secret_for_dcr: bool,
37
38    /// CORS origins to allow. If empty, all origins are allowed.
39    #[serde(default)]
40    pub allowed_origins: Vec<String>,
41
42    /// Default `sub` claim value used when no user is logged in.
43    #[serde(default = "default_user_id")]
44    pub default_user_id: String,
45
46    /// Require `state` parameter in authorization requests (RFC 6749 compliance).
47    /// Default: true
48    #[serde(default = "default_true")]
49    pub require_state: bool,
50
51    /// Access token expiration time in seconds.
52    /// Default: 3600 (1 hour)
53    #[serde(default = "default_access_token_expires")]
54    pub access_token_expires_in: u64,
55
56    /// Refresh token expiration time in seconds.
57    /// Default: 86400 * 30 (30 days)
58    #[serde(default = "default_refresh_token_expires")]
59    pub refresh_token_expires_in: u64,
60
61    /// Authorization code expiration time in seconds.
62    /// Default: 600 (10 minutes)
63    #[serde(default = "default_code_expires")]
64    pub authorization_code_expires_in: u64,
65
66    /// Cleanup interval for expired tokens/codes in seconds.
67    /// Default: 300 (5 minutes). Set to 0 to disable.
68    #[serde(default = "default_cleanup_interval")]
69    pub cleanup_interval_secs: u64,
70}
71
72fn default_true() -> bool {
73    true
74}
75fn default_access_token_expires() -> u64 {
76    3600
77}
78fn default_refresh_token_expires() -> u64 {
79    86400 * 30
80}
81fn default_code_expires() -> u64 {
82    600
83}
84fn default_cleanup_interval() -> u64 {
85    300
86}
87
88fn default_scheme() -> String {
89    "http".into()
90}
91fn default_host() -> String {
92    "localhost".into()
93}
94fn default_generate_client_secret() -> bool {
95    true
96}
97fn default_user_id() -> String {
98    "test-user-123".into()
99}
100
101fn default_scopes_supported() -> HashSet<String> {
102    [
103        "openid",
104        "profile",
105        "email",
106        "offline_access",
107        "address",
108        "phone",
109    ]
110    .iter()
111    .map(|s| s.to_string())
112    .collect()
113}
114fn default_claims_supported() -> Vec<String> {
115    vec![
116        "sub".to_string(),
117        "name".to_string(),
118        "given_name".to_string(),
119        "family_name".to_string(),
120        "email".to_string(),
121        "email_verified".to_string(),
122        "picture".to_string(),
123        "locale".to_string(),
124    ]
125}
126fn default_grant_types_supported() -> HashSet<String> {
127    ["authorization_code", "refresh_token", "client_credentials"]
128        .iter()
129        .map(|s| s.to_string())
130        .collect()
131}
132fn default_response_types_supported() -> HashSet<String> {
133    ["code", "token", "id_token"]
134        .iter()
135        .map(|s| s.to_string())
136        .collect()
137}
138fn default_token_endpoint_auth_methods_supported() -> HashSet<String> {
139    [
140        "client_secret_basic",
141        "client_secret_post",
142        "none",
143        "private_key_jwt",
144    ]
145    .iter()
146    .map(|s| s.to_string())
147    .collect()
148}
149fn default_code_challenge_methods_supported() -> HashSet<String> {
150    ["plain", "S256"].iter().map(|s| s.to_string()).collect()
151}
152fn default_subject_types_supported() -> Vec<String> {
153    vec!["public".to_string()]
154}
155fn default_id_token_signing_alg_values_supported() -> Vec<String> {
156    vec!["RS256".to_string()]
157}
158
159impl Default for IssuerConfig {
160    fn default() -> Self {
161        let mut scopes = HashSet::new();
162        scopes.extend([
163            "openid".into(),
164            "profile".into(),
165            "email".into(),
166            "offline_access".into(),
167            "address".into(),
168            "phone".into(),
169        ]);
170
171        let mut grants = HashSet::new();
172        grants.extend([
173            "authorization_code".into(),
174            "refresh_token".into(),
175            "client_credentials".into(),
176        ]);
177
178        let mut auth_methods = HashSet::new();
179        auth_methods.extend([
180            "client_secret_basic".into(),
181            "client_secret_post".into(),
182            "none".into(),
183            "private_key_jwt".into(),
184        ]);
185
186        Self {
187            scheme: "http".into(),
188            host: "localhost".into(),
189            port: 0, // 0 = OS assigns a random free port
190            scopes_supported: scopes,
191            claims_supported: vec![
192                "sub".into(),
193                "name".into(),
194                "given_name".into(),
195                "family_name".into(),
196                "email".into(),
197                "email_verified".into(),
198                "picture".into(),
199                "locale".into(),
200            ],
201            generate_client_secret_for_dcr: true,
202            grant_types_supported: grants,
203            response_types_supported: ["code".into(), "token".into(), "id_token".into()].into(),
204            token_endpoint_auth_methods_supported: auth_methods,
205            code_challenge_methods_supported: ["plain".into(), "S256".into()].into(),
206            subject_types_supported: vec!["public".into()],
207            id_token_signing_alg_values_supported: vec!["RS256".into()],
208            // Empty by default → CorsLayer uses AllowOrigin::any()
209            allowed_origins: vec![],
210            default_user_id: "test-user-123".into(),
211            require_state: true,
212            access_token_expires_in: 3600,
213            refresh_token_expires_in: 86400 * 30,
214            authorization_code_expires_in: 600,
215            cleanup_interval_secs: 300,
216        }
217    }
218}
219
220impl IssuerConfig {
221    /// Load configuration from environment variables, prefixed with `OAUTH_`.
222    pub fn from_env() -> Result<Self, envy::Error> {
223        dotenvy::dotenv().ok();
224        envy::prefixed("OAUTH_").from_env::<Self>()
225    }
226
227    /// Build the OpenID Connect Discovery document for this issuer.
228    pub fn to_discovery_document(&self, issuer: String) -> Value {
229        let iss = issuer;
230        json!({
231            "issuer": iss,
232            "authorization_endpoint": format!("{}/authorize", iss),
233            "token_endpoint": format!("{}/token", iss),
234            "userinfo_endpoint": format!("{}/userinfo", iss),
235            "jwks_uri": format!("{}/.well-known/jwks.json", iss),
236            "registration_endpoint": format!("{}/register", iss),
237            "revocation_endpoint": format!("{}/revoke", iss),
238            "introspection_endpoint": format!("{}/introspect", iss),
239            "scopes_supported": self.scopes_supported.iter().collect::<Vec<_>>(),
240            "claims_supported": &self.claims_supported,
241            "grant_types_supported": self.grant_types_supported.iter().collect::<Vec<_>>(),
242            "response_types_supported": self.response_types_supported.iter().collect::<Vec<_>>(),
243            "token_endpoint_auth_methods_supported": self.token_endpoint_auth_methods_supported.iter().collect::<Vec<_>>(),
244            "code_challenge_methods_supported": self.code_challenge_methods_supported.iter().collect::<Vec<_>>(),
245            "subject_types_supported": &self.subject_types_supported,
246            "id_token_signing_alg_values_supported": &self.id_token_signing_alg_values_supported,
247        })
248    }
249
250    /// Validates that all requested scopes are in `scopes_supported`.
251    /// Returns the original scope string on success, or an error message on failure.
252    pub fn validate_scope(&self, scope: &str) -> Result<String, String> {
253        let requested: HashSet<_> = scope.split_whitespace().map(|s| s.to_string()).collect();
254        let unknown: Vec<_> = requested
255            .difference(&self.scopes_supported)
256            .cloned()
257            .collect();
258        if unknown.is_empty() {
259            Ok(scope.to_string())
260        } else {
261            Err(format!("invalid_scope: {}", unknown.join(" ")))
262        }
263    }
264
265    /// Returns `true` if the given grant type is in `grant_types_supported`.
266    pub fn validate_grant_type(&self, grant: &str) -> bool {
267        self.grant_types_supported.contains(grant)
268    }
269
270    /// Load configuration from a file (YAML or TOML).
271    /// The format is detected from the file extension.
272    ///
273    /// # Example
274    ///
275    /// ```ignore
276    /// let config = IssuerConfig::from_file("config.yaml").unwrap();
277    /// let config = IssuerConfig::from_file("config.toml").unwrap();
278    /// ```
279    #[cfg(feature = "config")]
280    pub fn from_file(path: &std::path::Path) -> Result<Self, ConfigError> {
281        use std::fs;
282
283        let content = fs::read_to_string(path)?;
284        let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
285
286        match ext.to_lowercase().as_str() {
287            "yaml" | "yml" => Self::from_yaml(&content),
288            "toml" => Self::from_toml(&content),
289            _ => Err(ConfigError::UnsupportedFormat(ext.to_string())),
290        }
291    }
292
293    /// Load configuration from YAML string.
294    #[cfg(feature = "config")]
295    pub fn from_yaml(yaml: &str) -> Result<Self, ConfigError> {
296        serde_yaml::from_str(yaml).map_err(ConfigError::YamlParseError)
297    }
298
299    /// Load configuration from TOML string.
300    #[cfg(feature = "config")]
301    pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
302        toml::from_str(toml_str).map_err(ConfigError::TomlParseError)
303    }
304}
305
306#[cfg(feature = "config")]
307#[derive(Debug, thiserror::Error)]
308pub enum ConfigError {
309    #[error("IO error: {0}")]
310    Io(#[from] std::io::Error),
311    #[error("YAML parse error: {0}")]
312    YamlParseError(serde_yaml::Error),
313    #[error("TOML parse error: {0}")]
314    TomlParseError(toml::de::Error),
315    #[error("Unsupported config format: {0}")]
316    UnsupportedFormat(String),
317}