imap_next/
stream.rs

1use std::{
2    convert::Infallible,
3    io::{ErrorKind, Read, Write},
4};
5
6use bytes::{Buf, BufMut, BytesMut};
7#[cfg(debug_assertions)]
8use imap_codec::imap_types::utils::escape_byte_string;
9use thiserror::Error;
10use tokio::{
11    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
12    net::TcpStream,
13    select,
14};
15use tokio_rustls::{rustls, TlsStream};
16#[cfg(debug_assertions)]
17use tracing::trace;
18
19use crate::{Interrupt, Io, State};
20
21pub struct Stream {
22    stream: TcpStream,
23    tls: Option<rustls::Connection>,
24    read_buffer: BytesMut,
25    write_buffer: BytesMut,
26}
27
28impl Stream {
29    pub fn insecure(stream: TcpStream) -> Self {
30        Self {
31            stream,
32            tls: None,
33            read_buffer: BytesMut::default(),
34            write_buffer: BytesMut::default(),
35        }
36    }
37
38    pub fn tls(stream: TlsStream<TcpStream>) -> Self {
39        // We want to use `TcpStream::split` for handling reading and writing separately,
40        // but `TlsStream` does not expose this functionality. Therefore, we destruct `TlsStream`
41        // into `TcpStream` and `rustls::Connection` and handling them ourselves.
42        //
43        // Some notes:
44        //
45        // - There is also `tokio::io::split` which works for all kind of streams. But this
46        //   involves too much scary magic because its use-case is reading and writing from
47        //   different threads. We prefer to use the more low-level `TcpStream::split`.
48        //
49        // - We could get rid of `TlsStream` and construct `rustls::Connection` directly.
50        //   But `TlsStream` is still useful because it gives us the guarantee that the handshake
51        //   was already handled properly.
52        //
53        // - In the long run it would be nice if `TlsStream::split` would exist and we would use
54        //   it because `TlsStream` is better at handling the edge cases of `rustls`.
55        let (stream, tls) = match stream {
56            TlsStream::Client(stream) => {
57                let (stream, tls) = stream.into_inner();
58                (stream, rustls::Connection::Client(tls))
59            }
60            TlsStream::Server(stream) => {
61                let (stream, tls) = stream.into_inner();
62                (stream, rustls::Connection::Server(tls))
63            }
64        };
65
66        Self {
67            stream,
68            tls: Some(tls),
69            read_buffer: BytesMut::default(),
70            write_buffer: BytesMut::default(),
71        }
72    }
73
74    pub async fn flush(&mut self) -> Result<(), Error<Infallible>> {
75        // Flush TLS
76        if let Some(tls) = &mut self.tls {
77            tls.writer().flush()?;
78            encrypt(tls, &mut self.write_buffer, Vec::new())?;
79        }
80
81        // Flush TCP
82        write(&mut self.stream, &mut self.write_buffer).await?;
83        self.stream.flush().await?;
84
85        Ok(())
86    }
87
88    pub async fn next<F: State>(&mut self, mut state: F) -> Result<F::Event, Error<F::Error>> {
89        let event = loop {
90            match &mut self.tls {
91                None => {
92                    // Provide input bytes to the client/server
93                    if !self.read_buffer.is_empty() {
94                        state.enqueue_input(&self.read_buffer);
95                        self.read_buffer.clear();
96                    }
97                }
98                Some(tls) => {
99                    // Decrypt input bytes
100                    let plain_bytes = decrypt(tls, &mut self.read_buffer)?;
101
102                    // Provide input bytes to the client/server
103                    if !plain_bytes.is_empty() {
104                        state.enqueue_input(&plain_bytes);
105                    }
106                }
107            }
108
109            // Progress the client/server
110            let result = state.next();
111
112            // Return events immediately without doing IO
113            let interrupt = match result {
114                Err(interrupt) => interrupt,
115                Ok(event) => break event,
116            };
117
118            // Return errors immediately without doing IO
119            let io = match interrupt {
120                Interrupt::Io(io) => io,
121                Interrupt::Error(err) => return Err(Error::State(err)),
122            };
123
124            match &mut self.tls {
125                None => {
126                    // Handle the output bytes from the client/server
127                    if let Io::Output(bytes) = io {
128                        self.write_buffer.extend(bytes);
129                    }
130                }
131                Some(tls) => {
132                    // Handle the output bytes from the client/server
133                    let plain_bytes = if let Io::Output(bytes) = io {
134                        bytes
135                    } else {
136                        Vec::new()
137                    };
138
139                    // Encrypt output bytes
140                    encrypt(tls, &mut self.write_buffer, plain_bytes)?;
141                }
142            }
143
144            // Progress the stream
145            if self.write_buffer.is_empty() {
146                read(&mut self.stream, &mut self.read_buffer).await?;
147            } else {
148                // We read and write the stream simultaneously because otherwise
149                // a deadlock between client and server might occur if both sides
150                // would only read or only write.
151                let (read_stream, write_stream) = self.stream.split();
152                select! {
153                    result = read(read_stream, &mut self.read_buffer) => result,
154                    result = write(write_stream, &mut self.write_buffer) => result,
155                }?;
156            };
157        };
158
159        Ok(event)
160    }
161
162    #[cfg(feature = "expose_stream")]
163    /// Return the underlying stream for debug purposes (or experiments).
164    ///
165    /// Note: Writing to or reading from the stream may introduce
166    /// conflicts with `imap-next`.
167    pub fn stream_mut(&mut self) -> &mut TcpStream {
168        &mut self.stream
169    }
170}
171
172/// Take the [`TcpStream`] out of a [`Stream`].
173///
174/// Useful when a TCP stream needs to be upgraded to a TLS one.
175#[cfg(feature = "expose_stream")]
176impl From<Stream> for TcpStream {
177    fn from(stream: Stream) -> Self {
178        stream.stream
179    }
180}
181
182/// Error during reading into or writing from a stream.
183#[derive(Debug, Error)]
184pub enum Error<E> {
185    /// Operation failed because stream is closed.
186    ///
187    /// We detect this by checking if the read or written byte count is 0. Whether the stream is
188    /// closed indefinitely or temporarily depends on the actual stream implementation.
189    #[error("Stream was closed")]
190    Closed,
191    /// An I/O error occurred in the underlying stream.
192    #[error(transparent)]
193    Io(#[from] tokio::io::Error),
194    /// An error occurred in the underlying TLS connection.
195    #[error(transparent)]
196    Tls(#[from] rustls::Error),
197    /// An error occurred while progressing the state.
198    #[error(transparent)]
199    State(E),
200}
201
202async fn read<S: AsyncRead + Unpin>(
203    mut stream: S,
204    read_buffer: &mut BytesMut,
205) -> Result<(), ReadWriteError> {
206    #[cfg(debug_assertions)]
207    let old_len = read_buffer.len();
208    let byte_count = stream.read_buf(read_buffer).await?;
209    #[cfg(debug_assertions)]
210    trace!(
211        data = escape_byte_string(&read_buffer[old_len..]),
212        "io/read/raw"
213    );
214
215    if byte_count == 0 {
216        // The result is 0 if the stream reached "end of file" or the read buffer was
217        // already full before calling `read_buf`. Because we use an unlimited buffer we
218        // know that the first case occurred.
219        return Err(ReadWriteError::Closed);
220    }
221
222    Ok(())
223}
224
225async fn write<S: AsyncWrite + Unpin>(
226    mut stream: S,
227    write_buffer: &mut BytesMut,
228) -> Result<(), ReadWriteError> {
229    while !write_buffer.is_empty() {
230        let byte_count = stream.write(write_buffer).await?;
231        #[cfg(debug_assertions)]
232        trace!(
233            data = escape_byte_string(&write_buffer[..byte_count]),
234            "io/write/raw"
235        );
236        write_buffer.advance(byte_count);
237
238        if byte_count == 0 {
239            // The result is 0 if the stream doesn't accept bytes anymore or the write buffer
240            // was already empty before calling `write_buf`. Because we checked the buffer
241            // we know that the first case occurred.
242            return Err(ReadWriteError::Closed);
243        }
244    }
245
246    Ok(())
247}
248
249#[derive(Debug, Error)]
250enum ReadWriteError {
251    #[error("Stream was closed")]
252    Closed,
253    #[error(transparent)]
254    Io(#[from] tokio::io::Error),
255}
256
257impl<E> From<ReadWriteError> for Error<E> {
258    fn from(value: ReadWriteError) -> Self {
259        match value {
260            ReadWriteError::Closed => Error::Closed,
261            ReadWriteError::Io(err) => Error::Io(err),
262        }
263    }
264}
265
266fn decrypt(
267    tls: &mut rustls::Connection,
268    read_buffer: &mut BytesMut,
269) -> Result<Vec<u8>, DecryptEncryptError> {
270    let mut plain_bytes = Vec::new();
271
272    while tls.wants_read() && !read_buffer.is_empty() {
273        let mut encrypted_bytes = read_buffer.reader();
274        tls.read_tls(&mut encrypted_bytes)?;
275        tls.process_new_packets()?;
276    }
277
278    loop {
279        let mut plain_bytes_chunk = [0; 128];
280        // We need to handle different cases according to:
281        // https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read
282        match tls.reader().read(&mut plain_bytes_chunk) {
283            // There are no more bytes to read
284            Err(err) if err.kind() == ErrorKind::WouldBlock => break,
285            // The TLS session was closed uncleanly
286            Err(err) if err.kind() == ErrorKind::UnexpectedEof => {
287                return Err(DecryptEncryptError::Closed)
288            }
289            // We got an unexpected error
290            Err(err) => return Err(DecryptEncryptError::Io(err)),
291            // The TLS session was closed cleanly
292            Ok(0) => return Err(DecryptEncryptError::Closed),
293            // We read some plaintext bytes
294            Ok(n) => plain_bytes.extend(&plain_bytes_chunk[0..n]),
295        };
296    }
297
298    Ok(plain_bytes)
299}
300
301fn encrypt(
302    tls: &mut rustls::Connection,
303    write_buffer: &mut BytesMut,
304    plain_bytes: Vec<u8>,
305) -> Result<(), DecryptEncryptError> {
306    if !plain_bytes.is_empty() {
307        tls.writer().write_all(&plain_bytes)?;
308    }
309
310    while tls.wants_write() {
311        let mut encrypted_bytes = write_buffer.writer();
312        tls.write_tls(&mut encrypted_bytes)?;
313    }
314
315    Ok(())
316}
317
318#[derive(Debug, Error)]
319enum DecryptEncryptError {
320    #[error("Session was closed")]
321    Closed,
322    #[error(transparent)]
323    Io(#[from] std::io::Error),
324    #[error(transparent)]
325    Tls(#[from] rustls::Error),
326}
327
328impl<E> From<DecryptEncryptError> for Error<E> {
329    fn from(value: DecryptEncryptError) -> Self {
330        match value {
331            DecryptEncryptError::Closed => Error::Closed,
332            DecryptEncryptError::Io(err) => Error::Io(err),
333            DecryptEncryptError::Tls(err) => Error::Tls(err),
334        }
335    }
336}