mail_send/smtp/
client.rs

1/*
2 * SPDX-FileCopyrightText: 2020 Stalwart Labs LLC <hello@stalw.art>
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR MIT
5 */
6
7use std::{
8    net::{IpAddr, SocketAddr},
9    time::Duration,
10};
11
12use smtp_proto::{response::parser::ResponseReceiver, Response};
13use tokio::{
14    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
15    net::{TcpSocket, TcpStream},
16};
17
18use crate::SmtpClient;
19
20impl<T: AsyncRead + AsyncWrite + Unpin> SmtpClient<T> {
21    pub async fn read(&mut self) -> crate::Result<Response<String>> {
22        let mut buf = vec![0u8; 1024];
23        let mut parser = ResponseReceiver::default();
24
25        loop {
26            let br = self.stream.read(&mut buf).await?;
27
28            if br > 0 {
29                match parser.parse(&mut buf[..br].iter()) {
30                    Ok(reply) => return Ok(reply),
31                    Err(err) => match err {
32                        smtp_proto::Error::NeedsMoreData { .. } => (),
33                        _ => {
34                            return Err(crate::Error::UnparseableReply);
35                        }
36                    },
37                }
38            } else {
39                return Err(crate::Error::UnparseableReply);
40            }
41        }
42    }
43
44    pub async fn read_many(&mut self, num: usize) -> crate::Result<Vec<Response<String>>> {
45        let mut buf = vec![0u8; 1024];
46        let mut response = Vec::with_capacity(num);
47        let mut parser = ResponseReceiver::default();
48
49        'outer: loop {
50            let br = self.stream.read(&mut buf).await?;
51
52            if br > 0 {
53                let mut iter = buf[..br].iter();
54
55                loop {
56                    match parser.parse(&mut iter) {
57                        Ok(reply) => {
58                            response.push(reply);
59                            if response.len() != num {
60                                parser.reset();
61                            } else {
62                                break 'outer;
63                            }
64                        }
65                        Err(err) => match err {
66                            smtp_proto::Error::NeedsMoreData { .. } => break,
67                            _ => {
68                                return Err(crate::Error::UnparseableReply);
69                            }
70                        },
71                    }
72                }
73            } else {
74                return Err(crate::Error::UnparseableReply);
75            }
76        }
77
78        Ok(response)
79    }
80
81    /// Sends a command to the SMTP server and waits for a reply.
82    pub async fn cmd(&mut self, cmd: impl AsRef<[u8]>) -> crate::Result<Response<String>> {
83        tokio::time::timeout(self.timeout, async {
84            self.stream.write_all(cmd.as_ref()).await?;
85            self.stream.flush().await?;
86            self.read().await
87        })
88        .await
89        .map_err(|_| crate::Error::Timeout)?
90    }
91
92    /// Pipelines multiple command to the SMTP server and waits for a reply.
93    pub async fn cmds(
94        &mut self,
95        cmds: impl IntoIterator<Item = impl AsRef<[u8]>>,
96    ) -> crate::Result<Vec<Response<String>>> {
97        tokio::time::timeout(self.timeout, async {
98            let mut num_replies = 0;
99            for cmd in cmds {
100                self.stream.write_all(cmd.as_ref()).await?;
101                num_replies += 1;
102            }
103            self.stream.flush().await?;
104            self.read_many(num_replies).await
105        })
106        .await
107        .map_err(|_| crate::Error::Timeout)?
108    }
109}
110
111impl SmtpClient<TcpStream> {
112    /// Connects to a remote host address
113    pub async fn connect(remote_addr: SocketAddr, timeout: Duration) -> crate::Result<Self> {
114        tokio::time::timeout(timeout, async {
115            Ok(SmtpClient {
116                stream: TcpStream::connect(remote_addr).await?,
117                timeout,
118            })
119        })
120        .await
121        .map_err(|_| crate::Error::Timeout)?
122    }
123
124    /// Connects to a remote host address using the provided local IP
125    pub async fn connect_using(
126        local_ip: IpAddr,
127        remote_addr: SocketAddr,
128        timeout: Duration,
129    ) -> crate::Result<Self> {
130        tokio::time::timeout(timeout, async {
131            let socket = if local_ip.is_ipv4() {
132                TcpSocket::new_v4()?
133            } else {
134                TcpSocket::new_v6()?
135            };
136            socket.bind(SocketAddr::new(local_ip, 0))?;
137
138            Ok(SmtpClient {
139                stream: socket.connect(remote_addr).await?,
140                timeout,
141            })
142        })
143        .await
144        .map_err(|_| crate::Error::Timeout)?
145    }
146}
147
148#[cfg(test)]
149mod test {
150    use std::time::Duration;
151
152    use tokio::io::{AsyncRead, AsyncWrite};
153
154    use crate::{SmtpClient, SmtpClientBuilder};
155
156    #[tokio::test]
157    async fn smtp_basic() {
158        // StartTLS test
159        env_logger::init();
160        let client = SmtpClientBuilder::new("mail.smtp2go.com", 2525)
161            .implicit_tls(false)
162            .connect()
163            .await
164            .unwrap();
165        client.quit().await.unwrap();
166        let client = SmtpClientBuilder::new("mail.smtp2go.com", 2525)
167            .allow_invalid_certs()
168            .implicit_tls(false)
169            .connect()
170            .await
171            .unwrap();
172        client.quit().await.unwrap();
173
174        // Say hello to Google over TLS and quit
175        let client = SmtpClientBuilder::new("smtp.gmail.com", 465)
176            .connect()
177            .await
178            .unwrap();
179        client.quit().await.unwrap();
180
181        // Say hello to Google over TLS and quit
182        let client = SmtpClientBuilder::new("smtp.gmail.com", 465)
183            .allow_invalid_certs()
184            .connect()
185            .await
186            .unwrap();
187        client.quit().await.unwrap();
188    }
189
190    #[derive(Default)]
191    struct AsyncBufWriter {
192        buf: Vec<u8>,
193    }
194
195    impl AsyncRead for AsyncBufWriter {
196        fn poll_read(
197            self: std::pin::Pin<&mut Self>,
198            _cx: &mut std::task::Context<'_>,
199            _buf: &mut tokio::io::ReadBuf<'_>,
200        ) -> std::task::Poll<std::io::Result<()>> {
201            unreachable!()
202        }
203    }
204
205    impl AsyncWrite for AsyncBufWriter {
206        fn poll_write(
207            mut self: std::pin::Pin<&mut Self>,
208            _cx: &mut std::task::Context<'_>,
209            buf: &[u8],
210        ) -> std::task::Poll<Result<usize, std::io::Error>> {
211            self.buf.extend_from_slice(buf);
212            std::task::Poll::Ready(Ok(buf.len()))
213        }
214
215        fn poll_flush(
216            self: std::pin::Pin<&mut Self>,
217            _cx: &mut std::task::Context<'_>,
218        ) -> std::task::Poll<Result<(), std::io::Error>> {
219            std::task::Poll::Ready(Ok(()))
220        }
221
222        fn poll_shutdown(
223            self: std::pin::Pin<&mut Self>,
224            _cx: &mut std::task::Context<'_>,
225        ) -> std::task::Poll<Result<(), std::io::Error>> {
226            std::task::Poll::Ready(Ok(()))
227        }
228    }
229
230    #[tokio::test]
231    async fn transparency_procedure() {
232        const SMUGGLER: &str = r#"From: Joe SixPack <john@foobar.net>
233To: Suzie Q <suzie@foobar.org>
234Subject: Is dinner ready?
235
236Hi.
237
238We lost the game. Are you hungry yet?
239
240Joe.
241
242<SEP>.
243MAIL FROM:<admin@foobar.net>
244RCPT TO:<ok@foobar.org>
245DATA
246From: Joe SixPack <admin@foobar.net>
247To: Suzie Q <suzie@foobar.org>
248Subject: smuggled message
249
250This is a smuggled message
251"#;
252
253        for (test, result) in [
254            (
255                "A: b\r\n.\r\n".to_string(),
256                "A: b\r\n..\r\n\r\n.\r\n".to_string(),
257            ),
258            ("A: b\r\n.".to_string(), "A: b\r\n..\r\n.\r\n".to_string()),
259            (
260                "A: b\r\n..\r\n".to_string(),
261                "A: b\r\n...\r\n\r\n.\r\n".to_string(),
262            ),
263            ("A: ...b".to_string(), "A: ...b\r\n.\r\n".to_string()),
264            (
265                "A: \n.\r\nMAIL FROM:<>".to_string(),
266                "A: \n..\r\nMAIL FROM:<>\r\n.\r\n".to_string(),
267            ),
268            (
269                "A: \r.\r\nMAIL FROM:<>".to_string(),
270                "A: \r..\r\nMAIL FROM:<>\r\n.\r\n".to_string(),
271            ),
272            (
273                SMUGGLER
274                    .replace('\r', "")
275                    .replace('\n', "\r\n")
276                    .replace("<SEP>", "\r"),
277                SMUGGLER
278                    .replace('\r', "")
279                    .replace('\n', "\r\n")
280                    .replace("<SEP>", "\r.")
281                    + "\r\n.\r\n",
282            ),
283            (
284                SMUGGLER
285                    .replace('\r', "")
286                    .replace('\n', "\r\n")
287                    .replace("<SEP>", "\n"),
288                SMUGGLER
289                    .replace('\r', "")
290                    .replace('\n', "\r\n")
291                    .replace("<SEP>", "\n.")
292                    + "\r\n.\r\n",
293            ),
294        ] {
295            let mut client = SmtpClient {
296                stream: AsyncBufWriter::default(),
297                timeout: Duration::from_secs(30),
298            };
299            client.write_message(test.as_bytes()).await.unwrap();
300            assert_eq!(String::from_utf8(client.stream.buf).unwrap(), result);
301        }
302    }
303}