use std::time::Duration;
use regex::Regex;
use serde::{Deserialize, Serialize};
use crate::{
cache::Cache,
dns::DnsClient,
error::{MailGuardError, Result},
threat::ThreatType,
};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct EmailStatus {
pub email: String,
pub domain: String,
pub is_threat: bool,
pub threat_type: Option<ThreatType>,
pub from_cache: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DomainStatus {
pub domain: String,
pub is_threat: bool,
pub threat_type: Option<ThreatType>,
pub from_cache: bool,
}
#[derive(Debug, Clone)]
pub struct MailGuardConfig {
pub dns_timeout: Duration,
pub enable_cache: bool,
pub cache_ttl: Duration,
}
impl Default for MailGuardConfig {
fn default() -> Self {
Self {
dns_timeout: Duration::from_secs(5),
enable_cache: true,
cache_ttl: Duration::from_secs(300), }
}
}
pub struct MailGuard {
dns_client: DnsClient,
cache: Option<Cache>,
email_regex: Regex,
#[allow(dead_code)]
config: MailGuardConfig,
}
impl MailGuard {
pub fn new() -> Self {
Self::with_config(MailGuardConfig::default())
}
pub fn with_config(config: MailGuardConfig) -> Self {
let dns_client = DnsClient::with_timeout(config.dns_timeout);
let cache = if config.enable_cache {
Some(Cache::with_ttl(config.cache_ttl))
} else {
None
};
let email_regex = Regex::new(
r"^[a-zA-Z0-9.!#$%&'*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$"
).expect("Invalid email regex");
Self {
dns_client,
cache,
email_regex,
config,
}
}
pub async fn check_email(&self, email: &str) -> Result<EmailStatus> {
if !self.email_regex.is_match(email) {
return Err(MailGuardError::InvalidEmail(email.to_string()));
}
let domain = self.extract_domain(email)?;
let domain_status = self.check_domain(&domain).await?;
Ok(EmailStatus {
email: email.to_string(),
domain: domain_status.domain,
is_threat: domain_status.is_threat,
threat_type: domain_status.threat_type,
from_cache: domain_status.from_cache,
})
}
pub async fn check_domain(&self, domain: &str) -> Result<DomainStatus> {
self.dns_client.validate_domain(domain)?;
let domain = domain.to_lowercase();
if let Some(cache) = &self.cache
&& let Some(cached_threat) = cache.get(&domain)
{
return Ok(DomainStatus {
domain: domain.clone(),
is_threat: cached_threat.is_some(),
threat_type: cached_threat,
from_cache: true,
});
}
let threat_type = self.dns_client.query_surbl(&domain).await?;
if let Some(cache) = &self.cache {
cache.set(domain.clone(), threat_type.clone());
}
Ok(DomainStatus {
domain,
is_threat: threat_type.is_some(),
threat_type,
from_cache: false,
})
}
pub async fn check_emails_batch(&self, emails: &[&str]) -> Vec<Result<EmailStatus>> {
let mut results = Vec::with_capacity(emails.len());
for email in emails {
let result = self.check_email(email).await;
results.push(result);
}
results
}
pub async fn check_domains_batch(&self, domains: &[&str]) -> Vec<Result<DomainStatus>> {
let mut results = Vec::with_capacity(domains.len());
for domain in domains {
let result = self.check_domain(domain).await;
results.push(result);
}
results
}
fn extract_domain(&self, email: &str) -> Result<String> {
if let Some(at_pos) = email.rfind('@') {
let domain = &email[at_pos + 1..];
if domain.is_empty() {
return Err(MailGuardError::InvalidEmail("邮箱域名为空".to_string()));
}
Ok(domain.to_string())
} else {
Err(MailGuardError::InvalidEmail("邮箱格式无效".to_string()))
}
}
pub fn cleanup_cache(&self) {
if let Some(cache) = &self.cache {
cache.cleanup_expired();
}
}
pub fn cache_stats(&self) -> Option<usize> {
self.cache.as_ref().map(|cache| cache.size())
}
pub fn clear_cache(&self) {
if let Some(cache) = &self.cache {
cache.clear();
}
}
}
impl Default for MailGuard {
fn default() -> Self {
Self::new()
}
}