1use crate::errors::{AuthError, Result};
4use crate::storage::AuthStorage;
5use serde::{Deserialize, Serialize};
6use subtle::ConstantTimeEq;
7use serde_json::json;
8use std::sync::Arc;
9use std::time::Duration;
10use tracing::{debug, error, info};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct EmailProviderConfig {
15 pub provider: EmailProvider,
17 pub from_email: String,
19 pub from_name: Option<String>,
21 pub provider_config: ProviderConfig,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
31pub enum EmailProvider {
32 SendGrid,
34 AwsSes,
37 Smtp,
39 Development,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub enum ProviderConfig {
46 SendGrid {
48 api_key: String,
49 endpoint: Option<String>,
50 },
51 AwsSes {
53 region: String,
54 access_key_id: String,
55 secret_access_key: String,
56 },
57 Smtp {
59 host: String,
60 port: u16,
61 username: String,
62 password: String,
63 use_tls: bool,
64 },
65 Development,
67}
68
69impl Default for EmailProviderConfig {
70 fn default() -> Self {
71 Self {
72 provider: EmailProvider::Development,
73 from_email: "noreply@example.com".to_string(),
74 from_name: Some("AuthFramework".to_string()),
75 provider_config: ProviderConfig::Development,
76 }
77 }
78}
79
80pub struct EmailManager {
82 storage: Arc<dyn AuthStorage>,
83 email_config: EmailProviderConfig,
84}
85
86impl EmailManager {
87 pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
89 Self {
90 storage,
91 email_config: EmailProviderConfig::default(),
92 }
93 }
94
95 pub fn new_with_config(
97 storage: Arc<dyn AuthStorage>,
98 email_config: EmailProviderConfig,
99 ) -> Self {
100 Self {
101 storage,
102 email_config,
103 }
104 }
105
106 pub async fn register_email(&self, user_id: &str, email: &str) -> Result<()> {
108 debug!("Registering email for user '{}'", user_id);
109
110 if email.is_empty() {
112 return Err(AuthError::validation("Email address cannot be empty"));
113 }
114
115 if !email.contains('@') || !email.contains('.') {
117 return Err(AuthError::validation(
118 "Email address must be in valid format (user@domain.com)",
119 ));
120 }
121
122 let parts: Vec<&str> = email.split('@').collect();
124 if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
125 return Err(AuthError::validation("Email address format is invalid"));
126 }
127
128 let domain = parts[1];
129 if !domain.contains('.') || domain.starts_with('.') || domain.ends_with('.') {
130 return Err(AuthError::validation("Email domain format is invalid"));
131 }
132
133 let key = format!("user:{}:email", user_id);
135 self.storage.store_kv(&key, email.as_bytes(), None).await?;
136
137 info!("Email registered for user '{}': {}", user_id, email);
138 Ok(())
139 }
140
141 pub async fn initiate_challenge(&self, user_id: &str) -> Result<String> {
143 debug!("Initiating email challenge for user '{}'", user_id);
144
145 let challenge_id = crate::utils::string::generate_id(Some("email"));
146
147 info!("Email challenge initiated for user '{}'", user_id);
148 Ok(challenge_id)
149 }
150
151 pub async fn generate_code(&self, challenge_id: &str) -> Result<String> {
153 debug!("Generating email code for challenge '{}'", challenge_id);
154
155 let code = format!("{:06}", rand::random::<u32>() % 1000000);
156
157 let email_key = format!("email_challenge:{}:code", challenge_id);
159 self.storage
160 .store_kv(
161 &email_key,
162 code.as_bytes(),
163 Some(Duration::from_secs(300)), )
165 .await?;
166
167 Ok(code)
168 }
169
170 pub async fn verify_code(&self, challenge_id: &str, code: &str) -> Result<bool> {
172 debug!("Verifying email code for challenge '{}'", challenge_id);
173
174 if challenge_id.is_empty() {
176 return Err(AuthError::validation("Challenge ID cannot be empty"));
177 }
178
179 if code.is_empty() {
180 return Err(AuthError::validation("Email code cannot be empty"));
181 }
182
183 let email_key = format!("email_challenge:{}:code", challenge_id);
185 if let Some(stored_code_data) = self.storage.get_kv(&email_key).await? {
186 let stored_code = std::str::from_utf8(&stored_code_data).unwrap_or("");
187
188 let is_valid_format = code.len() == 6 && code.chars().all(|c| c.is_ascii_digit());
190
191 if !is_valid_format {
192 return Ok(false);
193 }
194
195 let is_valid: bool = stored_code.as_bytes().ct_eq(code.as_bytes()).into();
197
198 if is_valid {
199 let _ = self.storage.delete_kv(&email_key).await;
201 }
202
203 Ok(is_valid)
204 } else {
205 Err(AuthError::validation("Invalid or expired challenge ID"))
207 }
208 }
209
210 pub async fn send_code(&self, user_id: &str, code: &str) -> Result<()> {
212 debug!("Sending email code to user '{}'", user_id);
213
214 let email_key = format!("user:{}:email", user_id);
216 if let Some(email_data) = self.storage.get_kv(&email_key).await? {
217 let email_address = String::from_utf8(email_data).map_err(|e| {
218 AuthError::internal(format!("Failed to parse email address: {}", e))
219 })?;
220
221 match self.send_email_via_provider(&email_address, "MFA Code", &format!(
223 "Your authentication code is: {}\n\nThis code will expire in 5 minutes.\nIf you didn't request this code, please ignore this email.",
224 code
225 )).await {
226 Ok(()) => {
227 info!(
228 "Email code '{}' sent successfully to {} for user '{}' via {:?}",
229 code, email_address, user_id, self.email_config.provider
230 );
231 Ok(())
232 }
233 Err(e) => {
234 error!(
235 "Failed to send email code to {} for user '{}': {}",
236 email_address, user_id, e
237 );
238 Err(e)
239 }
240 }
241 } else {
242 Err(AuthError::validation(
243 "No email address registered for user",
244 ))
245 }
246 }
247
248 pub async fn get_user_email(&self, user_id: &str) -> Result<Option<String>> {
250 let email_key = format!("user:{}:email", user_id);
251
252 if let Some(email_data) = self.storage.get_kv(&email_key).await? {
253 Ok(Some(String::from_utf8(email_data).map_err(|e| {
254 AuthError::internal(format!("Failed to parse email address: {}", e))
255 })?))
256 } else {
257 Ok(None)
258 }
259 }
260
261 async fn send_email_via_provider(
263 &self,
264 to_email: &str,
265 subject: &str,
266 body: &str,
267 ) -> Result<()> {
268 match &self.email_config.provider {
269 EmailProvider::SendGrid => self.send_via_sendgrid(to_email, subject, body).await,
270 EmailProvider::AwsSes => self.send_via_aws_ses(to_email, subject, body).await,
271 EmailProvider::Smtp => self.send_via_smtp(to_email, subject, body).await,
272 EmailProvider::Development => {
273 info!("📧 [DEVELOPMENT] Email would be sent:");
275 info!(" To: {}", to_email);
276 info!(" Subject: {}", subject);
277 info!(" Body: {}", body);
278 Ok(())
279 }
280 }
281 }
282
283 async fn send_via_sendgrid(&self, to_email: &str, subject: &str, body: &str) -> Result<()> {
285 if let ProviderConfig::SendGrid { api_key, endpoint } = &self.email_config.provider_config {
286 let client = reqwest::Client::new();
287 let sendgrid_endpoint = endpoint
288 .as_deref()
289 .unwrap_or("https://api.sendgrid.com/v3/mail/send");
290
291 let payload = json!({
292 "personalizations": [{
293 "to": [{"email": to_email}]
294 }],
295 "from": {
296 "email": self.email_config.from_email,
297 "name": self.email_config.from_name.as_deref().unwrap_or("AuthFramework")
298 },
299 "subject": subject,
300 "content": [{
301 "type": "text/plain",
302 "value": body
303 }]
304 });
305
306 let response = client
307 .post(sendgrid_endpoint)
308 .header("Authorization", format!("Bearer {}", api_key))
309 .header("Content-Type", "application/json")
310 .json(&payload)
311 .send()
312 .await
313 .map_err(|e| AuthError::internal(format!("SendGrid request failed: {}", e)))?;
314
315 let status = response.status();
316 if status.is_success() {
317 debug!("SendGrid email sent successfully to {}", to_email);
318 Ok(())
319 } else {
320 let error_text = response.text().await.unwrap_or_default();
321 Err(AuthError::internal(format!(
322 "SendGrid API error: {} - {}",
323 status, error_text
324 )))
325 }
326 } else {
327 Err(AuthError::internal("Invalid SendGrid configuration"))
328 }
329 }
330
331 async fn send_via_aws_ses(&self, to_email: &str, subject: &str, body: &str) -> Result<()> {
333 if let ProviderConfig::AwsSes {
334 region,
335 access_key_id,
336 secret_access_key,
337 } = &self.email_config.provider_config
338 {
339 let from_email = &self.email_config.from_email;
340 let from_name = self
341 .email_config
342 .from_name
343 .as_deref()
344 .unwrap_or("AuthFramework");
345
346 let host = format!("email.{}.amazonaws.com", region);
347 let url = format!("https://{}/v2/email/outbound-emails", host);
348 let now = chrono::Utc::now();
349 let date_stamp = now.format("%Y%m%d").to_string();
350 let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
351
352 let payload = serde_json::json!({
353 "Content": {
354 "Simple": {
355 "Subject": { "Data": subject, "Charset": "UTF-8" },
356 "Body": { "Text": { "Data": body, "Charset": "UTF-8" } }
357 }
358 },
359 "Destination": {
360 "ToAddresses": [to_email]
361 },
362 "FromEmailAddress": format!("{} <{}>", from_name, from_email)
363 });
364 let payload_bytes = serde_json::to_vec(&payload).map_err(|e| {
365 AuthError::internal(format!("SES payload serialization failed: {}", e))
366 })?;
367
368 let payload_hash = ses_sha256_hex(&payload_bytes);
370 let canonical_headers = format!(
371 "content-type:application/json\nhost:{}\nx-amz-date:{}\n",
372 host, amz_date
373 );
374 let signed_headers = "content-type;host;x-amz-date";
375 let canonical_request = format!(
376 "POST\n/v2/email/outbound-emails\n\n{}\n{}\n{}",
377 canonical_headers, signed_headers, payload_hash
378 );
379
380 let credential_scope = format!("{}/{}/ses/aws4_request", date_stamp, region);
381 let string_to_sign = format!(
382 "AWS4-HMAC-SHA256\n{}\n{}\n{}",
383 amz_date,
384 credential_scope,
385 ses_sha256_hex(canonical_request.as_bytes())
386 );
387
388 let signing_key =
389 ses_sigv4_key(secret_access_key.as_bytes(), &date_stamp, region, "ses");
390 let signature = ses_hmac_sha256_hex(&signing_key, string_to_sign.as_bytes());
391
392 let authorization = format!(
393 "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
394 access_key_id, credential_scope, signed_headers, signature
395 );
396
397 let client = reqwest::Client::new();
398 let resp = client
399 .post(&url)
400 .header("Content-Type", "application/json")
401 .header("x-amz-date", &amz_date)
402 .header("Authorization", &authorization)
403 .body(payload_bytes)
404 .send()
405 .await
406 .map_err(|e| AuthError::internal(format!("AWS SES request failed: {}", e)))?;
407
408 let status = resp.status();
409 if status.is_success() {
410 debug!("AWS SES email sent successfully to {}", to_email);
411 Ok(())
412 } else {
413 let error_text = resp.text().await.unwrap_or_default();
414 Err(AuthError::internal(format!(
415 "AWS SES error ({}): {}",
416 status, error_text
417 )))
418 }
419 } else {
420 Err(AuthError::internal("Invalid AWS SES configuration"))
421 }
422 }
423
424 async fn send_via_smtp(&self, to_email: &str, subject: &str, body: &str) -> Result<()> {
426 if let ProviderConfig::Smtp {
427 host,
428 port,
429 username,
430 password,
431 use_tls,
432 } = &self.email_config.provider_config
433 {
434 use lettre::{
435 Message, SmtpTransport, Transport, transport::smtp::authentication::Credentials,
436 };
437
438 let from_address = self.email_config.from_email.clone();
439 let from_name = self
440 .email_config
441 .from_name
442 .clone()
443 .unwrap_or_else(|| "AuthFramework".to_string());
444
445 let email = Message::builder()
446 .from(
447 format!("{} <{}>", from_name, from_address)
448 .parse()
449 .map_err(|e| AuthError::internal(format!("Invalid from address: {}", e)))?,
450 )
451 .to(to_email
452 .parse()
453 .map_err(|e| AuthError::internal(format!("Invalid to address: {}", e)))?)
454 .subject(subject)
455 .body(body.to_string())
456 .map_err(|e| AuthError::internal(format!("Failed to build email: {}", e)))?;
457
458 let creds = Credentials::new(username.clone(), password.clone());
459
460 let host = host.clone();
461 let port = *port;
462 let use_tls = *use_tls;
463
464 let result = tokio::task::spawn_blocking(move || {
467 let transport = if use_tls {
468 SmtpTransport::relay(&host)
469 .map_err(|e| AuthError::internal(format!("SMTP relay error: {}", e)))?
470 .port(port)
471 .credentials(creds)
472 .build()
473 } else {
474 SmtpTransport::builder_dangerous(&host)
475 .port(port)
476 .credentials(creds)
477 .build()
478 };
479
480 transport
481 .send(&email)
482 .map_err(|e| AuthError::internal(format!("SMTP send failed: {}", e)))
483 })
484 .await
485 .map_err(|e| AuthError::internal(format!("SMTP task join error: {}", e)))?;
486
487 result?;
488 debug!("SMTP email sent successfully to {}", to_email);
489 Ok(())
490 } else {
491 Err(AuthError::internal("Invalid SMTP configuration"))
492 }
493 }
494
495 pub async fn has_email(&self, user_id: &str) -> Result<bool> {
497 let email_key = format!("user:{}:email", user_id);
498 match self.storage.get_kv(&email_key).await {
499 Ok(Some(_)) => Ok(true),
500 Ok(None) => Ok(false),
501 Err(_) => Ok(false), }
503 }
504
505 pub async fn send_email_code(&self, user_id: &str) -> Result<String> {
507 let code = format!("{:06}", rand::random::<u32>() % 1_000_000);
509
510 self.send_code(user_id, &code).await?;
512
513 let email_key = format!("email_code:{}", user_id);
515 self.storage
516 .store_kv(
517 &email_key,
518 code.as_bytes(),
519 Some(std::time::Duration::from_secs(300)),
520 )
521 .await?;
522
523 Ok(code)
524 }
525}
526
527fn ses_sha256_hex(data: &[u8]) -> String {
530 use ring::digest;
531 let d = digest::digest(&digest::SHA256, data);
532 hex::encode(d.as_ref())
533}
534
535fn ses_hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
536 use ring::hmac;
537 let s_key = hmac::Key::new(hmac::HMAC_SHA256, key);
538 hmac::sign(&s_key, data).as_ref().to_vec()
539}
540
541fn ses_hmac_sha256_hex(key: &[u8], data: &[u8]) -> String {
542 hex::encode(ses_hmac_sha256(key, data))
543}
544
545fn ses_sigv4_key(secret: &[u8], date_stamp: &str, region: &str, service: &str) -> Vec<u8> {
546 let k_date = ses_hmac_sha256(&[b"AWS4", secret].concat(), date_stamp.as_bytes());
547 let k_region = ses_hmac_sha256(&k_date, region.as_bytes());
548 let k_service = ses_hmac_sha256(&k_region, service.as_bytes());
549 ses_hmac_sha256(&k_service, b"aws4_request")
550}