auth_framework/server/core/
client_registry.rs1use crate::errors::{AuthError, Result};
7use crate::storage::core::AuthStorage;
8use crate::storage::memory::InMemoryStorage;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use uuid::Uuid;
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
16pub enum ClientType {
17 Confidential,
19 Public,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ClientConfig {
26 pub client_id: String,
28 pub client_secret: Option<String>,
30 pub client_type: ClientType,
32 pub redirect_uris: Vec<String>,
34 pub authorized_scopes: Vec<String>,
36 pub authorized_grant_types: Vec<String>,
38 pub authorized_response_types: Vec<String>,
40 pub client_name: Option<String>,
42 pub client_description: Option<String>,
44 pub metadata: HashMap<String, serde_json::Value>,
46}
47
48impl Default for ClientConfig {
49 fn default() -> Self {
50 Self {
51 client_id: Uuid::new_v4().to_string(),
52 client_secret: None,
53 client_type: ClientType::Public,
54 redirect_uris: Vec::new(),
55 authorized_scopes: vec!["read".to_string()],
56 authorized_grant_types: vec!["authorization_code".to_string()],
57 authorized_response_types: vec!["code".to_string()],
58 client_name: None,
59 client_description: None,
60 metadata: HashMap::new(),
61 }
62 }
63}
64
65#[derive(Clone)]
67pub struct ClientRegistry {
68 storage: Arc<dyn AuthStorage>,
69}
70
71impl ClientRegistry {
72 pub async fn new(storage: Arc<dyn AuthStorage>) -> Result<Self> {
74 Ok(Self { storage })
75 }
76
77 pub async fn register_client(&self, config: ClientConfig) -> Result<ClientConfig> {
79 self.validate_client_config(&config)?;
81
82 let client_key = format!("oauth_client:{}", config.client_id);
84 let client_data = serde_json::to_string(&config)
85 .map_err(|e| AuthError::internal(format!("Failed to serialize client: {}", e)))?;
86
87 self.storage
88 .store_kv(&client_key, client_data.as_bytes(), None)
89 .await?;
90
91 Ok(config)
92 }
93
94 pub async fn get_client(&self, client_id: &str) -> Result<Option<ClientConfig>> {
96 let client_key = format!("oauth_client:{}", client_id);
97
98 if let Some(client_data) = self.storage.get_kv(&client_key).await? {
99 let client_str = std::str::from_utf8(&client_data)
100 .map_err(|e| AuthError::internal(format!("Invalid UTF-8 in client data: {}", e)))?;
101 let config: ClientConfig = serde_json::from_str(client_str)
102 .map_err(|e| AuthError::internal(format!("Failed to deserialize client: {}", e)))?;
103 Ok(Some(config))
104 } else {
105 Ok(None)
106 }
107 }
108
109 pub async fn update_client(&self, client_id: &str, config: ClientConfig) -> Result<()> {
111 if config.client_id != client_id {
113 return Err(AuthError::validation("Client ID mismatch"));
114 }
115
116 self.validate_client_config(&config)?;
118
119 let client_key = format!("oauth_client:{}", client_id);
121 let client_data = serde_json::to_string(&config)
122 .map_err(|e| AuthError::internal(format!("Failed to serialize client: {}", e)))?;
123
124 self.storage
125 .store_kv(&client_key, client_data.as_bytes(), None)
126 .await?;
127
128 Ok(())
129 }
130
131 pub async fn delete_client(&self, client_id: &str) -> Result<()> {
133 let client_key = format!("oauth_client:{}", client_id);
134 self.storage.delete_kv(&client_key).await?;
135 Ok(())
136 }
137
138 pub async fn validate_redirect_uri(&self, client_id: &str, redirect_uri: &str) -> Result<bool> {
140 if let Some(client) = self.get_client(client_id).await? {
141 Ok(client.redirect_uris.contains(&redirect_uri.to_string()))
142 } else {
143 Ok(false)
144 }
145 }
146
147 pub async fn validate_scope(&self, client_id: &str, scope: &str) -> Result<bool> {
149 if let Some(client) = self.get_client(client_id).await? {
150 Ok(client.authorized_scopes.contains(&scope.to_string()))
151 } else {
152 Ok(false)
153 }
154 }
155
156 pub async fn validate_grant_type(&self, client_id: &str, grant_type: &str) -> Result<bool> {
158 if let Some(client) = self.get_client(client_id).await? {
159 Ok(client
160 .authorized_grant_types
161 .contains(&grant_type.to_string()))
162 } else {
163 Ok(false)
164 }
165 }
166
167 pub async fn authenticate_client(&self, client_id: &str, client_secret: &str) -> Result<bool> {
169 if let Some(client) = self.get_client(client_id).await? {
170 match (&client.client_type, &client.client_secret) {
171 (ClientType::Confidential, Some(stored_secret)) => {
172 Ok(crate::security::secure_utils::constant_time_compare(
174 client_secret.as_bytes(),
175 stored_secret.as_bytes(),
176 ))
177 }
178 (ClientType::Public, None) => {
179 Ok(true)
181 }
182 _ => Ok(false),
183 }
184 } else {
185 Ok(false)
186 }
187 }
188
189 fn validate_client_config(&self, config: &ClientConfig) -> Result<()> {
191 if config.client_id.is_empty() {
193 return Err(AuthError::validation("Client ID cannot be empty"));
194 }
195
196 if config.client_type == ClientType::Confidential && config.client_secret.is_none() {
198 return Err(AuthError::validation(
199 "Confidential clients must have a client secret",
200 ));
201 }
202
203 if config.client_type == ClientType::Public && config.client_secret.is_some() {
205 return Err(AuthError::validation(
206 "Public clients must not have a client secret",
207 ));
208 }
209
210 if config.redirect_uris.is_empty() {
212 return Err(AuthError::validation(
213 "At least one redirect URI must be provided",
214 ));
215 }
216
217 for uri in &config.redirect_uris {
219 if uri.is_empty() {
220 return Err(AuthError::validation("Redirect URI cannot be empty"));
221 }
222
223 if !uri.starts_with("https://") && !uri.starts_with("http://localhost") {
225 return Err(AuthError::validation(
226 "Redirect URIs must use HTTPS (except localhost)",
227 ));
228 }
229 }
230
231 if config.authorized_scopes.is_empty() {
233 return Err(AuthError::validation(
234 "At least one authorized scope must be provided",
235 ));
236 }
237
238 if config.authorized_grant_types.is_empty() {
240 return Err(AuthError::validation(
241 "At least one authorized grant type must be provided",
242 ));
243 }
244
245 if config.authorized_response_types.is_empty() {
247 return Err(AuthError::validation(
248 "At least one authorized response type must be provided",
249 ));
250 }
251
252 Ok(())
253 }
254}
255
256impl Default for ClientRegistry {
257 fn default() -> Self {
258 let storage =
260 if std::env::var("CLIENT_REGISTRY_STORAGE").unwrap_or_default() == "persistent" {
261 Arc::new(InMemoryStorage::new())
263 } else {
264 Arc::new(InMemoryStorage::new())
266 };
267
268 Self { storage }
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::storage::memory::InMemoryStorage;
276
277 #[tokio::test]
278 async fn test_client_registry_operations() {
279 let storage = Arc::new(InMemoryStorage::new());
280 let registry = ClientRegistry::new(storage).await.unwrap();
281
282 let client_config = ClientConfig {
284 client_id: "test_client".to_string(),
285 client_type: ClientType::Confidential,
286 client_secret: Some("test_secret".to_string()),
287 redirect_uris: vec!["https://example.com/callback".to_string()],
288 ..Default::default()
289 };
290
291 let registered_client = registry
293 .register_client(client_config.clone())
294 .await
295 .unwrap();
296 assert_eq!(registered_client.client_id, "test_client");
297
298 let retrieved_client = registry.get_client("test_client").await.unwrap().unwrap();
300 assert_eq!(retrieved_client.client_id, "test_client");
301 assert_eq!(retrieved_client.client_type, ClientType::Confidential);
302
303 let auth_result = registry
305 .authenticate_client("test_client", "test_secret")
306 .await
307 .unwrap();
308 assert!(auth_result);
309
310 let auth_fail = registry
311 .authenticate_client("test_client", "wrong_secret")
312 .await
313 .unwrap();
314 assert!(!auth_fail);
315
316 let valid_uri = registry
318 .validate_redirect_uri("test_client", "https://example.com/callback")
319 .await
320 .unwrap();
321 assert!(valid_uri);
322
323 let invalid_uri = registry
324 .validate_redirect_uri("test_client", "https://malicious.com/callback")
325 .await
326 .unwrap();
327 assert!(!invalid_uri);
328
329 registry.delete_client("test_client").await.unwrap();
331 let deleted_client = registry.get_client("test_client").await.unwrap();
332 assert!(deleted_client.is_none());
333 }
334
335 #[tokio::test]
336 async fn test_client_validation() {
337 let storage = Arc::new(InMemoryStorage::new());
338 let registry = ClientRegistry::new(storage).await.unwrap();
339
340 let invalid_config = ClientConfig {
342 client_id: "".to_string(),
343 ..Default::default()
344 };
345 assert!(registry.register_client(invalid_config).await.is_err());
346
347 let invalid_config = ClientConfig {
349 client_type: ClientType::Confidential,
350 client_secret: None,
351 ..Default::default()
352 };
353 assert!(registry.register_client(invalid_config).await.is_err());
354
355 let invalid_config = ClientConfig {
357 client_type: ClientType::Public,
358 client_secret: Some("secret".to_string()),
359 ..Default::default()
360 };
361 assert!(registry.register_client(invalid_config).await.is_err());
362
363 let invalid_config = ClientConfig {
365 redirect_uris: vec![],
366 ..Default::default()
367 };
368 assert!(registry.register_client(invalid_config).await.is_err());
369
370 let invalid_config = ClientConfig {
372 redirect_uris: vec!["http://example.com/callback".to_string()],
373 ..Default::default()
374 };
375 assert!(registry.register_client(invalid_config).await.is_err());
376 }
377}
378
379