auth_framework/server/oauth/
oauth21.rs1use crate::errors::{AuthError, Result};
8use crate::server::core::client_registry::ClientConfig;
9use crate::server::oauth::oauth2::OAuth2Server;
10use crate::storage::core::AuthStorage;
11use std::sync::Arc;
12
13#[derive(Debug, Clone)]
15pub struct OAuth21SecurityConfig {
16 pub require_pkce_for_public_clients: bool,
18 pub disallow_implicit_grant: bool,
20 pub require_exact_redirect_uri_matching: bool,
22 pub require_secure_redirect_uris: bool,
24 pub max_auth_code_lifetime: u64,
26 pub max_access_token_lifetime: u64,
28 pub require_client_authentication: bool,
30 pub disallow_password_grant: bool,
32}
33
34impl Default for OAuth21SecurityConfig {
35 fn default() -> Self {
36 Self {
37 require_pkce_for_public_clients: true,
38 disallow_implicit_grant: true,
39 require_exact_redirect_uri_matching: true,
40 require_secure_redirect_uris: true,
41 max_auth_code_lifetime: 600, max_access_token_lifetime: 3600, require_client_authentication: true,
44 disallow_password_grant: true,
45 }
46 }
47}
48
49#[derive(Clone)]
54pub struct OAuth21Server {
55 oauth2_server: Arc<OAuth2Server>,
57 security_config: OAuth21SecurityConfig,
59}
60
61impl OAuth21Server {
62 pub async fn new(
64 security_config: Option<OAuth21SecurityConfig>,
65 storage: Arc<dyn AuthStorage>,
66 ) -> Result<Self> {
67 let oauth2_server = Arc::new(OAuth2Server::new(storage).await?);
68 let security_config = security_config.unwrap_or_default();
69
70 Ok(Self {
71 oauth2_server,
72 security_config,
73 })
74 }
75
76 pub async fn register_client(&self, mut config: ClientConfig) -> Result<ClientConfig> {
78 self.validate_oauth21_client_config(&mut config)?;
80
81 self.oauth2_server.register_client(config).await
83 }
84
85 pub async fn get_client(&self, client_id: &str) -> Result<Option<ClientConfig>> {
87 self.oauth2_server.get_client(client_id).await
88 }
89
90 pub async fn update_client(&self, client_id: &str, config: ClientConfig) -> Result<()> {
92 self.oauth2_server.update_client(client_id, config).await
93 }
94
95 pub async fn delete_client(&self, client_id: &str) -> Result<()> {
97 self.oauth2_server.delete_client(client_id).await
98 }
99
100 pub async fn get_server_configuration(&self) -> Result<serde_json::Value> {
102 let mut config = self.oauth2_server.get_server_configuration().await?;
103
104 if let Some(obj) = config.as_object_mut() {
106 if self.security_config.disallow_implicit_grant {
108 if let Some(grant_types) = obj.get_mut("grant_types_supported")
109 && let Some(grants) = grant_types.as_array_mut()
110 {
111 grants.retain(|g| g.as_str() != Some("implicit"));
112 }
113
114 if let Some(response_types) = obj.get_mut("response_types_supported")
115 && let Some(types) = response_types.as_array_mut()
116 {
117 types.retain(|t| {
118 if let Some(type_str) = t.as_str() {
119 !type_str.contains("token") || type_str.contains("code")
120 } else {
121 true
122 }
123 });
124 }
125 }
126
127 if self.security_config.disallow_password_grant
129 && let Some(grant_types) = obj.get_mut("grant_types_supported")
130 && let Some(grants) = grant_types.as_array_mut()
131 {
132 grants.retain(|g| g.as_str() != Some("password"));
133 }
134
135 obj.insert(
137 "oauth21_compliant".to_string(),
138 serde_json::Value::Bool(true),
139 );
140 obj.insert(
141 "pkce_required".to_string(),
142 serde_json::Value::Bool(self.security_config.require_pkce_for_public_clients),
143 );
144 obj.insert(
145 "implicit_grant_disabled".to_string(),
146 serde_json::Value::Bool(self.security_config.disallow_implicit_grant),
147 );
148 obj.insert(
149 "password_grant_disabled".to_string(),
150 serde_json::Value::Bool(self.security_config.disallow_password_grant),
151 );
152 }
153
154 Ok(config)
155 }
156
157 fn validate_oauth21_client_config(&self, config: &mut ClientConfig) -> Result<()> {
159 if self.security_config.require_secure_redirect_uris {
161 for uri in &config.redirect_uris {
162 if !uri.starts_with("https://")
163 && !uri.starts_with("http://localhost")
164 && !uri.starts_with("http://127.0.0.1")
165 {
166 return Err(AuthError::validation(
167 "OAuth 2.1 requires HTTPS redirect URIs (except localhost)",
168 ));
169 }
170 }
171 }
172
173 if self.security_config.disallow_implicit_grant {
175 config.authorized_grant_types.retain(|g| g != "implicit");
176 config
177 .authorized_response_types
178 .retain(|r| !r.contains("token") || r.contains("code"));
179 }
180
181 if self.security_config.disallow_password_grant {
182 config.authorized_grant_types.retain(|g| g != "password");
183 }
184
185 if config.authorized_grant_types.is_empty() {
187 config
188 .authorized_grant_types
189 .push("authorization_code".to_string());
190 }
191
192 if config.authorized_response_types.is_empty() {
193 config.authorized_response_types.push("code".to_string());
194 }
195
196 Ok(())
197 }
198
199 pub fn get_security_config(&self) -> &OAuth21SecurityConfig {
201 &self.security_config
202 }
203
204 pub fn update_security_config(&mut self, config: OAuth21SecurityConfig) {
206 self.security_config = config;
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::server::core::client_registry::{ClientConfig, ClientType};
214 use crate::storage::memory::InMemoryStorage;
215
216 #[tokio::test]
217 async fn test_oauth21_server_creation() {
218 let storage = Arc::new(InMemoryStorage::new());
219 let security_config = OAuth21SecurityConfig::default();
220
221 let server = OAuth21Server::new(Some(security_config), storage)
222 .await
223 .unwrap();
224
225 let client_config = ClientConfig {
227 client_id: "test_client".to_string(),
228 client_type: ClientType::Public,
229 redirect_uris: vec!["https://example.com/callback".to_string()],
230 ..Default::default()
231 };
232
233 let registered_client = server.register_client(client_config).await.unwrap();
234 assert_eq!(registered_client.client_id, "test_client");
235
236 assert!(
238 !registered_client
239 .authorized_grant_types
240 .contains(&"implicit".to_string())
241 );
242 assert!(
243 !registered_client
244 .authorized_response_types
245 .iter()
246 .any(|r| r.contains("token") && !r.contains("code"))
247 );
248 }
249
250 #[tokio::test]
251 async fn test_oauth21_security_validations() {
252 let storage = Arc::new(InMemoryStorage::new());
253 let security_config = OAuth21SecurityConfig::default();
254
255 let server = OAuth21Server::new(Some(security_config), storage)
256 .await
257 .unwrap();
258
259 let client_config = ClientConfig {
261 client_id: "test_client".to_string(),
262 client_type: ClientType::Public,
263 redirect_uris: vec!["http://example.com/callback".to_string()],
264 ..Default::default()
265 };
266
267 assert!(server.register_client(client_config).await.is_err());
268
269 let client_config = ClientConfig {
271 client_id: "test_client".to_string(),
272 client_type: ClientType::Public,
273 redirect_uris: vec!["https://example.com/callback".to_string()],
274 ..Default::default()
275 };
276
277 assert!(server.register_client(client_config).await.is_ok());
278 }
279
280 #[tokio::test]
281 async fn test_oauth21_server_configuration() {
282 let storage = Arc::new(InMemoryStorage::new());
283 let security_config = OAuth21SecurityConfig::default();
284
285 let server = OAuth21Server::new(Some(security_config), storage)
286 .await
287 .unwrap();
288 let config = server.get_server_configuration().await.unwrap();
289
290 assert_eq!(config["oauth21_compliant"], true);
291 assert_eq!(config["pkce_required"], true);
292 assert_eq!(config["implicit_grant_disabled"], true);
293 assert_eq!(config["password_grant_disabled"], true);
294
295 let grant_types = config["grant_types_supported"].as_array().unwrap();
297 assert!(!grant_types.iter().any(|g| g.as_str() == Some("implicit")));
298
299 assert!(!grant_types.iter().any(|g| g.as_str() == Some("password")));
301 }
302}
303
304