mockforge_smtp/
server.rs

1//! SMTP server implementation
2
3use crate::{SmtpConfig, SmtpSpecRegistry};
4use mockforge_core::protocol_abstraction::{
5    MessagePattern, MiddlewareChain, Protocol, ProtocolRequest, SpecRegistry,
6};
7use mockforge_core::Result;
8use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::net::{TcpListener, TcpStream};
13use tokio_rustls::TlsAcceptor;
14use tracing::{debug, error, info, warn};
15
16/// SMTP server
17pub struct SmtpServer {
18    config: SmtpConfig,
19    spec_registry: Arc<SmtpSpecRegistry>,
20    middleware_chain: Arc<MiddlewareChain>,
21    #[allow(dead_code)]
22    tls_acceptor: Option<TlsAcceptor>,
23}
24
25impl SmtpServer {
26    /// Create a new SMTP server
27    pub fn new(config: SmtpConfig, spec_registry: Arc<SmtpSpecRegistry>) -> Result<Self> {
28        let middleware_chain = Arc::new(MiddlewareChain::new());
29
30        let tls_acceptor = if config.enable_starttls {
31            Some(Self::load_tls_acceptor(&config)?)
32        } else {
33            None
34        };
35
36        Ok(Self {
37            config,
38            spec_registry,
39            middleware_chain,
40            tls_acceptor,
41        })
42    }
43
44    /// Load TLS acceptor from certificate and key files
45    fn load_tls_acceptor(config: &SmtpConfig) -> Result<TlsAcceptor> {
46        use rustls_pemfile::{certs, pkcs8_private_keys};
47        use std::fs::File;
48        use std::io::BufReader;
49
50        let cert_path = config
51            .tls_cert_path
52            .as_ref()
53            .ok_or_else(|| mockforge_core::Error::generic("TLS certificate path not configured"))?;
54        let key_path = config
55            .tls_key_path
56            .as_ref()
57            .ok_or_else(|| mockforge_core::Error::generic("TLS private key path not configured"))?;
58
59        // Load certificate
60        let cert_file = File::open(cert_path)?;
61        let mut cert_reader = BufReader::new(cert_file);
62        let certs: Vec<Vec<u8>> = certs(&mut cert_reader)?;
63        let certs = certs.into_iter().map(rustls::Certificate).collect();
64
65        // Load private key
66        let key_file = File::open(key_path)?;
67        let mut key_reader = BufReader::new(key_file);
68        let mut keys: Vec<Vec<u8>> = pkcs8_private_keys(&mut key_reader)?;
69
70        if keys.is_empty() {
71            return Err(mockforge_core::Error::generic("No private keys found"));
72        }
73
74        let mut server_config = rustls::ServerConfig::builder()
75            .with_safe_defaults()
76            .with_no_client_auth()
77            .with_single_cert(certs, rustls::PrivateKey(keys.remove(0)))
78            .map_err(|e| mockforge_core::Error::generic(format!("TLS config error: {}", e)))?;
79
80        server_config.alpn_protocols = vec![b"smtp".to_vec()];
81
82        Ok(TlsAcceptor::from(Arc::new(server_config)))
83    }
84
85    /// Create a new SMTP server with custom middleware
86    pub fn with_middleware(
87        config: SmtpConfig,
88        spec_registry: Arc<SmtpSpecRegistry>,
89        middleware_chain: Arc<MiddlewareChain>,
90    ) -> Result<Self> {
91        let tls_acceptor = if config.enable_starttls {
92            Some(Self::load_tls_acceptor(&config)?)
93        } else {
94            None
95        };
96
97        Ok(Self {
98            config,
99            spec_registry,
100            middleware_chain,
101            tls_acceptor,
102        })
103    }
104
105    /// Start the SMTP server
106    pub async fn start(&self) -> Result<()> {
107        let addr = format!("{}:{}", self.config.host, self.config.port);
108        let listener = TcpListener::bind(&addr).await?;
109
110        info!("SMTP server listening on {}", addr);
111
112        loop {
113            match listener.accept().await {
114                Ok((stream, peer_addr)) => {
115                    debug!("New SMTP connection from {}", peer_addr);
116
117                    let registry = self.spec_registry.clone();
118                    let middleware = self.middleware_chain.clone();
119                    let hostname = self.config.hostname.clone();
120
121                    tokio::spawn(async move {
122                        if let Err(e) =
123                            handle_smtp_session(stream, peer_addr, registry, middleware, hostname)
124                                .await
125                        {
126                            error!("SMTP session error from {}: {}", peer_addr, e);
127                        }
128                    });
129                }
130                Err(e) => {
131                    error!("Failed to accept SMTP connection: {}", e);
132                }
133            }
134        }
135    }
136}
137
138/// Handle a single SMTP session
139async fn handle_smtp_session(
140    stream: TcpStream,
141    peer_addr: SocketAddr,
142    registry: Arc<SmtpSpecRegistry>,
143    middleware: Arc<MiddlewareChain>,
144    hostname: String,
145) -> Result<()> {
146    let (reader, mut writer) = stream.into_split();
147    let mut reader = BufReader::new(reader);
148
149    // Send greeting
150    let greeting = format!("220 {} ESMTP MockForge SMTP Server\r\n", hostname);
151    writer.write_all(greeting.as_bytes()).await?;
152
153    let mut session_state = SessionState::new();
154    let mut line = String::new();
155
156    while reader.read_line(&mut line).await? > 0 {
157        let command = line.trim();
158        debug!("SMTP command from {}: {}", peer_addr, command);
159
160        if command.is_empty() {
161            line.clear();
162            continue;
163        }
164
165        // Parse and handle SMTP command
166        match handle_smtp_command(
167            command,
168            &mut session_state,
169            &mut writer,
170            &hostname,
171            &registry,
172            &middleware,
173            peer_addr,
174        )
175        .await
176        {
177            Ok(should_continue) => {
178                if !should_continue {
179                    debug!("SMTP session ended for {}", peer_addr);
180                    break;
181                }
182            }
183            Err(e) => {
184                error!("Error handling SMTP command: {}", e);
185                let error_response = "500 Internal server error\r\n";
186                writer.write_all(error_response.as_bytes()).await?;
187            }
188        }
189
190        line.clear();
191    }
192
193    Ok(())
194}
195
196/// Handle a single SMTP command
197async fn handle_smtp_command<W: AsyncWriteExt + Unpin>(
198    command: &str,
199    state: &mut SessionState,
200    writer: &mut W,
201    hostname: &str,
202    registry: &Arc<SmtpSpecRegistry>,
203    middleware: &Arc<MiddlewareChain>,
204    peer_addr: SocketAddr,
205) -> Result<bool> {
206    let parts: Vec<&str> = command.splitn(2, ' ').collect();
207    let cmd = parts[0].to_uppercase();
208
209    match cmd.as_str() {
210        "HELLO" | "EHLO" => {
211            let domain = parts.get(1).unwrap_or(&hostname);
212            let response = if cmd == "EHLO" {
213                format!(
214                    "250-{} Hello {}\r\n250-SIZE 10485760\r\n250-8BITMIME\r\n250-STARTTLS\r\n250 HELP\r\n",
215                    hostname, domain
216                )
217            } else {
218                format!("250 {} Hello {}\r\n", hostname, domain)
219            };
220            writer.write_all(response.as_bytes()).await?;
221            Ok(true)
222        }
223
224        "MAIL" => {
225            if let Some(from_part) = parts.get(1) {
226                // Parse MAIL FROM:<address>
227                let from = extract_email_address(from_part);
228                state.mail_from = Some(from);
229                writer.write_all(b"250 OK\r\n").await?;
230            } else {
231                writer.write_all(b"501 Syntax error in parameters\r\n").await?;
232            }
233            Ok(true)
234        }
235
236        "RCPT" => {
237            if let Some(to_part) = parts.get(1) {
238                // Parse RCPT TO:<address>
239                let to = extract_email_address(to_part);
240                state.rcpt_to.push(to);
241                writer.write_all(b"250 OK\r\n").await?;
242            } else {
243                writer.write_all(b"501 Syntax error in parameters\r\n").await?;
244            }
245            Ok(true)
246        }
247
248        "DATA" => {
249            writer.write_all(b"354 Start mail input; end with <CRLF>.<CRLF>\r\n").await?;
250            state.in_data_mode = true;
251            Ok(true)
252        }
253
254        "RSET" => {
255            state.reset();
256            writer.write_all(b"250 OK\r\n").await?;
257            Ok(true)
258        }
259
260        "NOOP" => {
261            writer.write_all(b"250 OK\r\n").await?;
262            Ok(true)
263        }
264
265        "QUIT" => {
266            writer.write_all(b"221 Bye\r\n").await?;
267            Ok(false) // End session
268        }
269
270        "STARTTLS" => {
271            // Mock STARTTLS implementation - accept but don't actually upgrade
272            writer.write_all(b"220 Ready to start TLS\r\n").await?;
273            Ok(true)
274        }
275
276        "HELP" => {
277            let help_text = "214-Commands supported:\r\n\
278                            214-  HELLO EHLO MAIL RCPT DATA\r\n\
279                            214-  RSET NOOP QUIT HELP STARTTLS\r\n\
280                            214 End of HELP info\r\n";
281            writer.write_all(help_text.as_bytes()).await?;
282            Ok(true)
283        }
284
285        _ => {
286            // Handle data mode or unknown command
287            if state.in_data_mode {
288                if command == "." {
289                    // End of data
290                    state.in_data_mode = false;
291
292                    // Process the email
293                    let response = process_email(state, registry, middleware, peer_addr).await?;
294
295                    writer.write_all(response.as_bytes()).await?;
296                    state.reset();
297                } else {
298                    // Accumulate email data
299                    state.data.push_str(command);
300                    state.data.push('\n');
301                }
302                Ok(true)
303            } else {
304                warn!("Unknown SMTP command: {}", command);
305                writer.write_all(b"502 Command not implemented\r\n").await?;
306                Ok(true)
307            }
308        }
309    }
310}
311
312/// Process received email and generate response
313async fn process_email(
314    state: &SessionState,
315    registry: &Arc<SmtpSpecRegistry>,
316    middleware: &Arc<MiddlewareChain>,
317    peer_addr: SocketAddr,
318) -> Result<String> {
319    let from = state
320        .mail_from
321        .as_ref()
322        .ok_or_else(|| mockforge_core::Error::generic("Missing MAIL FROM"))?;
323    let to = state.rcpt_to.join(", ");
324
325    // Extract subject from data
326    let subject = extract_subject(&state.data);
327
328    // Create protocol request
329    let mut request = ProtocolRequest {
330        protocol: Protocol::Smtp,
331        pattern: MessagePattern::OneWay,
332        operation: "SEND".to_string(),
333        path: from.clone(),
334        topic: None,
335        routing_key: None,
336        partition: None,
337        qos: None,
338        metadata: HashMap::from([
339            ("from".to_string(), from.clone()),
340            ("to".to_string(), to.clone()),
341            ("subject".to_string(), subject.clone()),
342        ]),
343        body: Some(state.data.as_bytes().to_vec()),
344        client_ip: Some(peer_addr.ip().to_string()),
345    };
346
347    // Process through middleware
348    middleware.process_request(&mut request).await?;
349
350    // Generate response
351    let mut response = registry.generate_mock_response(&request)?;
352
353    // Process response through middleware
354    middleware.process_response(&request, &mut response).await?;
355
356    // Return SMTP response
357    Ok(String::from_utf8_lossy(&response.body).to_string())
358}
359
360/// Extract email address from SMTP command parameter
361fn extract_email_address(param: &str) -> String {
362    // Handle formats like "FROM:<user@example.com>" or "TO:<user@example.com>"
363    if let Some(start) = param.find('<') {
364        if let Some(end) = param.find('>') {
365            return param[start + 1..end].to_string();
366        }
367    }
368
369    // If no angle brackets, just trim and return
370    param.trim().to_string()
371}
372
373/// Extract subject from email data
374fn extract_subject(data: &str) -> String {
375    for line in data.lines() {
376        if line.to_lowercase().starts_with("subject:") {
377            return line[8..].trim().to_string();
378        }
379    }
380    String::new()
381}
382
383/// Session state for SMTP connection
384struct SessionState {
385    mail_from: Option<String>,
386    rcpt_to: Vec<String>,
387    data: String,
388    in_data_mode: bool,
389}
390
391impl SessionState {
392    fn new() -> Self {
393        Self {
394            mail_from: None,
395            rcpt_to: Vec::new(),
396            data: String::new(),
397            in_data_mode: false,
398        }
399    }
400
401    fn reset(&mut self) {
402        self.mail_from = None;
403        self.rcpt_to.clear();
404        self.data.clear();
405        self.in_data_mode = false;
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    #[test]
414    fn test_extract_email_address() {
415        assert_eq!(extract_email_address("FROM:<user@example.com>"), "user@example.com");
416        assert_eq!(extract_email_address("TO:<admin@test.com>"), "admin@test.com");
417        assert_eq!(extract_email_address("user@example.com"), "user@example.com");
418    }
419
420    #[test]
421    fn test_extract_subject() {
422        let data =
423            "From: sender@example.com\nSubject: Test Email\nTo: recipient@example.com\n\nBody text";
424        assert_eq!(extract_subject(data), "Test Email");
425    }
426
427    #[test]
428    fn test_session_state() {
429        let mut state = SessionState::new();
430        assert!(state.mail_from.is_none());
431        assert_eq!(state.rcpt_to.len(), 0);
432
433        state.mail_from = Some("sender@example.com".to_string());
434        state.rcpt_to.push("recipient@example.com".to_string());
435
436        state.reset();
437        assert!(state.mail_from.is_none());
438        assert_eq!(state.rcpt_to.len(), 0);
439    }
440}