1use std::io;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use rustls::ClientConfig;
7use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
8use tokio::net::TcpStream;
9use tokio_rustls::client::TlsStream;
10use tokio_rustls::TlsConnector;
11
12use crate::mx::{format_mail_from, format_rcpt_to};
13use crate::response::{parse_response, SmtpResponse};
14use crate::tls_outcome::{StarttlsResult, TlsOutcome};
15
16#[derive(Debug, Clone)]
18pub struct TimeoutConfig {
19 pub connect: std::time::Duration,
21 pub greeting: std::time::Duration,
23 pub command: std::time::Duration,
25}
26
27impl Default for TimeoutConfig {
28 fn default() -> Self {
29 Self {
30 connect: std::time::Duration::from_secs(30),
31 greeting: std::time::Duration::from_secs(30),
32 command: std::time::Duration::from_secs(60),
33 }
34 }
35}
36
37enum Transport {
38 Plain(TcpStream),
39 Tls(Box<TlsStream<TcpStream>>),
40}
41
42impl AsyncRead for Transport {
43 fn poll_read(
44 self: Pin<&mut Self>,
45 cx: &mut Context<'_>,
46 buf: &mut ReadBuf<'_>,
47 ) -> Poll<io::Result<()>> {
48 match self.get_mut() {
49 Transport::Plain(s) => Pin::new(s).poll_read(cx, buf),
50 Transport::Tls(s) => Pin::new(s).poll_read(cx, buf),
51 }
52 }
53}
54
55impl AsyncWrite for Transport {
56 fn poll_write(
57 self: Pin<&mut Self>,
58 cx: &mut Context<'_>,
59 buf: &[u8],
60 ) -> Poll<io::Result<usize>> {
61 match self.get_mut() {
62 Transport::Plain(s) => Pin::new(s).poll_write(cx, buf),
63 Transport::Tls(s) => Pin::new(s).poll_write(cx, buf),
64 }
65 }
66
67 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
68 match self.get_mut() {
69 Transport::Plain(s) => Pin::new(s).poll_flush(cx),
70 Transport::Tls(s) => Pin::new(s).poll_flush(cx),
71 }
72 }
73
74 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
75 match self.get_mut() {
76 Transport::Plain(s) => Pin::new(s).poll_shutdown(cx),
77 Transport::Tls(s) => Pin::new(s).poll_shutdown(cx),
78 }
79 }
80}
81
82pub struct SmtpConnection {
84 stream: BufStream<Transport>,
85 command_timeout: std::time::Duration,
86}
87
88impl SmtpConnection {
89 pub async fn connect(host: &str, port: u16) -> io::Result<Self> {
91 Self::connect_with_timeout(host, port, &TimeoutConfig::default()).await
92 }
93
94 pub async fn connect_with_timeout(
96 host: &str,
97 port: u16,
98 timeouts: &TimeoutConfig,
99 ) -> io::Result<Self> {
100 let tcp = tokio::time::timeout(timeouts.connect, TcpStream::connect((host, port)))
101 .await
102 .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "connect timeout"))??;
103
104 let mut conn = Self {
105 stream: BufStream::new(Transport::Plain(tcp)),
106 command_timeout: timeouts.command,
107 };
108
109 let greeting = tokio::time::timeout(timeouts.greeting, conn.read_response())
110 .await
111 .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "greeting timeout"))??;
112
113 if !greeting.is_positive() {
114 return Err(io::Error::new(
115 io::ErrorKind::ConnectionRefused,
116 format!("server rejected: {}", greeting.message()),
117 ));
118 }
119 Ok(conn)
120 }
121
122 pub fn is_tls(&self) -> bool {
124 matches!(self.stream.get_ref(), Transport::Tls(_))
125 }
126
127 pub async fn ehlo(&mut self, hostname: &str) -> io::Result<SmtpResponse> {
129 self.send_command(&format!("EHLO {hostname}\r\n")).await
130 }
131
132 pub async fn starttls(self, hostname: &str) -> io::Result<Self> {
140 self.try_starttls(hostname).await.into_io_result()
141 }
142
143 pub async fn starttls_dane(
147 self,
148 hostname: &str,
149 tlsa_records: Vec<crate::dane::TlsaRecord>,
150 ) -> io::Result<Self> {
151 self.try_starttls_dane(hostname, tlsa_records)
152 .await
153 .into_io_result()
154 }
155
156 pub async fn try_starttls(mut self, hostname: &str) -> StarttlsResult {
163 let resp = match self.send_command("STARTTLS\r\n").await {
164 Ok(r) => r,
165 Err(e) => {
166 let outcome = crate::tls_outcome::classify_io_error(&e, false);
167 return StarttlsResult::HandshakeFailed {
168 outcome,
169 source: e,
170 };
171 }
172 };
173 if !resp.is_positive() {
174 return StarttlsResult::Rejected {
175 conn: self,
176 code: resp.code,
177 message: resp.message(),
178 };
179 }
180
181 let mut config = ClientConfig::builder()
182 .with_root_certificates(rustls::RootCertStore {
183 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
184 })
185 .with_no_client_auth();
186 config.alpn_protocols = vec![];
187
188 let connector = TlsConnector::from(Arc::new(config));
189 let server_name: rustls::pki_types::ServerName<'static> =
190 match hostname.to_string().try_into() {
191 Ok(n) => n,
192 Err(e) => {
193 let detail = format!("{e}");
194 return StarttlsResult::HandshakeFailed {
195 outcome: TlsOutcome::InvalidServerName(detail.clone()),
196 source: io::Error::new(
197 io::ErrorKind::InvalidInput,
198 format!("invalid SNI: {detail}"),
199 ),
200 };
201 }
202 };
203
204 let inner = self.stream.into_inner();
205 let tcp = match inner {
206 Transport::Plain(tcp) => tcp,
207 Transport::Tls(_) => {
208 let e = io::Error::other("already using TLS");
209 return StarttlsResult::HandshakeFailed {
210 outcome: TlsOutcome::Other(e.to_string()),
211 source: e,
212 };
213 }
214 };
215
216 match connector.connect(server_name, tcp).await {
217 Ok(tls_stream) => StarttlsResult::Success(Self {
218 stream: BufStream::new(Transport::Tls(Box::new(tls_stream))),
219 command_timeout: self.command_timeout,
220 }),
221 Err(e) => {
222 let outcome = crate::tls_outcome::classify_io_error(&e, false);
223 StarttlsResult::HandshakeFailed { outcome, source: e }
224 }
225 }
226 }
227
228 pub async fn try_starttls_dane(
235 mut self,
236 hostname: &str,
237 tlsa_records: Vec<crate::dane::TlsaRecord>,
238 ) -> StarttlsResult {
239 let resp = match self.send_command("STARTTLS\r\n").await {
240 Ok(r) => r,
241 Err(e) => {
242 let outcome = crate::tls_outcome::classify_io_error(&e, true);
243 return StarttlsResult::HandshakeFailed {
244 outcome,
245 source: e,
246 };
247 }
248 };
249 if !resp.is_positive() {
250 return StarttlsResult::Rejected {
251 conn: self,
252 code: resp.code,
253 message: resp.message(),
254 };
255 }
256
257 let config = crate::dane::dane_tls_config(tlsa_records);
258 let connector = TlsConnector::from(Arc::new(config));
259 let server_name: rustls::pki_types::ServerName<'static> =
260 match hostname.to_string().try_into() {
261 Ok(n) => n,
262 Err(e) => {
263 let detail = format!("{e}");
264 return StarttlsResult::HandshakeFailed {
265 outcome: TlsOutcome::InvalidServerName(detail.clone()),
266 source: io::Error::new(
267 io::ErrorKind::InvalidInput,
268 format!("invalid SNI: {detail}"),
269 ),
270 };
271 }
272 };
273
274 let inner = self.stream.into_inner();
275 let tcp = match inner {
276 Transport::Plain(tcp) => tcp,
277 Transport::Tls(_) => {
278 let e = io::Error::other("already using TLS");
279 return StarttlsResult::HandshakeFailed {
280 outcome: TlsOutcome::Other(e.to_string()),
281 source: e,
282 };
283 }
284 };
285
286 match connector.connect(server_name, tcp).await {
287 Ok(tls_stream) => StarttlsResult::Success(Self {
288 stream: BufStream::new(Transport::Tls(Box::new(tls_stream))),
289 command_timeout: self.command_timeout,
290 }),
291 Err(e) => {
292 let outcome = crate::tls_outcome::classify_io_error(&e, true);
293 StarttlsResult::HandshakeFailed { outcome, source: e }
294 }
295 }
296 }
297
298 pub async fn deliver(
300 &mut self,
301 from: &str,
302 to: &[&str],
303 message: &[u8],
304 ) -> io::Result<SmtpResponse> {
305 let resp = self.send_command(&format_mail_from(from)).await?;
307 if !resp.is_positive() {
308 return Ok(resp);
309 }
310
311 for recipient in to {
313 let resp = self.send_command(&format_rcpt_to(recipient)).await?;
314 if !resp.is_positive() {
315 return Ok(resp);
316 }
317 }
318
319 let resp = self.send_command("DATA\r\n").await?;
321 if resp.code != 354 {
322 return Ok(resp);
323 }
324
325 let stuffed = dot_stuff(message);
327 self.stream.write_all(&stuffed).await?;
328 if !stuffed.ends_with(b"\r\n") {
329 self.stream.write_all(b"\r\n").await?;
330 }
331 self.stream.write_all(b".\r\n").await?;
332 self.stream.flush().await?;
333
334 tokio::time::timeout(self.command_timeout, self.read_response())
335 .await
336 .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "DATA response timeout"))?
337 }
338
339 pub async fn quit(&mut self) -> io::Result<()> {
341 let _ = self.send_command("QUIT\r\n").await;
342 Ok(())
343 }
344
345 async fn send_command(&mut self, cmd: &str) -> io::Result<SmtpResponse> {
346 self.stream.write_all(cmd.as_bytes()).await?;
347 self.stream.flush().await?;
348 tokio::time::timeout(self.command_timeout, self.read_response())
349 .await
350 .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "command timeout"))?
351 }
352
353 async fn read_response(&mut self) -> io::Result<SmtpResponse> {
354 const MAX_RESPONSE_SIZE: usize = 65536;
355 let mut buf = String::new();
356 loop {
357 let mut line = String::new();
358 let n = self.stream.read_line(&mut line).await?;
359 if n == 0 {
360 return Err(io::Error::new(
361 io::ErrorKind::UnexpectedEof,
362 "connection closed",
363 ));
364 }
365 buf.push_str(&line);
366 if buf.len() > MAX_RESPONSE_SIZE {
367 return Err(io::Error::new(
368 io::ErrorKind::InvalidData,
369 "SMTP response too large",
370 ));
371 }
372
373 if line.len() >= 4 && line.as_bytes()[3] == b' ' {
375 break;
376 }
377 }
378 parse_response(&buf).ok_or_else(|| {
379 io::Error::new(
380 io::ErrorKind::InvalidData,
381 format!("invalid SMTP response: {buf}"),
382 )
383 })
384 }
385}
386
387pub fn dot_stuff(data: &[u8]) -> Vec<u8> {
390 let mut result = Vec::with_capacity(data.len());
391 let mut at_line_start = true;
392
393 for &byte in data {
394 if at_line_start && byte == b'.' {
395 result.push(b'.');
396 }
397 result.push(byte);
398 at_line_start = byte == b'\n';
399 }
400 result
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn dot_stuff_no_dots() {
409 assert_eq!(dot_stuff(b"hello\r\nworld\r\n"), b"hello\r\nworld\r\n");
410 }
411
412 #[test]
413 fn dot_stuff_line_starting_with_dot() {
414 assert_eq!(dot_stuff(b".hello\r\n"), b"..hello\r\n");
415 }
416
417 #[test]
418 fn dot_stuff_multiple_dots() {
419 assert_eq!(
420 dot_stuff(b"ok\r\n.line1\r\n..line2\r\n"),
421 b"ok\r\n..line1\r\n...line2\r\n"
422 );
423 }
424
425 #[test]
426 fn dot_stuff_dot_only_line() {
427 assert_eq!(dot_stuff(b".\r\n"), b"..\r\n");
429 }
430
431 #[test]
432 fn dot_stuff_empty() {
433 assert_eq!(dot_stuff(b""), b"");
434 }
435
436 #[test]
437 fn timeout_config_defaults() {
438 let cfg = TimeoutConfig::default();
439 assert_eq!(cfg.connect, std::time::Duration::from_secs(30));
440 assert_eq!(cfg.greeting, std::time::Duration::from_secs(30));
441 assert_eq!(cfg.command, std::time::Duration::from_secs(60));
442 }
443
444 #[test]
445 fn timeout_config_clone() {
446 let cfg = TimeoutConfig {
447 connect: std::time::Duration::from_secs(5),
448 greeting: std::time::Duration::from_secs(10),
449 command: std::time::Duration::from_secs(15),
450 };
451 let cloned = cfg.clone();
452 assert_eq!(cloned.connect, std::time::Duration::from_secs(5));
453 assert_eq!(cloned.greeting, std::time::Duration::from_secs(10));
454 assert_eq!(cloned.command, std::time::Duration::from_secs(15));
455 }
456
457 #[test]
458 fn timeout_config_debug() {
459 let cfg = TimeoutConfig::default();
460 let debug = format!("{:?}", cfg);
461 assert!(debug.contains("TimeoutConfig"));
462 }
463
464 #[test]
467 fn dot_stuff_bare_lf() {
468 assert_eq!(dot_stuff(b"ok\n.next\n"), b"ok\n..next\n");
470 }
471
472 #[test]
473 fn dot_stuff_consecutive_dot_lines() {
474 assert_eq!(
475 dot_stuff(b".\r\n.\r\n.\r\n"),
476 b"..\r\n..\r\n..\r\n"
477 );
478 }
479
480 #[test]
481 fn dot_stuff_no_newline_at_end() {
482 assert_eq!(dot_stuff(b".hello"), b"..hello");
484 }
485
486 #[test]
487 fn dot_stuff_dot_mid_line_not_stuffed() {
488 assert_eq!(dot_stuff(b"hello.world\r\n"), b"hello.world\r\n");
490 }
491
492 #[test]
493 fn dot_stuff_single_dot_no_newline() {
494 assert_eq!(dot_stuff(b"."), b"..");
495 }
496
497 #[test]
498 fn dot_stuff_crlf_only() {
499 assert_eq!(dot_stuff(b"\r\n"), b"\r\n");
500 }
501
502 #[test]
503 fn dot_stuff_multiple_dots_at_line_start() {
504 assert_eq!(dot_stuff(b"...test\r\n"), b"....test\r\n");
506 }
507
508 #[test]
509 fn dot_stuff_large_message() {
510 let mut input = Vec::new();
512 for _ in 0..100 {
513 input.extend_from_slice(b".line\r\n");
514 }
515 let result = dot_stuff(&input);
516 assert_eq!(result.len(), 800);
518 }
519
520 #[test]
521 fn dot_stuff_mixed_content() {
522 let input = b"From: test@example.com\r\n\
523 Subject: Hello\r\n\
524 \r\n\
525 .This line starts with a dot.\r\n\
526 This line does not.\r\n\
527 ..Two dots here.\r\n";
528 let result = dot_stuff(input);
529 let result_str = String::from_utf8_lossy(&result);
530 assert!(result_str.contains("..This line starts with a dot."));
531 assert!(result_str.contains("...Two dots here."));
532 assert!(result_str.contains("This line does not."));
533 }
534
535 #[test]
538 fn dot_stuff_preserves_non_dot_content_exactly() {
539 let input = b"Hello World\r\nSecond line\r\n";
540 let result = dot_stuff(input);
541 assert_eq!(result, input.to_vec());
542 }
543
544 #[test]
545 fn dot_stuff_after_bare_cr_no_stuff() {
546 let input = b"test\r.not-stuffed";
548 let result = dot_stuff(input);
549 assert_eq!(result, b"test\r.not-stuffed".to_vec());
550 }
551
552 #[test]
553 fn dot_stuff_first_byte_is_dot() {
554 let result = dot_stuff(b".first");
556 assert_eq!(result, b"..first".to_vec());
557 }
558
559 #[test]
560 fn dot_stuff_only_newlines() {
561 let input = b"\r\n\r\n\r\n";
562 let result = dot_stuff(input);
563 assert_eq!(result, input.to_vec());
564 }
565
566 #[test]
567 fn dot_stuff_dot_after_crlf_crlf() {
568 let input = b"header\r\n\r\n.body\r\n";
570 let result = dot_stuff(input);
571 assert_eq!(result, b"header\r\n\r\n..body\r\n".to_vec());
572 }
573
574 #[test]
575 fn dot_stuff_binary_content() {
576 let input = b"\x00\r\n.\x00\r\n";
578 let result = dot_stuff(input);
579 assert_eq!(result, b"\x00\r\n..\x00\r\n".to_vec());
580 }
581
582 #[test]
583 fn dot_stuff_result_capacity_hint() {
584 let input = b"no dots here\r\n";
586 let result = dot_stuff(input);
587 assert!(result.len() >= input.len());
588 }
589
590 #[test]
591 fn timeout_config_custom_values() {
592 let cfg = TimeoutConfig {
593 connect: std::time::Duration::from_millis(100),
594 greeting: std::time::Duration::from_millis(200),
595 command: std::time::Duration::from_millis(300),
596 };
597 assert_eq!(cfg.connect.as_millis(), 100);
598 assert_eq!(cfg.greeting.as_millis(), 200);
599 assert_eq!(cfg.command.as_millis(), 300);
600 }
601
602 #[test]
603 fn timeout_config_zero_durations() {
604 let cfg = TimeoutConfig {
605 connect: std::time::Duration::ZERO,
606 greeting: std::time::Duration::ZERO,
607 command: std::time::Duration::ZERO,
608 };
609 assert_eq!(cfg.connect, std::time::Duration::ZERO);
610 }
611}