1use crate::errors::{AuthError, Result};
56use crate::server::oidc::oidc_error_extensions::{
57 OidcErrorCode, OidcErrorManager, OidcErrorResponse,
58};
59use serde::{Deserialize, Serialize};
60use std::collections::HashMap;
61use uuid::Uuid;
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct RegistrationRequest {
66 pub client_id: String,
68 pub redirect_uri: String,
69 pub scope: String,
70 pub response_type: String,
71 pub state: Option<String>,
72 pub nonce: Option<String>,
73
74 pub prompt: Option<String>,
77 pub login_hint: Option<String>,
79 pub ui_locales: Option<String>,
81 pub registration_metadata: Option<String>,
83 pub claims: Option<String>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize, Default)]
89pub struct RegistrationData {
90 pub registration_id: String,
92 pub email: Option<String>,
94 pub phone_number: Option<String>,
96 pub given_name: Option<String>,
98 pub family_name: Option<String>,
100 pub name: Option<String>,
102 pub preferred_username: Option<String>,
104 pub picture: Option<String>,
106 pub website: Option<String>,
108 pub gender: Option<String>,
110 pub birthdate: Option<String>,
112 pub zoneinfo: Option<String>,
114 pub locale: Option<String>,
116 pub custom_fields: HashMap<String, serde_json::Value>,
118 pub completed: bool,
120 pub created_at: u64,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct RegistrationResponse {
127 pub sub: String,
129 pub completed: bool,
131 pub code: Option<String>,
133 pub state: Option<String>,
135}
136
137#[derive(Debug, Clone)]
139pub struct RegistrationConfig {
140 pub enabled: bool,
142 pub registration_endpoint: String,
144 pub required_fields: Vec<String>,
146 pub optional_fields: Vec<String>,
148 pub session_timeout: u64,
150 pub require_email_verification: bool,
152 pub require_phone_verification: bool,
154 pub field_validation_rules: HashMap<String, String>,
156}
157
158impl Default for RegistrationConfig {
159 fn default() -> Self {
160 Self {
161 enabled: true,
162 registration_endpoint: "/connect/register".to_string(),
163 required_fields: vec!["email".to_string()],
164 optional_fields: vec![
165 "given_name".to_string(),
166 "family_name".to_string(),
167 "name".to_string(),
168 "preferred_username".to_string(),
169 "phone_number".to_string(),
170 ],
171 session_timeout: 1800, require_email_verification: true,
173 require_phone_verification: false,
174 field_validation_rules: HashMap::new(),
175 }
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct RegistrationManager {
182 config: RegistrationConfig,
184 error_manager: OidcErrorManager,
186 registration_sessions: HashMap<String, RegistrationData>,
188}
189
190impl RegistrationManager {
191 pub fn new(config: RegistrationConfig) -> Self {
193 Self {
194 config,
195 error_manager: OidcErrorManager::default(),
196 registration_sessions: HashMap::new(),
197 }
198 }
199
200 pub fn with_error_manager(config: RegistrationConfig, error_manager: OidcErrorManager) -> Self {
202 Self {
203 config,
204 error_manager,
205 registration_sessions: HashMap::new(),
206 }
207 }
208
209 pub fn create_registration_disabled_error(&self, state: Option<String>) -> OidcErrorResponse {
211 self.error_manager.create_error_response(
212 OidcErrorCode::RegistrationNotSupported,
213 Some("User registration is not enabled on this server".to_string()),
214 state,
215 HashMap::new(),
216 )
217 }
218
219 pub fn create_invalid_registration_request_error(
221 &self,
222 description: String,
223 state: Option<String>,
224 ) -> OidcErrorResponse {
225 self.error_manager.create_error_response(
226 OidcErrorCode::InvalidRequest,
227 Some(description),
228 state,
229 HashMap::new(),
230 )
231 }
232
233 pub fn create_session_not_found_error(&self, state: Option<String>) -> OidcErrorResponse {
235 self.error_manager.create_error_response(
236 OidcErrorCode::SessionSelectionRequired,
237 Some("Registration session not found or expired".to_string()),
238 state,
239 HashMap::new(),
240 )
241 }
242
243 pub fn create_registration_incomplete_error(
245 &self,
246 missing_fields: Vec<String>,
247 state: Option<String>,
248 ) -> OidcErrorResponse {
249 let mut additional_details = HashMap::new();
250 additional_details.insert(
251 "missing_fields".to_string(),
252 serde_json::to_value(missing_fields.clone()).unwrap(),
253 );
254
255 self.error_manager.create_error_response(
256 OidcErrorCode::RegistrationRequired,
257 Some(format!(
258 "Registration incomplete. Missing required fields: {}",
259 missing_fields.join(", ")
260 )),
261 state,
262 additional_details,
263 )
264 }
265
266 pub fn create_session_expired_error(&self, state: Option<String>) -> OidcErrorResponse {
268 self.error_manager.create_error_response(
269 OidcErrorCode::LoginRequired,
270 Some("Registration session has expired. Please start registration again".to_string()),
271 state,
272 HashMap::new(),
273 )
274 }
275
276 pub fn get_error_manager(&self) -> &OidcErrorManager {
278 &self.error_manager
279 }
280
281 pub fn update_error_manager(&mut self, error_manager: OidcErrorManager) {
283 self.error_manager = error_manager;
284 }
285
286 pub fn is_registration_requested(&self, prompt: Option<&str>) -> bool {
288 if !self.config.enabled {
289 return false;
290 }
291
292 if let Some(prompt_values) = prompt {
293 let prompts: Vec<&str> = prompt_values.split_whitespace().collect();
294 prompts.contains(&"create")
295 } else {
296 false
297 }
298 }
299
300 pub fn initiate_registration(&mut self, request: RegistrationRequest) -> Result<String> {
302 if !self.config.enabled {
303 let error_response = self.create_registration_disabled_error(request.state.clone());
304 return Err(AuthError::validation(format!(
305 "Registration disabled: {}",
306 error_response.error_description.unwrap_or_default()
307 )));
308 }
309
310 if !self.is_registration_requested(request.prompt.as_deref()) {
312 let error_response = self.create_invalid_registration_request_error(
313 "Registration requires prompt=create parameter".to_string(),
314 request.state.clone(),
315 );
316 return Err(AuthError::validation(format!(
317 "Invalid request: {}",
318 error_response.error_description.unwrap_or_default()
319 )));
320 }
321
322 let registration_id = Uuid::new_v4().to_string();
324 let now = std::time::SystemTime::now()
325 .duration_since(std::time::UNIX_EPOCH)
326 .unwrap()
327 .as_secs();
328
329 let mut registration_data = RegistrationData {
330 registration_id: registration_id.clone(),
331 email: None,
332 phone_number: None,
333 given_name: None,
334 family_name: None,
335 name: None,
336 preferred_username: None,
337 picture: None,
338 website: None,
339 gender: None,
340 birthdate: None,
341 zoneinfo: None,
342 locale: None,
343 custom_fields: HashMap::new(),
344 completed: false,
345 created_at: now,
346 };
347
348 if let Some(login_hint) = &request.login_hint {
350 if login_hint.contains('@') {
351 registration_data.email = Some(login_hint.clone());
352 } else {
353 registration_data.preferred_username = Some(login_hint.clone());
354 }
355 }
356
357 if let Some(metadata_str) = &request.registration_metadata {
359 match serde_json::from_str::<HashMap<String, serde_json::Value>>(metadata_str) {
360 Ok(metadata) => {
361 registration_data.custom_fields.extend(metadata);
362 }
363 Err(_) => {
364 let error_response = self.create_invalid_registration_request_error(
365 "Invalid registration metadata JSON format".to_string(),
366 request.state.clone(),
367 );
368 return Err(AuthError::validation(format!(
369 "Invalid metadata: {}",
370 error_response.error_description.unwrap_or_default()
371 )));
372 }
373 }
374 }
375
376 self.registration_sessions
377 .insert(registration_id.clone(), registration_data);
378
379 Ok(registration_id)
380 }
381
382 pub fn update_registration_data(
384 &mut self,
385 registration_id: &str,
386 updates: HashMap<String, serde_json::Value>,
387 ) -> Result<()> {
388 if !self.registration_sessions.contains_key(registration_id) {
390 let error_response = self.create_session_not_found_error(None);
391 return Err(AuthError::validation(format!(
392 "Session error: {}",
393 error_response.error_description.unwrap_or_default()
394 )));
395 }
396
397 let registration = self.registration_sessions.get_mut(registration_id).unwrap(); let now = std::time::SystemTime::now()
401 .duration_since(std::time::UNIX_EPOCH)
402 .unwrap()
403 .as_secs();
404
405 if now - registration.created_at > self.config.session_timeout {
406 let error_response = self.create_session_expired_error(None);
407 return Err(AuthError::validation(format!(
408 "Session expired: {}",
409 error_response.error_description.unwrap_or_default()
410 )));
411 }
412
413 for (key, value) in updates {
415 match key.as_str() {
416 "email" => registration.email = value.as_str().map(|s| s.to_string()),
417 "phone_number" => registration.phone_number = value.as_str().map(|s| s.to_string()),
418 "given_name" => registration.given_name = value.as_str().map(|s| s.to_string()),
419 "family_name" => registration.family_name = value.as_str().map(|s| s.to_string()),
420 "name" => registration.name = value.as_str().map(|s| s.to_string()),
421 "preferred_username" => {
422 registration.preferred_username = value.as_str().map(|s| s.to_string())
423 }
424 "picture" => registration.picture = value.as_str().map(|s| s.to_string()),
425 "website" => registration.website = value.as_str().map(|s| s.to_string()),
426 "gender" => registration.gender = value.as_str().map(|s| s.to_string()),
427 "birthdate" => registration.birthdate = value.as_str().map(|s| s.to_string()),
428 "zoneinfo" => registration.zoneinfo = value.as_str().map(|s| s.to_string()),
429 "locale" => registration.locale = value.as_str().map(|s| s.to_string()),
430 _ => {
431 registration.custom_fields.insert(key, value);
433 }
434 }
435 }
436
437 Ok(())
438 }
439
440 pub fn validate_registration_data(&self, registration_id: &str) -> Result<Vec<String>> {
442 if !self.registration_sessions.contains_key(registration_id) {
444 let error_response = self.create_session_not_found_error(None);
445 return Err(AuthError::validation(format!(
446 "Session error: {}",
447 error_response.error_description.unwrap_or_default()
448 )));
449 }
450
451 let registration = self.registration_sessions.get(registration_id).unwrap(); let mut missing_fields = Vec::new();
454
455 for field in &self.config.required_fields {
457 let is_present = match field.as_str() {
458 "email" => registration.email.is_some(),
459 "phone_number" => registration.phone_number.is_some(),
460 "given_name" => registration.given_name.is_some(),
461 "family_name" => registration.family_name.is_some(),
462 "name" => registration.name.is_some(),
463 "preferred_username" => registration.preferred_username.is_some(),
464 _ => registration.custom_fields.contains_key(field),
465 };
466
467 if !is_present {
468 missing_fields.push(field.clone());
469 }
470 }
471
472 Ok(missing_fields)
473 }
474
475 pub fn validate_registration_completeness(
477 &self,
478 registration_id: &str,
479 state: Option<String>,
480 ) -> Result<()> {
481 let missing_fields = self.validate_registration_data(registration_id)?;
482 if !missing_fields.is_empty() {
483 let error_response = self.create_registration_incomplete_error(missing_fields, state);
484 return Err(AuthError::validation(format!(
485 "Registration incomplete: {}",
486 error_response.error_description.unwrap_or_default()
487 )));
488 }
489 Ok(())
490 }
491
492 pub fn complete_registration(&mut self, registration_id: &str) -> Result<RegistrationResponse> {
494 self.validate_registration_completeness(registration_id, None)?;
496
497 if !self.registration_sessions.contains_key(registration_id) {
499 let error_response = self.create_session_not_found_error(None);
500 return Err(AuthError::validation(format!(
501 "Session error: {}",
502 error_response.error_description.unwrap_or_default()
503 )));
504 }
505
506 let mut registration = self.registration_sessions.remove(registration_id).unwrap(); let sub = format!("user_{}", Uuid::new_v4());
510
511 registration.completed = true;
513
514 let authorization_code = format!("reg_auth_{}", Uuid::new_v4());
522
523 Ok(RegistrationResponse {
524 sub,
525 completed: true,
526 code: Some(authorization_code),
527 state: None, })
529 }
530
531 pub fn get_registration_data(&self, registration_id: &str) -> Option<&RegistrationData> {
533 self.registration_sessions.get(registration_id)
534 }
535
536 pub fn generate_registration_form(&self, registration_id: &str) -> Result<String> {
538 if !self.registration_sessions.contains_key(registration_id) {
540 let error_response = self.create_session_not_found_error(None);
541 return Err(AuthError::validation(format!(
542 "Session error: {}",
543 error_response.error_description.unwrap_or_default()
544 )));
545 }
546
547 let registration = self.registration_sessions.get(registration_id).unwrap(); let mut form = format!(
550 r#"<!DOCTYPE html>
551<html>
552<head>
553 <title>User Registration</title>
554 <style>
555 body {{ font-family: Arial, sans-serif; margin: 40px; }}
556 .form-group {{ margin-bottom: 15px; }}
557 label {{ display: block; margin-bottom: 5px; font-weight: bold; }}
558 input {{ width: 100%; padding: 8px; border: 1px solid #ccc; border-radius: 4px; }}
559 .required {{ color: red; }}
560 .submit-btn {{ background: #007bff; color: white; padding: 10px 20px; border: none; border-radius: 4px; cursor: pointer; }}
561 </style>
562</head>
563<body>
564 <h1>Create Your Account</h1>
565 <form method="post" action="/connect/register/{}/complete">
566"#,
567 registration.registration_id
568 );
569
570 for field in &self.config.required_fields {
572 let (field_name, field_type, current_value) = match field.as_str() {
573 "email" => (
574 "Email Address",
575 "email",
576 registration.email.as_deref().unwrap_or(""),
577 ),
578 "given_name" => (
579 "First Name",
580 "text",
581 registration.given_name.as_deref().unwrap_or(""),
582 ),
583 "family_name" => (
584 "Last Name",
585 "text",
586 registration.family_name.as_deref().unwrap_or(""),
587 ),
588 "phone_number" => (
589 "Phone Number",
590 "tel",
591 registration.phone_number.as_deref().unwrap_or(""),
592 ),
593 _ => (field.as_str(), "text", ""),
594 };
595
596 form.push_str(&format!(
597 r#" <div class="form-group">
598 <label for="{}">{} <span class="required">*</span></label>
599 <input type="{}" id="{}" name="{}" value="{}" required>
600 </div>
601"#,
602 field, field_name, field_type, field, field, current_value
603 ));
604 }
605
606 for field in &self.config.optional_fields {
608 if !self.config.required_fields.contains(field) {
609 let (field_name, field_type, current_value) = match field.as_str() {
610 "preferred_username" => (
611 "Username",
612 "text",
613 registration.preferred_username.as_deref().unwrap_or(""),
614 ),
615 "website" => (
616 "Website",
617 "url",
618 registration.website.as_deref().unwrap_or(""),
619 ),
620 "picture" => (
621 "Profile Picture URL",
622 "url",
623 registration.picture.as_deref().unwrap_or(""),
624 ),
625 _ => (field.as_str(), "text", ""),
626 };
627
628 form.push_str(&format!(
629 r#" <div class="form-group">
630 <label for="{}">{}</label>
631 <input type="{}" id="{}" name="{}" value="{}">
632 </div>
633"#,
634 field, field_name, field_type, field, field, current_value
635 ));
636 }
637 }
638
639 form.push_str(
640 r#" <button type="submit" class="submit-btn">Create Account</button>
641 </form>
642</body>
643</html>"#,
644 );
645
646 Ok(form)
647 }
648
649 pub fn cleanup_expired_sessions(&mut self) -> usize {
651 let now = std::time::SystemTime::now()
652 .duration_since(std::time::UNIX_EPOCH)
653 .unwrap()
654 .as_secs();
655
656 let initial_count = self.registration_sessions.len();
657
658 self.registration_sessions
659 .retain(|_, registration| now - registration.created_at < self.config.session_timeout);
660
661 initial_count - self.registration_sessions.len()
662 }
663
664 pub fn get_discovery_metadata(&self) -> HashMap<String, serde_json::Value> {
666 let mut metadata = HashMap::new();
667
668 if self.config.enabled {
669 metadata.insert(
670 "registration_endpoint".to_string(),
671 serde_json::Value::String(self.config.registration_endpoint.clone()),
672 );
673 metadata.insert(
674 "prompt_values_supported".to_string(),
675 serde_json::Value::Array(vec![
676 serde_json::Value::String("none".to_string()),
677 serde_json::Value::String("login".to_string()),
678 serde_json::Value::String("consent".to_string()),
679 serde_json::Value::String("select_account".to_string()),
680 serde_json::Value::String("create".to_string()),
681 ]),
682 );
683 }
684
685 metadata
686 }
687}
688
689#[cfg(test)]
690mod tests {
691 use super::*;
692
693 #[test]
694 fn test_error_manager_integration() {
695 let mut manager = RegistrationManager::new(RegistrationConfig::default());
696
697 let disabled_config = RegistrationConfig {
699 enabled: false,
700 ..Default::default()
701 };
702 let mut disabled_manager = RegistrationManager::new(disabled_config);
703
704 let request = RegistrationRequest {
705 client_id: "test_client".to_string(),
706 redirect_uri: "https://client.example.com/callback".to_string(),
707 scope: "openid profile email".to_string(),
708 response_type: "code".to_string(),
709 state: Some("state123".to_string()),
710 nonce: Some("nonce456".to_string()),
711 prompt: Some("create".to_string()),
712 login_hint: None,
713 ui_locales: None,
714 registration_metadata: None,
715 claims: None,
716 };
717
718 let result = disabled_manager.initiate_registration(request.clone());
719 assert!(result.is_err());
720 assert!(
721 result
722 .unwrap_err()
723 .to_string()
724 .contains("Registration disabled")
725 );
726
727 let invalid_request = RegistrationRequest {
729 prompt: Some("login".to_string()), ..request.clone()
731 };
732
733 let result = manager.initiate_registration(invalid_request);
734 assert!(result.is_err());
735 assert!(result.unwrap_err().to_string().contains("Invalid request"));
736
737 let invalid_metadata_request = RegistrationRequest {
739 registration_metadata: Some("invalid json".to_string()),
740 ..request
741 };
742
743 let result = manager.initiate_registration(invalid_metadata_request);
744 assert!(result.is_err());
745 assert!(result.unwrap_err().to_string().contains("Invalid metadata"));
746 }
747
748 #[test]
749 fn test_error_manager_session_handling() {
750 let mut manager = RegistrationManager::new(RegistrationConfig::default());
751
752 let result = manager.update_registration_data("nonexistent", HashMap::new());
754 assert!(result.is_err());
755 assert!(result.unwrap_err().to_string().contains("Session error"));
756
757 let config = RegistrationConfig {
759 required_fields: vec!["email".to_string(), "given_name".to_string()],
760 ..Default::default()
761 };
762 let mut manager = RegistrationManager::new(config);
763
764 let registration_data = RegistrationData {
765 registration_id: "test123".to_string(),
766 email: Some("user@example.com".to_string()),
767 given_name: None, ..Default::default()
769 };
770
771 manager
772 .registration_sessions
773 .insert("test123".to_string(), registration_data);
774
775 let result =
776 manager.validate_registration_completeness("test123", Some("state456".to_string()));
777 assert!(result.is_err());
778 assert!(
779 result
780 .unwrap_err()
781 .to_string()
782 .contains("Registration incomplete")
783 );
784 }
785
786 #[test]
787 fn test_error_manager_custom_configuration() {
788 use crate::server::oidc::oidc_error_extensions::{OidcErrorCode, OidcErrorManager};
789
790 let mut custom_error_manager = OidcErrorManager::default();
791 custom_error_manager.add_custom_error_mapping(
792 "custom_registration_error".to_string(),
793 OidcErrorCode::RegistrationRequired,
794 );
795
796 let manager = RegistrationManager::with_error_manager(
797 RegistrationConfig::default(),
798 custom_error_manager,
799 );
800
801 assert!(
803 manager
804 .get_error_manager()
805 .has_custom_mapping("custom_registration_error")
806 );
807
808 let error_response =
810 manager.create_registration_disabled_error(Some("test_state".to_string()));
811 assert_eq!(error_response.state.as_ref().unwrap(), "test_state");
812
813 let session_error = manager.create_session_not_found_error(None);
814 assert_eq!(session_error.error, OidcErrorCode::SessionSelectionRequired);
815 }
816
817 #[test]
818 fn test_registration_request_detection() {
819 let manager = RegistrationManager::new(RegistrationConfig::default());
820
821 assert!(manager.is_registration_requested(Some("create")));
822 assert!(manager.is_registration_requested(Some("login create")));
823 assert!(manager.is_registration_requested(Some("create consent")));
824 assert!(!manager.is_registration_requested(Some("login")));
825 assert!(!manager.is_registration_requested(None));
826 }
827
828 #[test]
829 fn test_registration_initiation() {
830 let mut manager = RegistrationManager::new(RegistrationConfig::default());
831
832 let request = RegistrationRequest {
833 client_id: "test_client".to_string(),
834 redirect_uri: "https://client.example.com/callback".to_string(),
835 scope: "openid profile email".to_string(),
836 response_type: "code".to_string(),
837 state: Some("state123".to_string()),
838 nonce: Some("nonce456".to_string()),
839 prompt: Some("create".to_string()),
840 login_hint: Some("user@example.com".to_string()),
841 ui_locales: None,
842 registration_metadata: None,
843 claims: None,
844 };
845
846 let registration_id = manager.initiate_registration(request).unwrap();
847 assert!(!registration_id.is_empty());
848
849 let registration_data = manager.get_registration_data(®istration_id).unwrap();
850 assert_eq!(
851 registration_data.email,
852 Some("user@example.com".to_string())
853 );
854 assert!(!registration_data.completed);
855 }
856
857 #[test]
858 fn test_registration_data_validation() {
859 let mut manager = RegistrationManager::new(RegistrationConfig {
860 required_fields: vec!["email".to_string(), "given_name".to_string()],
861 ..RegistrationConfig::default()
862 });
863
864 let registration_id = "test_reg_123";
865 let registration_data = RegistrationData {
866 registration_id: registration_id.to_string(),
867 email: Some("user@example.com".to_string()),
868 given_name: None, ..Default::default()
870 };
871
872 manager
873 .registration_sessions
874 .insert(registration_id.to_string(), registration_data);
875
876 let missing_fields = manager.validate_registration_data(registration_id).unwrap();
877 assert_eq!(missing_fields, vec!["given_name"]);
878 }
879}