1use std::error::Error;
4use std::fmt;
5
6use time::{Duration, OffsetDateTime};
7
8use crate::crypto::password::{hash_password, verify_password};
9use crate::db::{DbAdapter, DbRecord, Session, User};
10use crate::error::OpenAuthError;
11use crate::session::{CreateSessionInput, DbSessionStore};
12use crate::user::{CreateCredentialAccountInput, CreateUserInput, DbUserStore};
13
14pub type PasswordHashFn = fn(&str) -> Result<String, OpenAuthError>;
15pub type PasswordVerifyFn = fn(&str, &str) -> Result<bool, OpenAuthError>;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum AuthFlowErrorCode {
19 InvalidEmail,
20 InvalidPasswordLength,
21 InvalidEmailOrPassword,
22 UserAlreadyExists,
23 EmailNotVerified,
24 FailedToCreateSession,
25 StorageError,
26}
27
28impl AuthFlowErrorCode {
29 pub fn as_str(self) -> &'static str {
30 match self {
31 Self::InvalidEmail => "INVALID_EMAIL",
32 Self::InvalidPasswordLength => "INVALID_PASSWORD_LENGTH",
33 Self::InvalidEmailOrPassword => "INVALID_EMAIL_OR_PASSWORD",
34 Self::UserAlreadyExists => "USER_ALREADY_EXISTS",
35 Self::EmailNotVerified => "EMAIL_NOT_VERIFIED",
36 Self::FailedToCreateSession => "FAILED_TO_CREATE_SESSION",
37 Self::StorageError => "STORAGE_ERROR",
38 }
39 }
40
41 pub fn message(self) -> &'static str {
42 match self {
43 Self::InvalidEmail => "Invalid email",
44 Self::InvalidPasswordLength => "Invalid password length",
45 Self::InvalidEmailOrPassword => "Invalid email or password",
46 Self::UserAlreadyExists => "User already exists",
47 Self::EmailNotVerified => "Email not verified",
48 Self::FailedToCreateSession => "Failed to create session",
49 Self::StorageError => "Storage error",
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct AuthFlowError {
56 code: AuthFlowErrorCode,
57 message: String,
58}
59
60impl AuthFlowError {
61 pub fn new(code: AuthFlowErrorCode) -> Self {
62 Self {
63 code,
64 message: code.message().to_owned(),
65 }
66 }
67
68 pub fn storage(error: OpenAuthError) -> Self {
69 Self {
70 code: AuthFlowErrorCode::StorageError,
71 message: error.to_string(),
72 }
73 }
74
75 pub fn code(&self) -> AuthFlowErrorCode {
76 self.code
77 }
78
79 pub fn code_str(&self) -> &'static str {
80 self.code.as_str()
81 }
82
83 pub fn message(&self) -> &str {
84 self.message.as_str()
85 }
86}
87
88impl fmt::Display for AuthFlowError {
89 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
90 write!(formatter, "{}: {}", self.code.as_str(), self.message)
91 }
92}
93
94impl Error for AuthFlowError {}
95
96impl From<OpenAuthError> for AuthFlowError {
97 fn from(error: OpenAuthError) -> Self {
98 Self::storage(error)
99 }
100}
101
102#[derive(Debug, Clone, PartialEq, Eq)]
103pub struct EmailPasswordConfig {
104 pub session_expires_in: u64,
105 pub dont_remember_session_expires_in: u64,
106 pub min_password_length: usize,
107 pub max_password_length: usize,
108 pub require_email_verification: bool,
109}
110
111impl Default for EmailPasswordConfig {
112 fn default() -> Self {
113 Self {
114 session_expires_in: 60 * 60 * 24 * 7,
115 dont_remember_session_expires_in: 60 * 60 * 24,
116 min_password_length: 8,
117 max_password_length: 128,
118 require_email_verification: false,
119 }
120 }
121}
122
123#[derive(Debug, Clone, PartialEq)]
124pub struct SignUpInput {
125 pub name: String,
126 pub email: String,
127 pub password: String,
128 pub image: Option<String>,
129 pub username: Option<String>,
130 pub display_username: Option<String>,
131 pub remember_me: bool,
132 pub ip_address: Option<String>,
133 pub user_agent: Option<String>,
134 pub additional_user_fields: DbRecord,
135 pub additional_session_fields: DbRecord,
136}
137
138impl SignUpInput {
139 pub fn new(
140 name: impl Into<String>,
141 email: impl Into<String>,
142 password: impl Into<String>,
143 ) -> Self {
144 Self {
145 name: name.into(),
146 email: email.into(),
147 password: password.into(),
148 image: None,
149 username: None,
150 display_username: None,
151 remember_me: true,
152 ip_address: None,
153 user_agent: None,
154 additional_user_fields: DbRecord::new(),
155 additional_session_fields: DbRecord::new(),
156 }
157 }
158
159 #[must_use]
160 pub fn image(mut self, image: impl Into<String>) -> Self {
161 self.image = Some(image.into());
162 self
163 }
164
165 #[must_use]
166 pub fn username(mut self, username: impl Into<String>) -> Self {
167 self.username = Some(username.into());
168 self
169 }
170
171 #[must_use]
172 pub fn display_username(mut self, display_username: impl Into<String>) -> Self {
173 self.display_username = Some(display_username.into());
174 self
175 }
176
177 #[must_use]
178 pub fn remember_me(mut self, remember_me: bool) -> Self {
179 self.remember_me = remember_me;
180 self
181 }
182
183 #[must_use]
184 pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
185 self.ip_address = Some(ip_address.into());
186 self
187 }
188
189 #[must_use]
190 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
191 self.user_agent = Some(user_agent.into());
192 self
193 }
194
195 #[must_use]
196 pub fn additional_user_fields(mut self, fields: DbRecord) -> Self {
197 self.additional_user_fields = fields;
198 self
199 }
200
201 #[must_use]
202 pub fn additional_session_fields(mut self, fields: DbRecord) -> Self {
203 self.additional_session_fields = fields;
204 self
205 }
206}
207
208#[derive(Debug, Clone, PartialEq)]
209pub struct SignInInput {
210 pub email: String,
211 pub password: String,
212 pub remember_me: bool,
213 pub ip_address: Option<String>,
214 pub user_agent: Option<String>,
215 pub additional_session_fields: DbRecord,
216}
217
218impl SignInInput {
219 pub fn new(email: impl Into<String>, password: impl Into<String>) -> Self {
220 Self {
221 email: email.into(),
222 password: password.into(),
223 remember_me: true,
224 ip_address: None,
225 user_agent: None,
226 additional_session_fields: DbRecord::new(),
227 }
228 }
229
230 #[must_use]
231 pub fn remember_me(mut self, remember_me: bool) -> Self {
232 self.remember_me = remember_me;
233 self
234 }
235
236 #[must_use]
237 pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
238 self.ip_address = Some(ip_address.into());
239 self
240 }
241
242 #[must_use]
243 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
244 self.user_agent = Some(user_agent.into());
245 self
246 }
247
248 #[must_use]
249 pub fn additional_session_fields(mut self, fields: DbRecord) -> Self {
250 self.additional_session_fields = fields;
251 self
252 }
253}
254
255#[derive(Debug, Clone, PartialEq, Eq)]
256pub struct EmailPasswordAuthResult {
257 pub user: User,
258 pub session: Session,
259}
260
261#[derive(Clone)]
262pub struct EmailPasswordAuth<'a> {
263 adapter: &'a dyn DbAdapter,
264 config: EmailPasswordConfig,
265 hash_password: PasswordHashFn,
266 verify_password: PasswordVerifyFn,
267}
268
269impl<'a> EmailPasswordAuth<'a> {
270 pub fn new(
271 adapter: &'a dyn DbAdapter,
272 config: EmailPasswordConfig,
273 hash_password: PasswordHashFn,
274 verify_password: PasswordVerifyFn,
275 ) -> Self {
276 Self {
277 adapter,
278 config,
279 hash_password,
280 verify_password,
281 }
282 }
283
284 pub fn with_defaults(adapter: &'a dyn DbAdapter, config: EmailPasswordConfig) -> Self {
285 Self::new(adapter, config, hash_password, verify_password)
286 }
287
288 pub async fn sign_up(
289 &self,
290 input: SignUpInput,
291 ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
292 self.validate_email_and_password(&input.email, &input.password)?;
293 let users = DbUserStore::new(self.adapter);
294 if users.find_user_by_email(&input.email).await?.is_some() {
295 return Err(AuthFlowError::new(AuthFlowErrorCode::UserAlreadyExists));
296 }
297
298 let password_hash = (self.hash_password)(&input.password)?;
299 let mut create_user = CreateUserInput::new(input.name, input.email)
300 .additional_fields(input.additional_user_fields);
301 if let Some(image) = input.image {
302 create_user = create_user.image(image);
303 }
304 if let Some(username) = input.username {
305 create_user = create_user.username(username);
306 }
307 if let Some(display_username) = input.display_username {
308 create_user = create_user.display_username(display_username);
309 }
310 let user = users.create_user(create_user).await?;
311 users
312 .create_credential_account(CreateCredentialAccountInput::new(
313 user.id.clone(),
314 password_hash,
315 ))
316 .await?;
317 let session = self
318 .create_session(
319 &user.id,
320 input.remember_me,
321 input.ip_address,
322 input.user_agent,
323 input.additional_session_fields,
324 )
325 .await?;
326
327 Ok(EmailPasswordAuthResult { user, session })
328 }
329
330 pub async fn sign_in(
331 &self,
332 input: SignInInput,
333 ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
334 validate_email(&input.email)?;
335 let users = DbUserStore::new(self.adapter);
336 let Some(user_with_accounts) = users.find_user_by_email_with_accounts(&input.email).await?
337 else {
338 let _ = (self.hash_password)(&input.password);
339 return Err(AuthFlowError::new(
340 AuthFlowErrorCode::InvalidEmailOrPassword,
341 ));
342 };
343 let Some(account) = user_with_accounts
344 .accounts
345 .iter()
346 .find(|account| account.provider_id == "credential")
347 else {
348 let _ = (self.hash_password)(&input.password);
349 return Err(AuthFlowError::new(
350 AuthFlowErrorCode::InvalidEmailOrPassword,
351 ));
352 };
353 let Some(password_hash) = account.password.as_deref() else {
354 let _ = (self.hash_password)(&input.password);
355 return Err(AuthFlowError::new(
356 AuthFlowErrorCode::InvalidEmailOrPassword,
357 ));
358 };
359 if !(self.verify_password)(password_hash, &input.password)? {
360 return Err(AuthFlowError::new(
361 AuthFlowErrorCode::InvalidEmailOrPassword,
362 ));
363 }
364 if self.config.require_email_verification && !user_with_accounts.user.email_verified {
365 return Err(AuthFlowError::new(AuthFlowErrorCode::EmailNotVerified));
366 }
367 let session = self
368 .create_session(
369 &user_with_accounts.user.id,
370 input.remember_me,
371 input.ip_address,
372 input.user_agent,
373 input.additional_session_fields,
374 )
375 .await?;
376
377 Ok(EmailPasswordAuthResult {
378 user: user_with_accounts.user,
379 session,
380 })
381 }
382
383 fn validate_email_and_password(
384 &self,
385 email: &str,
386 password: &str,
387 ) -> Result<(), AuthFlowError> {
388 validate_email(email)?;
389 if password.len() < self.config.min_password_length
390 || password.len() > self.config.max_password_length
391 {
392 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidPasswordLength));
393 }
394 Ok(())
395 }
396
397 async fn create_session(
398 &self,
399 user_id: &str,
400 remember_me: bool,
401 ip_address: Option<String>,
402 user_agent: Option<String>,
403 additional_fields: DbRecord,
404 ) -> Result<Session, AuthFlowError> {
405 let expires_in = if remember_me {
406 self.config.session_expires_in
407 } else {
408 self.config.dont_remember_session_expires_in
409 };
410 let seconds = i64::try_from(expires_in)
411 .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))?;
412 let expires_at = OffsetDateTime::now_utc() + Duration::seconds(seconds);
413 let mut input =
414 CreateSessionInput::new(user_id, expires_at).additional_fields(additional_fields);
415 if let Some(ip_address) = ip_address {
416 input = input.ip_address(ip_address);
417 }
418 if let Some(user_agent) = user_agent {
419 input = input.user_agent(user_agent);
420 }
421
422 DbSessionStore::new(self.adapter)
423 .create_session(input)
424 .await
425 .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))
426 }
427}
428
429fn validate_email(email: &str) -> Result<(), AuthFlowError> {
430 let email = email.trim();
431 let Some((local, domain)) = email.split_once('@') else {
432 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
433 };
434 if local.is_empty()
435 || domain.is_empty()
436 || domain.starts_with('.')
437 || domain.ends_with('.')
438 || !domain.contains('.')
439 {
440 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
441 }
442 Ok(())
443}