1use crate::errors::{AuthError, Result};
56use crate::server::oidc::oidc_error_extensions::{OidcErrorCode, OidcErrorManager, OidcErrorResponse};
57use serde::{Deserialize, Serialize};
58use std::collections::HashMap;
59use uuid::Uuid;
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct RegistrationRequest {
64 pub client_id: String,
66 pub redirect_uri: String,
67 pub scope: String,
68 pub response_type: String,
69 pub state: Option<String>,
70 pub nonce: Option<String>,
71
72 pub prompt: Option<String>,
75 pub login_hint: Option<String>,
77 pub ui_locales: Option<String>,
79 pub registration_metadata: Option<String>,
81 pub claims: Option<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, Default)]
87pub struct RegistrationData {
88 pub registration_id: String,
90 pub email: Option<String>,
92 pub phone_number: Option<String>,
94 pub given_name: Option<String>,
96 pub family_name: Option<String>,
98 pub name: Option<String>,
100 pub preferred_username: Option<String>,
102 pub picture: Option<String>,
104 pub website: Option<String>,
106 pub gender: Option<String>,
108 pub birthdate: Option<String>,
110 pub zoneinfo: Option<String>,
112 pub locale: Option<String>,
114 pub custom_fields: HashMap<String, serde_json::Value>,
116 pub completed: bool,
118 pub created_at: u64,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct RegistrationResponse {
125 pub sub: String,
127 pub completed: bool,
129 pub code: Option<String>,
131 pub state: Option<String>,
133}
134
135#[derive(Debug, Clone)]
137pub struct RegistrationConfig {
138 pub enabled: bool,
140 pub registration_endpoint: String,
142 pub required_fields: Vec<String>,
144 pub optional_fields: Vec<String>,
146 pub session_timeout: u64,
148 pub require_email_verification: bool,
150 pub require_phone_verification: bool,
152 pub field_validation_rules: HashMap<String, String>,
154}
155
156impl Default for RegistrationConfig {
157 fn default() -> Self {
158 Self {
159 enabled: true,
160 registration_endpoint: "/connect/register".to_string(),
161 required_fields: vec!["email".to_string()],
162 optional_fields: vec![
163 "given_name".to_string(),
164 "family_name".to_string(),
165 "name".to_string(),
166 "preferred_username".to_string(),
167 "phone_number".to_string(),
168 ],
169 session_timeout: 1800, require_email_verification: true,
171 require_phone_verification: false,
172 field_validation_rules: HashMap::new(),
173 }
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct RegistrationManager {
180 config: RegistrationConfig,
182 error_manager: OidcErrorManager,
184 registration_sessions: HashMap<String, RegistrationData>,
186}
187
188impl RegistrationManager {
189 pub fn new(config: RegistrationConfig) -> Self {
191 Self {
192 config,
193 error_manager: OidcErrorManager::default(),
194 registration_sessions: HashMap::new(),
195 }
196 }
197
198 pub fn with_error_manager(config: RegistrationConfig, error_manager: OidcErrorManager) -> Self {
200 Self {
201 config,
202 error_manager,
203 registration_sessions: HashMap::new(),
204 }
205 }
206
207 pub fn create_registration_disabled_error(&self, state: Option<String>) -> OidcErrorResponse {
209 self.error_manager.create_error_response(
210 OidcErrorCode::RegistrationNotSupported,
211 Some("User registration is not enabled on this server".to_string()),
212 state,
213 HashMap::new(),
214 )
215 }
216
217 pub fn create_invalid_registration_request_error(
219 &self,
220 description: String,
221 state: Option<String>,
222 ) -> OidcErrorResponse {
223 self.error_manager.create_error_response(
224 OidcErrorCode::InvalidRequest,
225 Some(description),
226 state,
227 HashMap::new(),
228 )
229 }
230
231 pub fn create_session_not_found_error(&self, state: Option<String>) -> OidcErrorResponse {
233 self.error_manager.create_error_response(
234 OidcErrorCode::SessionSelectionRequired,
235 Some("Registration session not found or expired".to_string()),
236 state,
237 HashMap::new(),
238 )
239 }
240
241 pub fn create_registration_incomplete_error(
243 &self,
244 missing_fields: Vec<String>,
245 state: Option<String>,
246 ) -> OidcErrorResponse {
247 let mut additional_details = HashMap::new();
248 additional_details.insert(
249 "missing_fields".to_string(),
250 serde_json::to_value(missing_fields.clone()).unwrap(),
251 );
252
253 self.error_manager.create_error_response(
254 OidcErrorCode::RegistrationRequired,
255 Some(format!(
256 "Registration incomplete. Missing required fields: {}",
257 missing_fields.join(", ")
258 )),
259 state,
260 additional_details,
261 )
262 }
263
264 pub fn create_session_expired_error(&self, state: Option<String>) -> OidcErrorResponse {
266 self.error_manager.create_error_response(
267 OidcErrorCode::LoginRequired,
268 Some("Registration session has expired. Please start registration again".to_string()),
269 state,
270 HashMap::new(),
271 )
272 }
273
274 pub fn get_error_manager(&self) -> &OidcErrorManager {
276 &self.error_manager
277 }
278
279 pub fn update_error_manager(&mut self, error_manager: OidcErrorManager) {
281 self.error_manager = error_manager;
282 }
283
284 pub fn is_registration_requested(&self, prompt: Option<&str>) -> bool {
286 if !self.config.enabled {
287 return false;
288 }
289
290 if let Some(prompt_values) = prompt {
291 let prompts: Vec<&str> = prompt_values.split_whitespace().collect();
292 prompts.contains(&"create")
293 } else {
294 false
295 }
296 }
297
298 pub fn initiate_registration(&mut self, request: RegistrationRequest) -> Result<String> {
300 if !self.config.enabled {
301 let error_response = self.create_registration_disabled_error(request.state.clone());
302 return Err(AuthError::validation(format!(
303 "Registration disabled: {}",
304 error_response.error_description.unwrap_or_default()
305 )));
306 }
307
308 if !self.is_registration_requested(request.prompt.as_deref()) {
310 let error_response = self.create_invalid_registration_request_error(
311 "Registration requires prompt=create parameter".to_string(),
312 request.state.clone(),
313 );
314 return Err(AuthError::validation(format!(
315 "Invalid request: {}",
316 error_response.error_description.unwrap_or_default()
317 )));
318 }
319
320 let registration_id = Uuid::new_v4().to_string();
322 let now = std::time::SystemTime::now()
323 .duration_since(std::time::UNIX_EPOCH)
324 .unwrap()
325 .as_secs();
326
327 let mut registration_data = RegistrationData {
328 registration_id: registration_id.clone(),
329 email: None,
330 phone_number: None,
331 given_name: None,
332 family_name: None,
333 name: None,
334 preferred_username: None,
335 picture: None,
336 website: None,
337 gender: None,
338 birthdate: None,
339 zoneinfo: None,
340 locale: None,
341 custom_fields: HashMap::new(),
342 completed: false,
343 created_at: now,
344 };
345
346 if let Some(login_hint) = &request.login_hint {
348 if login_hint.contains('@') {
349 registration_data.email = Some(login_hint.clone());
350 } else {
351 registration_data.preferred_username = Some(login_hint.clone());
352 }
353 }
354
355 if let Some(metadata_str) = &request.registration_metadata {
357 match serde_json::from_str::<HashMap<String, serde_json::Value>>(metadata_str) {
358 Ok(metadata) => {
359 registration_data.custom_fields.extend(metadata);
360 }
361 Err(_) => {
362 let error_response = self.create_invalid_registration_request_error(
363 "Invalid registration metadata JSON format".to_string(),
364 request.state.clone(),
365 );
366 return Err(AuthError::validation(format!(
367 "Invalid metadata: {}",
368 error_response.error_description.unwrap_or_default()
369 )));
370 }
371 }
372 }
373
374 self.registration_sessions
375 .insert(registration_id.clone(), registration_data);
376
377 Ok(registration_id)
378 }
379
380 pub fn update_registration_data(
382 &mut self,
383 registration_id: &str,
384 updates: HashMap<String, serde_json::Value>,
385 ) -> Result<()> {
386 if !self.registration_sessions.contains_key(registration_id) {
388 let error_response = self.create_session_not_found_error(None);
389 return Err(AuthError::validation(format!(
390 "Session error: {}",
391 error_response.error_description.unwrap_or_default()
392 )));
393 }
394
395 let registration = self.registration_sessions.get_mut(registration_id).unwrap(); let now = std::time::SystemTime::now()
399 .duration_since(std::time::UNIX_EPOCH)
400 .unwrap()
401 .as_secs();
402
403 if now - registration.created_at > self.config.session_timeout {
404 let error_response = self.create_session_expired_error(None);
405 return Err(AuthError::validation(format!(
406 "Session expired: {}",
407 error_response.error_description.unwrap_or_default()
408 )));
409 }
410
411 for (key, value) in updates {
413 match key.as_str() {
414 "email" => registration.email = value.as_str().map(|s| s.to_string()),
415 "phone_number" => registration.phone_number = value.as_str().map(|s| s.to_string()),
416 "given_name" => registration.given_name = value.as_str().map(|s| s.to_string()),
417 "family_name" => registration.family_name = value.as_str().map(|s| s.to_string()),
418 "name" => registration.name = value.as_str().map(|s| s.to_string()),
419 "preferred_username" => {
420 registration.preferred_username = value.as_str().map(|s| s.to_string())
421 }
422 "picture" => registration.picture = value.as_str().map(|s| s.to_string()),
423 "website" => registration.website = value.as_str().map(|s| s.to_string()),
424 "gender" => registration.gender = value.as_str().map(|s| s.to_string()),
425 "birthdate" => registration.birthdate = value.as_str().map(|s| s.to_string()),
426 "zoneinfo" => registration.zoneinfo = value.as_str().map(|s| s.to_string()),
427 "locale" => registration.locale = value.as_str().map(|s| s.to_string()),
428 _ => {
429 registration.custom_fields.insert(key, value);
431 }
432 }
433 }
434
435 Ok(())
436 }
437
438 pub fn validate_registration_data(&self, registration_id: &str) -> Result<Vec<String>> {
440 if !self.registration_sessions.contains_key(registration_id) {
442 let error_response = self.create_session_not_found_error(None);
443 return Err(AuthError::validation(format!(
444 "Session error: {}",
445 error_response.error_description.unwrap_or_default()
446 )));
447 }
448
449 let registration = self.registration_sessions.get(registration_id).unwrap(); let mut missing_fields = Vec::new();
452
453 for field in &self.config.required_fields {
455 let is_present = match field.as_str() {
456 "email" => registration.email.is_some(),
457 "phone_number" => registration.phone_number.is_some(),
458 "given_name" => registration.given_name.is_some(),
459 "family_name" => registration.family_name.is_some(),
460 "name" => registration.name.is_some(),
461 "preferred_username" => registration.preferred_username.is_some(),
462 _ => registration.custom_fields.contains_key(field),
463 };
464
465 if !is_present {
466 missing_fields.push(field.clone());
467 }
468 }
469
470 Ok(missing_fields)
471 }
472
473 pub fn validate_registration_completeness(
475 &self,
476 registration_id: &str,
477 state: Option<String>,
478 ) -> Result<()> {
479 let missing_fields = self.validate_registration_data(registration_id)?;
480 if !missing_fields.is_empty() {
481 let error_response = self.create_registration_incomplete_error(missing_fields, state);
482 return Err(AuthError::validation(format!(
483 "Registration incomplete: {}",
484 error_response.error_description.unwrap_or_default()
485 )));
486 }
487 Ok(())
488 }
489
490 pub fn complete_registration(&mut self, registration_id: &str) -> Result<RegistrationResponse> {
492 self.validate_registration_completeness(registration_id, None)?;
494
495 if !self.registration_sessions.contains_key(registration_id) {
497 let error_response = self.create_session_not_found_error(None);
498 return Err(AuthError::validation(format!(
499 "Session error: {}",
500 error_response.error_description.unwrap_or_default()
501 )));
502 }
503
504 let mut registration = self.registration_sessions.remove(registration_id).unwrap(); let sub = format!("user_{}", Uuid::new_v4());
508
509 registration.completed = true;
511
512 let authorization_code = format!("reg_auth_{}", Uuid::new_v4());
520
521 Ok(RegistrationResponse {
522 sub,
523 completed: true,
524 code: Some(authorization_code),
525 state: None, })
527 }
528
529 pub fn get_registration_data(&self, registration_id: &str) -> Option<&RegistrationData> {
531 self.registration_sessions.get(registration_id)
532 }
533
534 pub fn generate_registration_form(&self, registration_id: &str) -> Result<String> {
536 if !self.registration_sessions.contains_key(registration_id) {
538 let error_response = self.create_session_not_found_error(None);
539 return Err(AuthError::validation(format!(
540 "Session error: {}",
541 error_response.error_description.unwrap_or_default()
542 )));
543 }
544
545 let registration = self.registration_sessions.get(registration_id).unwrap(); let mut form = format!(
548 r#"<!DOCTYPE html>
549<html>
550<head>
551 <title>User Registration</title>
552 <style>
553 body {{ font-family: Arial, sans-serif; margin: 40px; }}
554 .form-group {{ margin-bottom: 15px; }}
555 label {{ display: block; margin-bottom: 5px; font-weight: bold; }}
556 input {{ width: 100%; padding: 8px; border: 1px solid #ccc; border-radius: 4px; }}
557 .required {{ color: red; }}
558 .submit-btn {{ background: #007bff; color: white; padding: 10px 20px; border: none; border-radius: 4px; cursor: pointer; }}
559 </style>
560</head>
561<body>
562 <h1>Create Your Account</h1>
563 <form method="post" action="/connect/register/{}/complete">
564"#,
565 registration.registration_id
566 );
567
568 for field in &self.config.required_fields {
570 let (field_name, field_type, current_value) = match field.as_str() {
571 "email" => (
572 "Email Address",
573 "email",
574 registration.email.as_deref().unwrap_or(""),
575 ),
576 "given_name" => (
577 "First Name",
578 "text",
579 registration.given_name.as_deref().unwrap_or(""),
580 ),
581 "family_name" => (
582 "Last Name",
583 "text",
584 registration.family_name.as_deref().unwrap_or(""),
585 ),
586 "phone_number" => (
587 "Phone Number",
588 "tel",
589 registration.phone_number.as_deref().unwrap_or(""),
590 ),
591 _ => (field.as_str(), "text", ""),
592 };
593
594 form.push_str(&format!(
595 r#" <div class="form-group">
596 <label for="{}">{} <span class="required">*</span></label>
597 <input type="{}" id="{}" name="{}" value="{}" required>
598 </div>
599"#,
600 field, field_name, field_type, field, field, current_value
601 ));
602 }
603
604 for field in &self.config.optional_fields {
606 if !self.config.required_fields.contains(field) {
607 let (field_name, field_type, current_value) = match field.as_str() {
608 "preferred_username" => (
609 "Username",
610 "text",
611 registration.preferred_username.as_deref().unwrap_or(""),
612 ),
613 "website" => (
614 "Website",
615 "url",
616 registration.website.as_deref().unwrap_or(""),
617 ),
618 "picture" => (
619 "Profile Picture URL",
620 "url",
621 registration.picture.as_deref().unwrap_or(""),
622 ),
623 _ => (field.as_str(), "text", ""),
624 };
625
626 form.push_str(&format!(
627 r#" <div class="form-group">
628 <label for="{}">{}</label>
629 <input type="{}" id="{}" name="{}" value="{}">
630 </div>
631"#,
632 field, field_name, field_type, field, field, current_value
633 ));
634 }
635 }
636
637 form.push_str(
638 r#" <button type="submit" class="submit-btn">Create Account</button>
639 </form>
640</body>
641</html>"#,
642 );
643
644 Ok(form)
645 }
646
647 pub fn cleanup_expired_sessions(&mut self) -> usize {
649 let now = std::time::SystemTime::now()
650 .duration_since(std::time::UNIX_EPOCH)
651 .unwrap()
652 .as_secs();
653
654 let initial_count = self.registration_sessions.len();
655
656 self.registration_sessions
657 .retain(|_, registration| now - registration.created_at < self.config.session_timeout);
658
659 initial_count - self.registration_sessions.len()
660 }
661
662 pub fn get_discovery_metadata(&self) -> HashMap<String, serde_json::Value> {
664 let mut metadata = HashMap::new();
665
666 if self.config.enabled {
667 metadata.insert(
668 "registration_endpoint".to_string(),
669 serde_json::Value::String(self.config.registration_endpoint.clone()),
670 );
671 metadata.insert(
672 "prompt_values_supported".to_string(),
673 serde_json::Value::Array(vec![
674 serde_json::Value::String("none".to_string()),
675 serde_json::Value::String("login".to_string()),
676 serde_json::Value::String("consent".to_string()),
677 serde_json::Value::String("select_account".to_string()),
678 serde_json::Value::String("create".to_string()),
679 ]),
680 );
681 }
682
683 metadata
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690
691 #[test]
692 fn test_error_manager_integration() {
693 let mut manager = RegistrationManager::new(RegistrationConfig::default());
694
695 let disabled_config = RegistrationConfig {
697 enabled: false,
698 ..Default::default()
699 };
700 let mut disabled_manager = RegistrationManager::new(disabled_config);
701
702 let request = RegistrationRequest {
703 client_id: "test_client".to_string(),
704 redirect_uri: "https://client.example.com/callback".to_string(),
705 scope: "openid profile email".to_string(),
706 response_type: "code".to_string(),
707 state: Some("state123".to_string()),
708 nonce: Some("nonce456".to_string()),
709 prompt: Some("create".to_string()),
710 login_hint: None,
711 ui_locales: None,
712 registration_metadata: None,
713 claims: None,
714 };
715
716 let result = disabled_manager.initiate_registration(request.clone());
717 assert!(result.is_err());
718 assert!(
719 result
720 .unwrap_err()
721 .to_string()
722 .contains("Registration disabled")
723 );
724
725 let invalid_request = RegistrationRequest {
727 prompt: Some("login".to_string()), ..request.clone()
729 };
730
731 let result = manager.initiate_registration(invalid_request);
732 assert!(result.is_err());
733 assert!(result.unwrap_err().to_string().contains("Invalid request"));
734
735 let invalid_metadata_request = RegistrationRequest {
737 registration_metadata: Some("invalid json".to_string()),
738 ..request
739 };
740
741 let result = manager.initiate_registration(invalid_metadata_request);
742 assert!(result.is_err());
743 assert!(result.unwrap_err().to_string().contains("Invalid metadata"));
744 }
745
746 #[test]
747 fn test_error_manager_session_handling() {
748 let mut manager = RegistrationManager::new(RegistrationConfig::default());
749
750 let result = manager.update_registration_data("nonexistent", HashMap::new());
752 assert!(result.is_err());
753 assert!(result.unwrap_err().to_string().contains("Session error"));
754
755 let config = RegistrationConfig {
757 required_fields: vec!["email".to_string(), "given_name".to_string()],
758 ..Default::default()
759 };
760 let mut manager = RegistrationManager::new(config);
761
762 let registration_data = RegistrationData {
763 registration_id: "test123".to_string(),
764 email: Some("user@example.com".to_string()),
765 given_name: None, ..Default::default()
767 };
768
769 manager
770 .registration_sessions
771 .insert("test123".to_string(), registration_data);
772
773 let result =
774 manager.validate_registration_completeness("test123", Some("state456".to_string()));
775 assert!(result.is_err());
776 assert!(
777 result
778 .unwrap_err()
779 .to_string()
780 .contains("Registration incomplete")
781 );
782 }
783
784 #[test]
785 fn test_error_manager_custom_configuration() {
786 use crate::server::oidc::oidc_error_extensions::{OidcErrorCode, OidcErrorManager};
787
788 let mut custom_error_manager = OidcErrorManager::default();
789 custom_error_manager.add_custom_error_mapping(
790 "custom_registration_error".to_string(),
791 OidcErrorCode::RegistrationRequired,
792 );
793
794 let manager = RegistrationManager::with_error_manager(
795 RegistrationConfig::default(),
796 custom_error_manager,
797 );
798
799 assert!(
801 manager
802 .get_error_manager()
803 .has_custom_mapping("custom_registration_error")
804 );
805
806 let error_response =
808 manager.create_registration_disabled_error(Some("test_state".to_string()));
809 assert_eq!(error_response.state.as_ref().unwrap(), "test_state");
810
811 let session_error = manager.create_session_not_found_error(None);
812 assert_eq!(session_error.error, OidcErrorCode::SessionSelectionRequired);
813 }
814
815 #[test]
816 fn test_registration_request_detection() {
817 let manager = RegistrationManager::new(RegistrationConfig::default());
818
819 assert!(manager.is_registration_requested(Some("create")));
820 assert!(manager.is_registration_requested(Some("login create")));
821 assert!(manager.is_registration_requested(Some("create consent")));
822 assert!(!manager.is_registration_requested(Some("login")));
823 assert!(!manager.is_registration_requested(None));
824 }
825
826 #[test]
827 fn test_registration_initiation() {
828 let mut manager = RegistrationManager::new(RegistrationConfig::default());
829
830 let request = RegistrationRequest {
831 client_id: "test_client".to_string(),
832 redirect_uri: "https://client.example.com/callback".to_string(),
833 scope: "openid profile email".to_string(),
834 response_type: "code".to_string(),
835 state: Some("state123".to_string()),
836 nonce: Some("nonce456".to_string()),
837 prompt: Some("create".to_string()),
838 login_hint: Some("user@example.com".to_string()),
839 ui_locales: None,
840 registration_metadata: None,
841 claims: None,
842 };
843
844 let registration_id = manager.initiate_registration(request).unwrap();
845 assert!(!registration_id.is_empty());
846
847 let registration_data = manager.get_registration_data(®istration_id).unwrap();
848 assert_eq!(
849 registration_data.email,
850 Some("user@example.com".to_string())
851 );
852 assert!(!registration_data.completed);
853 }
854
855 #[test]
856 fn test_registration_data_validation() {
857 let mut manager = RegistrationManager::new(RegistrationConfig {
858 required_fields: vec!["email".to_string(), "given_name".to_string()],
859 ..RegistrationConfig::default()
860 });
861
862 let registration_id = "test_reg_123";
863 let registration_data = RegistrationData {
864 registration_id: registration_id.to_string(),
865 email: Some("user@example.com".to_string()),
866 given_name: None, ..Default::default()
868 };
869
870 manager
871 .registration_sessions
872 .insert(registration_id.to_string(), registration_data);
873
874 let missing_fields = manager.validate_registration_data(registration_id).unwrap();
875 assert_eq!(missing_fields, vec!["given_name"]);
876 }
877}
878
879