auth_framework/server/core/
client_registry.rs

1//! OAuth 2.0 Client Registry Module
2//!
3//! This module implements a client registry for managing OAuth 2.0 clients
4//! including registration, retrieval, and validation.
5
6use crate::errors::{AuthError, Result};
7use crate::storage::core::AuthStorage;
8use crate::storage::memory::InMemoryStorage;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use uuid::Uuid;
13
14/// OAuth 2.0 Client Types as defined in RFC 6749
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
16pub enum ClientType {
17    /// Confidential clients capable of securely storing credentials
18    Confidential,
19    /// Public clients unable to securely store credentials
20    Public,
21}
22
23/// OAuth 2.0 Client Configuration
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ClientConfig {
26    /// Unique client identifier
27    pub client_id: String,
28    /// Client secret (only for confidential clients)
29    pub client_secret: Option<String>,
30    /// Client type
31    pub client_type: ClientType,
32    /// Authorized redirect URIs
33    pub redirect_uris: Vec<String>,
34    /// Authorized scopes
35    pub authorized_scopes: Vec<String>,
36    /// Grant types the client is authorized to use
37    pub authorized_grant_types: Vec<String>,
38    /// Response types the client is authorized to use
39    pub authorized_response_types: Vec<String>,
40    /// Client name for display purposes
41    pub client_name: Option<String>,
42    /// Client description
43    pub client_description: Option<String>,
44    /// Client metadata
45    pub metadata: HashMap<String, serde_json::Value>,
46}
47
48impl Default for ClientConfig {
49    fn default() -> Self {
50        Self {
51            client_id: Uuid::new_v4().to_string(),
52            client_secret: None,
53            client_type: ClientType::Public,
54            redirect_uris: Vec::new(),
55            authorized_scopes: vec!["read".to_string()],
56            authorized_grant_types: vec!["authorization_code".to_string()],
57            authorized_response_types: vec!["code".to_string()],
58            client_name: None,
59            client_description: None,
60            metadata: HashMap::new(),
61        }
62    }
63}
64
65/// Client Registry for managing OAuth 2.0 clients
66#[derive(Clone)]
67pub struct ClientRegistry {
68    storage: Arc<dyn AuthStorage>,
69}
70
71impl ClientRegistry {
72    /// Create a new client registry
73    pub async fn new(storage: Arc<dyn AuthStorage>) -> Result<Self> {
74        Ok(Self { storage })
75    }
76
77    /// Register a new OAuth 2.0 client
78    pub async fn register_client(&self, config: ClientConfig) -> Result<ClientConfig> {
79        // Validate the client configuration
80        self.validate_client_config(&config)?;
81
82        // Store the client in the storage backend
83        let client_key = format!("oauth_client:{}", config.client_id);
84        let client_data = serde_json::to_string(&config)
85            .map_err(|e| AuthError::internal(format!("Failed to serialize client: {}", e)))?;
86
87        self.storage
88            .store_kv(&client_key, client_data.as_bytes(), None)
89            .await?;
90
91        Ok(config)
92    }
93
94    /// Retrieve a client by ID
95    pub async fn get_client(&self, client_id: &str) -> Result<Option<ClientConfig>> {
96        let client_key = format!("oauth_client:{}", client_id);
97
98        if let Some(client_data) = self.storage.get_kv(&client_key).await? {
99            let client_str = std::str::from_utf8(&client_data)
100                .map_err(|e| AuthError::internal(format!("Invalid UTF-8 in client data: {}", e)))?;
101            let config: ClientConfig = serde_json::from_str(client_str)
102                .map_err(|e| AuthError::internal(format!("Failed to deserialize client: {}", e)))?;
103            Ok(Some(config))
104        } else {
105            Ok(None)
106        }
107    }
108
109    /// Update a client configuration
110    pub async fn update_client(&self, client_id: &str, config: ClientConfig) -> Result<()> {
111        // Ensure the client ID matches
112        if config.client_id != client_id {
113            return Err(AuthError::validation("Client ID mismatch"));
114        }
115
116        // Validate the updated configuration
117        self.validate_client_config(&config)?;
118
119        // Store the updated client
120        let client_key = format!("oauth_client:{}", client_id);
121        let client_data = serde_json::to_string(&config)
122            .map_err(|e| AuthError::internal(format!("Failed to serialize client: {}", e)))?;
123
124        self.storage
125            .store_kv(&client_key, client_data.as_bytes(), None)
126            .await?;
127
128        Ok(())
129    }
130
131    /// Delete a client
132    pub async fn delete_client(&self, client_id: &str) -> Result<()> {
133        let client_key = format!("oauth_client:{}", client_id);
134        self.storage.delete_kv(&client_key).await?;
135        Ok(())
136    }
137
138    /// Validate that a redirect URI is authorized for a client
139    pub async fn validate_redirect_uri(&self, client_id: &str, redirect_uri: &str) -> Result<bool> {
140        if let Some(client) = self.get_client(client_id).await? {
141            Ok(client.redirect_uris.contains(&redirect_uri.to_string()))
142        } else {
143            Ok(false)
144        }
145    }
146
147    /// Validate that a scope is authorized for a client
148    pub async fn validate_scope(&self, client_id: &str, scope: &str) -> Result<bool> {
149        if let Some(client) = self.get_client(client_id).await? {
150            Ok(client.authorized_scopes.contains(&scope.to_string()))
151        } else {
152            Ok(false)
153        }
154    }
155
156    /// Validate that a grant type is authorized for a client
157    pub async fn validate_grant_type(&self, client_id: &str, grant_type: &str) -> Result<bool> {
158        if let Some(client) = self.get_client(client_id).await? {
159            Ok(client
160                .authorized_grant_types
161                .contains(&grant_type.to_string()))
162        } else {
163            Ok(false)
164        }
165    }
166
167    /// Authenticate a confidential client using client credentials
168    pub async fn authenticate_client(&self, client_id: &str, client_secret: &str) -> Result<bool> {
169        if let Some(client) = self.get_client(client_id).await? {
170            match (&client.client_type, &client.client_secret) {
171                (ClientType::Confidential, Some(stored_secret)) => {
172                    // Use constant-time comparison to prevent timing attacks
173                    Ok(crate::security::secure_utils::constant_time_compare(
174                        client_secret.as_bytes(),
175                        stored_secret.as_bytes(),
176                    ))
177                }
178                (ClientType::Public, None) => {
179                    // Public clients don't have secrets
180                    Ok(true)
181                }
182                _ => Ok(false),
183            }
184        } else {
185            Ok(false)
186        }
187    }
188
189    /// Validate client configuration
190    fn validate_client_config(&self, config: &ClientConfig) -> Result<()> {
191        // Client ID must not be empty
192        if config.client_id.is_empty() {
193            return Err(AuthError::validation("Client ID cannot be empty"));
194        }
195
196        // Confidential clients must have a secret
197        if config.client_type == ClientType::Confidential && config.client_secret.is_none() {
198            return Err(AuthError::validation(
199                "Confidential clients must have a client secret",
200            ));
201        }
202
203        // Public clients must not have a secret
204        if config.client_type == ClientType::Public && config.client_secret.is_some() {
205            return Err(AuthError::validation(
206                "Public clients must not have a client secret",
207            ));
208        }
209
210        // At least one redirect URI must be provided
211        if config.redirect_uris.is_empty() {
212            return Err(AuthError::validation(
213                "At least one redirect URI must be provided",
214            ));
215        }
216
217        // Validate redirect URIs
218        for uri in &config.redirect_uris {
219            if uri.is_empty() {
220                return Err(AuthError::validation("Redirect URI cannot be empty"));
221            }
222
223            // Basic URI validation (in production, use a proper URI parser)
224            if !uri.starts_with("https://") && !uri.starts_with("http://localhost") {
225                return Err(AuthError::validation(
226                    "Redirect URIs must use HTTPS (except localhost)",
227                ));
228            }
229        }
230
231        // At least one scope must be provided
232        if config.authorized_scopes.is_empty() {
233            return Err(AuthError::validation(
234                "At least one authorized scope must be provided",
235            ));
236        }
237
238        // At least one grant type must be provided
239        if config.authorized_grant_types.is_empty() {
240            return Err(AuthError::validation(
241                "At least one authorized grant type must be provided",
242            ));
243        }
244
245        // At least one response type must be provided
246        if config.authorized_response_types.is_empty() {
247            return Err(AuthError::validation(
248                "At least one authorized response type must be provided",
249            ));
250        }
251
252        Ok(())
253    }
254}
255
256impl Default for ClientRegistry {
257    fn default() -> Self {
258        // Create default registry with environment-based storage configuration
259        let storage =
260            if std::env::var("CLIENT_REGISTRY_STORAGE").unwrap_or_default() == "persistent" {
261                // In production, this could be database or file-based storage
262                Arc::new(InMemoryStorage::new())
263            } else {
264                // Default to in-memory storage for development/testing
265                Arc::new(InMemoryStorage::new())
266            };
267
268        Self { storage }
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::storage::memory::InMemoryStorage;
276
277    #[tokio::test]
278    async fn test_client_registry_operations() {
279        let storage = Arc::new(InMemoryStorage::new());
280        let registry = ClientRegistry::new(storage).await.unwrap();
281
282        // Create a test client configuration
283        let client_config = ClientConfig {
284            client_id: "test_client".to_string(),
285            client_type: ClientType::Confidential,
286            client_secret: Some("test_secret".to_string()),
287            redirect_uris: vec!["https://example.com/callback".to_string()],
288            ..Default::default()
289        };
290
291        // Register the client
292        let registered_client = registry
293            .register_client(client_config.clone())
294            .await
295            .unwrap();
296        assert_eq!(registered_client.client_id, "test_client");
297
298        // Retrieve the client
299        let retrieved_client = registry.get_client("test_client").await.unwrap().unwrap();
300        assert_eq!(retrieved_client.client_id, "test_client");
301        assert_eq!(retrieved_client.client_type, ClientType::Confidential);
302
303        // Authenticate the client
304        let auth_result = registry
305            .authenticate_client("test_client", "test_secret")
306            .await
307            .unwrap();
308        assert!(auth_result);
309
310        let auth_fail = registry
311            .authenticate_client("test_client", "wrong_secret")
312            .await
313            .unwrap();
314        assert!(!auth_fail);
315
316        // Validate redirect URI
317        let valid_uri = registry
318            .validate_redirect_uri("test_client", "https://example.com/callback")
319            .await
320            .unwrap();
321        assert!(valid_uri);
322
323        let invalid_uri = registry
324            .validate_redirect_uri("test_client", "https://malicious.com/callback")
325            .await
326            .unwrap();
327        assert!(!invalid_uri);
328
329        // Delete the client
330        registry.delete_client("test_client").await.unwrap();
331        let deleted_client = registry.get_client("test_client").await.unwrap();
332        assert!(deleted_client.is_none());
333    }
334
335    #[tokio::test]
336    async fn test_client_validation() {
337        let storage = Arc::new(InMemoryStorage::new());
338        let registry = ClientRegistry::new(storage).await.unwrap();
339
340        // Test empty client ID
341        let invalid_config = ClientConfig {
342            client_id: "".to_string(),
343            ..Default::default()
344        };
345        assert!(registry.register_client(invalid_config).await.is_err());
346
347        // Test confidential client without secret
348        let invalid_config = ClientConfig {
349            client_type: ClientType::Confidential,
350            client_secret: None,
351            ..Default::default()
352        };
353        assert!(registry.register_client(invalid_config).await.is_err());
354
355        // Test public client with secret
356        let invalid_config = ClientConfig {
357            client_type: ClientType::Public,
358            client_secret: Some("secret".to_string()),
359            ..Default::default()
360        };
361        assert!(registry.register_client(invalid_config).await.is_err());
362
363        // Test empty redirect URIs
364        let invalid_config = ClientConfig {
365            redirect_uris: vec![],
366            ..Default::default()
367        };
368        assert!(registry.register_client(invalid_config).await.is_err());
369
370        // Test insecure redirect URI
371        let invalid_config = ClientConfig {
372            redirect_uris: vec!["http://example.com/callback".to_string()],
373            ..Default::default()
374        };
375        assert!(registry.register_client(invalid_config).await.is_err());
376    }
377}
378
379