mockforge_smtp/
server.rs

1//! SMTP server implementation
2
3use crate::{SmtpConfig, SmtpSpecRegistry};
4use mockforge_core::protocol_abstraction::{
5    MessagePattern, MiddlewareChain, Protocol, ProtocolRequest, SpecRegistry,
6};
7use mockforge_core::Result;
8use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::net::{TcpListener, TcpStream};
13use tokio_rustls::{rustls, TlsAcceptor};
14use tracing::{debug, error, info, warn};
15
16/// SMTP server
17pub struct SmtpServer {
18    config: SmtpConfig,
19    spec_registry: Arc<SmtpSpecRegistry>,
20    middleware_chain: Arc<MiddlewareChain>,
21    #[allow(dead_code)]
22    tls_acceptor: Option<TlsAcceptor>,
23}
24
25impl SmtpServer {
26    /// Create a new SMTP server
27    pub fn new(config: SmtpConfig, spec_registry: Arc<SmtpSpecRegistry>) -> Result<Self> {
28        let middleware_chain = Arc::new(MiddlewareChain::new());
29
30        let tls_acceptor = if config.enable_starttls {
31            Some(Self::load_tls_acceptor(&config)?)
32        } else {
33            None
34        };
35
36        Ok(Self {
37            config,
38            spec_registry,
39            middleware_chain,
40            tls_acceptor,
41        })
42    }
43
44    /// Load TLS acceptor from certificate and key files
45    fn load_tls_acceptor(config: &SmtpConfig) -> Result<TlsAcceptor> {
46        use rustls_pemfile::{certs, pkcs8_private_keys};
47        use std::fs::File;
48        use std::io::BufReader;
49
50        let cert_path = config
51            .tls_cert_path
52            .as_ref()
53            .ok_or_else(|| mockforge_core::Error::generic("TLS certificate path not configured"))?;
54        let key_path = config
55            .tls_key_path
56            .as_ref()
57            .ok_or_else(|| mockforge_core::Error::generic("TLS private key path not configured"))?;
58
59        // Load certificate
60        let cert_file = File::open(cert_path)?;
61        let mut cert_reader = BufReader::new(cert_file);
62        let certs: Vec<Vec<u8>> = certs(&mut cert_reader)?;
63        // Use rustls types from tokio-rustls for compatibility
64        let certs: Vec<rustls::Certificate> = certs.into_iter().map(rustls::Certificate).collect();
65
66        // Load private key
67        let key_file = File::open(key_path)?;
68        let mut key_reader = BufReader::new(key_file);
69        let mut keys: Vec<Vec<u8>> = pkcs8_private_keys(&mut key_reader)?;
70
71        if keys.is_empty() {
72            return Err(mockforge_core::Error::generic("No private keys found"));
73        }
74
75        // Use rustls from tokio-rustls which has compatible API
76        let mut server_config = rustls::ServerConfig::builder()
77            .with_safe_defaults()
78            .with_no_client_auth()
79            .with_single_cert(certs, rustls::PrivateKey(keys.remove(0)))
80            .map_err(|e| mockforge_core::Error::generic(format!("TLS config error: {}", e)))?;
81
82        server_config.alpn_protocols = vec![b"smtp".to_vec()];
83
84        Ok(TlsAcceptor::from(Arc::new(server_config)))
85    }
86
87    /// Create a new SMTP server with custom middleware
88    pub fn with_middleware(
89        config: SmtpConfig,
90        spec_registry: Arc<SmtpSpecRegistry>,
91        middleware_chain: Arc<MiddlewareChain>,
92    ) -> Result<Self> {
93        let tls_acceptor = if config.enable_starttls {
94            Some(Self::load_tls_acceptor(&config)?)
95        } else {
96            None
97        };
98
99        Ok(Self {
100            config,
101            spec_registry,
102            middleware_chain,
103            tls_acceptor,
104        })
105    }
106
107    /// Start the SMTP server
108    pub async fn start(&self) -> Result<()> {
109        let addr = format!("{}:{}", self.config.host, self.config.port);
110        let listener = TcpListener::bind(&addr).await?;
111
112        info!("SMTP server listening on {}", addr);
113
114        loop {
115            match listener.accept().await {
116                Ok((stream, peer_addr)) => {
117                    debug!("New SMTP connection from {}", peer_addr);
118
119                    let registry = self.spec_registry.clone();
120                    let middleware = self.middleware_chain.clone();
121                    let hostname = self.config.hostname.clone();
122
123                    tokio::spawn(async move {
124                        if let Err(e) =
125                            handle_smtp_session(stream, peer_addr, registry, middleware, hostname)
126                                .await
127                        {
128                            error!("SMTP session error from {}: {}", peer_addr, e);
129                        }
130                    });
131                }
132                Err(e) => {
133                    error!("Failed to accept SMTP connection: {}", e);
134                }
135            }
136        }
137    }
138}
139
140/// Handle a single SMTP session
141async fn handle_smtp_session(
142    stream: TcpStream,
143    peer_addr: SocketAddr,
144    registry: Arc<SmtpSpecRegistry>,
145    middleware: Arc<MiddlewareChain>,
146    hostname: String,
147) -> Result<()> {
148    let (reader, mut writer) = stream.into_split();
149    let mut reader = BufReader::new(reader);
150
151    // Send greeting
152    let greeting = format!("220 {} ESMTP MockForge SMTP Server\r\n", hostname);
153    writer.write_all(greeting.as_bytes()).await?;
154
155    let mut session_state = SessionState::new();
156    let mut line = String::new();
157
158    while reader.read_line(&mut line).await? > 0 {
159        let command = line.trim();
160        debug!("SMTP command from {}: {}", peer_addr, command);
161
162        if command.is_empty() {
163            line.clear();
164            continue;
165        }
166
167        // Parse and handle SMTP command
168        match handle_smtp_command(
169            command,
170            &mut session_state,
171            &mut writer,
172            &hostname,
173            &registry,
174            &middleware,
175            peer_addr,
176        )
177        .await
178        {
179            Ok(should_continue) => {
180                if !should_continue {
181                    debug!("SMTP session ended for {}", peer_addr);
182                    break;
183                }
184            }
185            Err(e) => {
186                error!("Error handling SMTP command: {}", e);
187                let error_response = "500 Internal server error\r\n";
188                writer.write_all(error_response.as_bytes()).await?;
189            }
190        }
191
192        line.clear();
193    }
194
195    Ok(())
196}
197
198/// Handle a single SMTP command
199async fn handle_smtp_command<W: AsyncWriteExt + Unpin>(
200    command: &str,
201    state: &mut SessionState,
202    writer: &mut W,
203    hostname: &str,
204    registry: &Arc<SmtpSpecRegistry>,
205    middleware: &Arc<MiddlewareChain>,
206    peer_addr: SocketAddr,
207) -> Result<bool> {
208    let parts: Vec<&str> = command.splitn(2, ' ').collect();
209    let cmd = parts[0].to_uppercase();
210
211    match cmd.as_str() {
212        "HELLO" | "EHLO" => {
213            let domain = parts.get(1).unwrap_or(&hostname);
214            let response = if cmd == "EHLO" {
215                format!(
216                    "250-{} Hello {}\r\n250-SIZE 10485760\r\n250-8BITMIME\r\n250-STARTTLS\r\n250 HELP\r\n",
217                    hostname, domain
218                )
219            } else {
220                format!("250 {} Hello {}\r\n", hostname, domain)
221            };
222            writer.write_all(response.as_bytes()).await?;
223            Ok(true)
224        }
225
226        "MAIL" => {
227            if let Some(from_part) = parts.get(1) {
228                // Parse MAIL FROM:<address>
229                let from = extract_email_address(from_part);
230                state.mail_from = Some(from);
231                writer.write_all(b"250 OK\r\n").await?;
232            } else {
233                writer.write_all(b"501 Syntax error in parameters\r\n").await?;
234            }
235            Ok(true)
236        }
237
238        "RCPT" => {
239            if let Some(to_part) = parts.get(1) {
240                // Parse RCPT TO:<address>
241                let to = extract_email_address(to_part);
242                state.rcpt_to.push(to);
243                writer.write_all(b"250 OK\r\n").await?;
244            } else {
245                writer.write_all(b"501 Syntax error in parameters\r\n").await?;
246            }
247            Ok(true)
248        }
249
250        "DATA" => {
251            writer.write_all(b"354 Start mail input; end with <CRLF>.<CRLF>\r\n").await?;
252            state.in_data_mode = true;
253            Ok(true)
254        }
255
256        "RSET" => {
257            state.reset();
258            writer.write_all(b"250 OK\r\n").await?;
259            Ok(true)
260        }
261
262        "NOOP" => {
263            writer.write_all(b"250 OK\r\n").await?;
264            Ok(true)
265        }
266
267        "QUIT" => {
268            writer.write_all(b"221 Bye\r\n").await?;
269            Ok(false) // End session
270        }
271
272        "STARTTLS" => {
273            // Mock STARTTLS implementation - accept but don't actually upgrade
274            writer.write_all(b"220 Ready to start TLS\r\n").await?;
275            Ok(true)
276        }
277
278        "HELP" => {
279            let help_text = "214-Commands supported:\r\n\
280                            214-  HELLO EHLO MAIL RCPT DATA\r\n\
281                            214-  RSET NOOP QUIT HELP STARTTLS\r\n\
282                            214 End of HELP info\r\n";
283            writer.write_all(help_text.as_bytes()).await?;
284            Ok(true)
285        }
286
287        _ => {
288            // Handle data mode or unknown command
289            if state.in_data_mode {
290                if command == "." {
291                    // End of data
292                    state.in_data_mode = false;
293
294                    // Process the email
295                    let response = process_email(state, registry, middleware, peer_addr).await?;
296
297                    writer.write_all(response.as_bytes()).await?;
298                    state.reset();
299                } else {
300                    // Accumulate email data
301                    state.data.push_str(command);
302                    state.data.push('\n');
303                }
304                Ok(true)
305            } else {
306                warn!("Unknown SMTP command: {}", command);
307                writer.write_all(b"502 Command not implemented\r\n").await?;
308                Ok(true)
309            }
310        }
311    }
312}
313
314/// Process received email and generate response
315async fn process_email(
316    state: &SessionState,
317    registry: &Arc<SmtpSpecRegistry>,
318    middleware: &Arc<MiddlewareChain>,
319    peer_addr: SocketAddr,
320) -> Result<String> {
321    let from = state
322        .mail_from
323        .as_ref()
324        .ok_or_else(|| mockforge_core::Error::generic("Missing MAIL FROM"))?;
325    let to = state.rcpt_to.join(", ");
326
327    // Extract subject from data
328    let subject = extract_subject(&state.data);
329
330    // Create protocol request
331    let mut request = ProtocolRequest {
332        protocol: Protocol::Smtp,
333        pattern: MessagePattern::OneWay,
334        operation: "SEND".to_string(),
335        path: from.clone(),
336        topic: None,
337        routing_key: None,
338        partition: None,
339        qos: None,
340        metadata: HashMap::from([
341            ("from".to_string(), from.clone()),
342            ("to".to_string(), to.clone()),
343            ("subject".to_string(), subject.clone()),
344        ]),
345        body: Some(state.data.as_bytes().to_vec()),
346        client_ip: Some(peer_addr.ip().to_string()),
347    };
348
349    // Process through middleware
350    middleware.process_request(&mut request).await?;
351
352    // Generate response
353    let mut response = registry.generate_mock_response(&request)?;
354
355    // Process response through middleware
356    middleware.process_response(&request, &mut response).await?;
357
358    // Return SMTP response
359    Ok(String::from_utf8_lossy(&response.body).to_string())
360}
361
362/// Extract email address from SMTP command parameter
363fn extract_email_address(param: &str) -> String {
364    // Handle formats like "FROM:<user@example.com>" or "TO:<user@example.com>"
365    if let Some(start) = param.find('<') {
366        if let Some(end) = param.find('>') {
367            return param[start + 1..end].to_string();
368        }
369    }
370
371    // If no angle brackets, just trim and return
372    param.trim().to_string()
373}
374
375/// Extract subject from email data
376fn extract_subject(data: &str) -> String {
377    for line in data.lines() {
378        if line.to_lowercase().starts_with("subject:") {
379            return line[8..].trim().to_string();
380        }
381    }
382    String::new()
383}
384
385/// Session state for SMTP connection
386struct SessionState {
387    mail_from: Option<String>,
388    rcpt_to: Vec<String>,
389    data: String,
390    in_data_mode: bool,
391}
392
393impl SessionState {
394    fn new() -> Self {
395        Self {
396            mail_from: None,
397            rcpt_to: Vec::new(),
398            data: String::new(),
399            in_data_mode: false,
400        }
401    }
402
403    fn reset(&mut self) {
404        self.mail_from = None;
405        self.rcpt_to.clear();
406        self.data.clear();
407        self.in_data_mode = false;
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_extract_email_address() {
417        assert_eq!(extract_email_address("FROM:<user@example.com>"), "user@example.com");
418        assert_eq!(extract_email_address("TO:<admin@test.com>"), "admin@test.com");
419        assert_eq!(extract_email_address("user@example.com"), "user@example.com");
420    }
421
422    #[test]
423    fn test_extract_email_address_whitespace() {
424        assert_eq!(extract_email_address("  user@example.com  "), "user@example.com");
425    }
426
427    #[test]
428    fn test_extract_email_address_no_brackets() {
429        assert_eq!(extract_email_address("plain@email.com"), "plain@email.com");
430    }
431
432    #[test]
433    fn test_extract_email_address_mail_from_format() {
434        assert_eq!(extract_email_address("FROM:<sender@domain.com>"), "sender@domain.com");
435    }
436
437    #[test]
438    fn test_extract_subject() {
439        let data =
440            "From: sender@example.com\nSubject: Test Email\nTo: recipient@example.com\n\nBody text";
441        assert_eq!(extract_subject(data), "Test Email");
442    }
443
444    #[test]
445    fn test_extract_subject_not_found() {
446        let data = "From: sender@example.com\nTo: recipient@example.com\n\nBody text";
447        assert_eq!(extract_subject(data), "");
448    }
449
450    #[test]
451    fn test_extract_subject_lowercase() {
452        let data = "subject: lowercase subject\nFrom: sender@example.com";
453        assert_eq!(extract_subject(data), "lowercase subject");
454    }
455
456    #[test]
457    fn test_extract_subject_mixed_case() {
458        let data = "SUBJECT: UPPERCASE SUBJECT\nFrom: sender@example.com";
459        assert_eq!(extract_subject(data), "UPPERCASE SUBJECT");
460    }
461
462    #[test]
463    fn test_session_state() {
464        let mut state = SessionState::new();
465        assert!(state.mail_from.is_none());
466        assert_eq!(state.rcpt_to.len(), 0);
467
468        state.mail_from = Some("sender@example.com".to_string());
469        state.rcpt_to.push("recipient@example.com".to_string());
470
471        state.reset();
472        assert!(state.mail_from.is_none());
473        assert_eq!(state.rcpt_to.len(), 0);
474    }
475
476    #[test]
477    fn test_session_state_new() {
478        let state = SessionState::new();
479        assert!(state.mail_from.is_none());
480        assert!(state.rcpt_to.is_empty());
481        assert!(state.data.is_empty());
482        assert!(!state.in_data_mode);
483    }
484
485    #[test]
486    fn test_session_state_reset() {
487        let mut state = SessionState::new();
488        state.mail_from = Some("test@example.com".to_string());
489        state.rcpt_to.push("recipient1@example.com".to_string());
490        state.rcpt_to.push("recipient2@example.com".to_string());
491        state.data = "Email body content".to_string();
492        state.in_data_mode = true;
493
494        state.reset();
495
496        assert!(state.mail_from.is_none());
497        assert!(state.rcpt_to.is_empty());
498        assert!(state.data.is_empty());
499        assert!(!state.in_data_mode);
500    }
501
502    #[test]
503    fn test_session_state_multiple_recipients() {
504        let mut state = SessionState::new();
505        state.rcpt_to.push("a@example.com".to_string());
506        state.rcpt_to.push("b@example.com".to_string());
507        state.rcpt_to.push("c@example.com".to_string());
508        assert_eq!(state.rcpt_to.len(), 3);
509    }
510
511    #[test]
512    fn test_session_state_data_accumulation() {
513        let mut state = SessionState::new();
514        state.data.push_str("Line 1\n");
515        state.data.push_str("Line 2\n");
516        state.data.push_str("Line 3\n");
517        assert_eq!(state.data, "Line 1\nLine 2\nLine 3\n");
518    }
519
520    #[tokio::test]
521    async fn test_smtp_server_new() {
522        let config = SmtpConfig::default();
523        let registry = Arc::new(SmtpSpecRegistry::new());
524        let server = SmtpServer::new(config, registry);
525        assert!(server.is_ok());
526    }
527
528    #[tokio::test]
529    async fn test_smtp_server_with_middleware() {
530        let config = SmtpConfig::default();
531        let registry = Arc::new(SmtpSpecRegistry::new());
532        let middleware = Arc::new(MiddlewareChain::new());
533        let server = SmtpServer::with_middleware(config, registry, middleware);
534        assert!(server.is_ok());
535    }
536}