1use crate::errors::{AuthError, Result, StorageError};
17use crate::storage::AuthStorage;
18use base64::{Engine as _, engine::general_purpose};
19use chrono::{DateTime, Duration, Utc};
20use governor;
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use std::collections::HashMap;
24use std::sync::Arc;
25use url;
26use uuid::Uuid;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ClientRegistrationRequest {
31 pub redirect_uris: Option<Vec<String>>,
33
34 pub token_endpoint_auth_method: Option<String>,
36
37 pub grant_types: Option<Vec<String>>,
39
40 pub response_types: Option<Vec<String>>,
42
43 pub client_name: Option<String>,
45
46 pub client_uri: Option<String>,
48
49 pub logo_uri: Option<String>,
51
52 pub scope: Option<String>,
54
55 pub contacts: Option<Vec<String>>,
57
58 pub tos_uri: Option<String>,
60
61 pub policy_uri: Option<String>,
63
64 pub jwks_uri: Option<String>,
66
67 pub jwks: Option<Value>,
69
70 pub software_id: Option<String>,
72
73 pub software_version: Option<String>,
75
76 #[serde(flatten)]
78 pub additional_metadata: HashMap<String, Value>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ClientRegistrationResponse {
84 pub client_id: String,
86
87 pub client_secret: Option<String>,
89
90 pub registration_access_token: String,
92
93 pub registration_client_uri: String,
95
96 pub client_id_issued_at: Option<i64>,
98
99 pub client_secret_expires_at: Option<i64>,
101
102 pub redirect_uris: Option<Vec<String>>,
104
105 pub token_endpoint_auth_method: Option<String>,
107
108 pub grant_types: Option<Vec<String>>,
110
111 pub response_types: Option<Vec<String>>,
113
114 pub client_name: Option<String>,
116
117 pub client_uri: Option<String>,
119
120 pub logo_uri: Option<String>,
122
123 pub scope: Option<String>,
125
126 pub contacts: Option<Vec<String>>,
128
129 pub tos_uri: Option<String>,
131
132 pub policy_uri: Option<String>,
134
135 pub jwks_uri: Option<String>,
137
138 pub jwks: Option<Value>,
140
141 pub software_id: Option<String>,
143
144 pub software_version: Option<String>,
146
147 #[serde(flatten)]
149 pub additional_metadata: HashMap<String, Value>,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct RegisteredClient {
155 pub client_id: String,
157
158 pub client_secret_hash: Option<String>,
160
161 pub registration_access_token_hash: String,
163
164 pub metadata: ClientRegistrationRequest,
166
167 pub registered_at: DateTime<Utc>,
169
170 pub updated_at: DateTime<Utc>,
172
173 pub client_secret_expires_at: Option<DateTime<Utc>>,
175
176 pub is_active: bool,
178}
179
180#[derive(Debug, Clone)]
182pub struct ClientRegistrationConfig {
183 pub base_url: String,
185
186 pub require_authentication: bool,
188
189 pub default_secret_expiration: Option<i64>,
191
192 pub max_redirect_uris: usize,
194
195 pub allowed_grant_types: Vec<String>,
197
198 pub allowed_response_types: Vec<String>,
200
201 pub allowed_auth_methods: Vec<String>,
203
204 pub allow_public_clients: bool,
206
207 pub rate_limit_per_ip: u32,
209 pub rate_limit_window: std::time::Duration,
210}
211
212impl Default for ClientRegistrationConfig {
213 fn default() -> Self {
214 Self {
215 base_url: "https://auth.example.com".to_string(),
216 require_authentication: false,
217 default_secret_expiration: Some(86400 * 365), max_redirect_uris: 10,
219 allowed_grant_types: vec![
220 "authorization_code".to_string(),
221 "client_credentials".to_string(),
222 "refresh_token".to_string(),
223 "urn:ietf:params:oauth:grant-type:device_code".to_string(),
224 ],
225 allowed_response_types: vec![
226 "code".to_string(),
227 "token".to_string(),
228 "id_token".to_string(),
229 ],
230 allowed_auth_methods: vec![
231 "client_secret_basic".to_string(),
232 "client_secret_post".to_string(),
233 "private_key_jwt".to_string(),
234 "none".to_string(),
235 ],
236 allow_public_clients: true,
237 rate_limit_per_ip: 10,
238 rate_limit_window: std::time::Duration::from_secs(3600),
239 }
240 }
241}
242
243pub struct ClientRegistrationManager {
245 config: ClientRegistrationConfig,
246 storage: Arc<dyn AuthStorage>,
247 rate_limiter: Arc<
248 governor::RateLimiter<
249 governor::state::direct::NotKeyed,
250 governor::state::InMemoryState,
251 governor::clock::DefaultClock,
252 >,
253 >,
254}
255
256impl ClientRegistrationManager {
257 pub fn new(config: ClientRegistrationConfig, storage: Arc<dyn AuthStorage>) -> Self {
259 let quota =
260 governor::Quota::per_hour(std::num::NonZeroU32::new(config.rate_limit_per_ip).unwrap());
261 let rate_limiter = Arc::new(governor::RateLimiter::direct(quota));
262
263 Self {
264 config,
265 storage,
266 rate_limiter,
267 }
268 }
269
270 pub async fn register_client(
272 &self,
273 request: ClientRegistrationRequest,
274 client_ip: Option<std::net::IpAddr>,
275 ) -> Result<ClientRegistrationResponse> {
276 if let Some(_ip) = client_ip
278 && self.rate_limiter.check().is_err()
279 {
280 return Err(AuthError::rate_limit(
281 "Client registration rate limit exceeded",
282 ));
283 }
284
285 self.validate_registration_request(&request)?;
287
288 let client_id = self.generate_client_id();
290 let (client_secret, client_secret_hash) = if self.requires_client_secret(&request) {
291 let secret = self.generate_client_secret();
292 let hash = self.hash_secret(&secret)?;
293 (Some(secret), Some(hash))
294 } else {
295 (None, None)
296 };
297
298 let registration_access_token = self.generate_registration_access_token();
300 let registration_access_token_hash = self.hash_secret(®istration_access_token)?;
301
302 let client_secret_expires_at = if client_secret.is_some() {
304 self.config
305 .default_secret_expiration
306 .map(|seconds| Utc::now() + Duration::seconds(seconds))
307 } else {
308 None
309 };
310
311 let registered_client = RegisteredClient {
313 client_id: client_id.clone(),
314 client_secret_hash,
315 registration_access_token_hash,
316 metadata: request.clone(),
317 registered_at: Utc::now(),
318 updated_at: Utc::now(),
319 client_secret_expires_at,
320 is_active: true,
321 };
322
323 self.store_client(®istered_client).await?;
325
326 let response = ClientRegistrationResponse {
328 client_id: client_id.clone(),
329 client_secret,
330 registration_access_token,
331 registration_client_uri: format!("{}/register/{}", self.config.base_url, client_id),
332 client_id_issued_at: Some(Utc::now().timestamp()),
333 client_secret_expires_at: client_secret_expires_at.map(|dt| dt.timestamp()),
334 redirect_uris: request.redirect_uris,
335 token_endpoint_auth_method: request.token_endpoint_auth_method,
336 grant_types: request.grant_types,
337 response_types: request.response_types,
338 client_name: request.client_name,
339 client_uri: request.client_uri,
340 logo_uri: request.logo_uri,
341 scope: request.scope,
342 contacts: request.contacts,
343 tos_uri: request.tos_uri,
344 policy_uri: request.policy_uri,
345 jwks_uri: request.jwks_uri,
346 jwks: request.jwks,
347 software_id: request.software_id,
348 software_version: request.software_version,
349 additional_metadata: request.additional_metadata,
350 };
351
352 Ok(response)
353 }
354
355 pub async fn read_client(
357 &self,
358 client_id: &str,
359 registration_access_token: &str,
360 ) -> Result<ClientRegistrationResponse> {
361 let client = self.get_client(client_id).await?;
362
363 if !self.verify_registration_token(&client, registration_access_token)? {
365 return Err(AuthError::auth_method(
366 "client_registration",
367 "Invalid registration access token",
368 ));
369 }
370
371 self.client_to_response(&client)
372 }
373
374 pub async fn update_client(
376 &self,
377 client_id: &str,
378 registration_access_token: &str,
379 request: ClientRegistrationRequest,
380 ) -> Result<ClientRegistrationResponse> {
381 let mut client = self.get_client(client_id).await?;
382
383 if !self.verify_registration_token(&client, registration_access_token)? {
385 return Err(AuthError::auth_method(
386 "client_registration",
387 "Invalid registration access token",
388 ));
389 }
390
391 self.validate_registration_request(&request)?;
393
394 client.metadata = request;
396 client.updated_at = Utc::now();
397
398 self.store_client(&client).await?;
400
401 self.client_to_response(&client)
402 }
403
404 pub async fn delete_client(
406 &self,
407 client_id: &str,
408 registration_access_token: &str,
409 ) -> Result<()> {
410 let client = self.get_client(client_id).await?;
411
412 if !self.verify_registration_token(&client, registration_access_token)? {
414 return Err(AuthError::auth_method(
415 "client_registration",
416 "Invalid registration access token",
417 ));
418 }
419
420 let key = format!("client_registration:{}", client_id);
422 self.storage.delete_kv(&key).await?;
423
424 Ok(())
425 }
426
427 fn validate_registration_request(&self, request: &ClientRegistrationRequest) -> Result<()> {
429 if let Some(redirect_uris) = &request.redirect_uris {
431 if redirect_uris.len() > self.config.max_redirect_uris {
432 return Err(AuthError::auth_method(
433 "client_registration",
434 "Too many redirect URIs",
435 ));
436 }
437
438 for uri in redirect_uris {
439 if !self.is_valid_uri(uri) {
440 return Err(AuthError::auth_method(
441 "client_registration",
442 format!("Invalid redirect URI: {}", uri),
443 ));
444 }
445 }
446 }
447
448 if let Some(grant_types) = &request.grant_types {
450 for grant_type in grant_types {
451 if !self.config.allowed_grant_types.contains(grant_type) {
452 return Err(AuthError::auth_method(
453 "client_registration",
454 format!("Unsupported grant type: {}", grant_type),
455 ));
456 }
457 }
458 }
459
460 if let Some(response_types) = &request.response_types {
462 for response_type in response_types {
463 if !self.config.allowed_response_types.contains(response_type) {
464 return Err(AuthError::auth_method(
465 "client_registration",
466 format!("Unsupported response type: {}", response_type),
467 ));
468 }
469 }
470 }
471
472 if let Some(auth_method) = &request.token_endpoint_auth_method
474 && !self.config.allowed_auth_methods.contains(auth_method)
475 {
476 return Err(AuthError::auth_method(
477 "client_registration",
478 format!("Unsupported authentication method: {}", auth_method),
479 ));
480 }
481
482 Ok(())
483 }
484
485 fn requires_client_secret(&self, request: &ClientRegistrationRequest) -> bool {
487 if !self.config.allow_public_clients {
488 return true;
489 }
490
491 !matches!(request.token_endpoint_auth_method.as_deref(), Some("none"))
492 }
493
494 fn generate_client_id(&self) -> String {
496 format!("client_{}", Uuid::new_v4().simple())
497 }
498
499 fn generate_client_secret(&self) -> String {
501 use rand::RngCore;
502 let mut rng = rand::rng();
503 let mut bytes = [0u8; 32];
504 rng.fill_bytes(&mut bytes);
505 general_purpose::URL_SAFE_NO_PAD.encode(bytes)
506 }
507
508 fn generate_registration_access_token(&self) -> String {
510 use rand::RngCore;
511 let mut rng = rand::rng();
512 let mut bytes = [0u8; 32];
513 rng.fill_bytes(&mut bytes);
514 general_purpose::URL_SAFE_NO_PAD.encode(bytes)
515 }
516
517 fn hash_secret(&self, secret: &str) -> Result<String> {
519 use sha2::{Digest, Sha256};
520 let mut hasher = Sha256::new();
521 hasher.update(secret.as_bytes());
522 Ok(format!("{:x}", hasher.finalize()))
523 }
524
525 fn verify_registration_token(&self, client: &RegisteredClient, token: &str) -> Result<bool> {
527 let token_hash = self.hash_secret(token)?;
528 Ok(client.registration_access_token_hash == token_hash)
529 }
530
531 fn is_valid_uri(&self, uri: &str) -> bool {
533 url::Url::parse(uri).is_ok()
534 }
535
536 async fn store_client(&self, client: &RegisteredClient) -> Result<()> {
538 let key = format!("client_registration:{}", client.client_id);
539 let value = serde_json::to_string(client)?;
540 self.storage.store_kv(&key, value.as_bytes(), None).await?;
541 Ok(())
542 }
543
544 async fn get_client(&self, client_id: &str) -> Result<RegisteredClient> {
546 let key = format!("client_registration:{}", client_id);
547 let value = match self.storage.get_kv(&key).await? {
548 Some(value) => value,
549 None => {
550 return Err(AuthError::auth_method(
551 "client_registration",
552 "Client not found",
553 ));
554 }
555 };
556 let value_str = String::from_utf8(value).map_err(|e| {
557 AuthError::Storage(StorageError::Serialization {
558 message: format!("Invalid UTF-8 data: {}", e),
559 })
560 })?;
561 let client: RegisteredClient = serde_json::from_str(&value_str)?;
562 Ok(client)
563 }
564
565 fn client_to_response(&self, client: &RegisteredClient) -> Result<ClientRegistrationResponse> {
567 Ok(ClientRegistrationResponse {
568 client_id: client.client_id.clone(),
569 client_secret: None, registration_access_token: "***".to_string(), registration_client_uri: format!(
572 "{}/register/{}",
573 self.config.base_url, client.client_id
574 ),
575 client_id_issued_at: Some(client.registered_at.timestamp()),
576 client_secret_expires_at: client.client_secret_expires_at.map(|dt| dt.timestamp()),
577 redirect_uris: client.metadata.redirect_uris.clone(),
578 token_endpoint_auth_method: client.metadata.token_endpoint_auth_method.clone(),
579 grant_types: client.metadata.grant_types.clone(),
580 response_types: client.metadata.response_types.clone(),
581 client_name: client.metadata.client_name.clone(),
582 client_uri: client.metadata.client_uri.clone(),
583 logo_uri: client.metadata.logo_uri.clone(),
584 scope: client.metadata.scope.clone(),
585 contacts: client.metadata.contacts.clone(),
586 tos_uri: client.metadata.tos_uri.clone(),
587 policy_uri: client.metadata.policy_uri.clone(),
588 jwks_uri: client.metadata.jwks_uri.clone(),
589 jwks: client.metadata.jwks.clone(),
590 software_id: client.metadata.software_id.clone(),
591 software_version: client.metadata.software_version.clone(),
592 additional_metadata: client.metadata.additional_metadata.clone(),
593 })
594 }
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600 use crate::storage::MemoryStorage;
601
602 #[tokio::test]
603 async fn test_client_registration() {
604 let storage = Arc::new(MemoryStorage::new());
605 let config = ClientRegistrationConfig::default();
606 let manager = ClientRegistrationManager::new(config, storage);
607
608 let request = ClientRegistrationRequest {
609 redirect_uris: Some(vec!["https://client.example.com/callback".to_string()]),
610 token_endpoint_auth_method: Some("client_secret_basic".to_string()),
611 grant_types: Some(vec!["authorization_code".to_string()]),
612 response_types: Some(vec!["code".to_string()]),
613 client_name: Some("Test Client".to_string()),
614 client_uri: Some("https://client.example.com".to_string()),
615 logo_uri: Some("https://client.example.com/logo.png".to_string()),
616 scope: Some("read write".to_string()),
617 contacts: Some(vec!["admin@client.example.com".to_string()]),
618 tos_uri: Some("https://client.example.com/tos".to_string()),
619 policy_uri: Some("https://client.example.com/privacy".to_string()),
620 jwks_uri: Some("https://client.example.com/jwks".to_string()),
621 jwks: None,
622 software_id: Some("test-client".to_string()),
623 software_version: Some("1.0.0".to_string()),
624 additional_metadata: HashMap::new(),
625 };
626
627 let response = manager
628 .register_client(request.clone(), None)
629 .await
630 .unwrap();
631
632 assert!(!response.client_id.is_empty());
633 assert!(response.client_secret.is_some());
634 assert!(!response.registration_access_token.is_empty());
635 assert_eq!(response.client_name, Some("Test Client".to_string()));
636 assert_eq!(
637 response.redirect_uris,
638 Some(vec!["https://client.example.com/callback".to_string()])
639 );
640 }
641
642 #[tokio::test]
643 async fn test_public_client_registration() {
644 let storage = Arc::new(MemoryStorage::new());
645 let config = ClientRegistrationConfig::default();
646 let manager = ClientRegistrationManager::new(config, storage);
647
648 let request = ClientRegistrationRequest {
649 redirect_uris: Some(vec!["https://client.example.com/callback".to_string()]),
650 token_endpoint_auth_method: Some("none".to_string()),
651 grant_types: Some(vec!["authorization_code".to_string()]),
652 response_types: Some(vec!["code".to_string()]),
653 client_name: Some("Public Client".to_string()),
654 client_uri: None,
655 logo_uri: None,
656 scope: Some("read".to_string()),
657 contacts: None,
658 tos_uri: None,
659 policy_uri: None,
660 jwks_uri: None,
661 jwks: None,
662 software_id: None,
663 software_version: None,
664 additional_metadata: HashMap::new(),
665 };
666
667 let response = manager.register_client(request, None).await.unwrap();
668
669 assert!(!response.client_id.is_empty());
670 assert!(response.client_secret.is_none()); assert!(!response.registration_access_token.is_empty());
672 assert_eq!(response.client_name, Some("Public Client".to_string()));
673 }
674}