1use crate::errors::{AuthError, Result, StorageError};
17use crate::storage::AuthStorage;
18use base64::{Engine as _, engine::general_purpose};
19use chrono::{DateTime, Duration, Utc};
20use governor;
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use std::collections::HashMap;
24use std::sync::Arc;
25use url;
26use uuid::Uuid;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ClientRegistrationRequest {
31 pub redirect_uris: Option<Vec<String>>,
33
34 pub token_endpoint_auth_method: Option<String>,
36
37 pub grant_types: Option<Vec<String>>,
39
40 pub response_types: Option<Vec<String>>,
42
43 pub client_name: Option<String>,
45
46 pub client_uri: Option<String>,
48
49 pub logo_uri: Option<String>,
51
52 pub scope: Option<String>,
54
55 pub contacts: Option<Vec<String>>,
57
58 pub tos_uri: Option<String>,
60
61 pub policy_uri: Option<String>,
63
64 pub jwks_uri: Option<String>,
66
67 pub jwks: Option<Value>,
69
70 pub software_id: Option<String>,
72
73 pub software_version: Option<String>,
75
76 #[serde(flatten)]
78 pub additional_metadata: HashMap<String, Value>,
79}
80
81impl ClientRegistrationRequest {
82 pub fn builder(redirect_uri: impl Into<String>) -> ClientRegistrationRequestBuilder {
84 ClientRegistrationRequestBuilder {
85 redirect_uris: Some(vec![redirect_uri.into()]),
86 token_endpoint_auth_method: None,
87 grant_types: None,
88 response_types: None,
89 client_name: None,
90 client_uri: None,
91 logo_uri: None,
92 scope: None,
93 contacts: None,
94 tos_uri: None,
95 policy_uri: None,
96 jwks_uri: None,
97 jwks: None,
98 software_id: None,
99 software_version: None,
100 additional_metadata: HashMap::new(),
101 }
102 }
103}
104
105pub struct ClientRegistrationRequestBuilder {
107 redirect_uris: Option<Vec<String>>,
108 token_endpoint_auth_method: Option<String>,
109 grant_types: Option<Vec<String>>,
110 response_types: Option<Vec<String>>,
111 client_name: Option<String>,
112 client_uri: Option<String>,
113 logo_uri: Option<String>,
114 scope: Option<String>,
115 contacts: Option<Vec<String>>,
116 tos_uri: Option<String>,
117 policy_uri: Option<String>,
118 jwks_uri: Option<String>,
119 jwks: Option<Value>,
120 software_id: Option<String>,
121 software_version: Option<String>,
122 additional_metadata: HashMap<String, Value>,
123}
124
125impl ClientRegistrationRequestBuilder {
126 pub fn redirect_uris<I, S>(mut self, redirect_uris: I) -> Self
128 where
129 I: IntoIterator<Item = S>,
130 S: Into<String>,
131 {
132 self.redirect_uris = Some(redirect_uris.into_iter().map(Into::into).collect());
133 self
134 }
135
136 pub fn auth_method(mut self, method: impl Into<String>) -> Self {
138 self.token_endpoint_auth_method = Some(method.into());
139 self
140 }
141
142 pub fn public_client(self) -> Self {
144 self.auth_method("none")
145 }
146
147 pub fn grant_types<I, S>(mut self, grant_types: I) -> Self
149 where
150 I: IntoIterator<Item = S>,
151 S: Into<String>,
152 {
153 self.grant_types = Some(grant_types.into_iter().map(Into::into).collect());
154 self
155 }
156
157 pub fn response_types<I, S>(mut self, response_types: I) -> Self
159 where
160 I: IntoIterator<Item = S>,
161 S: Into<String>,
162 {
163 self.response_types = Some(response_types.into_iter().map(Into::into).collect());
164 self
165 }
166
167 pub fn client_name(mut self, client_name: impl Into<String>) -> Self {
169 self.client_name = Some(client_name.into());
170 self
171 }
172
173 pub fn client_uri(mut self, client_uri: impl Into<String>) -> Self {
175 self.client_uri = Some(client_uri.into());
176 self
177 }
178
179 pub fn logo_uri(mut self, logo_uri: impl Into<String>) -> Self {
181 self.logo_uri = Some(logo_uri.into());
182 self
183 }
184
185 pub fn scope(mut self, scope: impl Into<String>) -> Self {
187 self.scope = Some(scope.into());
188 self
189 }
190
191 pub fn contacts<I, S>(mut self, contacts: I) -> Self
193 where
194 I: IntoIterator<Item = S>,
195 S: Into<String>,
196 {
197 self.contacts = Some(contacts.into_iter().map(Into::into).collect());
198 self
199 }
200
201 pub fn tos_uri(mut self, tos_uri: impl Into<String>) -> Self {
203 self.tos_uri = Some(tos_uri.into());
204 self
205 }
206
207 pub fn policy_uri(mut self, policy_uri: impl Into<String>) -> Self {
209 self.policy_uri = Some(policy_uri.into());
210 self
211 }
212
213 pub fn jwks_uri(mut self, jwks_uri: impl Into<String>) -> Self {
215 self.jwks_uri = Some(jwks_uri.into());
216 self
217 }
218
219 pub fn jwks(mut self, jwks: Value) -> Self {
221 self.jwks = Some(jwks);
222 self
223 }
224
225 pub fn software(mut self, software_id: impl Into<String>, software_version: impl Into<String>) -> Self {
227 self.software_id = Some(software_id.into());
228 self.software_version = Some(software_version.into());
229 self
230 }
231
232 pub fn metadata(mut self, key: impl Into<String>, value: Value) -> Self {
234 self.additional_metadata.insert(key.into(), value);
235 self
236 }
237
238 pub fn build(self) -> ClientRegistrationRequest {
240 ClientRegistrationRequest {
241 redirect_uris: self.redirect_uris,
242 token_endpoint_auth_method: self.token_endpoint_auth_method,
243 grant_types: self.grant_types,
244 response_types: self.response_types,
245 client_name: self.client_name,
246 client_uri: self.client_uri,
247 logo_uri: self.logo_uri,
248 scope: self.scope,
249 contacts: self.contacts,
250 tos_uri: self.tos_uri,
251 policy_uri: self.policy_uri,
252 jwks_uri: self.jwks_uri,
253 jwks: self.jwks,
254 software_id: self.software_id,
255 software_version: self.software_version,
256 additional_metadata: self.additional_metadata,
257 }
258 }
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct ClientRegistrationResponse {
264 pub client_id: String,
266
267 pub client_secret: Option<String>,
269
270 pub registration_access_token: String,
272
273 pub registration_client_uri: String,
275
276 pub client_id_issued_at: Option<i64>,
278
279 pub client_secret_expires_at: Option<i64>,
281
282 pub redirect_uris: Option<Vec<String>>,
284
285 pub token_endpoint_auth_method: Option<String>,
287
288 pub grant_types: Option<Vec<String>>,
290
291 pub response_types: Option<Vec<String>>,
293
294 pub client_name: Option<String>,
296
297 pub client_uri: Option<String>,
299
300 pub logo_uri: Option<String>,
302
303 pub scope: Option<String>,
305
306 pub contacts: Option<Vec<String>>,
308
309 pub tos_uri: Option<String>,
311
312 pub policy_uri: Option<String>,
314
315 pub jwks_uri: Option<String>,
317
318 pub jwks: Option<Value>,
320
321 pub software_id: Option<String>,
323
324 pub software_version: Option<String>,
326
327 #[serde(flatten)]
329 pub additional_metadata: HashMap<String, Value>,
330}
331
332#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct RegisteredClient {
335 pub client_id: String,
337
338 pub client_secret_hash: Option<String>,
340
341 pub registration_access_token_hash: String,
343
344 pub metadata: ClientRegistrationRequest,
346
347 pub registered_at: DateTime<Utc>,
349
350 pub updated_at: DateTime<Utc>,
352
353 pub client_secret_expires_at: Option<DateTime<Utc>>,
355
356 pub is_active: bool,
358}
359
360#[derive(Debug, Clone)]
362pub struct ClientRegistrationConfig {
363 pub base_url: String,
365
366 pub require_authentication: bool,
368
369 pub default_secret_expiration: Option<i64>,
371
372 pub max_redirect_uris: usize,
374
375 pub allowed_grant_types: Vec<String>,
377
378 pub allowed_response_types: Vec<String>,
380
381 pub allowed_auth_methods: Vec<String>,
383
384 pub allow_public_clients: bool,
386
387 pub rate_limit_per_ip: u32,
389 pub rate_limit_window: std::time::Duration,
390}
391
392impl Default for ClientRegistrationConfig {
393 fn default() -> Self {
394 Self {
395 base_url: "https://auth.example.com".to_string(),
396 require_authentication: false,
397 default_secret_expiration: Some(86400 * 365), max_redirect_uris: 10,
399 allowed_grant_types: vec![
400 "authorization_code".to_string(),
401 "client_credentials".to_string(),
402 "refresh_token".to_string(),
403 "urn:ietf:params:oauth:grant-type:device_code".to_string(),
404 ],
405 allowed_response_types: vec![
406 "code".to_string(),
407 "token".to_string(),
408 "id_token".to_string(),
409 ],
410 allowed_auth_methods: vec![
411 "client_secret_basic".to_string(),
412 "client_secret_post".to_string(),
413 "private_key_jwt".to_string(),
414 "none".to_string(),
415 ],
416 allow_public_clients: true,
417 rate_limit_per_ip: 10,
418 rate_limit_window: std::time::Duration::from_secs(3600),
419 }
420 }
421}
422
423pub struct ClientRegistrationManager {
425 config: ClientRegistrationConfig,
426 storage: Arc<dyn AuthStorage>,
427 rate_limiter: Arc<
428 governor::RateLimiter<
429 governor::state::direct::NotKeyed,
430 governor::state::InMemoryState,
431 governor::clock::DefaultClock,
432 >,
433 >,
434}
435
436impl ClientRegistrationManager {
437 pub fn new(config: ClientRegistrationConfig, storage: Arc<dyn AuthStorage>) -> Self {
439 let quota = governor::Quota::per_hour(
440 std::num::NonZeroU32::new(config.rate_limit_per_ip.max(1))
441 .expect("clamped to at least 1"),
442 );
443 let rate_limiter = Arc::new(governor::RateLimiter::direct(quota));
444
445 Self {
446 config,
447 storage,
448 rate_limiter,
449 }
450 }
451
452 pub async fn register_client(
454 &self,
455 request: ClientRegistrationRequest,
456 client_ip: Option<std::net::IpAddr>,
457 ) -> Result<ClientRegistrationResponse> {
458 if let Some(_ip) = client_ip
460 && self.rate_limiter.check().is_err()
461 {
462 return Err(AuthError::rate_limit(
463 "Client registration rate limit exceeded",
464 ));
465 }
466
467 self.validate_registration_request(&request)?;
469
470 let client_id = self.generate_client_id();
472 let (client_secret, client_secret_hash) = if self.requires_client_secret(&request) {
473 let secret = self.generate_client_secret();
474 let hash = self.hash_secret(&secret)?;
475 (Some(secret), Some(hash))
476 } else {
477 (None, None)
478 };
479
480 let registration_access_token = self.generate_registration_access_token();
482 let registration_access_token_hash = self.hash_secret(®istration_access_token)?;
483
484 let client_secret_expires_at = if client_secret.is_some() {
486 self.config
487 .default_secret_expiration
488 .map(|seconds| Utc::now() + Duration::seconds(seconds))
489 } else {
490 None
491 };
492
493 let registered_client = RegisteredClient {
495 client_id: client_id.clone(),
496 client_secret_hash,
497 registration_access_token_hash,
498 metadata: request.clone(),
499 registered_at: Utc::now(),
500 updated_at: Utc::now(),
501 client_secret_expires_at,
502 is_active: true,
503 };
504
505 self.store_client(®istered_client).await?;
507
508 let response = ClientRegistrationResponse {
510 client_id: client_id.clone(),
511 client_secret,
512 registration_access_token,
513 registration_client_uri: format!("{}/register/{}", self.config.base_url, client_id),
514 client_id_issued_at: Some(Utc::now().timestamp()),
515 client_secret_expires_at: client_secret_expires_at.map(|dt| dt.timestamp()),
516 redirect_uris: request.redirect_uris,
517 token_endpoint_auth_method: request.token_endpoint_auth_method,
518 grant_types: request.grant_types,
519 response_types: request.response_types,
520 client_name: request.client_name,
521 client_uri: request.client_uri,
522 logo_uri: request.logo_uri,
523 scope: request.scope,
524 contacts: request.contacts,
525 tos_uri: request.tos_uri,
526 policy_uri: request.policy_uri,
527 jwks_uri: request.jwks_uri,
528 jwks: request.jwks,
529 software_id: request.software_id,
530 software_version: request.software_version,
531 additional_metadata: request.additional_metadata,
532 };
533
534 Ok(response)
535 }
536
537 pub async fn read_client(
539 &self,
540 client_id: &str,
541 registration_access_token: &str,
542 ) -> Result<ClientRegistrationResponse> {
543 let client = self.get_client(client_id).await?;
544
545 if !self.verify_registration_token(&client, registration_access_token)? {
547 return Err(AuthError::auth_method(
548 "client_registration",
549 "Invalid registration access token",
550 ));
551 }
552
553 self.client_to_response(&client)
554 }
555
556 pub async fn update_client(
558 &self,
559 client_id: &str,
560 registration_access_token: &str,
561 request: ClientRegistrationRequest,
562 ) -> Result<ClientRegistrationResponse> {
563 let mut client = self.get_client(client_id).await?;
564
565 if !self.verify_registration_token(&client, registration_access_token)? {
567 return Err(AuthError::auth_method(
568 "client_registration",
569 "Invalid registration access token",
570 ));
571 }
572
573 self.validate_registration_request(&request)?;
575
576 client.metadata = request;
578 client.updated_at = Utc::now();
579
580 self.store_client(&client).await?;
582
583 self.client_to_response(&client)
584 }
585
586 pub async fn delete_client(
588 &self,
589 client_id: &str,
590 registration_access_token: &str,
591 ) -> Result<()> {
592 let client = self.get_client(client_id).await?;
593
594 if !self.verify_registration_token(&client, registration_access_token)? {
596 return Err(AuthError::auth_method(
597 "client_registration",
598 "Invalid registration access token",
599 ));
600 }
601
602 let key = format!("client_registration:{}", client_id);
604 self.storage.delete_kv(&key).await?;
605
606 Ok(())
607 }
608
609 fn validate_registration_request(&self, request: &ClientRegistrationRequest) -> Result<()> {
611 if let Some(redirect_uris) = &request.redirect_uris {
613 if redirect_uris.len() > self.config.max_redirect_uris {
614 return Err(AuthError::auth_method(
615 "client_registration",
616 "Too many redirect URIs",
617 ));
618 }
619
620 for uri in redirect_uris {
621 if !self.is_valid_uri(uri) {
622 return Err(AuthError::auth_method(
623 "client_registration",
624 format!("Invalid redirect URI: {}", uri),
625 ));
626 }
627 }
628 }
629
630 if let Some(grant_types) = &request.grant_types {
632 for grant_type in grant_types {
633 if !self.config.allowed_grant_types.contains(grant_type) {
634 return Err(AuthError::auth_method(
635 "client_registration",
636 format!("Unsupported grant type: {}", grant_type),
637 ));
638 }
639 }
640 }
641
642 if let Some(response_types) = &request.response_types {
644 for response_type in response_types {
645 if !self.config.allowed_response_types.contains(response_type) {
646 return Err(AuthError::auth_method(
647 "client_registration",
648 format!("Unsupported response type: {}", response_type),
649 ));
650 }
651 }
652 }
653
654 if let Some(auth_method) = &request.token_endpoint_auth_method
656 && !self.config.allowed_auth_methods.contains(auth_method)
657 {
658 return Err(AuthError::auth_method(
659 "client_registration",
660 format!("Unsupported authentication method: {}", auth_method),
661 ));
662 }
663
664 Ok(())
665 }
666
667 fn requires_client_secret(&self, request: &ClientRegistrationRequest) -> bool {
669 if !self.config.allow_public_clients {
670 return true;
671 }
672
673 !matches!(request.token_endpoint_auth_method.as_deref(), Some("none"))
674 }
675
676 fn generate_client_id(&self) -> String {
678 format!("client_{}", Uuid::new_v4().simple())
679 }
680
681 fn generate_client_secret(&self) -> String {
683 use rand::Rng;
684 let mut rng = rand::rng();
685 let mut bytes = [0u8; 32];
686 rng.fill_bytes(&mut bytes);
687 general_purpose::URL_SAFE_NO_PAD.encode(bytes)
688 }
689
690 fn generate_registration_access_token(&self) -> String {
692 use rand::Rng;
693 let mut rng = rand::rng();
694 let mut bytes = [0u8; 32];
695 rng.fill_bytes(&mut bytes);
696 general_purpose::URL_SAFE_NO_PAD.encode(bytes)
697 }
698
699 fn hash_secret(&self, secret: &str) -> Result<String> {
701 use sha2::{Digest, Sha256};
702 let mut hasher = Sha256::new();
703 hasher.update(secret.as_bytes());
704 Ok(format!("{:x}", hasher.finalize()))
705 }
706
707 fn verify_registration_token(&self, client: &RegisteredClient, token: &str) -> Result<bool> {
709 use subtle::ConstantTimeEq;
710 let token_hash = self.hash_secret(token)?;
711 Ok(client.registration_access_token_hash.as_bytes().ct_eq(token_hash.as_bytes()).into())
712 }
713
714 fn is_valid_uri(&self, uri: &str) -> bool {
718 let parsed = match url::Url::parse(uri) {
719 Ok(u) => u,
720 Err(_) => return false,
721 };
722 match parsed.scheme() {
723 "https" => true,
724 "http" => {
725 matches!(parsed.host_str(), Some("localhost" | "127.0.0.1" | "[::1]"))
727 }
728 _ => false,
729 }
730 }
731
732 async fn store_client(&self, client: &RegisteredClient) -> Result<()> {
734 let key = format!("client_registration:{}", client.client_id);
735 let value = serde_json::to_string(client)?;
736 self.storage.store_kv(&key, value.as_bytes(), None).await?;
737 Ok(())
738 }
739
740 async fn get_client(&self, client_id: &str) -> Result<RegisteredClient> {
742 let key = format!("client_registration:{}", client_id);
743 let value = match self.storage.get_kv(&key).await? {
744 Some(value) => value,
745 None => {
746 return Err(AuthError::auth_method(
747 "client_registration",
748 "Client not found",
749 ));
750 }
751 };
752 let value_str = String::from_utf8(value).map_err(|e| {
753 AuthError::Storage(StorageError::Serialization {
754 message: format!("Invalid UTF-8 data: {}", e),
755 })
756 })?;
757 let client: RegisteredClient = serde_json::from_str(&value_str)?;
758 Ok(client)
759 }
760
761 fn client_to_response(&self, client: &RegisteredClient) -> Result<ClientRegistrationResponse> {
763 Ok(ClientRegistrationResponse {
764 client_id: client.client_id.clone(),
765 client_secret: None, registration_access_token: "***".to_string(), registration_client_uri: format!(
768 "{}/register/{}",
769 self.config.base_url, client.client_id
770 ),
771 client_id_issued_at: Some(client.registered_at.timestamp()),
772 client_secret_expires_at: client.client_secret_expires_at.map(|dt| dt.timestamp()),
773 redirect_uris: client.metadata.redirect_uris.clone(),
774 token_endpoint_auth_method: client.metadata.token_endpoint_auth_method.clone(),
775 grant_types: client.metadata.grant_types.clone(),
776 response_types: client.metadata.response_types.clone(),
777 client_name: client.metadata.client_name.clone(),
778 client_uri: client.metadata.client_uri.clone(),
779 logo_uri: client.metadata.logo_uri.clone(),
780 scope: client.metadata.scope.clone(),
781 contacts: client.metadata.contacts.clone(),
782 tos_uri: client.metadata.tos_uri.clone(),
783 policy_uri: client.metadata.policy_uri.clone(),
784 jwks_uri: client.metadata.jwks_uri.clone(),
785 jwks: client.metadata.jwks.clone(),
786 software_id: client.metadata.software_id.clone(),
787 software_version: client.metadata.software_version.clone(),
788 additional_metadata: client.metadata.additional_metadata.clone(),
789 })
790 }
791}
792
793#[cfg(test)]
794mod tests {
795 use super::*;
796 use crate::storage::MemoryStorage;
797
798 #[tokio::test]
799 async fn test_client_registration() {
800 let storage = Arc::new(MemoryStorage::new());
801 let config = ClientRegistrationConfig::default();
802 let manager = ClientRegistrationManager::new(config, storage);
803
804 let request = ClientRegistrationRequest::builder(
805 "https://client.example.com/callback",
806 )
807 .auth_method("client_secret_basic")
808 .grant_types(["authorization_code"])
809 .response_types(["code"])
810 .client_name("Test Client")
811 .client_uri("https://client.example.com")
812 .logo_uri("https://client.example.com/logo.png")
813 .scope("read write")
814 .contacts(["admin@client.example.com"])
815 .tos_uri("https://client.example.com/tos")
816 .policy_uri("https://client.example.com/privacy")
817 .jwks_uri("https://client.example.com/jwks")
818 .software("test-client", "1.0.0")
819 .build();
820
821 let response = manager
822 .register_client(request.clone(), None)
823 .await
824 .unwrap();
825
826 assert!(!response.client_id.is_empty());
827 assert!(response.client_secret.is_some());
828 assert!(!response.registration_access_token.is_empty());
829 assert_eq!(response.client_name, Some("Test Client".to_string()));
830 assert_eq!(
831 response.redirect_uris,
832 Some(vec!["https://client.example.com/callback".to_string()])
833 );
834 }
835
836 #[tokio::test]
837 async fn test_public_client_registration() {
838 let storage = Arc::new(MemoryStorage::new());
839 let config = ClientRegistrationConfig::default();
840 let manager = ClientRegistrationManager::new(config, storage);
841
842 let request = ClientRegistrationRequest::builder(
843 "https://client.example.com/callback",
844 )
845 .public_client()
846 .grant_types(["authorization_code"])
847 .response_types(["code"])
848 .client_name("Public Client")
849 .scope("read")
850 .build();
851
852 let response = manager.register_client(request, None).await.unwrap();
853
854 assert!(!response.client_id.is_empty());
855 assert!(response.client_secret.is_none()); assert!(!response.registration_access_token.is_empty());
857 assert_eq!(response.client_name, Some("Public Client".to_string()));
858 }
859
860 #[test]
861 fn test_client_registration_request_builder() {
862 let request = ClientRegistrationRequest::builder("https://client.example.com/callback")
863 .redirect_uris([
864 "https://client.example.com/callback",
865 "https://client.example.com/alt",
866 ])
867 .auth_method("private_key_jwt")
868 .grant_types(["authorization_code", "refresh_token"])
869 .response_types(["code"])
870 .client_name("Builder Client")
871 .metadata("tenant", serde_json::json!("acme"))
872 .build();
873
874 assert_eq!(request.redirect_uris.as_ref().map(Vec::len), Some(2));
875 assert_eq!(
876 request.token_endpoint_auth_method.as_deref(),
877 Some("private_key_jwt")
878 );
879 assert_eq!(request.additional_metadata["tenant"], "acme");
880 }
881}