imap_next/
stream.rs

1use std::{
2    convert::Infallible,
3    io::{ErrorKind, Read, Write},
4};
5
6use bytes::{Buf, BufMut, BytesMut};
7#[cfg(debug_assertions)]
8use imap_codec::imap_types::utils::escape_byte_string;
9use thiserror::Error;
10use tokio::{
11    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
12    net::TcpStream,
13    select,
14};
15use tokio_rustls::{rustls, TlsStream};
16use tracing::instrument;
17#[cfg(debug_assertions)]
18use tracing::trace;
19
20use crate::{Interrupt, Io, State};
21
22pub struct Stream {
23    stream: TcpStream,
24    tls: Option<rustls::Connection>,
25    read_buffer: BytesMut,
26    write_buffer: BytesMut,
27}
28
29impl Stream {
30    pub fn insecure(stream: TcpStream) -> Self {
31        Self {
32            stream,
33            tls: None,
34            read_buffer: BytesMut::default(),
35            write_buffer: BytesMut::default(),
36        }
37    }
38
39    pub fn tls(stream: TlsStream<TcpStream>) -> Self {
40        // We want to use `TcpStream::split` for handling reading and writing separately,
41        // but `TlsStream` does not expose this functionality. Therefore, we destruct `TlsStream`
42        // into `TcpStream` and `rustls::Connection` and handling them ourselves.
43        //
44        // Some notes:
45        //
46        // - There is also `tokio::io::split` which works for all kind of streams. But this
47        //   involves too much scary magic because its use-case is reading and writing from
48        //   different threads. We prefer to use the more low-level `TcpStream::split`.
49        //
50        // - We could get rid of `TlsStream` and construct `rustls::Connection` directly.
51        //   But `TlsStream` is still useful because it gives us the guarantee that the handshake
52        //   was already handled properly.
53        //
54        // - In the long run it would be nice if `TlsStream::split` would exist and we would use
55        //   it because `TlsStream` is better at handling the edge cases of `rustls`.
56        let (stream, tls) = match stream {
57            TlsStream::Client(stream) => {
58                let (stream, tls) = stream.into_inner();
59                (stream, rustls::Connection::Client(tls))
60            }
61            TlsStream::Server(stream) => {
62                let (stream, tls) = stream.into_inner();
63                (stream, rustls::Connection::Server(tls))
64            }
65        };
66
67        Self {
68            stream,
69            tls: Some(tls),
70            read_buffer: BytesMut::default(),
71            write_buffer: BytesMut::default(),
72        }
73    }
74
75    pub async fn flush(&mut self) -> Result<(), Error<Infallible>> {
76        // Flush TLS
77        if let Some(tls) = &mut self.tls {
78            tls.writer().flush()?;
79            encrypt(tls, &mut self.write_buffer, Vec::new())?;
80        }
81
82        // Flush TCP
83        write(&mut self.stream, &mut self.write_buffer).await?;
84        self.stream.flush().await?;
85
86        Ok(())
87    }
88
89    pub async fn next<F: State>(&mut self, mut state: F) -> Result<F::Event, Error<F::Error>> {
90        let event = loop {
91            match &mut self.tls {
92                None => {
93                    // Provide input bytes to the client/server
94                    if !self.read_buffer.is_empty() {
95                        state.enqueue_input(&self.read_buffer);
96                        self.read_buffer.clear();
97                    }
98                }
99                Some(tls) => {
100                    // Decrypt input bytes
101                    let plain_bytes = decrypt(tls, &mut self.read_buffer)?;
102
103                    // Provide input bytes to the client/server
104                    if !plain_bytes.is_empty() {
105                        state.enqueue_input(&plain_bytes);
106                    }
107                }
108            }
109
110            // Progress the client/server
111            let result = state.next();
112
113            // Return events immediately without doing IO
114            let interrupt = match result {
115                Err(interrupt) => interrupt,
116                Ok(event) => break event,
117            };
118
119            // Return errors immediately without doing IO
120            let io = match interrupt {
121                Interrupt::Io(io) => io,
122                Interrupt::Error(err) => return Err(Error::State(err)),
123            };
124
125            match &mut self.tls {
126                None => {
127                    // Handle the output bytes from the client/server
128                    if let Io::Output(bytes) = io {
129                        self.write_buffer.extend(bytes);
130                    }
131                }
132                Some(tls) => {
133                    // Handle the output bytes from the client/server
134                    let plain_bytes = if let Io::Output(bytes) = io {
135                        bytes
136                    } else {
137                        Vec::new()
138                    };
139
140                    // Encrypt output bytes
141                    encrypt(tls, &mut self.write_buffer, plain_bytes)?;
142                }
143            }
144
145            // Progress the stream
146            if self.write_buffer.is_empty() {
147                read(&mut self.stream, &mut self.read_buffer).await?;
148            } else {
149                // We read and write the stream simultaneously because otherwise
150                // a deadlock between client and server might occur if both sides
151                // would only read or only write.
152                let (read_stream, write_stream) = self.stream.split();
153                select! {
154                    result = read(read_stream, &mut self.read_buffer) => result,
155                    result = write(write_stream, &mut self.write_buffer) => result,
156                }?;
157            };
158        };
159
160        Ok(event)
161    }
162
163    #[cfg(feature = "expose_stream")]
164    /// Return the underlying stream for debug purposes (or experiments).
165    ///
166    /// Note: Writing to or reading from the stream may introduce
167    /// conflicts with `imap-next`.
168    pub fn stream_mut(&mut self) -> &mut TcpStream {
169        &mut self.stream
170    }
171}
172
173/// Take the [`TcpStream`] out of a [`Stream`].
174///
175/// Useful when a TCP stream needs to be upgraded to a TLS one.
176#[cfg(feature = "expose_stream")]
177impl From<Stream> for TcpStream {
178    fn from(stream: Stream) -> Self {
179        stream.stream
180    }
181}
182
183/// Error during reading into or writing from a stream.
184#[derive(Debug, Error)]
185pub enum Error<E> {
186    /// Operation failed because stream is closed.
187    ///
188    /// We detect this by checking if the read or written byte count is 0. Whether the stream is
189    /// closed indefinitely or temporarily depends on the actual stream implementation.
190    #[error("Stream was closed")]
191    Closed,
192    /// An I/O error occurred in the underlying stream.
193    #[error(transparent)]
194    Io(#[from] tokio::io::Error),
195    /// An error occurred in the underlying TLS connection.
196    #[error(transparent)]
197    Tls(#[from] rustls::Error),
198    /// An error occurred while progressing the state.
199    #[error(transparent)]
200    State(E),
201}
202
203#[instrument(name = "io", skip_all, fields(action = "read"))]
204async fn read<S: AsyncRead + Unpin>(
205    mut stream: S,
206    read_buffer: &mut BytesMut,
207) -> Result<(), ReadWriteError> {
208    #[cfg(debug_assertions)]
209    let old_len = read_buffer.len();
210    let byte_count = stream.read_buf(read_buffer).await?;
211    #[cfg(debug_assertions)]
212    trace!(data = escape_byte_string(&read_buffer[old_len..]));
213
214    if byte_count == 0 {
215        // The result is 0 if the stream reached "end of file" or the read buffer was
216        // already full before calling `read_buf`. Because we use an unlimited buffer we
217        // know that the first case occurred.
218        return Err(ReadWriteError::Closed);
219    }
220
221    Ok(())
222}
223
224#[instrument(name = "io", skip_all, fields(action = "write"))]
225async fn write<S: AsyncWrite + Unpin>(
226    mut stream: S,
227    write_buffer: &mut BytesMut,
228) -> Result<(), ReadWriteError> {
229    while !write_buffer.is_empty() {
230        let byte_count = stream.write(write_buffer).await?;
231        #[cfg(debug_assertions)]
232        trace!(data = escape_byte_string(&write_buffer[..byte_count]));
233        write_buffer.advance(byte_count);
234
235        if byte_count == 0 {
236            // The result is 0 if the stream doesn't accept bytes anymore or the write buffer
237            // was already empty before calling `write_buf`. Because we checked the buffer
238            // we know that the first case occurred.
239            return Err(ReadWriteError::Closed);
240        }
241    }
242
243    Ok(())
244}
245
246#[derive(Debug, Error)]
247enum ReadWriteError {
248    #[error("Stream was closed")]
249    Closed,
250    #[error(transparent)]
251    Io(#[from] tokio::io::Error),
252}
253
254impl<E> From<ReadWriteError> for Error<E> {
255    fn from(value: ReadWriteError) -> Self {
256        match value {
257            ReadWriteError::Closed => Error::Closed,
258            ReadWriteError::Io(err) => Error::Io(err),
259        }
260    }
261}
262
263fn decrypt(
264    tls: &mut rustls::Connection,
265    read_buffer: &mut BytesMut,
266) -> Result<Vec<u8>, DecryptEncryptError> {
267    let mut plain_bytes = Vec::new();
268
269    while tls.wants_read() && !read_buffer.is_empty() {
270        let mut encrypted_bytes = read_buffer.reader();
271        tls.read_tls(&mut encrypted_bytes)?;
272        tls.process_new_packets()?;
273    }
274
275    loop {
276        let mut plain_bytes_chunk = [0; 128];
277        // We need to handle different cases according to:
278        // https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read
279        match tls.reader().read(&mut plain_bytes_chunk) {
280            // There are no more bytes to read
281            Err(err) if err.kind() == ErrorKind::WouldBlock => break,
282            // The TLS session was closed uncleanly
283            Err(err) if err.kind() == ErrorKind::UnexpectedEof => {
284                return Err(DecryptEncryptError::Closed)
285            }
286            // We got an unexpected error
287            Err(err) => return Err(DecryptEncryptError::Io(err)),
288            // The TLS session was closed cleanly
289            Ok(0) => return Err(DecryptEncryptError::Closed),
290            // We read some plaintext bytes
291            Ok(n) => plain_bytes.extend(&plain_bytes_chunk[0..n]),
292        };
293    }
294
295    Ok(plain_bytes)
296}
297
298fn encrypt(
299    tls: &mut rustls::Connection,
300    write_buffer: &mut BytesMut,
301    plain_bytes: Vec<u8>,
302) -> Result<(), DecryptEncryptError> {
303    if !plain_bytes.is_empty() {
304        tls.writer().write_all(&plain_bytes)?;
305    }
306
307    while tls.wants_write() {
308        let mut encrypted_bytes = write_buffer.writer();
309        tls.write_tls(&mut encrypted_bytes)?;
310    }
311
312    Ok(())
313}
314
315#[derive(Debug, Error)]
316enum DecryptEncryptError {
317    #[error("Session was closed")]
318    Closed,
319    #[error(transparent)]
320    Io(#[from] std::io::Error),
321    #[error(transparent)]
322    Tls(#[from] rustls::Error),
323}
324
325impl<E> From<DecryptEncryptError> for Error<E> {
326    fn from(value: DecryptEncryptError) -> Self {
327        match value {
328            DecryptEncryptError::Closed => Error::Closed,
329            DecryptEncryptError::Io(err) => Error::Io(err),
330            DecryptEncryptError::Tls(err) => Error::Tls(err),
331        }
332    }
333}