1use std::error::Error;
4use std::fmt;
5use std::sync::{Arc, Mutex};
6
7use time::{Duration, OffsetDateTime};
8
9use crate::crypto::password::{hash_password, verify_password};
10use crate::db::{DbAdapter, DbRecord, Session, User};
11use crate::error::OpenAuthError;
12use crate::options::SecondaryStorage;
13use crate::session::{CreateSessionInput, SessionStore};
14use crate::user::{CreateCredentialAccountInput, CreateUserInput, DbUserStore};
15
16pub type PasswordHashFn = fn(&str) -> Result<String, OpenAuthError>;
17pub type PasswordVerifyFn = fn(&str, &str) -> Result<bool, OpenAuthError>;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum AuthFlowErrorCode {
21 InvalidEmail,
22 InvalidPasswordLength,
23 InvalidEmailOrPassword,
24 UserAlreadyExists,
25 EmailNotVerified,
26 FailedToCreateSession,
27 StorageError,
28}
29
30impl AuthFlowErrorCode {
31 pub fn as_str(self) -> &'static str {
32 match self {
33 Self::InvalidEmail => "INVALID_EMAIL",
34 Self::InvalidPasswordLength => "INVALID_PASSWORD_LENGTH",
35 Self::InvalidEmailOrPassword => "INVALID_EMAIL_OR_PASSWORD",
36 Self::UserAlreadyExists => "USER_ALREADY_EXISTS",
37 Self::EmailNotVerified => "EMAIL_NOT_VERIFIED",
38 Self::FailedToCreateSession => "FAILED_TO_CREATE_SESSION",
39 Self::StorageError => "STORAGE_ERROR",
40 }
41 }
42
43 pub fn message(self) -> &'static str {
44 match self {
45 Self::InvalidEmail => "Invalid email",
46 Self::InvalidPasswordLength => "Invalid password length",
47 Self::InvalidEmailOrPassword => "Invalid email or password",
48 Self::UserAlreadyExists => "User already exists",
49 Self::EmailNotVerified => "Email not verified",
50 Self::FailedToCreateSession => "Failed to create session",
51 Self::StorageError => "Storage error",
52 }
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct AuthFlowError {
58 code: AuthFlowErrorCode,
59 message: String,
60}
61
62impl AuthFlowError {
63 pub fn new(code: AuthFlowErrorCode) -> Self {
64 Self {
65 code,
66 message: code.message().to_owned(),
67 }
68 }
69
70 pub fn storage(error: OpenAuthError) -> Self {
71 Self {
72 code: AuthFlowErrorCode::StorageError,
73 message: error.to_string(),
74 }
75 }
76
77 pub fn code(&self) -> AuthFlowErrorCode {
78 self.code
79 }
80
81 pub fn code_str(&self) -> &'static str {
82 self.code.as_str()
83 }
84
85 pub fn message(&self) -> &str {
86 self.message.as_str()
87 }
88}
89
90impl fmt::Display for AuthFlowError {
91 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
92 write!(formatter, "{}: {}", self.code.as_str(), self.message)
93 }
94}
95
96impl Error for AuthFlowError {}
97
98impl From<OpenAuthError> for AuthFlowError {
99 fn from(error: OpenAuthError) -> Self {
100 Self::storage(error)
101 }
102}
103
104#[derive(Clone)]
105pub struct EmailPasswordConfig {
106 pub session_expires_in: u64,
107 pub dont_remember_session_expires_in: u64,
108 pub min_password_length: usize,
109 pub max_password_length: usize,
110 pub require_email_verification: bool,
111 pub secondary_storage: Option<Arc<dyn SecondaryStorage>>,
112 pub store_session_in_database: bool,
113 pub preserve_session_in_database: bool,
114}
115
116impl Default for EmailPasswordConfig {
117 fn default() -> Self {
118 Self {
119 session_expires_in: 60 * 60 * 24 * 7,
120 dont_remember_session_expires_in: 60 * 60 * 24,
121 min_password_length: 8,
122 max_password_length: 128,
123 require_email_verification: false,
124 secondary_storage: None,
125 store_session_in_database: false,
126 preserve_session_in_database: false,
127 }
128 }
129}
130
131impl fmt::Debug for EmailPasswordConfig {
132 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
133 formatter
134 .debug_struct("EmailPasswordConfig")
135 .field("session_expires_in", &self.session_expires_in)
136 .field(
137 "dont_remember_session_expires_in",
138 &self.dont_remember_session_expires_in,
139 )
140 .field("min_password_length", &self.min_password_length)
141 .field("max_password_length", &self.max_password_length)
142 .field(
143 "require_email_verification",
144 &self.require_email_verification,
145 )
146 .field(
147 "secondary_storage",
148 &self
149 .secondary_storage
150 .as_ref()
151 .map(|_| "<secondary-storage>"),
152 )
153 .field("store_session_in_database", &self.store_session_in_database)
154 .field(
155 "preserve_session_in_database",
156 &self.preserve_session_in_database,
157 )
158 .finish()
159 }
160}
161
162#[derive(Debug, Clone, PartialEq)]
163pub struct SignUpInput {
164 pub name: String,
165 pub email: String,
166 pub password: String,
167 pub image: Option<String>,
168 pub username: Option<String>,
169 pub display_username: Option<String>,
170 pub remember_me: bool,
171 pub ip_address: Option<String>,
172 pub user_agent: Option<String>,
173 pub additional_user_fields: DbRecord,
174 pub additional_session_fields: DbRecord,
175}
176
177impl SignUpInput {
178 pub fn new(
179 name: impl Into<String>,
180 email: impl Into<String>,
181 password: impl Into<String>,
182 ) -> Self {
183 Self {
184 name: name.into(),
185 email: email.into(),
186 password: password.into(),
187 image: None,
188 username: None,
189 display_username: None,
190 remember_me: true,
191 ip_address: None,
192 user_agent: None,
193 additional_user_fields: DbRecord::new(),
194 additional_session_fields: DbRecord::new(),
195 }
196 }
197
198 #[must_use]
199 pub fn image(mut self, image: impl Into<String>) -> Self {
200 self.image = Some(image.into());
201 self
202 }
203
204 #[must_use]
205 pub fn username(mut self, username: impl Into<String>) -> Self {
206 self.username = Some(username.into());
207 self
208 }
209
210 #[must_use]
211 pub fn display_username(mut self, display_username: impl Into<String>) -> Self {
212 self.display_username = Some(display_username.into());
213 self
214 }
215
216 #[must_use]
217 pub fn remember_me(mut self, remember_me: bool) -> Self {
218 self.remember_me = remember_me;
219 self
220 }
221
222 #[must_use]
223 pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
224 self.ip_address = Some(ip_address.into());
225 self
226 }
227
228 #[must_use]
229 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
230 self.user_agent = Some(user_agent.into());
231 self
232 }
233
234 #[must_use]
235 pub fn additional_user_fields(mut self, fields: DbRecord) -> Self {
236 self.additional_user_fields = fields;
237 self
238 }
239
240 #[must_use]
241 pub fn additional_session_fields(mut self, fields: DbRecord) -> Self {
242 self.additional_session_fields = fields;
243 self
244 }
245}
246
247#[derive(Debug, Clone, PartialEq)]
248pub struct SignInInput {
249 pub email: String,
250 pub password: String,
251 pub remember_me: bool,
252 pub ip_address: Option<String>,
253 pub user_agent: Option<String>,
254 pub additional_session_fields: DbRecord,
255}
256
257impl SignInInput {
258 pub fn new(email: impl Into<String>, password: impl Into<String>) -> Self {
259 Self {
260 email: email.into(),
261 password: password.into(),
262 remember_me: true,
263 ip_address: None,
264 user_agent: None,
265 additional_session_fields: DbRecord::new(),
266 }
267 }
268
269 #[must_use]
270 pub fn remember_me(mut self, remember_me: bool) -> Self {
271 self.remember_me = remember_me;
272 self
273 }
274
275 #[must_use]
276 pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
277 self.ip_address = Some(ip_address.into());
278 self
279 }
280
281 #[must_use]
282 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
283 self.user_agent = Some(user_agent.into());
284 self
285 }
286
287 #[must_use]
288 pub fn additional_session_fields(mut self, fields: DbRecord) -> Self {
289 self.additional_session_fields = fields;
290 self
291 }
292}
293
294#[derive(Debug, Clone, PartialEq, Eq)]
295pub struct EmailPasswordAuthResult {
296 pub user: User,
297 pub session: Session,
298}
299
300#[derive(Clone)]
301pub struct EmailPasswordAuth<'a> {
302 adapter: &'a dyn DbAdapter,
303 config: EmailPasswordConfig,
304 hash_password: PasswordHashFn,
305 verify_password: PasswordVerifyFn,
306}
307
308impl<'a> EmailPasswordAuth<'a> {
309 pub fn new(
310 adapter: &'a dyn DbAdapter,
311 config: EmailPasswordConfig,
312 hash_password: PasswordHashFn,
313 verify_password: PasswordVerifyFn,
314 ) -> Self {
315 Self {
316 adapter,
317 config,
318 hash_password,
319 verify_password,
320 }
321 }
322
323 pub fn with_defaults(adapter: &'a dyn DbAdapter, config: EmailPasswordConfig) -> Self {
324 Self::new(adapter, config, hash_password, verify_password)
325 }
326
327 pub async fn sign_up(
328 &self,
329 input: SignUpInput,
330 ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
331 self.validate_email_and_password(&input.email, &input.password)?;
332 let users = DbUserStore::new(self.adapter);
333 if users.find_user_by_email(&input.email).await?.is_some() {
334 return Err(AuthFlowError::new(AuthFlowErrorCode::UserAlreadyExists));
335 }
336
337 let password_hash = (self.hash_password)(&input.password)?;
338 let mut create_user = CreateUserInput::new(input.name, input.email)
339 .additional_fields(input.additional_user_fields);
340 if let Some(image) = input.image {
341 create_user = create_user.image(image);
342 }
343 if let Some(username) = input.username {
344 create_user = create_user.username(username);
345 }
346 if let Some(display_username) = input.display_username {
347 create_user = create_user.display_username(display_username);
348 }
349 let result = Arc::new(Mutex::new(None));
350 let result_for_transaction = Arc::clone(&result);
351 let config = self.config.clone();
352 let transaction_status = self
353 .adapter
354 .transaction(Box::new(move |transaction| {
355 Box::pin(async move {
356 let outcome = create_sign_up_records(SignUpRecordsInput {
357 adapter: transaction.as_ref(),
358 config: &config,
359 create_user,
360 password_hash,
361 remember_me: input.remember_me,
362 ip_address: input.ip_address,
363 user_agent: input.user_agent,
364 additional_session_fields: input.additional_session_fields,
365 })
366 .await;
367 match outcome {
368 Ok(result) => {
369 store_sign_up_result(&result_for_transaction, Ok(result))?;
370 Ok(())
371 }
372 Err(error) => {
373 let transaction_error = OpenAuthError::Adapter(error.to_string());
374 store_sign_up_result(&result_for_transaction, Err(error))?;
375 Err(transaction_error)
376 }
377 }
378 })
379 }))
380 .await;
381
382 match transaction_status {
383 Ok(()) => match take_sign_up_result(&result)? {
384 Some(Ok(result)) => Ok(result),
385 Some(Err(error)) => Err(error),
386 None => Err(AuthFlowError::storage(OpenAuthError::Adapter(
387 "sign-up transaction completed without a result".to_owned(),
388 ))),
389 },
390 Err(error) => match take_sign_up_result(&result)? {
391 Some(Err(auth_error)) => Err(auth_error),
392 _ => Err(AuthFlowError::storage(error)),
393 },
394 }
395 }
396
397 pub async fn sign_in(
398 &self,
399 input: SignInInput,
400 ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
401 validate_email(&input.email)?;
402 let users = DbUserStore::new(self.adapter);
403 let Some(user_with_accounts) = users.find_user_by_email_with_accounts(&input.email).await?
404 else {
405 let _ = (self.hash_password)(&input.password);
406 return Err(AuthFlowError::new(
407 AuthFlowErrorCode::InvalidEmailOrPassword,
408 ));
409 };
410 let Some(account) = user_with_accounts
411 .accounts
412 .iter()
413 .find(|account| account.provider_id == "credential")
414 else {
415 let _ = (self.hash_password)(&input.password);
416 return Err(AuthFlowError::new(
417 AuthFlowErrorCode::InvalidEmailOrPassword,
418 ));
419 };
420 let Some(password_hash) = account.password.as_deref() else {
421 let _ = (self.hash_password)(&input.password);
422 return Err(AuthFlowError::new(
423 AuthFlowErrorCode::InvalidEmailOrPassword,
424 ));
425 };
426 if !(self.verify_password)(password_hash, &input.password)? {
427 return Err(AuthFlowError::new(
428 AuthFlowErrorCode::InvalidEmailOrPassword,
429 ));
430 }
431 if self.config.require_email_verification && !user_with_accounts.user.email_verified {
432 return Err(AuthFlowError::new(AuthFlowErrorCode::EmailNotVerified));
433 }
434 let session = create_session_record(
435 self.adapter,
436 &self.config,
437 &user_with_accounts.user.id,
438 input.remember_me,
439 input.ip_address,
440 input.user_agent,
441 input.additional_session_fields,
442 )
443 .await?;
444
445 Ok(EmailPasswordAuthResult {
446 user: user_with_accounts.user,
447 session,
448 })
449 }
450
451 fn validate_email_and_password(
452 &self,
453 email: &str,
454 password: &str,
455 ) -> Result<(), AuthFlowError> {
456 validate_email(email)?;
457 if password.len() < self.config.min_password_length
458 || password.len() > self.config.max_password_length
459 {
460 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidPasswordLength));
461 }
462 Ok(())
463 }
464}
465
466struct SignUpRecordsInput<'a> {
467 adapter: &'a dyn DbAdapter,
468 config: &'a EmailPasswordConfig,
469 create_user: CreateUserInput,
470 password_hash: String,
471 remember_me: bool,
472 ip_address: Option<String>,
473 user_agent: Option<String>,
474 additional_session_fields: DbRecord,
475}
476
477async fn create_sign_up_records(
478 input: SignUpRecordsInput<'_>,
479) -> Result<EmailPasswordAuthResult, AuthFlowError> {
480 let users = DbUserStore::new(input.adapter);
481 let user = users.create_user(input.create_user).await?;
482 users
483 .create_credential_account(CreateCredentialAccountInput::new(
484 user.id.clone(),
485 input.password_hash,
486 ))
487 .await?;
488 let session = create_session_record(
489 input.adapter,
490 input.config,
491 &user.id,
492 input.remember_me,
493 input.ip_address,
494 input.user_agent,
495 input.additional_session_fields,
496 )
497 .await?;
498
499 Ok(EmailPasswordAuthResult { user, session })
500}
501
502async fn create_session_record(
503 adapter: &dyn DbAdapter,
504 config: &EmailPasswordConfig,
505 user_id: &str,
506 remember_me: bool,
507 ip_address: Option<String>,
508 user_agent: Option<String>,
509 additional_fields: DbRecord,
510) -> Result<Session, AuthFlowError> {
511 let expires_in = if remember_me {
512 config.session_expires_in
513 } else {
514 config.dont_remember_session_expires_in
515 };
516 let seconds = i64::try_from(expires_in)
517 .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))?;
518 let expires_at = OffsetDateTime::now_utc() + Duration::seconds(seconds);
519 let mut input =
520 CreateSessionInput::new(user_id, expires_at).additional_fields(additional_fields);
521 if let Some(ip_address) = ip_address {
522 input = input.ip_address(ip_address);
523 }
524 if let Some(user_agent) = user_agent {
525 input = input.user_agent(user_agent);
526 }
527
528 SessionStore::with_storage(
529 adapter,
530 config.secondary_storage.clone(),
531 config.store_session_in_database,
532 config.preserve_session_in_database,
533 )
534 .create_session(input)
535 .await
536 .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))
537}
538
539fn store_sign_up_result(
540 result: &Mutex<Option<Result<EmailPasswordAuthResult, AuthFlowError>>>,
541 value: Result<EmailPasswordAuthResult, AuthFlowError>,
542) -> Result<(), OpenAuthError> {
543 let mut guard = result.lock().map_err(|_| OpenAuthError::LockPoisoned {
544 context: "sign-up result",
545 })?;
546 *guard = Some(value);
547 Ok(())
548}
549
550fn take_sign_up_result(
551 result: &Mutex<Option<Result<EmailPasswordAuthResult, AuthFlowError>>>,
552) -> Result<Option<Result<EmailPasswordAuthResult, AuthFlowError>>, AuthFlowError> {
553 result
554 .lock()
555 .map_err(|_| {
556 AuthFlowError::storage(OpenAuthError::LockPoisoned {
557 context: "sign-up result",
558 })
559 })
560 .map(|mut guard| guard.take())
561}
562
563fn validate_email(email: &str) -> Result<(), AuthFlowError> {
564 let email = email.trim();
565 let Some((local, domain)) = email.split_once('@') else {
566 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
567 };
568 if local.is_empty()
569 || domain.is_empty()
570 || domain.starts_with('.')
571 || domain.ends_with('.')
572 || !domain.contains('.')
573 {
574 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
575 }
576 Ok(())
577}