Skip to main content

auth_framework/server/core/
client_registry.rs

1//! OAuth 2.0 Client Registry Module
2//!
3//! This module implements a client registry for managing OAuth 2.0 clients
4//! including registration, retrieval, and validation.
5
6use crate::errors::{AuthError, Result};
7use crate::storage::core::AuthStorage;
8use crate::storage::memory::InMemoryStorage;
9use std::sync::Arc;
10
11// Re-export canonical types so existing imports like
12// `server::core::client_registry::{ClientType, ClientConfig}` keep working.
13pub use crate::client::{ClientConfig, ClientType};
14
15/// Client Registry for managing OAuth 2.0 clients
16#[derive(Clone)]
17pub struct ClientRegistry {
18    storage: Arc<dyn AuthStorage>,
19}
20
21impl ClientRegistry {
22    /// Create a new client registry
23    pub async fn new(storage: Arc<dyn AuthStorage>) -> Result<Self> {
24        Ok(Self { storage })
25    }
26
27    /// Register a new OAuth 2.0 client
28    pub async fn register_client(&self, config: ClientConfig) -> Result<ClientConfig> {
29        // Validate the client configuration
30        self.validate_client_config(&config)?;
31
32        // Store the client in the storage backend
33        let client_key = format!("oauth_client:{}", config.client_id);
34        let client_data = serde_json::to_string(&config)
35            .map_err(|e| AuthError::internal(format!("Failed to serialize client: {}", e)))?;
36
37        self.storage
38            .store_kv(&client_key, client_data.as_bytes(), None)
39            .await?;
40
41        Ok(config)
42    }
43
44    /// Retrieve a client by ID
45    pub async fn get_client(&self, client_id: &str) -> Result<Option<ClientConfig>> {
46        let client_key = format!("oauth_client:{}", client_id);
47
48        if let Some(client_data) = self.storage.get_kv(&client_key).await? {
49            let client_str = std::str::from_utf8(&client_data)
50                .map_err(|e| AuthError::internal(format!("Invalid UTF-8 in client data: {}", e)))?;
51            let config: ClientConfig = serde_json::from_str(client_str)
52                .map_err(|e| AuthError::internal(format!("Failed to deserialize client: {}", e)))?;
53            Ok(Some(config))
54        } else {
55            Ok(None)
56        }
57    }
58
59    /// Update a client configuration
60    pub async fn update_client(&self, client_id: &str, config: ClientConfig) -> Result<()> {
61        // Ensure the client ID matches
62        if config.client_id != client_id {
63            return Err(AuthError::validation("Client ID mismatch"));
64        }
65
66        // Validate the updated configuration
67        self.validate_client_config(&config)?;
68
69        // Store the updated client
70        let client_key = format!("oauth_client:{}", client_id);
71        let client_data = serde_json::to_string(&config)
72            .map_err(|e| AuthError::internal(format!("Failed to serialize client: {}", e)))?;
73
74        self.storage
75            .store_kv(&client_key, client_data.as_bytes(), None)
76            .await?;
77
78        Ok(())
79    }
80
81    /// Delete a client
82    pub async fn delete_client(&self, client_id: &str) -> Result<()> {
83        let client_key = format!("oauth_client:{}", client_id);
84        self.storage.delete_kv(&client_key).await?;
85        Ok(())
86    }
87
88    /// Validate that a redirect URI is authorized for a client
89    pub async fn validate_redirect_uri(&self, client_id: &str, redirect_uri: &str) -> Result<bool> {
90        if let Some(client) = self.get_client(client_id).await? {
91            Ok(client.redirect_uris.contains(&redirect_uri.to_string()))
92        } else {
93            Ok(false)
94        }
95    }
96
97    /// Validate that a scope is authorized for a client
98    pub async fn validate_scope(&self, client_id: &str, scope: &str) -> Result<bool> {
99        if let Some(client) = self.get_client(client_id).await? {
100            Ok(client.authorized_scopes.contains(&scope.to_string()))
101        } else {
102            Ok(false)
103        }
104    }
105
106    /// Validate that a grant type is authorized for a client
107    pub async fn validate_grant_type(&self, client_id: &str, grant_type: &str) -> Result<bool> {
108        if let Some(client) = self.get_client(client_id).await? {
109            Ok(client
110                .authorized_grant_types
111                .contains(&grant_type.to_string()))
112        } else {
113            Ok(false)
114        }
115    }
116
117    /// Authenticate a confidential client using client credentials
118    pub async fn authenticate_client(&self, client_id: &str, client_secret: &str) -> Result<bool> {
119        if let Some(client) = self.get_client(client_id).await? {
120            match (&client.client_type, &client.client_secret) {
121                (ClientType::Confidential, Some(stored_secret)) => {
122                    // Use constant-time comparison to prevent timing attacks
123                    Ok(crate::security::secure_utils::constant_time_compare(
124                        client_secret.as_bytes(),
125                        stored_secret.as_bytes(),
126                    ))
127                }
128                (ClientType::Public, None) => {
129                    // Public clients don't have secrets
130                    Ok(true)
131                }
132                _ => Ok(false),
133            }
134        } else {
135            Ok(false)
136        }
137    }
138
139    /// Validate client configuration
140    fn validate_client_config(&self, config: &ClientConfig) -> Result<()> {
141        // Client ID must not be empty
142        if config.client_id.is_empty() {
143            return Err(AuthError::validation("Client ID cannot be empty"));
144        }
145
146        // Confidential clients must have a secret
147        if config.client_type == ClientType::Confidential && config.client_secret.is_none() {
148            return Err(AuthError::validation(
149                "Confidential clients must have a client secret",
150            ));
151        }
152
153        // Public clients must not have a secret
154        if config.client_type == ClientType::Public && config.client_secret.is_some() {
155            return Err(AuthError::validation(
156                "Public clients must not have a client secret",
157            ));
158        }
159
160        // At least one redirect URI must be provided
161        if config.redirect_uris.is_empty() {
162            return Err(AuthError::validation(
163                "At least one redirect URI must be provided",
164            ));
165        }
166
167        // Validate redirect URIs
168        for uri in &config.redirect_uris {
169            if uri.is_empty() {
170                return Err(AuthError::validation("Redirect URI cannot be empty"));
171            }
172
173            // Basic URI validation (in production, use a proper URI parser)
174            if !uri.starts_with("https://") && !uri.starts_with("http://localhost") {
175                return Err(AuthError::validation(
176                    "Redirect URIs must use HTTPS (except localhost)",
177                ));
178            }
179        }
180
181        // At least one scope must be provided
182        if config.authorized_scopes.is_empty() {
183            return Err(AuthError::validation(
184                "At least one authorized scope must be provided",
185            ));
186        }
187
188        // At least one grant type must be provided
189        if config.authorized_grant_types.is_empty() {
190            return Err(AuthError::validation(
191                "At least one authorized grant type must be provided",
192            ));
193        }
194
195        // At least one response type must be provided
196        if config.authorized_response_types.is_empty() {
197            return Err(AuthError::validation(
198                "At least one authorized response type must be provided",
199            ));
200        }
201
202        Ok(())
203    }
204}
205
206impl Default for ClientRegistry {
207    fn default() -> Self {
208        // Create default registry with environment-based storage configuration
209        let storage =
210            if std::env::var("CLIENT_REGISTRY_STORAGE").unwrap_or_default() == "persistent" {
211                // In production, this could be database or file-based storage
212                Arc::new(InMemoryStorage::new())
213            } else {
214                // Default to in-memory storage for development/testing
215                Arc::new(InMemoryStorage::new())
216            };
217
218        Self { storage }
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use crate::storage::memory::InMemoryStorage;
226
227    #[tokio::test]
228    async fn test_client_registry_operations() {
229        let storage = Arc::new(InMemoryStorage::new());
230        let registry = ClientRegistry::new(storage).await.unwrap();
231
232        // Create a test client configuration
233        let client_config = ClientConfig {
234            client_id: "test_client".to_string(),
235            client_type: ClientType::Confidential,
236            client_secret: Some("test_secret".to_string()),
237            redirect_uris: vec!["https://example.com/callback".to_string()].into(),
238            ..Default::default()
239        };
240
241        // Register the client
242        let registered_client = registry
243            .register_client(client_config.clone())
244            .await
245            .unwrap();
246        assert_eq!(registered_client.client_id, "test_client");
247
248        // Retrieve the client
249        let retrieved_client = registry.get_client("test_client").await.unwrap().unwrap();
250        assert_eq!(retrieved_client.client_id, "test_client");
251        assert_eq!(retrieved_client.client_type, ClientType::Confidential);
252
253        // Authenticate the client
254        let auth_result = registry
255            .authenticate_client("test_client", "test_secret")
256            .await
257            .unwrap();
258        assert!(auth_result);
259
260        let auth_fail = registry
261            .authenticate_client("test_client", "wrong_secret")
262            .await
263            .unwrap();
264        assert!(!auth_fail);
265
266        // Validate redirect URI
267        let valid_uri = registry
268            .validate_redirect_uri("test_client", "https://example.com/callback")
269            .await
270            .unwrap();
271        assert!(valid_uri);
272
273        let invalid_uri = registry
274            .validate_redirect_uri("test_client", "https://malicious.com/callback")
275            .await
276            .unwrap();
277        assert!(!invalid_uri);
278
279        // Delete the client
280        registry.delete_client("test_client").await.unwrap();
281        let deleted_client = registry.get_client("test_client").await.unwrap();
282        assert!(deleted_client.is_none());
283    }
284
285    #[tokio::test]
286    async fn test_client_validation() {
287        let storage = Arc::new(InMemoryStorage::new());
288        let registry = ClientRegistry::new(storage).await.unwrap();
289
290        // Test empty client ID
291        let invalid_config = ClientConfig {
292            client_id: "".to_string(),
293            ..Default::default()
294        };
295        assert!(registry.register_client(invalid_config).await.is_err());
296
297        // Test confidential client without secret
298        let invalid_config = ClientConfig {
299            client_type: ClientType::Confidential,
300            client_secret: None,
301            ..Default::default()
302        };
303        assert!(registry.register_client(invalid_config).await.is_err());
304
305        // Test public client with secret
306        let invalid_config = ClientConfig {
307            client_type: ClientType::Public,
308            client_secret: Some("secret".to_string()),
309            ..Default::default()
310        };
311        assert!(registry.register_client(invalid_config).await.is_err());
312
313        // Test empty redirect URIs
314        let invalid_config = ClientConfig {
315            redirect_uris: vec![].into(),
316            ..Default::default()
317        };
318        assert!(registry.register_client(invalid_config).await.is_err());
319
320        // Test insecure redirect URI
321        let invalid_config = ClientConfig {
322            redirect_uris: vec!["http://example.com/callback".to_string()].into(),
323            ..Default::default()
324        };
325        assert!(registry.register_client(invalid_config).await.is_err());
326    }
327}