Skip to main content

openauth_core/auth/
email_password.rs

1//! Email/password auth service built on top of core DB stores.
2
3use std::error::Error;
4use std::fmt;
5use std::sync::{Arc, Mutex};
6
7use time::{Duration, OffsetDateTime};
8
9use crate::crypto::password::{hash_password, verify_password};
10use crate::db::{DbAdapter, DbRecord, Session, User};
11use crate::error::OpenAuthError;
12use crate::session::{CreateSessionInput, DbSessionStore};
13use crate::user::{CreateCredentialAccountInput, CreateUserInput, DbUserStore};
14
15pub type PasswordHashFn = fn(&str) -> Result<String, OpenAuthError>;
16pub type PasswordVerifyFn = fn(&str, &str) -> Result<bool, OpenAuthError>;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum AuthFlowErrorCode {
20    InvalidEmail,
21    InvalidPasswordLength,
22    InvalidEmailOrPassword,
23    UserAlreadyExists,
24    EmailNotVerified,
25    FailedToCreateSession,
26    StorageError,
27}
28
29impl AuthFlowErrorCode {
30    pub fn as_str(self) -> &'static str {
31        match self {
32            Self::InvalidEmail => "INVALID_EMAIL",
33            Self::InvalidPasswordLength => "INVALID_PASSWORD_LENGTH",
34            Self::InvalidEmailOrPassword => "INVALID_EMAIL_OR_PASSWORD",
35            Self::UserAlreadyExists => "USER_ALREADY_EXISTS",
36            Self::EmailNotVerified => "EMAIL_NOT_VERIFIED",
37            Self::FailedToCreateSession => "FAILED_TO_CREATE_SESSION",
38            Self::StorageError => "STORAGE_ERROR",
39        }
40    }
41
42    pub fn message(self) -> &'static str {
43        match self {
44            Self::InvalidEmail => "Invalid email",
45            Self::InvalidPasswordLength => "Invalid password length",
46            Self::InvalidEmailOrPassword => "Invalid email or password",
47            Self::UserAlreadyExists => "User already exists",
48            Self::EmailNotVerified => "Email not verified",
49            Self::FailedToCreateSession => "Failed to create session",
50            Self::StorageError => "Storage error",
51        }
52    }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct AuthFlowError {
57    code: AuthFlowErrorCode,
58    message: String,
59}
60
61impl AuthFlowError {
62    pub fn new(code: AuthFlowErrorCode) -> Self {
63        Self {
64            code,
65            message: code.message().to_owned(),
66        }
67    }
68
69    pub fn storage(error: OpenAuthError) -> Self {
70        Self {
71            code: AuthFlowErrorCode::StorageError,
72            message: error.to_string(),
73        }
74    }
75
76    pub fn code(&self) -> AuthFlowErrorCode {
77        self.code
78    }
79
80    pub fn code_str(&self) -> &'static str {
81        self.code.as_str()
82    }
83
84    pub fn message(&self) -> &str {
85        self.message.as_str()
86    }
87}
88
89impl fmt::Display for AuthFlowError {
90    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
91        write!(formatter, "{}: {}", self.code.as_str(), self.message)
92    }
93}
94
95impl Error for AuthFlowError {}
96
97impl From<OpenAuthError> for AuthFlowError {
98    fn from(error: OpenAuthError) -> Self {
99        Self::storage(error)
100    }
101}
102
103#[derive(Debug, Clone, PartialEq, Eq)]
104pub struct EmailPasswordConfig {
105    pub session_expires_in: u64,
106    pub dont_remember_session_expires_in: u64,
107    pub min_password_length: usize,
108    pub max_password_length: usize,
109    pub require_email_verification: bool,
110}
111
112impl Default for EmailPasswordConfig {
113    fn default() -> Self {
114        Self {
115            session_expires_in: 60 * 60 * 24 * 7,
116            dont_remember_session_expires_in: 60 * 60 * 24,
117            min_password_length: 8,
118            max_password_length: 128,
119            require_email_verification: false,
120        }
121    }
122}
123
124#[derive(Debug, Clone, PartialEq)]
125pub struct SignUpInput {
126    pub name: String,
127    pub email: String,
128    pub password: String,
129    pub image: Option<String>,
130    pub username: Option<String>,
131    pub display_username: Option<String>,
132    pub remember_me: bool,
133    pub ip_address: Option<String>,
134    pub user_agent: Option<String>,
135    pub additional_user_fields: DbRecord,
136    pub additional_session_fields: DbRecord,
137}
138
139impl SignUpInput {
140    pub fn new(
141        name: impl Into<String>,
142        email: impl Into<String>,
143        password: impl Into<String>,
144    ) -> Self {
145        Self {
146            name: name.into(),
147            email: email.into(),
148            password: password.into(),
149            image: None,
150            username: None,
151            display_username: None,
152            remember_me: true,
153            ip_address: None,
154            user_agent: None,
155            additional_user_fields: DbRecord::new(),
156            additional_session_fields: DbRecord::new(),
157        }
158    }
159
160    #[must_use]
161    pub fn image(mut self, image: impl Into<String>) -> Self {
162        self.image = Some(image.into());
163        self
164    }
165
166    #[must_use]
167    pub fn username(mut self, username: impl Into<String>) -> Self {
168        self.username = Some(username.into());
169        self
170    }
171
172    #[must_use]
173    pub fn display_username(mut self, display_username: impl Into<String>) -> Self {
174        self.display_username = Some(display_username.into());
175        self
176    }
177
178    #[must_use]
179    pub fn remember_me(mut self, remember_me: bool) -> Self {
180        self.remember_me = remember_me;
181        self
182    }
183
184    #[must_use]
185    pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
186        self.ip_address = Some(ip_address.into());
187        self
188    }
189
190    #[must_use]
191    pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
192        self.user_agent = Some(user_agent.into());
193        self
194    }
195
196    #[must_use]
197    pub fn additional_user_fields(mut self, fields: DbRecord) -> Self {
198        self.additional_user_fields = fields;
199        self
200    }
201
202    #[must_use]
203    pub fn additional_session_fields(mut self, fields: DbRecord) -> Self {
204        self.additional_session_fields = fields;
205        self
206    }
207}
208
209#[derive(Debug, Clone, PartialEq)]
210pub struct SignInInput {
211    pub email: String,
212    pub password: String,
213    pub remember_me: bool,
214    pub ip_address: Option<String>,
215    pub user_agent: Option<String>,
216    pub additional_session_fields: DbRecord,
217}
218
219impl SignInInput {
220    pub fn new(email: impl Into<String>, password: impl Into<String>) -> Self {
221        Self {
222            email: email.into(),
223            password: password.into(),
224            remember_me: true,
225            ip_address: None,
226            user_agent: None,
227            additional_session_fields: DbRecord::new(),
228        }
229    }
230
231    #[must_use]
232    pub fn remember_me(mut self, remember_me: bool) -> Self {
233        self.remember_me = remember_me;
234        self
235    }
236
237    #[must_use]
238    pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
239        self.ip_address = Some(ip_address.into());
240        self
241    }
242
243    #[must_use]
244    pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
245        self.user_agent = Some(user_agent.into());
246        self
247    }
248
249    #[must_use]
250    pub fn additional_session_fields(mut self, fields: DbRecord) -> Self {
251        self.additional_session_fields = fields;
252        self
253    }
254}
255
256#[derive(Debug, Clone, PartialEq, Eq)]
257pub struct EmailPasswordAuthResult {
258    pub user: User,
259    pub session: Session,
260}
261
262#[derive(Clone)]
263pub struct EmailPasswordAuth<'a> {
264    adapter: &'a dyn DbAdapter,
265    config: EmailPasswordConfig,
266    hash_password: PasswordHashFn,
267    verify_password: PasswordVerifyFn,
268}
269
270impl<'a> EmailPasswordAuth<'a> {
271    pub fn new(
272        adapter: &'a dyn DbAdapter,
273        config: EmailPasswordConfig,
274        hash_password: PasswordHashFn,
275        verify_password: PasswordVerifyFn,
276    ) -> Self {
277        Self {
278            adapter,
279            config,
280            hash_password,
281            verify_password,
282        }
283    }
284
285    pub fn with_defaults(adapter: &'a dyn DbAdapter, config: EmailPasswordConfig) -> Self {
286        Self::new(adapter, config, hash_password, verify_password)
287    }
288
289    pub async fn sign_up(
290        &self,
291        input: SignUpInput,
292    ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
293        self.validate_email_and_password(&input.email, &input.password)?;
294        let users = DbUserStore::new(self.adapter);
295        if users.find_user_by_email(&input.email).await?.is_some() {
296            return Err(AuthFlowError::new(AuthFlowErrorCode::UserAlreadyExists));
297        }
298
299        let password_hash = (self.hash_password)(&input.password)?;
300        let mut create_user = CreateUserInput::new(input.name, input.email)
301            .additional_fields(input.additional_user_fields);
302        if let Some(image) = input.image {
303            create_user = create_user.image(image);
304        }
305        if let Some(username) = input.username {
306            create_user = create_user.username(username);
307        }
308        if let Some(display_username) = input.display_username {
309            create_user = create_user.display_username(display_username);
310        }
311        let result = Arc::new(Mutex::new(None));
312        let result_for_transaction = Arc::clone(&result);
313        let config = self.config.clone();
314        let transaction_status = self
315            .adapter
316            .transaction(Box::new(move |transaction| {
317                Box::pin(async move {
318                    let outcome = create_sign_up_records(SignUpRecordsInput {
319                        adapter: transaction.as_ref(),
320                        config: &config,
321                        create_user,
322                        password_hash,
323                        remember_me: input.remember_me,
324                        ip_address: input.ip_address,
325                        user_agent: input.user_agent,
326                        additional_session_fields: input.additional_session_fields,
327                    })
328                    .await;
329                    match outcome {
330                        Ok(result) => {
331                            store_sign_up_result(&result_for_transaction, Ok(result))?;
332                            Ok(())
333                        }
334                        Err(error) => {
335                            let transaction_error = OpenAuthError::Adapter(error.to_string());
336                            store_sign_up_result(&result_for_transaction, Err(error))?;
337                            Err(transaction_error)
338                        }
339                    }
340                })
341            }))
342            .await;
343
344        match transaction_status {
345            Ok(()) => match take_sign_up_result(&result)? {
346                Some(Ok(result)) => Ok(result),
347                Some(Err(error)) => Err(error),
348                None => Err(AuthFlowError::storage(OpenAuthError::Adapter(
349                    "sign-up transaction completed without a result".to_owned(),
350                ))),
351            },
352            Err(error) => match take_sign_up_result(&result)? {
353                Some(Err(auth_error)) => Err(auth_error),
354                _ => Err(AuthFlowError::storage(error)),
355            },
356        }
357    }
358
359    pub async fn sign_in(
360        &self,
361        input: SignInInput,
362    ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
363        validate_email(&input.email)?;
364        let users = DbUserStore::new(self.adapter);
365        let Some(user_with_accounts) = users.find_user_by_email_with_accounts(&input.email).await?
366        else {
367            let _ = (self.hash_password)(&input.password);
368            return Err(AuthFlowError::new(
369                AuthFlowErrorCode::InvalidEmailOrPassword,
370            ));
371        };
372        let Some(account) = user_with_accounts
373            .accounts
374            .iter()
375            .find(|account| account.provider_id == "credential")
376        else {
377            let _ = (self.hash_password)(&input.password);
378            return Err(AuthFlowError::new(
379                AuthFlowErrorCode::InvalidEmailOrPassword,
380            ));
381        };
382        let Some(password_hash) = account.password.as_deref() else {
383            let _ = (self.hash_password)(&input.password);
384            return Err(AuthFlowError::new(
385                AuthFlowErrorCode::InvalidEmailOrPassword,
386            ));
387        };
388        if !(self.verify_password)(password_hash, &input.password)? {
389            return Err(AuthFlowError::new(
390                AuthFlowErrorCode::InvalidEmailOrPassword,
391            ));
392        }
393        if self.config.require_email_verification && !user_with_accounts.user.email_verified {
394            return Err(AuthFlowError::new(AuthFlowErrorCode::EmailNotVerified));
395        }
396        let session = create_session_record(
397            self.adapter,
398            &self.config,
399            &user_with_accounts.user.id,
400            input.remember_me,
401            input.ip_address,
402            input.user_agent,
403            input.additional_session_fields,
404        )
405        .await?;
406
407        Ok(EmailPasswordAuthResult {
408            user: user_with_accounts.user,
409            session,
410        })
411    }
412
413    fn validate_email_and_password(
414        &self,
415        email: &str,
416        password: &str,
417    ) -> Result<(), AuthFlowError> {
418        validate_email(email)?;
419        if password.len() < self.config.min_password_length
420            || password.len() > self.config.max_password_length
421        {
422            return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidPasswordLength));
423        }
424        Ok(())
425    }
426}
427
428struct SignUpRecordsInput<'a> {
429    adapter: &'a dyn DbAdapter,
430    config: &'a EmailPasswordConfig,
431    create_user: CreateUserInput,
432    password_hash: String,
433    remember_me: bool,
434    ip_address: Option<String>,
435    user_agent: Option<String>,
436    additional_session_fields: DbRecord,
437}
438
439async fn create_sign_up_records(
440    input: SignUpRecordsInput<'_>,
441) -> Result<EmailPasswordAuthResult, AuthFlowError> {
442    let users = DbUserStore::new(input.adapter);
443    let user = users.create_user(input.create_user).await?;
444    users
445        .create_credential_account(CreateCredentialAccountInput::new(
446            user.id.clone(),
447            input.password_hash,
448        ))
449        .await?;
450    let session = create_session_record(
451        input.adapter,
452        input.config,
453        &user.id,
454        input.remember_me,
455        input.ip_address,
456        input.user_agent,
457        input.additional_session_fields,
458    )
459    .await?;
460
461    Ok(EmailPasswordAuthResult { user, session })
462}
463
464async fn create_session_record(
465    adapter: &dyn DbAdapter,
466    config: &EmailPasswordConfig,
467    user_id: &str,
468    remember_me: bool,
469    ip_address: Option<String>,
470    user_agent: Option<String>,
471    additional_fields: DbRecord,
472) -> Result<Session, AuthFlowError> {
473    let expires_in = if remember_me {
474        config.session_expires_in
475    } else {
476        config.dont_remember_session_expires_in
477    };
478    let seconds = i64::try_from(expires_in)
479        .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))?;
480    let expires_at = OffsetDateTime::now_utc() + Duration::seconds(seconds);
481    let mut input =
482        CreateSessionInput::new(user_id, expires_at).additional_fields(additional_fields);
483    if let Some(ip_address) = ip_address {
484        input = input.ip_address(ip_address);
485    }
486    if let Some(user_agent) = user_agent {
487        input = input.user_agent(user_agent);
488    }
489
490    DbSessionStore::new(adapter)
491        .create_session(input)
492        .await
493        .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))
494}
495
496fn store_sign_up_result(
497    result: &Mutex<Option<Result<EmailPasswordAuthResult, AuthFlowError>>>,
498    value: Result<EmailPasswordAuthResult, AuthFlowError>,
499) -> Result<(), OpenAuthError> {
500    let mut guard = result
501        .lock()
502        .map_err(|_| OpenAuthError::Adapter("sign-up result lock poisoned".to_owned()))?;
503    *guard = Some(value);
504    Ok(())
505}
506
507fn take_sign_up_result(
508    result: &Mutex<Option<Result<EmailPasswordAuthResult, AuthFlowError>>>,
509) -> Result<Option<Result<EmailPasswordAuthResult, AuthFlowError>>, AuthFlowError> {
510    result
511        .lock()
512        .map_err(|_| {
513            AuthFlowError::storage(OpenAuthError::Adapter(
514                "sign-up result lock poisoned".to_owned(),
515            ))
516        })
517        .map(|mut guard| guard.take())
518}
519
520fn validate_email(email: &str) -> Result<(), AuthFlowError> {
521    let email = email.trim();
522    let Some((local, domain)) = email.split_once('@') else {
523        return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
524    };
525    if local.is_empty()
526        || domain.is_empty()
527        || domain.starts_with('.')
528        || domain.ends_with('.')
529        || !domain.contains('.')
530    {
531        return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
532    }
533    Ok(())
534}