1use crate::errors::{AuthError, Result};
56use crate::server::oidc::oidc_error_extensions::{
57 OidcErrorCode, OidcErrorManager, OidcErrorResponse,
58};
59use crate::storage::AuthStorage;
60use serde::{Deserialize, Serialize};
61use std::collections::HashMap;
62use std::sync::Arc;
63use tracing::warn;
64use uuid::Uuid;
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct RegistrationRequest {
69 pub client_id: String,
71 pub redirect_uri: String,
72 pub scope: String,
73 pub response_type: String,
74 pub state: Option<String>,
75 pub nonce: Option<String>,
76
77 pub prompt: Option<String>,
80 pub login_hint: Option<String>,
82 pub ui_locales: Option<String>,
84 pub registration_metadata: Option<String>,
86 pub claims: Option<String>,
88}
89
90impl RegistrationRequest {
91 pub fn builder(
93 client_id: impl Into<String>,
94 redirect_uri: impl Into<String>,
95 scope: impl Into<String>,
96 response_type: impl Into<String>,
97 ) -> RegistrationRequestBuilder {
98 RegistrationRequestBuilder {
99 client_id: client_id.into(),
100 redirect_uri: redirect_uri.into(),
101 scope: scope.into(),
102 response_type: response_type.into(),
103 state: None,
104 nonce: None,
105 prompt: None,
106 login_hint: None,
107 ui_locales: None,
108 registration_metadata: None,
109 claims: None,
110 }
111 }
112}
113
114pub struct RegistrationRequestBuilder {
116 client_id: String,
117 redirect_uri: String,
118 scope: String,
119 response_type: String,
120 state: Option<String>,
121 nonce: Option<String>,
122 prompt: Option<String>,
123 login_hint: Option<String>,
124 ui_locales: Option<String>,
125 registration_metadata: Option<String>,
126 claims: Option<String>,
127}
128
129impl RegistrationRequestBuilder {
130 pub fn state(mut self, state: impl Into<String>) -> Self {
132 self.state = Some(state.into());
133 self
134 }
135
136 pub fn nonce(mut self, nonce: impl Into<String>) -> Self {
138 self.nonce = Some(nonce.into());
139 self
140 }
141
142 pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
144 self.prompt = Some(prompt.into());
145 self
146 }
147
148 pub fn login_hint(mut self, login_hint: impl Into<String>) -> Self {
150 self.login_hint = Some(login_hint.into());
151 self
152 }
153
154 pub fn ui_locales(mut self, ui_locales: impl Into<String>) -> Self {
156 self.ui_locales = Some(ui_locales.into());
157 self
158 }
159
160 pub fn registration_metadata(mut self, registration_metadata: impl Into<String>) -> Self {
162 self.registration_metadata = Some(registration_metadata.into());
163 self
164 }
165
166 pub fn claims(mut self, claims: impl Into<String>) -> Self {
168 self.claims = Some(claims.into());
169 self
170 }
171
172 pub fn build(self) -> RegistrationRequest {
174 RegistrationRequest {
175 client_id: self.client_id,
176 redirect_uri: self.redirect_uri,
177 scope: self.scope,
178 response_type: self.response_type,
179 state: self.state,
180 nonce: self.nonce,
181 prompt: self.prompt,
182 login_hint: self.login_hint,
183 ui_locales: self.ui_locales,
184 registration_metadata: self.registration_metadata,
185 claims: self.claims,
186 }
187 }
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize, Default)]
192pub struct RegistrationData {
193 pub registration_id: String,
195 pub email: Option<String>,
197 pub phone_number: Option<String>,
199 pub given_name: Option<String>,
201 pub family_name: Option<String>,
203 pub name: Option<String>,
205 pub preferred_username: Option<String>,
207 pub picture: Option<String>,
209 pub website: Option<String>,
211 pub gender: Option<String>,
213 pub birthdate: Option<String>,
215 pub zoneinfo: Option<String>,
217 pub locale: Option<String>,
219 pub custom_fields: HashMap<String, serde_json::Value>,
221 pub completed: bool,
223 pub created_at: u64,
225}
226
227impl RegistrationData {
228 pub fn new(registration_id: impl Into<String>) -> Self {
230 use std::time::SystemTime;
231 Self {
232 registration_id: registration_id.into(),
233 email: None,
234 phone_number: None,
235 given_name: None,
236 family_name: None,
237 name: None,
238 preferred_username: None,
239 picture: None,
240 website: None,
241 gender: None,
242 birthdate: None,
243 zoneinfo: None,
244 locale: None,
245 custom_fields: HashMap::new(),
246 completed: false,
247 created_at: SystemTime::now()
248 .duration_since(SystemTime::UNIX_EPOCH)
249 .unwrap_or_default()
250 .as_secs(),
251 }
252 }
253
254 pub fn with_email(mut self, email: impl Into<String>) -> Self {
256 self.email = Some(email.into());
257 self
258 }
259
260 pub fn with_phone_number(mut self, phone_number: impl Into<String>) -> Self {
262 self.phone_number = Some(phone_number.into());
263 self
264 }
265
266 pub fn with_name(mut self, name: impl Into<String>) -> Self {
268 self.name = Some(name.into());
269 self
270 }
271
272 pub fn with_names(mut self, given: impl Into<String>, family: impl Into<String>) -> Self {
274 self.given_name = Some(given.into());
275 self.family_name = Some(family.into());
276 self
277 }
278
279 pub fn mark_completed(mut self) -> Self {
281 self.completed = true;
282 self
283 }
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct RegistrationResponse {
289 pub sub: String,
291 pub completed: bool,
293 pub code: Option<String>,
295 pub state: Option<String>,
297}
298
299#[derive(Debug, Clone)]
301pub struct RegistrationConfig {
302 pub enabled: bool,
304 pub registration_endpoint: String,
306 pub required_fields: Vec<String>,
308 pub optional_fields: Vec<String>,
310 pub session_timeout: u64,
312 pub require_email_verification: bool,
314 pub require_phone_verification: bool,
316 pub field_validation_rules: HashMap<String, String>,
318}
319
320impl Default for RegistrationConfig {
321 fn default() -> Self {
322 Self {
323 enabled: true,
324 registration_endpoint: "/connect/register".to_string(),
325 required_fields: vec!["email".to_string()],
326 optional_fields: vec![
327 "given_name".to_string(),
328 "family_name".to_string(),
329 "name".to_string(),
330 "preferred_username".to_string(),
331 "phone_number".to_string(),
332 ],
333 session_timeout: 1800, require_email_verification: true,
335 require_phone_verification: false,
336 field_validation_rules: HashMap::new(),
337 }
338 }
339}
340
341#[derive(Clone)]
343pub struct RegistrationManager {
344 config: RegistrationConfig,
346 error_manager: OidcErrorManager,
348 registration_sessions: HashMap<String, RegistrationData>,
350 storage: Option<Arc<dyn AuthStorage>>,
352}
353
354impl std::fmt::Debug for RegistrationManager {
355 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 f.debug_struct("RegistrationManager")
357 .field("config", &self.config)
358 .field("sessions", &self.registration_sessions.len())
359 .field("has_storage", &self.storage.is_some())
360 .finish()
361 }
362}
363
364impl RegistrationManager {
365 pub fn new(config: RegistrationConfig) -> Self {
367 Self {
368 config,
369 error_manager: OidcErrorManager::default(),
370 registration_sessions: HashMap::new(),
371 storage: None,
372 }
373 }
374
375 pub fn storage(mut self, storage: Arc<dyn AuthStorage>) -> Self {
377 self.storage = Some(storage);
378 self
379 }
380
381 pub fn error_manager(mut self, error_manager: OidcErrorManager) -> Self {
383 self.error_manager = error_manager;
384 self
385 }
386
387 pub fn create_registration_disabled_error(&self, state: Option<String>) -> OidcErrorResponse {
389 self.error_manager.create_error_response(
390 OidcErrorCode::RegistrationNotSupported,
391 Some("User registration is not enabled on this server".to_string()),
392 state,
393 HashMap::new(),
394 )
395 }
396
397 pub fn create_invalid_registration_request_error(
399 &self,
400 description: String,
401 state: Option<String>,
402 ) -> OidcErrorResponse {
403 self.error_manager.create_error_response(
404 OidcErrorCode::InvalidRequest,
405 Some(description),
406 state,
407 HashMap::new(),
408 )
409 }
410
411 pub fn create_session_not_found_error(&self, state: Option<String>) -> OidcErrorResponse {
413 self.error_manager.create_error_response(
414 OidcErrorCode::SessionSelectionRequired,
415 Some("Registration session not found or expired".to_string()),
416 state,
417 HashMap::new(),
418 )
419 }
420
421 pub fn create_registration_incomplete_error(
423 &self,
424 missing_fields: Vec<String>,
425 state: Option<String>,
426 ) -> OidcErrorResponse {
427 let mut additional_details = HashMap::new();
428 additional_details.insert(
429 "missing_fields".to_string(),
430 serde_json::to_value(missing_fields.clone()).unwrap_or_default(),
431 );
432
433 self.error_manager.create_error_response(
434 OidcErrorCode::RegistrationRequired,
435 Some(format!(
436 "Registration incomplete. Missing required fields: {}",
437 missing_fields.join(", ")
438 )),
439 state,
440 additional_details,
441 )
442 }
443
444 pub fn create_session_expired_error(&self, state: Option<String>) -> OidcErrorResponse {
446 self.error_manager.create_error_response(
447 OidcErrorCode::LoginRequired,
448 Some("Registration session has expired. Please start registration again".to_string()),
449 state,
450 HashMap::new(),
451 )
452 }
453
454 pub fn get_error_manager(&self) -> &OidcErrorManager {
456 &self.error_manager
457 }
458
459 pub fn update_error_manager(&mut self, error_manager: OidcErrorManager) {
461 self.error_manager = error_manager;
462 }
463
464 pub fn is_registration_requested(&self, prompt: Option<&str>) -> bool {
466 if !self.config.enabled {
467 return false;
468 }
469
470 if let Some(prompt_values) = prompt {
471 let prompts: Vec<&str> = prompt_values.split_whitespace().collect();
472 prompts.contains(&"create")
473 } else {
474 false
475 }
476 }
477
478 pub fn initiate_registration(&mut self, request: RegistrationRequest) -> Result<String> {
480 if !self.config.enabled {
481 let error_response = self.create_registration_disabled_error(request.state.clone());
482 return Err(AuthError::validation(format!(
483 "Registration disabled: {}",
484 error_response.error_description.unwrap_or_default()
485 )));
486 }
487
488 if !self.is_registration_requested(request.prompt.as_deref()) {
490 let error_response = self.create_invalid_registration_request_error(
491 "Registration requires prompt=create parameter".to_string(),
492 request.state.clone(),
493 );
494 return Err(AuthError::validation(format!(
495 "Invalid request: {}",
496 error_response.error_description.unwrap_or_default()
497 )));
498 }
499
500 let registration_id = Uuid::new_v4().to_string();
502 let now = std::time::SystemTime::now()
503 .duration_since(std::time::UNIX_EPOCH)
504 .unwrap_or_default()
505 .as_secs();
506
507 let mut registration_data = RegistrationData {
508 registration_id: registration_id.clone(),
509 email: None,
510 phone_number: None,
511 given_name: None,
512 family_name: None,
513 name: None,
514 preferred_username: None,
515 picture: None,
516 website: None,
517 gender: None,
518 birthdate: None,
519 zoneinfo: None,
520 locale: None,
521 custom_fields: HashMap::new(),
522 completed: false,
523 created_at: now,
524 };
525
526 if let Some(login_hint) = &request.login_hint {
528 if login_hint.contains('@') {
529 registration_data.email = Some(login_hint.clone());
530 } else {
531 registration_data.preferred_username = Some(login_hint.clone());
532 }
533 }
534
535 if let Some(metadata_str) = &request.registration_metadata {
537 match serde_json::from_str::<HashMap<String, serde_json::Value>>(metadata_str) {
538 Ok(metadata) => {
539 registration_data.custom_fields.extend(metadata);
540 }
541 Err(_) => {
542 let error_response = self.create_invalid_registration_request_error(
543 "Invalid registration metadata JSON format".to_string(),
544 request.state.clone(),
545 );
546 return Err(AuthError::validation(format!(
547 "Invalid metadata: {}",
548 error_response.error_description.unwrap_or_default()
549 )));
550 }
551 }
552 }
553
554 self.registration_sessions
555 .insert(registration_id.clone(), registration_data);
556
557 Ok(registration_id)
558 }
559
560 pub fn update_registration_data(
562 &mut self,
563 registration_id: &str,
564 updates: HashMap<String, serde_json::Value>,
565 ) -> Result<()> {
566 let Some(registration) = self.registration_sessions.get_mut(registration_id) else {
568 let error_response = self.create_session_not_found_error(None);
569 return Err(AuthError::validation(format!(
570 "Session error: {}",
571 error_response.error_description.unwrap_or_default()
572 )));
573 };
574
575 let now = std::time::SystemTime::now()
577 .duration_since(std::time::UNIX_EPOCH)
578 .unwrap_or_default()
579 .as_secs();
580
581 if now - registration.created_at > self.config.session_timeout {
582 let error_response = self.create_session_expired_error(None);
583 return Err(AuthError::validation(format!(
584 "Session expired: {}",
585 error_response.error_description.unwrap_or_default()
586 )));
587 }
588
589 for (key, value) in updates {
591 match key.as_str() {
592 "email" => registration.email = value.as_str().map(|s| s.to_string()),
593 "phone_number" => registration.phone_number = value.as_str().map(|s| s.to_string()),
594 "given_name" => registration.given_name = value.as_str().map(|s| s.to_string()),
595 "family_name" => registration.family_name = value.as_str().map(|s| s.to_string()),
596 "name" => registration.name = value.as_str().map(|s| s.to_string()),
597 "preferred_username" => {
598 registration.preferred_username = value.as_str().map(|s| s.to_string())
599 }
600 "picture" => registration.picture = value.as_str().map(|s| s.to_string()),
601 "website" => registration.website = value.as_str().map(|s| s.to_string()),
602 "gender" => registration.gender = value.as_str().map(|s| s.to_string()),
603 "birthdate" => registration.birthdate = value.as_str().map(|s| s.to_string()),
604 "zoneinfo" => registration.zoneinfo = value.as_str().map(|s| s.to_string()),
605 "locale" => registration.locale = value.as_str().map(|s| s.to_string()),
606 _ => {
607 registration.custom_fields.insert(key, value);
609 }
610 }
611 }
612
613 Ok(())
614 }
615
616 pub fn validate_registration_data(&self, registration_id: &str) -> Result<Vec<String>> {
618 let Some(registration) = self.registration_sessions.get(registration_id) else {
620 let error_response = self.create_session_not_found_error(None);
621 return Err(AuthError::validation(format!(
622 "Session error: {}",
623 error_response.error_description.unwrap_or_default()
624 )));
625 };
626
627 let mut missing_fields = Vec::new();
628
629 for field in &self.config.required_fields {
631 let is_present = match field.as_str() {
632 "email" => registration.email.is_some(),
633 "phone_number" => registration.phone_number.is_some(),
634 "given_name" => registration.given_name.is_some(),
635 "family_name" => registration.family_name.is_some(),
636 "name" => registration.name.is_some(),
637 "preferred_username" => registration.preferred_username.is_some(),
638 _ => registration.custom_fields.contains_key(field),
639 };
640
641 if !is_present {
642 missing_fields.push(field.clone());
643 }
644 }
645
646 Ok(missing_fields)
647 }
648
649 pub fn validate_registration_completeness(
651 &self,
652 registration_id: &str,
653 state: Option<String>,
654 ) -> Result<()> {
655 let missing_fields = self.validate_registration_data(registration_id)?;
656 if !missing_fields.is_empty() {
657 let error_response = self.create_registration_incomplete_error(missing_fields, state);
658 return Err(AuthError::validation(format!(
659 "Registration incomplete: {}",
660 error_response.error_description.unwrap_or_default()
661 )));
662 }
663 Ok(())
664 }
665
666 pub async fn complete_registration(
677 &mut self,
678 registration_id: &str,
679 ) -> Result<RegistrationResponse> {
680 self.validate_registration_completeness(registration_id, None)?;
682
683 let Some(mut registration) = self.registration_sessions.remove(registration_id) else {
685 let error_response = self.create_session_not_found_error(None);
686 return Err(AuthError::validation(format!(
687 "Session error: {}",
688 error_response.error_description.unwrap_or_default()
689 )));
690 };
691
692 let sub = crate::utils::string::generate_id(Some("user"));
694
695 registration.completed = true;
697
698 if let Some(storage) = &self.storage {
700 let username = registration
703 .preferred_username
704 .clone()
705 .or_else(|| {
706 registration
707 .email
708 .as_deref()
709 .and_then(|e| e.split('@').next())
710 .map(|s| s.to_string())
711 })
712 .unwrap_or_else(|| sub.clone());
713
714 let email = registration
715 .email
716 .clone()
717 .unwrap_or_else(|| format!("{}@unknown.invalid", sub));
718
719 let plain_password = registration
724 .custom_fields
725 .get("password")
726 .and_then(|v| v.as_str())
727 .map(|s| s.to_string())
728 .unwrap_or_else(|| format!("tmp_{}", Uuid::new_v4()));
729
730 let password_hash =
731 bcrypt::hash(&plain_password, bcrypt::DEFAULT_COST).map_err(|e| {
732 AuthError::crypto(format!(
733 "Failed to hash password during registration: {}",
734 e
735 ))
736 })?;
737
738 let user_data = serde_json::json!({
739 "user_id": sub,
740 "username": username,
741 "email": email,
742 "password_hash": password_hash,
743 "given_name": registration.given_name,
744 "family_name": registration.family_name,
745 "name": registration.name,
746 "phone_number": registration.phone_number,
747 "picture": registration.picture,
748 "roles": ["user"],
749 "active": true,
750 "created_at": chrono::Utc::now().to_rfc3339(),
751 });
752
753 storage
755 .store_kv(
756 &format!("user:{}", sub),
757 user_data.to_string().as_bytes(),
758 None,
759 )
760 .await?;
761 storage
762 .store_kv(&format!("user:username:{}", username), sub.as_bytes(), None)
763 .await?;
764 storage
765 .store_kv(&format!("user:email:{}", email), sub.as_bytes(), None)
766 .await?;
767
768 tracing::info!(
769 "OIDC registration complete: user '{}' (sub={}) persisted to storage",
770 username,
771 sub
772 );
773 } else {
774 warn!(
775 "OIDC user registration for sub '{}' completed the session but did NOT persist \
776 the user to storage. Provide a storage backend via \
777 RegistrationManager::storage() to enable user persistence.",
778 sub
779 );
780 }
781
782 let authorization_code = format!("reg_auth_{}", Uuid::new_v4());
784
785 Ok(RegistrationResponse {
786 sub,
787 completed: true,
788 code: Some(authorization_code),
789 state: None, })
791 }
792
793 pub fn get_registration_data(&self, registration_id: &str) -> Option<&RegistrationData> {
795 self.registration_sessions.get(registration_id)
796 }
797
798 pub fn generate_registration_form(&self, registration_id: &str) -> Result<String> {
800 let Some(registration) = self.registration_sessions.get(registration_id) else {
802 let error_response = self.create_session_not_found_error(None);
803 return Err(AuthError::validation(format!(
804 "Session error: {}",
805 error_response.error_description.unwrap_or_default()
806 )));
807 };
808
809 let mut form = format!(
810 r#"<!DOCTYPE html>
811<html>
812<head>
813 <title>User Registration</title>
814 <style>
815 body {{ font-family: Arial, sans-serif; margin: 40px; }}
816 .form-group {{ margin-bottom: 15px; }}
817 label {{ display: block; margin-bottom: 5px; font-weight: bold; }}
818 input {{ width: 100%; padding: 8px; border: 1px solid #ccc; border-radius: 4px; }}
819 .required {{ color: red; }}
820 .submit-btn {{ background: #007bff; color: white; padding: 10px 20px; border: none; border-radius: 4px; cursor: pointer; }}
821 </style>
822</head>
823<body>
824 <h1>Create Your Account</h1>
825 <form method="post" action="/connect/register/{}/complete">
826"#,
827 registration.registration_id
828 );
829
830 for field in &self.config.required_fields {
832 let (field_name, field_type, current_value) = match field.as_str() {
833 "email" => (
834 "Email Address",
835 "email",
836 registration.email.as_deref().unwrap_or(""),
837 ),
838 "given_name" => (
839 "First Name",
840 "text",
841 registration.given_name.as_deref().unwrap_or(""),
842 ),
843 "family_name" => (
844 "Last Name",
845 "text",
846 registration.family_name.as_deref().unwrap_or(""),
847 ),
848 "phone_number" => (
849 "Phone Number",
850 "tel",
851 registration.phone_number.as_deref().unwrap_or(""),
852 ),
853 _ => (field.as_str(), "text", ""),
854 };
855
856 form.push_str(&format!(
857 r#" <div class="form-group">
858 <label for="{}">{} <span class="required">*</span></label>
859 <input type="{}" id="{}" name="{}" value="{}" required>
860 </div>
861"#,
862 field, field_name, field_type, field, field, current_value
863 ));
864 }
865
866 for field in &self.config.optional_fields {
868 if !self.config.required_fields.contains(field) {
869 let (field_name, field_type, current_value) = match field.as_str() {
870 "preferred_username" => (
871 "Username",
872 "text",
873 registration.preferred_username.as_deref().unwrap_or(""),
874 ),
875 "website" => (
876 "Website",
877 "url",
878 registration.website.as_deref().unwrap_or(""),
879 ),
880 "picture" => (
881 "Profile Picture URL",
882 "url",
883 registration.picture.as_deref().unwrap_or(""),
884 ),
885 _ => (field.as_str(), "text", ""),
886 };
887
888 form.push_str(&format!(
889 r#" <div class="form-group">
890 <label for="{}">{}</label>
891 <input type="{}" id="{}" name="{}" value="{}">
892 </div>
893"#,
894 field, field_name, field_type, field, field, current_value
895 ));
896 }
897 }
898
899 form.push_str(
900 r#" <button type="submit" class="submit-btn">Create Account</button>
901 </form>
902</body>
903</html>"#,
904 );
905
906 Ok(form)
907 }
908
909 pub fn cleanup_expired_sessions(&mut self) -> usize {
911 let now = std::time::SystemTime::now()
912 .duration_since(std::time::UNIX_EPOCH)
913 .unwrap_or_default()
914 .as_secs();
915
916 let initial_count = self.registration_sessions.len();
917
918 self.registration_sessions
919 .retain(|_, registration| now - registration.created_at < self.config.session_timeout);
920
921 initial_count - self.registration_sessions.len()
922 }
923
924 pub fn get_discovery_metadata(&self) -> HashMap<String, serde_json::Value> {
926 let mut metadata = HashMap::new();
927
928 if self.config.enabled {
929 metadata.insert(
930 "registration_endpoint".to_string(),
931 serde_json::Value::String(self.config.registration_endpoint.clone()),
932 );
933 metadata.insert(
934 "prompt_values_supported".to_string(),
935 serde_json::Value::Array(vec![
936 serde_json::Value::String("none".to_string()),
937 serde_json::Value::String("login".to_string()),
938 serde_json::Value::String("consent".to_string()),
939 serde_json::Value::String("select_account".to_string()),
940 serde_json::Value::String("create".to_string()),
941 ]),
942 );
943 }
944
945 metadata
946 }
947}
948
949#[cfg(test)]
950mod tests {
951 use super::*;
952
953 #[test]
954 fn test_registration_request_builder() {
955 let req = RegistrationRequest::builder("client_123", "https://app/cb", "openid", "code")
956 .prompt("create")
957 .login_hint("user@example.com")
958 .ui_locales("fr-FR")
959 .build();
960
961 assert_eq!(req.client_id, "client_123");
962 assert_eq!(req.prompt, Some("create".to_string()));
963 assert_eq!(req.login_hint, Some("user@example.com".to_string()));
964 assert_eq!(req.ui_locales, Some("fr-FR".to_string()));
965 }
966
967 #[test]
968 fn test_registration_data_builder() {
969 let data = RegistrationData::new("reg_123")
970 .with_email("user@example.com")
971 .with_names("John", "Doe")
972 .mark_completed();
973
974 assert_eq!(data.registration_id, "reg_123");
975 assert_eq!(data.email, Some("user@example.com".to_string()));
976 assert_eq!(data.given_name, Some("John".to_string()));
977 assert_eq!(data.family_name, Some("Doe".to_string()));
978 assert!(data.completed);
979 }
980
981 #[test]
982 fn test_error_manager_integration() {
983 let mut manager = RegistrationManager::new(RegistrationConfig::default());
984
985 let disabled_config = RegistrationConfig {
987 enabled: false,
988 ..Default::default()
989 };
990 let mut disabled_manager = RegistrationManager::new(disabled_config);
991
992 let request = RegistrationRequest {
993 client_id: "test_client".to_string(),
994 redirect_uri: "https://client.example.com/callback".to_string(),
995 scope: "openid profile email".to_string(),
996 response_type: "code".to_string(),
997 state: Some("state123".to_string()),
998 nonce: Some("nonce456".to_string()),
999 prompt: Some("create".to_string()),
1000 login_hint: None,
1001 ui_locales: None,
1002 registration_metadata: None,
1003 claims: None,
1004 };
1005
1006 let result = disabled_manager.initiate_registration(request.clone());
1007 assert!(result.is_err());
1008 assert!(
1009 result
1010 .unwrap_err()
1011 .to_string()
1012 .contains("Registration disabled")
1013 );
1014
1015 let invalid_request = RegistrationRequest {
1017 prompt: Some("login".to_string()), ..request.clone()
1019 };
1020
1021 let result = manager.initiate_registration(invalid_request);
1022 assert!(result.is_err());
1023 assert!(result.unwrap_err().to_string().contains("Invalid request"));
1024
1025 let invalid_metadata_request = RegistrationRequest {
1027 registration_metadata: Some("invalid json".to_string()),
1028 ..request
1029 };
1030
1031 let result = manager.initiate_registration(invalid_metadata_request);
1032 assert!(result.is_err());
1033 assert!(result.unwrap_err().to_string().contains("Invalid metadata"));
1034 }
1035
1036 #[test]
1037 fn test_error_manager_session_handling() {
1038 let mut manager = RegistrationManager::new(RegistrationConfig::default());
1039
1040 let result = manager.update_registration_data("nonexistent", HashMap::new());
1042 assert!(result.is_err());
1043 assert!(result.unwrap_err().to_string().contains("Session error"));
1044
1045 let config = RegistrationConfig {
1047 required_fields: vec!["email".to_string(), "given_name".to_string()],
1048 ..Default::default()
1049 };
1050 let mut manager = RegistrationManager::new(config);
1051
1052 let registration_data = RegistrationData {
1053 registration_id: "test123".to_string(),
1054 email: Some("user@example.com".to_string()),
1055 given_name: None, ..Default::default()
1057 };
1058
1059 manager
1060 .registration_sessions
1061 .insert("test123".to_string(), registration_data);
1062
1063 let result =
1064 manager.validate_registration_completeness("test123", Some("state456".to_string()));
1065 assert!(result.is_err());
1066 assert!(
1067 result
1068 .unwrap_err()
1069 .to_string()
1070 .contains("Registration incomplete")
1071 );
1072 }
1073
1074 #[test]
1075 fn test_error_manager_custom_configuration() {
1076 use crate::server::oidc::oidc_error_extensions::{OidcErrorCode, OidcErrorManager};
1077
1078 let mut custom_error_manager = OidcErrorManager::default();
1079 custom_error_manager.add_custom_error_mapping(
1080 "custom_registration_error".to_string(),
1081 OidcErrorCode::RegistrationRequired,
1082 );
1083
1084 let manager = RegistrationManager::new(RegistrationConfig::default())
1085 .error_manager(custom_error_manager);
1086
1087 assert!(
1089 manager
1090 .get_error_manager()
1091 .has_custom_mapping("custom_registration_error")
1092 );
1093
1094 let error_response =
1096 manager.create_registration_disabled_error(Some("test_state".to_string()));
1097 assert_eq!(error_response.state.as_ref().unwrap(), "test_state");
1098
1099 let session_error = manager.create_session_not_found_error(None);
1100 assert_eq!(session_error.error, OidcErrorCode::SessionSelectionRequired);
1101 }
1102
1103 #[test]
1104 fn test_registration_request_detection() {
1105 let manager = RegistrationManager::new(RegistrationConfig::default());
1106
1107 assert!(manager.is_registration_requested(Some("create")));
1108 assert!(manager.is_registration_requested(Some("login create")));
1109 assert!(manager.is_registration_requested(Some("create consent")));
1110 assert!(!manager.is_registration_requested(Some("login")));
1111 assert!(!manager.is_registration_requested(None));
1112 }
1113
1114 #[test]
1115 fn test_registration_initiation() {
1116 let mut manager = RegistrationManager::new(RegistrationConfig::default());
1117
1118 let request = RegistrationRequest {
1119 client_id: "test_client".to_string(),
1120 redirect_uri: "https://client.example.com/callback".to_string(),
1121 scope: "openid profile email".to_string(),
1122 response_type: "code".to_string(),
1123 state: Some("state123".to_string()),
1124 nonce: Some("nonce456".to_string()),
1125 prompt: Some("create".to_string()),
1126 login_hint: Some("user@example.com".to_string()),
1127 ui_locales: None,
1128 registration_metadata: None,
1129 claims: None,
1130 };
1131
1132 let registration_id = manager.initiate_registration(request).unwrap();
1133 assert!(!registration_id.is_empty());
1134
1135 let registration_data = manager.get_registration_data(®istration_id).unwrap();
1136 assert_eq!(
1137 registration_data.email,
1138 Some("user@example.com".to_string())
1139 );
1140 assert!(!registration_data.completed);
1141 }
1142
1143 #[test]
1144 fn test_registration_data_validation() {
1145 let mut manager = RegistrationManager::new(RegistrationConfig {
1146 required_fields: vec!["email".to_string(), "given_name".to_string()],
1147 ..RegistrationConfig::default()
1148 });
1149
1150 let registration_id = "test_reg_123";
1151 let registration_data = RegistrationData {
1152 registration_id: registration_id.to_string(),
1153 email: Some("user@example.com".to_string()),
1154 given_name: None, ..Default::default()
1156 };
1157
1158 manager
1159 .registration_sessions
1160 .insert(registration_id.to_string(), registration_data);
1161
1162 let missing_fields = manager.validate_registration_data(registration_id).unwrap();
1163 assert_eq!(missing_fields, vec!["given_name"]);
1164 }
1165}