1use std::error::Error;
4use std::fmt;
5use std::sync::{Arc, Mutex};
6
7use time::{Duration, OffsetDateTime};
8
9use crate::crypto::password::{hash_password, verify_password};
10use crate::db::{DbAdapter, DbRecord, Session, User};
11use crate::error::OpenAuthError;
12use crate::session::{CreateSessionInput, DbSessionStore};
13use crate::user::{CreateCredentialAccountInput, CreateUserInput, DbUserStore};
14
15pub type PasswordHashFn = fn(&str) -> Result<String, OpenAuthError>;
16pub type PasswordVerifyFn = fn(&str, &str) -> Result<bool, OpenAuthError>;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum AuthFlowErrorCode {
20 InvalidEmail,
21 InvalidPasswordLength,
22 InvalidEmailOrPassword,
23 UserAlreadyExists,
24 EmailNotVerified,
25 FailedToCreateSession,
26 StorageError,
27}
28
29impl AuthFlowErrorCode {
30 pub fn as_str(self) -> &'static str {
31 match self {
32 Self::InvalidEmail => "INVALID_EMAIL",
33 Self::InvalidPasswordLength => "INVALID_PASSWORD_LENGTH",
34 Self::InvalidEmailOrPassword => "INVALID_EMAIL_OR_PASSWORD",
35 Self::UserAlreadyExists => "USER_ALREADY_EXISTS",
36 Self::EmailNotVerified => "EMAIL_NOT_VERIFIED",
37 Self::FailedToCreateSession => "FAILED_TO_CREATE_SESSION",
38 Self::StorageError => "STORAGE_ERROR",
39 }
40 }
41
42 pub fn message(self) -> &'static str {
43 match self {
44 Self::InvalidEmail => "Invalid email",
45 Self::InvalidPasswordLength => "Invalid password length",
46 Self::InvalidEmailOrPassword => "Invalid email or password",
47 Self::UserAlreadyExists => "User already exists",
48 Self::EmailNotVerified => "Email not verified",
49 Self::FailedToCreateSession => "Failed to create session",
50 Self::StorageError => "Storage error",
51 }
52 }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct AuthFlowError {
57 code: AuthFlowErrorCode,
58 message: String,
59}
60
61impl AuthFlowError {
62 pub fn new(code: AuthFlowErrorCode) -> Self {
63 Self {
64 code,
65 message: code.message().to_owned(),
66 }
67 }
68
69 pub fn storage(error: OpenAuthError) -> Self {
70 Self {
71 code: AuthFlowErrorCode::StorageError,
72 message: error.to_string(),
73 }
74 }
75
76 pub fn code(&self) -> AuthFlowErrorCode {
77 self.code
78 }
79
80 pub fn code_str(&self) -> &'static str {
81 self.code.as_str()
82 }
83
84 pub fn message(&self) -> &str {
85 self.message.as_str()
86 }
87}
88
89impl fmt::Display for AuthFlowError {
90 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
91 write!(formatter, "{}: {}", self.code.as_str(), self.message)
92 }
93}
94
95impl Error for AuthFlowError {}
96
97impl From<OpenAuthError> for AuthFlowError {
98 fn from(error: OpenAuthError) -> Self {
99 Self::storage(error)
100 }
101}
102
103#[derive(Debug, Clone, PartialEq, Eq)]
104pub struct EmailPasswordConfig {
105 pub session_expires_in: u64,
106 pub dont_remember_session_expires_in: u64,
107 pub min_password_length: usize,
108 pub max_password_length: usize,
109 pub require_email_verification: bool,
110}
111
112impl Default for EmailPasswordConfig {
113 fn default() -> Self {
114 Self {
115 session_expires_in: 60 * 60 * 24 * 7,
116 dont_remember_session_expires_in: 60 * 60 * 24,
117 min_password_length: 8,
118 max_password_length: 128,
119 require_email_verification: false,
120 }
121 }
122}
123
124#[derive(Debug, Clone, PartialEq)]
125pub struct SignUpInput {
126 pub name: String,
127 pub email: String,
128 pub password: String,
129 pub image: Option<String>,
130 pub username: Option<String>,
131 pub display_username: Option<String>,
132 pub remember_me: bool,
133 pub ip_address: Option<String>,
134 pub user_agent: Option<String>,
135 pub additional_user_fields: DbRecord,
136 pub additional_session_fields: DbRecord,
137}
138
139impl SignUpInput {
140 pub fn new(
141 name: impl Into<String>,
142 email: impl Into<String>,
143 password: impl Into<String>,
144 ) -> Self {
145 Self {
146 name: name.into(),
147 email: email.into(),
148 password: password.into(),
149 image: None,
150 username: None,
151 display_username: None,
152 remember_me: true,
153 ip_address: None,
154 user_agent: None,
155 additional_user_fields: DbRecord::new(),
156 additional_session_fields: DbRecord::new(),
157 }
158 }
159
160 #[must_use]
161 pub fn image(mut self, image: impl Into<String>) -> Self {
162 self.image = Some(image.into());
163 self
164 }
165
166 #[must_use]
167 pub fn username(mut self, username: impl Into<String>) -> Self {
168 self.username = Some(username.into());
169 self
170 }
171
172 #[must_use]
173 pub fn display_username(mut self, display_username: impl Into<String>) -> Self {
174 self.display_username = Some(display_username.into());
175 self
176 }
177
178 #[must_use]
179 pub fn remember_me(mut self, remember_me: bool) -> Self {
180 self.remember_me = remember_me;
181 self
182 }
183
184 #[must_use]
185 pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
186 self.ip_address = Some(ip_address.into());
187 self
188 }
189
190 #[must_use]
191 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
192 self.user_agent = Some(user_agent.into());
193 self
194 }
195
196 #[must_use]
197 pub fn additional_user_fields(mut self, fields: DbRecord) -> Self {
198 self.additional_user_fields = fields;
199 self
200 }
201
202 #[must_use]
203 pub fn additional_session_fields(mut self, fields: DbRecord) -> Self {
204 self.additional_session_fields = fields;
205 self
206 }
207}
208
209#[derive(Debug, Clone, PartialEq)]
210pub struct SignInInput {
211 pub email: String,
212 pub password: String,
213 pub remember_me: bool,
214 pub ip_address: Option<String>,
215 pub user_agent: Option<String>,
216 pub additional_session_fields: DbRecord,
217}
218
219impl SignInInput {
220 pub fn new(email: impl Into<String>, password: impl Into<String>) -> Self {
221 Self {
222 email: email.into(),
223 password: password.into(),
224 remember_me: true,
225 ip_address: None,
226 user_agent: None,
227 additional_session_fields: DbRecord::new(),
228 }
229 }
230
231 #[must_use]
232 pub fn remember_me(mut self, remember_me: bool) -> Self {
233 self.remember_me = remember_me;
234 self
235 }
236
237 #[must_use]
238 pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
239 self.ip_address = Some(ip_address.into());
240 self
241 }
242
243 #[must_use]
244 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
245 self.user_agent = Some(user_agent.into());
246 self
247 }
248
249 #[must_use]
250 pub fn additional_session_fields(mut self, fields: DbRecord) -> Self {
251 self.additional_session_fields = fields;
252 self
253 }
254}
255
256#[derive(Debug, Clone, PartialEq, Eq)]
257pub struct EmailPasswordAuthResult {
258 pub user: User,
259 pub session: Session,
260}
261
262#[derive(Clone)]
263pub struct EmailPasswordAuth<'a> {
264 adapter: &'a dyn DbAdapter,
265 config: EmailPasswordConfig,
266 hash_password: PasswordHashFn,
267 verify_password: PasswordVerifyFn,
268}
269
270impl<'a> EmailPasswordAuth<'a> {
271 pub fn new(
272 adapter: &'a dyn DbAdapter,
273 config: EmailPasswordConfig,
274 hash_password: PasswordHashFn,
275 verify_password: PasswordVerifyFn,
276 ) -> Self {
277 Self {
278 adapter,
279 config,
280 hash_password,
281 verify_password,
282 }
283 }
284
285 pub fn with_defaults(adapter: &'a dyn DbAdapter, config: EmailPasswordConfig) -> Self {
286 Self::new(adapter, config, hash_password, verify_password)
287 }
288
289 pub async fn sign_up(
290 &self,
291 input: SignUpInput,
292 ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
293 self.validate_email_and_password(&input.email, &input.password)?;
294 let users = DbUserStore::new(self.adapter);
295 if users.find_user_by_email(&input.email).await?.is_some() {
296 return Err(AuthFlowError::new(AuthFlowErrorCode::UserAlreadyExists));
297 }
298
299 let password_hash = (self.hash_password)(&input.password)?;
300 let mut create_user = CreateUserInput::new(input.name, input.email)
301 .additional_fields(input.additional_user_fields);
302 if let Some(image) = input.image {
303 create_user = create_user.image(image);
304 }
305 if let Some(username) = input.username {
306 create_user = create_user.username(username);
307 }
308 if let Some(display_username) = input.display_username {
309 create_user = create_user.display_username(display_username);
310 }
311 let result = Arc::new(Mutex::new(None));
312 let result_for_transaction = Arc::clone(&result);
313 let config = self.config.clone();
314 let transaction_status = self
315 .adapter
316 .transaction(Box::new(move |transaction| {
317 Box::pin(async move {
318 let outcome = create_sign_up_records(SignUpRecordsInput {
319 adapter: transaction.as_ref(),
320 config: &config,
321 create_user,
322 password_hash,
323 remember_me: input.remember_me,
324 ip_address: input.ip_address,
325 user_agent: input.user_agent,
326 additional_session_fields: input.additional_session_fields,
327 })
328 .await;
329 match outcome {
330 Ok(result) => {
331 store_sign_up_result(&result_for_transaction, Ok(result))?;
332 Ok(())
333 }
334 Err(error) => {
335 let transaction_error = OpenAuthError::Adapter(error.to_string());
336 store_sign_up_result(&result_for_transaction, Err(error))?;
337 Err(transaction_error)
338 }
339 }
340 })
341 }))
342 .await;
343
344 match transaction_status {
345 Ok(()) => match take_sign_up_result(&result)? {
346 Some(Ok(result)) => Ok(result),
347 Some(Err(error)) => Err(error),
348 None => Err(AuthFlowError::storage(OpenAuthError::Adapter(
349 "sign-up transaction completed without a result".to_owned(),
350 ))),
351 },
352 Err(error) => match take_sign_up_result(&result)? {
353 Some(Err(auth_error)) => Err(auth_error),
354 _ => Err(AuthFlowError::storage(error)),
355 },
356 }
357 }
358
359 pub async fn sign_in(
360 &self,
361 input: SignInInput,
362 ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
363 validate_email(&input.email)?;
364 let users = DbUserStore::new(self.adapter);
365 let Some(user_with_accounts) = users.find_user_by_email_with_accounts(&input.email).await?
366 else {
367 let _ = (self.hash_password)(&input.password);
368 return Err(AuthFlowError::new(
369 AuthFlowErrorCode::InvalidEmailOrPassword,
370 ));
371 };
372 let Some(account) = user_with_accounts
373 .accounts
374 .iter()
375 .find(|account| account.provider_id == "credential")
376 else {
377 let _ = (self.hash_password)(&input.password);
378 return Err(AuthFlowError::new(
379 AuthFlowErrorCode::InvalidEmailOrPassword,
380 ));
381 };
382 let Some(password_hash) = account.password.as_deref() else {
383 let _ = (self.hash_password)(&input.password);
384 return Err(AuthFlowError::new(
385 AuthFlowErrorCode::InvalidEmailOrPassword,
386 ));
387 };
388 if !(self.verify_password)(password_hash, &input.password)? {
389 return Err(AuthFlowError::new(
390 AuthFlowErrorCode::InvalidEmailOrPassword,
391 ));
392 }
393 if self.config.require_email_verification && !user_with_accounts.user.email_verified {
394 return Err(AuthFlowError::new(AuthFlowErrorCode::EmailNotVerified));
395 }
396 let session = create_session_record(
397 self.adapter,
398 &self.config,
399 &user_with_accounts.user.id,
400 input.remember_me,
401 input.ip_address,
402 input.user_agent,
403 input.additional_session_fields,
404 )
405 .await?;
406
407 Ok(EmailPasswordAuthResult {
408 user: user_with_accounts.user,
409 session,
410 })
411 }
412
413 fn validate_email_and_password(
414 &self,
415 email: &str,
416 password: &str,
417 ) -> Result<(), AuthFlowError> {
418 validate_email(email)?;
419 if password.len() < self.config.min_password_length
420 || password.len() > self.config.max_password_length
421 {
422 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidPasswordLength));
423 }
424 Ok(())
425 }
426}
427
428struct SignUpRecordsInput<'a> {
429 adapter: &'a dyn DbAdapter,
430 config: &'a EmailPasswordConfig,
431 create_user: CreateUserInput,
432 password_hash: String,
433 remember_me: bool,
434 ip_address: Option<String>,
435 user_agent: Option<String>,
436 additional_session_fields: DbRecord,
437}
438
439async fn create_sign_up_records(
440 input: SignUpRecordsInput<'_>,
441) -> Result<EmailPasswordAuthResult, AuthFlowError> {
442 let users = DbUserStore::new(input.adapter);
443 let user = users.create_user(input.create_user).await?;
444 users
445 .create_credential_account(CreateCredentialAccountInput::new(
446 user.id.clone(),
447 input.password_hash,
448 ))
449 .await?;
450 let session = create_session_record(
451 input.adapter,
452 input.config,
453 &user.id,
454 input.remember_me,
455 input.ip_address,
456 input.user_agent,
457 input.additional_session_fields,
458 )
459 .await?;
460
461 Ok(EmailPasswordAuthResult { user, session })
462}
463
464async fn create_session_record(
465 adapter: &dyn DbAdapter,
466 config: &EmailPasswordConfig,
467 user_id: &str,
468 remember_me: bool,
469 ip_address: Option<String>,
470 user_agent: Option<String>,
471 additional_fields: DbRecord,
472) -> Result<Session, AuthFlowError> {
473 let expires_in = if remember_me {
474 config.session_expires_in
475 } else {
476 config.dont_remember_session_expires_in
477 };
478 let seconds = i64::try_from(expires_in)
479 .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))?;
480 let expires_at = OffsetDateTime::now_utc() + Duration::seconds(seconds);
481 let mut input =
482 CreateSessionInput::new(user_id, expires_at).additional_fields(additional_fields);
483 if let Some(ip_address) = ip_address {
484 input = input.ip_address(ip_address);
485 }
486 if let Some(user_agent) = user_agent {
487 input = input.user_agent(user_agent);
488 }
489
490 DbSessionStore::new(adapter)
491 .create_session(input)
492 .await
493 .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))
494}
495
496fn store_sign_up_result(
497 result: &Mutex<Option<Result<EmailPasswordAuthResult, AuthFlowError>>>,
498 value: Result<EmailPasswordAuthResult, AuthFlowError>,
499) -> Result<(), OpenAuthError> {
500 let mut guard = result
501 .lock()
502 .map_err(|_| OpenAuthError::Adapter("sign-up result lock poisoned".to_owned()))?;
503 *guard = Some(value);
504 Ok(())
505}
506
507fn take_sign_up_result(
508 result: &Mutex<Option<Result<EmailPasswordAuthResult, AuthFlowError>>>,
509) -> Result<Option<Result<EmailPasswordAuthResult, AuthFlowError>>, AuthFlowError> {
510 result
511 .lock()
512 .map_err(|_| {
513 AuthFlowError::storage(OpenAuthError::Adapter(
514 "sign-up result lock poisoned".to_owned(),
515 ))
516 })
517 .map(|mut guard| guard.take())
518}
519
520fn validate_email(email: &str) -> Result<(), AuthFlowError> {
521 let email = email.trim();
522 let Some((local, domain)) = email.split_once('@') else {
523 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
524 };
525 if local.is_empty()
526 || domain.is_empty()
527 || domain.starts_with('.')
528 || domain.ends_with('.')
529 || !domain.contains('.')
530 {
531 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
532 }
533 Ok(())
534}