Skip to main content

mailrs_smtp_client/
connection.rs

1use std::io;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use rustls::ClientConfig;
7use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
8use tokio::net::TcpStream;
9use tokio_rustls::client::TlsStream;
10use tokio_rustls::TlsConnector;
11
12use crate::mx::{format_mail_from, format_rcpt_to};
13use crate::response::{parse_response, SmtpResponse};
14use crate::tls_outcome::{StarttlsResult, TlsOutcome};
15
16/// Connection timeout configuration.
17#[derive(Debug, Clone)]
18pub struct TimeoutConfig {
19    /// Time to wait for the TCP handshake.
20    pub connect: std::time::Duration,
21    /// Time to wait for the server's `220` greeting after connect.
22    pub greeting: std::time::Duration,
23    /// Time to wait for a response to each SMTP command.
24    pub command: std::time::Duration,
25}
26
27impl Default for TimeoutConfig {
28    fn default() -> Self {
29        Self {
30            connect: std::time::Duration::from_secs(30),
31            greeting: std::time::Duration::from_secs(30),
32            command: std::time::Duration::from_secs(60),
33        }
34    }
35}
36
37enum Transport {
38    Plain(TcpStream),
39    Tls(Box<TlsStream<TcpStream>>),
40}
41
42impl AsyncRead for Transport {
43    fn poll_read(
44        self: Pin<&mut Self>,
45        cx: &mut Context<'_>,
46        buf: &mut ReadBuf<'_>,
47    ) -> Poll<io::Result<()>> {
48        match self.get_mut() {
49            Transport::Plain(s) => Pin::new(s).poll_read(cx, buf),
50            Transport::Tls(s) => Pin::new(s).poll_read(cx, buf),
51        }
52    }
53}
54
55impl AsyncWrite for Transport {
56    fn poll_write(
57        self: Pin<&mut Self>,
58        cx: &mut Context<'_>,
59        buf: &[u8],
60    ) -> Poll<io::Result<usize>> {
61        match self.get_mut() {
62            Transport::Plain(s) => Pin::new(s).poll_write(cx, buf),
63            Transport::Tls(s) => Pin::new(s).poll_write(cx, buf),
64        }
65    }
66
67    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
68        match self.get_mut() {
69            Transport::Plain(s) => Pin::new(s).poll_flush(cx),
70            Transport::Tls(s) => Pin::new(s).poll_flush(cx),
71        }
72    }
73
74    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
75        match self.get_mut() {
76            Transport::Plain(s) => Pin::new(s).poll_shutdown(cx),
77            Transport::Tls(s) => Pin::new(s).poll_shutdown(cx),
78        }
79    }
80}
81
82/// SMTP client connection for outbound delivery
83pub struct SmtpConnection {
84    stream: BufStream<Transport>,
85    command_timeout: std::time::Duration,
86}
87
88impl SmtpConnection {
89    /// connect to an SMTP server and read the greeting
90    pub async fn connect(host: &str, port: u16) -> io::Result<Self> {
91        Self::connect_with_timeout(host, port, &TimeoutConfig::default()).await
92    }
93
94    /// connect with explicit timeout configuration
95    pub async fn connect_with_timeout(
96        host: &str,
97        port: u16,
98        timeouts: &TimeoutConfig,
99    ) -> io::Result<Self> {
100        let tcp = tokio::time::timeout(timeouts.connect, TcpStream::connect((host, port)))
101            .await
102            .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "connect timeout"))??;
103
104        let mut conn = Self {
105            stream: BufStream::new(Transport::Plain(tcp)),
106            command_timeout: timeouts.command,
107        };
108
109        let greeting = tokio::time::timeout(timeouts.greeting, conn.read_response())
110            .await
111            .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "greeting timeout"))??;
112
113        if !greeting.is_positive() {
114            return Err(io::Error::new(
115                io::ErrorKind::ConnectionRefused,
116                format!("server rejected: {}", greeting.message()),
117            ));
118        }
119        Ok(conn)
120    }
121
122    /// returns true if the connection is using TLS
123    pub fn is_tls(&self) -> bool {
124        matches!(self.stream.get_ref(), Transport::Tls(_))
125    }
126
127    /// send EHLO and return the response
128    pub async fn ehlo(&mut self, hostname: &str) -> io::Result<SmtpResponse> {
129        self.send_command(&format!("EHLO {hostname}\r\n")).await
130    }
131
132    /// Upgrade to TLS via STARTTLS — classic API returning the
133    /// upgraded connection or an opaque `io::Error`.
134    ///
135    /// Internally delegates to [`Self::try_starttls`]. Callers that
136    /// want structured TLS-failure classification (for TLSRPT
137    /// reporting, metrics labels, etc.) should call `try_starttls`
138    /// directly.
139    pub async fn starttls(self, hostname: &str) -> io::Result<Self> {
140        self.try_starttls(hostname).await.into_io_result()
141    }
142
143    /// Upgrade to TLS via STARTTLS with DANE TLSA verification —
144    /// classic API. See [`Self::try_starttls_dane`] for the
145    /// structured variant.
146    pub async fn starttls_dane(
147        self,
148        hostname: &str,
149        tlsa_records: Vec<crate::dane::TlsaRecord>,
150    ) -> io::Result<Self> {
151        self.try_starttls_dane(hostname, tlsa_records)
152            .await
153            .into_io_result()
154    }
155
156    /// Upgrade to TLS via STARTTLS, returning a structured
157    /// [`StarttlsResult`] that discriminates between server-side
158    /// rejection (connection still usable) and handshake failure
159    /// (connection unrecoverable, must reconnect). On handshake
160    /// failure, the wrapped [`TlsOutcome`] is RFC 8460 §4.3-aligned
161    /// so callers can build TLSRPT reports directly.
162    pub async fn try_starttls(mut self, hostname: &str) -> StarttlsResult {
163        let resp = match self.send_command("STARTTLS\r\n").await {
164            Ok(r) => r,
165            Err(e) => {
166                let outcome = crate::tls_outcome::classify_io_error(&e, false);
167                return StarttlsResult::HandshakeFailed {
168                    outcome,
169                    source: e,
170                };
171            }
172        };
173        if !resp.is_positive() {
174            return StarttlsResult::Rejected {
175                conn: self,
176                code: resp.code,
177                message: resp.message(),
178            };
179        }
180
181        let mut config = ClientConfig::builder()
182            .with_root_certificates(rustls::RootCertStore {
183                roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
184            })
185            .with_no_client_auth();
186        config.alpn_protocols = vec![];
187
188        let connector = TlsConnector::from(Arc::new(config));
189        let server_name: rustls::pki_types::ServerName<'static> =
190            match hostname.to_string().try_into() {
191                Ok(n) => n,
192                Err(e) => {
193                    let detail = format!("{e}");
194                    return StarttlsResult::HandshakeFailed {
195                        outcome: TlsOutcome::InvalidServerName(detail.clone()),
196                        source: io::Error::new(
197                            io::ErrorKind::InvalidInput,
198                            format!("invalid SNI: {detail}"),
199                        ),
200                    };
201                }
202            };
203
204        let inner = self.stream.into_inner();
205        let tcp = match inner {
206            Transport::Plain(tcp) => tcp,
207            Transport::Tls(_) => {
208                let e = io::Error::other("already using TLS");
209                return StarttlsResult::HandshakeFailed {
210                    outcome: TlsOutcome::Other(e.to_string()),
211                    source: e,
212                };
213            }
214        };
215
216        match connector.connect(server_name, tcp).await {
217            Ok(tls_stream) => StarttlsResult::Success(Self {
218                stream: BufStream::new(Transport::Tls(Box::new(tls_stream))),
219                command_timeout: self.command_timeout,
220            }),
221            Err(e) => {
222                let outcome = crate::tls_outcome::classify_io_error(&e, false);
223                StarttlsResult::HandshakeFailed { outcome, source: e }
224            }
225        }
226    }
227
228    /// Upgrade to TLS via STARTTLS with DANE TLSA verification,
229    /// returning a structured [`StarttlsResult`]. DANE-specific
230    /// certificate rejections are reported as
231    /// [`TlsOutcome::DaneValidationFailure`] rather than the PKIX
232    /// `CertificateNotTrusted` so TLSRPT reports can distinguish
233    /// `tlsa-invalid` from generic untrusted-CA failures.
234    pub async fn try_starttls_dane(
235        mut self,
236        hostname: &str,
237        tlsa_records: Vec<crate::dane::TlsaRecord>,
238    ) -> StarttlsResult {
239        let resp = match self.send_command("STARTTLS\r\n").await {
240            Ok(r) => r,
241            Err(e) => {
242                let outcome = crate::tls_outcome::classify_io_error(&e, true);
243                return StarttlsResult::HandshakeFailed {
244                    outcome,
245                    source: e,
246                };
247            }
248        };
249        if !resp.is_positive() {
250            return StarttlsResult::Rejected {
251                conn: self,
252                code: resp.code,
253                message: resp.message(),
254            };
255        }
256
257        let config = crate::dane::dane_tls_config(tlsa_records);
258        let connector = TlsConnector::from(Arc::new(config));
259        let server_name: rustls::pki_types::ServerName<'static> =
260            match hostname.to_string().try_into() {
261                Ok(n) => n,
262                Err(e) => {
263                    let detail = format!("{e}");
264                    return StarttlsResult::HandshakeFailed {
265                        outcome: TlsOutcome::InvalidServerName(detail.clone()),
266                        source: io::Error::new(
267                            io::ErrorKind::InvalidInput,
268                            format!("invalid SNI: {detail}"),
269                        ),
270                    };
271                }
272            };
273
274        let inner = self.stream.into_inner();
275        let tcp = match inner {
276            Transport::Plain(tcp) => tcp,
277            Transport::Tls(_) => {
278                let e = io::Error::other("already using TLS");
279                return StarttlsResult::HandshakeFailed {
280                    outcome: TlsOutcome::Other(e.to_string()),
281                    source: e,
282                };
283            }
284        };
285
286        match connector.connect(server_name, tcp).await {
287            Ok(tls_stream) => StarttlsResult::Success(Self {
288                stream: BufStream::new(Transport::Tls(Box::new(tls_stream))),
289                command_timeout: self.command_timeout,
290            }),
291            Err(e) => {
292                let outcome = crate::tls_outcome::classify_io_error(&e, true);
293                StarttlsResult::HandshakeFailed { outcome, source: e }
294            }
295        }
296    }
297
298    /// send MAIL FROM, RCPT TO, DATA, and message body
299    pub async fn deliver(
300        &mut self,
301        from: &str,
302        to: &[&str],
303        message: &[u8],
304    ) -> io::Result<SmtpResponse> {
305        // MAIL FROM
306        let resp = self.send_command(&format_mail_from(from)).await?;
307        if !resp.is_positive() {
308            return Ok(resp);
309        }
310
311        // RCPT TO
312        for recipient in to {
313            let resp = self.send_command(&format_rcpt_to(recipient)).await?;
314            if !resp.is_positive() {
315                return Ok(resp);
316            }
317        }
318
319        // DATA
320        let resp = self.send_command("DATA\r\n").await?;
321        if resp.code != 354 {
322            return Ok(resp);
323        }
324
325        // send message body with dot-stuffing (RFC 5321 section 4.5.2)
326        let stuffed = dot_stuff(message);
327        self.stream.write_all(&stuffed).await?;
328        if !stuffed.ends_with(b"\r\n") {
329            self.stream.write_all(b"\r\n").await?;
330        }
331        self.stream.write_all(b".\r\n").await?;
332        self.stream.flush().await?;
333
334        tokio::time::timeout(self.command_timeout, self.read_response())
335            .await
336            .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "DATA response timeout"))?
337    }
338
339    /// send QUIT
340    pub async fn quit(&mut self) -> io::Result<()> {
341        let _ = self.send_command("QUIT\r\n").await;
342        Ok(())
343    }
344
345    async fn send_command(&mut self, cmd: &str) -> io::Result<SmtpResponse> {
346        self.stream.write_all(cmd.as_bytes()).await?;
347        self.stream.flush().await?;
348        tokio::time::timeout(self.command_timeout, self.read_response())
349            .await
350            .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "command timeout"))?
351    }
352
353    async fn read_response(&mut self) -> io::Result<SmtpResponse> {
354        const MAX_RESPONSE_SIZE: usize = 65536;
355        let mut buf = String::new();
356        loop {
357            let mut line = String::new();
358            let n = self.stream.read_line(&mut line).await?;
359            if n == 0 {
360                return Err(io::Error::new(
361                    io::ErrorKind::UnexpectedEof,
362                    "connection closed",
363                ));
364            }
365            buf.push_str(&line);
366            if buf.len() > MAX_RESPONSE_SIZE {
367                return Err(io::Error::new(
368                    io::ErrorKind::InvalidData,
369                    "SMTP response too large",
370                ));
371            }
372
373            // check if this is the final line (code followed by space)
374            if line.len() >= 4 && line.as_bytes()[3] == b' ' {
375                break;
376            }
377        }
378        parse_response(&buf).ok_or_else(|| {
379            io::Error::new(
380                io::ErrorKind::InvalidData,
381                format!("invalid SMTP response: {buf}"),
382            )
383        })
384    }
385}
386
387/// dot-stuff message body for SMTP DATA transmission (RFC 5321 section 4.5.2)
388/// lines starting with '.' get an extra '.' prepended
389pub fn dot_stuff(data: &[u8]) -> Vec<u8> {
390    let mut result = Vec::with_capacity(data.len());
391    let mut at_line_start = true;
392
393    for &byte in data {
394        if at_line_start && byte == b'.' {
395            result.push(b'.');
396        }
397        result.push(byte);
398        at_line_start = byte == b'\n';
399    }
400    result
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn dot_stuff_no_dots() {
409        assert_eq!(dot_stuff(b"hello\r\nworld\r\n"), b"hello\r\nworld\r\n");
410    }
411
412    #[test]
413    fn dot_stuff_line_starting_with_dot() {
414        assert_eq!(dot_stuff(b".hello\r\n"), b"..hello\r\n");
415    }
416
417    #[test]
418    fn dot_stuff_multiple_dots() {
419        assert_eq!(
420            dot_stuff(b"ok\r\n.line1\r\n..line2\r\n"),
421            b"ok\r\n..line1\r\n...line2\r\n"
422        );
423    }
424
425    #[test]
426    fn dot_stuff_dot_only_line() {
427        // a lone dot on a line would be end-of-data marker without stuffing
428        assert_eq!(dot_stuff(b".\r\n"), b"..\r\n");
429    }
430
431    #[test]
432    fn dot_stuff_empty() {
433        assert_eq!(dot_stuff(b""), b"");
434    }
435
436    #[test]
437    fn timeout_config_defaults() {
438        let cfg = TimeoutConfig::default();
439        assert_eq!(cfg.connect, std::time::Duration::from_secs(30));
440        assert_eq!(cfg.greeting, std::time::Duration::from_secs(30));
441        assert_eq!(cfg.command, std::time::Duration::from_secs(60));
442    }
443
444    #[test]
445    fn timeout_config_clone() {
446        let cfg = TimeoutConfig {
447            connect: std::time::Duration::from_secs(5),
448            greeting: std::time::Duration::from_secs(10),
449            command: std::time::Duration::from_secs(15),
450        };
451        let cloned = cfg.clone();
452        assert_eq!(cloned.connect, std::time::Duration::from_secs(5));
453        assert_eq!(cloned.greeting, std::time::Duration::from_secs(10));
454        assert_eq!(cloned.command, std::time::Duration::from_secs(15));
455    }
456
457    #[test]
458    fn timeout_config_debug() {
459        let cfg = TimeoutConfig::default();
460        let debug = format!("{:?}", cfg);
461        assert!(debug.contains("TimeoutConfig"));
462    }
463
464    // --- more dot_stuff edge cases ---
465
466    #[test]
467    fn dot_stuff_bare_lf() {
468        // bare \n (not \r\n) should still trigger dot-stuffing on next line
469        assert_eq!(dot_stuff(b"ok\n.next\n"), b"ok\n..next\n");
470    }
471
472    #[test]
473    fn dot_stuff_consecutive_dot_lines() {
474        assert_eq!(
475            dot_stuff(b".\r\n.\r\n.\r\n"),
476            b"..\r\n..\r\n..\r\n"
477        );
478    }
479
480    #[test]
481    fn dot_stuff_no_newline_at_end() {
482        // message doesn't end with newline — dot at start should still be stuffed
483        assert_eq!(dot_stuff(b".hello"), b"..hello");
484    }
485
486    #[test]
487    fn dot_stuff_dot_mid_line_not_stuffed() {
488        // dots in the middle of a line should not be stuffed
489        assert_eq!(dot_stuff(b"hello.world\r\n"), b"hello.world\r\n");
490    }
491
492    #[test]
493    fn dot_stuff_single_dot_no_newline() {
494        assert_eq!(dot_stuff(b"."), b"..");
495    }
496
497    #[test]
498    fn dot_stuff_crlf_only() {
499        assert_eq!(dot_stuff(b"\r\n"), b"\r\n");
500    }
501
502    #[test]
503    fn dot_stuff_multiple_dots_at_line_start() {
504        // "..." at line start should become "...."
505        assert_eq!(dot_stuff(b"...test\r\n"), b"....test\r\n");
506    }
507
508    #[test]
509    fn dot_stuff_large_message() {
510        // verify dot_stuff works with a larger body
511        let mut input = Vec::new();
512        for _ in 0..100 {
513            input.extend_from_slice(b".line\r\n");
514        }
515        let result = dot_stuff(&input);
516        // each ".line\r\n" (7 bytes) becomes "..line\r\n" (8 bytes)
517        assert_eq!(result.len(), 800);
518    }
519
520    #[test]
521    fn dot_stuff_mixed_content() {
522        let input = b"From: test@example.com\r\n\
523                       Subject: Hello\r\n\
524                       \r\n\
525                       .This line starts with a dot.\r\n\
526                       This line does not.\r\n\
527                       ..Two dots here.\r\n";
528        let result = dot_stuff(input);
529        let result_str = String::from_utf8_lossy(&result);
530        assert!(result_str.contains("..This line starts with a dot."));
531        assert!(result_str.contains("...Two dots here."));
532        assert!(result_str.contains("This line does not."));
533    }
534
535    // --- new tests ---
536
537    #[test]
538    fn dot_stuff_preserves_non_dot_content_exactly() {
539        let input = b"Hello World\r\nSecond line\r\n";
540        let result = dot_stuff(input);
541        assert_eq!(result, input.to_vec());
542    }
543
544    #[test]
545    fn dot_stuff_after_bare_cr_no_stuff() {
546        // \r alone should NOT trigger line-start detection
547        let input = b"test\r.not-stuffed";
548        let result = dot_stuff(input);
549        assert_eq!(result, b"test\r.not-stuffed".to_vec());
550    }
551
552    #[test]
553    fn dot_stuff_first_byte_is_dot() {
554        // very first byte of message is a dot (at_line_start = true initially)
555        let result = dot_stuff(b".first");
556        assert_eq!(result, b"..first".to_vec());
557    }
558
559    #[test]
560    fn dot_stuff_only_newlines() {
561        let input = b"\r\n\r\n\r\n";
562        let result = dot_stuff(input);
563        assert_eq!(result, input.to_vec());
564    }
565
566    #[test]
567    fn dot_stuff_dot_after_crlf_crlf() {
568        // empty line followed by dot line
569        let input = b"header\r\n\r\n.body\r\n";
570        let result = dot_stuff(input);
571        assert_eq!(result, b"header\r\n\r\n..body\r\n".to_vec());
572    }
573
574    #[test]
575    fn dot_stuff_binary_content() {
576        // binary-ish content with 0x00 bytes
577        let input = b"\x00\r\n.\x00\r\n";
578        let result = dot_stuff(input);
579        assert_eq!(result, b"\x00\r\n..\x00\r\n".to_vec());
580    }
581
582    #[test]
583    fn dot_stuff_result_capacity_hint() {
584        // verify result is at least as large as input
585        let input = b"no dots here\r\n";
586        let result = dot_stuff(input);
587        assert!(result.len() >= input.len());
588    }
589
590    #[test]
591    fn timeout_config_custom_values() {
592        let cfg = TimeoutConfig {
593            connect: std::time::Duration::from_millis(100),
594            greeting: std::time::Duration::from_millis(200),
595            command: std::time::Duration::from_millis(300),
596        };
597        assert_eq!(cfg.connect.as_millis(), 100);
598        assert_eq!(cfg.greeting.as_millis(), 200);
599        assert_eq!(cfg.command.as_millis(), 300);
600    }
601
602    #[test]
603    fn timeout_config_zero_durations() {
604        let cfg = TimeoutConfig {
605            connect: std::time::Duration::ZERO,
606            greeting: std::time::Duration::ZERO,
607            command: std::time::Duration::ZERO,
608        };
609        assert_eq!(cfg.connect, std::time::Duration::ZERO);
610    }
611}