1use std::error::Error;
4use std::fmt;
5
6use time::{Duration, OffsetDateTime};
7
8use crate::crypto::password::{hash_password, verify_password};
9use crate::db::{DbAdapter, Session, User};
10use crate::error::OpenAuthError;
11use crate::session::{CreateSessionInput, DbSessionStore};
12use crate::user::{CreateCredentialAccountInput, CreateUserInput, DbUserStore};
13
14pub type PasswordHashFn = fn(&str) -> Result<String, OpenAuthError>;
15pub type PasswordVerifyFn = fn(&str, &str) -> Result<bool, OpenAuthError>;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum AuthFlowErrorCode {
19 InvalidEmail,
20 InvalidPasswordLength,
21 InvalidEmailOrPassword,
22 UserAlreadyExists,
23 EmailNotVerified,
24 FailedToCreateSession,
25 StorageError,
26}
27
28impl AuthFlowErrorCode {
29 pub fn as_str(self) -> &'static str {
30 match self {
31 Self::InvalidEmail => "INVALID_EMAIL",
32 Self::InvalidPasswordLength => "INVALID_PASSWORD_LENGTH",
33 Self::InvalidEmailOrPassword => "INVALID_EMAIL_OR_PASSWORD",
34 Self::UserAlreadyExists => "USER_ALREADY_EXISTS",
35 Self::EmailNotVerified => "EMAIL_NOT_VERIFIED",
36 Self::FailedToCreateSession => "FAILED_TO_CREATE_SESSION",
37 Self::StorageError => "STORAGE_ERROR",
38 }
39 }
40
41 pub fn message(self) -> &'static str {
42 match self {
43 Self::InvalidEmail => "Invalid email",
44 Self::InvalidPasswordLength => "Invalid password length",
45 Self::InvalidEmailOrPassword => "Invalid email or password",
46 Self::UserAlreadyExists => "User already exists",
47 Self::EmailNotVerified => "Email not verified",
48 Self::FailedToCreateSession => "Failed to create session",
49 Self::StorageError => "Storage error",
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct AuthFlowError {
56 code: AuthFlowErrorCode,
57 message: String,
58}
59
60impl AuthFlowError {
61 pub fn new(code: AuthFlowErrorCode) -> Self {
62 Self {
63 code,
64 message: code.message().to_owned(),
65 }
66 }
67
68 pub fn storage(error: OpenAuthError) -> Self {
69 Self {
70 code: AuthFlowErrorCode::StorageError,
71 message: error.to_string(),
72 }
73 }
74
75 pub fn code(&self) -> AuthFlowErrorCode {
76 self.code
77 }
78
79 pub fn code_str(&self) -> &'static str {
80 self.code.as_str()
81 }
82
83 pub fn message(&self) -> &str {
84 self.message.as_str()
85 }
86}
87
88impl fmt::Display for AuthFlowError {
89 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
90 write!(formatter, "{}: {}", self.code.as_str(), self.message)
91 }
92}
93
94impl Error for AuthFlowError {}
95
96impl From<OpenAuthError> for AuthFlowError {
97 fn from(error: OpenAuthError) -> Self {
98 Self::storage(error)
99 }
100}
101
102#[derive(Debug, Clone, PartialEq, Eq)]
103pub struct EmailPasswordConfig {
104 pub session_expires_in: u64,
105 pub dont_remember_session_expires_in: u64,
106 pub min_password_length: usize,
107 pub max_password_length: usize,
108 pub require_email_verification: bool,
109}
110
111impl Default for EmailPasswordConfig {
112 fn default() -> Self {
113 Self {
114 session_expires_in: 60 * 60 * 24 * 7,
115 dont_remember_session_expires_in: 60 * 60 * 24,
116 min_password_length: 8,
117 max_password_length: 128,
118 require_email_verification: false,
119 }
120 }
121}
122
123#[derive(Debug, Clone, PartialEq, Eq)]
124pub struct SignUpInput {
125 pub name: String,
126 pub email: String,
127 pub password: String,
128 pub image: Option<String>,
129 pub remember_me: bool,
130 pub ip_address: Option<String>,
131 pub user_agent: Option<String>,
132}
133
134impl SignUpInput {
135 pub fn new(
136 name: impl Into<String>,
137 email: impl Into<String>,
138 password: impl Into<String>,
139 ) -> Self {
140 Self {
141 name: name.into(),
142 email: email.into(),
143 password: password.into(),
144 image: None,
145 remember_me: true,
146 ip_address: None,
147 user_agent: None,
148 }
149 }
150
151 #[must_use]
152 pub fn image(mut self, image: impl Into<String>) -> Self {
153 self.image = Some(image.into());
154 self
155 }
156
157 #[must_use]
158 pub fn remember_me(mut self, remember_me: bool) -> Self {
159 self.remember_me = remember_me;
160 self
161 }
162
163 #[must_use]
164 pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
165 self.ip_address = Some(ip_address.into());
166 self
167 }
168
169 #[must_use]
170 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
171 self.user_agent = Some(user_agent.into());
172 self
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub struct SignInInput {
178 pub email: String,
179 pub password: String,
180 pub remember_me: bool,
181 pub ip_address: Option<String>,
182 pub user_agent: Option<String>,
183}
184
185impl SignInInput {
186 pub fn new(email: impl Into<String>, password: impl Into<String>) -> Self {
187 Self {
188 email: email.into(),
189 password: password.into(),
190 remember_me: true,
191 ip_address: None,
192 user_agent: None,
193 }
194 }
195
196 #[must_use]
197 pub fn remember_me(mut self, remember_me: bool) -> Self {
198 self.remember_me = remember_me;
199 self
200 }
201
202 #[must_use]
203 pub fn ip_address(mut self, ip_address: impl Into<String>) -> Self {
204 self.ip_address = Some(ip_address.into());
205 self
206 }
207
208 #[must_use]
209 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
210 self.user_agent = Some(user_agent.into());
211 self
212 }
213}
214
215#[derive(Debug, Clone, PartialEq, Eq)]
216pub struct EmailPasswordAuthResult {
217 pub user: User,
218 pub session: Session,
219}
220
221#[derive(Clone)]
222pub struct EmailPasswordAuth<'a> {
223 adapter: &'a dyn DbAdapter,
224 config: EmailPasswordConfig,
225 hash_password: PasswordHashFn,
226 verify_password: PasswordVerifyFn,
227}
228
229impl<'a> EmailPasswordAuth<'a> {
230 pub fn new(
231 adapter: &'a dyn DbAdapter,
232 config: EmailPasswordConfig,
233 hash_password: PasswordHashFn,
234 verify_password: PasswordVerifyFn,
235 ) -> Self {
236 Self {
237 adapter,
238 config,
239 hash_password,
240 verify_password,
241 }
242 }
243
244 pub fn with_defaults(adapter: &'a dyn DbAdapter, config: EmailPasswordConfig) -> Self {
245 Self::new(adapter, config, hash_password, verify_password)
246 }
247
248 pub async fn sign_up(
249 &self,
250 input: SignUpInput,
251 ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
252 self.validate_email_and_password(&input.email, &input.password)?;
253 let users = DbUserStore::new(self.adapter);
254 if users.find_user_by_email(&input.email).await?.is_some() {
255 return Err(AuthFlowError::new(AuthFlowErrorCode::UserAlreadyExists));
256 }
257
258 let password_hash = (self.hash_password)(&input.password)?;
259 let mut create_user = CreateUserInput::new(input.name, input.email);
260 if let Some(image) = input.image {
261 create_user = create_user.image(image);
262 }
263 let user = users.create_user(create_user).await?;
264 users
265 .create_credential_account(CreateCredentialAccountInput::new(
266 user.id.clone(),
267 password_hash,
268 ))
269 .await?;
270 let session = self
271 .create_session(
272 &user.id,
273 input.remember_me,
274 input.ip_address,
275 input.user_agent,
276 )
277 .await?;
278
279 Ok(EmailPasswordAuthResult { user, session })
280 }
281
282 pub async fn sign_in(
283 &self,
284 input: SignInInput,
285 ) -> Result<EmailPasswordAuthResult, AuthFlowError> {
286 validate_email(&input.email)?;
287 let users = DbUserStore::new(self.adapter);
288 let Some(user_with_accounts) = users.find_user_by_email_with_accounts(&input.email).await?
289 else {
290 let _ = (self.hash_password)(&input.password);
291 return Err(AuthFlowError::new(
292 AuthFlowErrorCode::InvalidEmailOrPassword,
293 ));
294 };
295 let Some(account) = user_with_accounts
296 .accounts
297 .iter()
298 .find(|account| account.provider_id == "credential")
299 else {
300 let _ = (self.hash_password)(&input.password);
301 return Err(AuthFlowError::new(
302 AuthFlowErrorCode::InvalidEmailOrPassword,
303 ));
304 };
305 let Some(password_hash) = account.password.as_deref() else {
306 let _ = (self.hash_password)(&input.password);
307 return Err(AuthFlowError::new(
308 AuthFlowErrorCode::InvalidEmailOrPassword,
309 ));
310 };
311 if !(self.verify_password)(password_hash, &input.password)? {
312 return Err(AuthFlowError::new(
313 AuthFlowErrorCode::InvalidEmailOrPassword,
314 ));
315 }
316 if self.config.require_email_verification && !user_with_accounts.user.email_verified {
317 return Err(AuthFlowError::new(AuthFlowErrorCode::EmailNotVerified));
318 }
319 let session = self
320 .create_session(
321 &user_with_accounts.user.id,
322 input.remember_me,
323 input.ip_address,
324 input.user_agent,
325 )
326 .await?;
327
328 Ok(EmailPasswordAuthResult {
329 user: user_with_accounts.user,
330 session,
331 })
332 }
333
334 fn validate_email_and_password(
335 &self,
336 email: &str,
337 password: &str,
338 ) -> Result<(), AuthFlowError> {
339 validate_email(email)?;
340 if password.len() < self.config.min_password_length
341 || password.len() > self.config.max_password_length
342 {
343 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidPasswordLength));
344 }
345 Ok(())
346 }
347
348 async fn create_session(
349 &self,
350 user_id: &str,
351 remember_me: bool,
352 ip_address: Option<String>,
353 user_agent: Option<String>,
354 ) -> Result<Session, AuthFlowError> {
355 let expires_in = if remember_me {
356 self.config.session_expires_in
357 } else {
358 self.config.dont_remember_session_expires_in
359 };
360 let seconds = i64::try_from(expires_in)
361 .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))?;
362 let expires_at = OffsetDateTime::now_utc() + Duration::seconds(seconds);
363 let mut input = CreateSessionInput::new(user_id, expires_at);
364 if let Some(ip_address) = ip_address {
365 input = input.ip_address(ip_address);
366 }
367 if let Some(user_agent) = user_agent {
368 input = input.user_agent(user_agent);
369 }
370
371 DbSessionStore::new(self.adapter)
372 .create_session(input)
373 .await
374 .map_err(|_| AuthFlowError::new(AuthFlowErrorCode::FailedToCreateSession))
375 }
376}
377
378fn validate_email(email: &str) -> Result<(), AuthFlowError> {
379 let email = email.trim();
380 let Some((local, domain)) = email.split_once('@') else {
381 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
382 };
383 if local.is_empty()
384 || domain.is_empty()
385 || domain.starts_with('.')
386 || domain.ends_with('.')
387 || !domain.contains('.')
388 {
389 return Err(AuthFlowError::new(AuthFlowErrorCode::InvalidEmail));
390 }
391 Ok(())
392}