1use std::{
2 convert::Infallible,
3 io::{ErrorKind, Read, Write},
4};
5
6use bytes::{Buf, BufMut, BytesMut};
7#[cfg(debug_assertions)]
8use imap_codec::imap_types::utils::escape_byte_string;
9use thiserror::Error;
10use tokio::{
11 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
12 net::TcpStream,
13 select,
14};
15use tokio_rustls::{rustls, TlsStream};
16use tracing::instrument;
17#[cfg(debug_assertions)]
18use tracing::trace;
19
20use crate::{Interrupt, Io, State};
21
22pub struct Stream {
23 stream: TcpStream,
24 tls: Option<rustls::Connection>,
25 read_buffer: BytesMut,
26 write_buffer: BytesMut,
27}
28
29impl Stream {
30 pub fn insecure(stream: TcpStream) -> Self {
31 Self {
32 stream,
33 tls: None,
34 read_buffer: BytesMut::default(),
35 write_buffer: BytesMut::default(),
36 }
37 }
38
39 pub fn tls(stream: TlsStream<TcpStream>) -> Self {
40 let (stream, tls) = match stream {
57 TlsStream::Client(stream) => {
58 let (stream, tls) = stream.into_inner();
59 (stream, rustls::Connection::Client(tls))
60 }
61 TlsStream::Server(stream) => {
62 let (stream, tls) = stream.into_inner();
63 (stream, rustls::Connection::Server(tls))
64 }
65 };
66
67 Self {
68 stream,
69 tls: Some(tls),
70 read_buffer: BytesMut::default(),
71 write_buffer: BytesMut::default(),
72 }
73 }
74
75 pub async fn flush(&mut self) -> Result<(), Error<Infallible>> {
76 if let Some(tls) = &mut self.tls {
78 tls.writer().flush()?;
79 encrypt(tls, &mut self.write_buffer, Vec::new())?;
80 }
81
82 write(&mut self.stream, &mut self.write_buffer).await?;
84 self.stream.flush().await?;
85
86 Ok(())
87 }
88
89 pub async fn next<F: State>(&mut self, mut state: F) -> Result<F::Event, Error<F::Error>> {
90 let event = loop {
91 match &mut self.tls {
92 None => {
93 if !self.read_buffer.is_empty() {
95 state.enqueue_input(&self.read_buffer);
96 self.read_buffer.clear();
97 }
98 }
99 Some(tls) => {
100 let plain_bytes = decrypt(tls, &mut self.read_buffer)?;
102
103 if !plain_bytes.is_empty() {
105 state.enqueue_input(&plain_bytes);
106 }
107 }
108 }
109
110 let result = state.next();
112
113 let interrupt = match result {
115 Err(interrupt) => interrupt,
116 Ok(event) => break event,
117 };
118
119 let io = match interrupt {
121 Interrupt::Io(io) => io,
122 Interrupt::Error(err) => return Err(Error::State(err)),
123 };
124
125 match &mut self.tls {
126 None => {
127 if let Io::Output(bytes) = io {
129 self.write_buffer.extend(bytes);
130 }
131 }
132 Some(tls) => {
133 let plain_bytes = if let Io::Output(bytes) = io {
135 bytes
136 } else {
137 Vec::new()
138 };
139
140 encrypt(tls, &mut self.write_buffer, plain_bytes)?;
142 }
143 }
144
145 if self.write_buffer.is_empty() {
147 read(&mut self.stream, &mut self.read_buffer).await?;
148 } else {
149 let (read_stream, write_stream) = self.stream.split();
153 select! {
154 result = read(read_stream, &mut self.read_buffer) => result,
155 result = write(write_stream, &mut self.write_buffer) => result,
156 }?;
157 };
158 };
159
160 Ok(event)
161 }
162
163 #[cfg(feature = "expose_stream")]
164 pub fn stream_mut(&mut self) -> &mut TcpStream {
169 &mut self.stream
170 }
171}
172
173#[cfg(feature = "expose_stream")]
177impl From<Stream> for TcpStream {
178 fn from(stream: Stream) -> Self {
179 stream.stream
180 }
181}
182
183#[derive(Debug, Error)]
185pub enum Error<E> {
186 #[error("Stream was closed")]
191 Closed,
192 #[error(transparent)]
194 Io(#[from] tokio::io::Error),
195 #[error(transparent)]
197 Tls(#[from] rustls::Error),
198 #[error(transparent)]
200 State(E),
201}
202
203#[instrument(name = "io", skip_all, fields(action = "read"))]
204async fn read<S: AsyncRead + Unpin>(
205 mut stream: S,
206 read_buffer: &mut BytesMut,
207) -> Result<(), ReadWriteError> {
208 #[cfg(debug_assertions)]
209 let old_len = read_buffer.len();
210 let byte_count = stream.read_buf(read_buffer).await?;
211 #[cfg(debug_assertions)]
212 trace!(data = escape_byte_string(&read_buffer[old_len..]));
213
214 if byte_count == 0 {
215 return Err(ReadWriteError::Closed);
219 }
220
221 Ok(())
222}
223
224#[instrument(name = "io", skip_all, fields(action = "write"))]
225async fn write<S: AsyncWrite + Unpin>(
226 mut stream: S,
227 write_buffer: &mut BytesMut,
228) -> Result<(), ReadWriteError> {
229 while !write_buffer.is_empty() {
230 let byte_count = stream.write(write_buffer).await?;
231 #[cfg(debug_assertions)]
232 trace!(data = escape_byte_string(&write_buffer[..byte_count]));
233 write_buffer.advance(byte_count);
234
235 if byte_count == 0 {
236 return Err(ReadWriteError::Closed);
240 }
241 }
242
243 Ok(())
244}
245
246#[derive(Debug, Error)]
247enum ReadWriteError {
248 #[error("Stream was closed")]
249 Closed,
250 #[error(transparent)]
251 Io(#[from] tokio::io::Error),
252}
253
254impl<E> From<ReadWriteError> for Error<E> {
255 fn from(value: ReadWriteError) -> Self {
256 match value {
257 ReadWriteError::Closed => Error::Closed,
258 ReadWriteError::Io(err) => Error::Io(err),
259 }
260 }
261}
262
263fn decrypt(
264 tls: &mut rustls::Connection,
265 read_buffer: &mut BytesMut,
266) -> Result<Vec<u8>, DecryptEncryptError> {
267 let mut plain_bytes = Vec::new();
268
269 while tls.wants_read() && !read_buffer.is_empty() {
270 let mut encrypted_bytes = read_buffer.reader();
271 tls.read_tls(&mut encrypted_bytes)?;
272 tls.process_new_packets()?;
273 }
274
275 loop {
276 let mut plain_bytes_chunk = [0; 128];
277 match tls.reader().read(&mut plain_bytes_chunk) {
280 Err(err) if err.kind() == ErrorKind::WouldBlock => break,
282 Err(err) if err.kind() == ErrorKind::UnexpectedEof => {
284 return Err(DecryptEncryptError::Closed)
285 }
286 Err(err) => return Err(DecryptEncryptError::Io(err)),
288 Ok(0) => return Err(DecryptEncryptError::Closed),
290 Ok(n) => plain_bytes.extend(&plain_bytes_chunk[0..n]),
292 };
293 }
294
295 Ok(plain_bytes)
296}
297
298fn encrypt(
299 tls: &mut rustls::Connection,
300 write_buffer: &mut BytesMut,
301 plain_bytes: Vec<u8>,
302) -> Result<(), DecryptEncryptError> {
303 if !plain_bytes.is_empty() {
304 tls.writer().write_all(&plain_bytes)?;
305 }
306
307 while tls.wants_write() {
308 let mut encrypted_bytes = write_buffer.writer();
309 tls.write_tls(&mut encrypted_bytes)?;
310 }
311
312 Ok(())
313}
314
315#[derive(Debug, Error)]
316enum DecryptEncryptError {
317 #[error("Session was closed")]
318 Closed,
319 #[error(transparent)]
320 Io(#[from] std::io::Error),
321 #[error(transparent)]
322 Tls(#[from] rustls::Error),
323}
324
325impl<E> From<DecryptEncryptError> for Error<E> {
326 fn from(value: DecryptEncryptError) -> Self {
327 match value {
328 DecryptEncryptError::Closed => Error::Closed,
329 DecryptEncryptError::Io(err) => Error::Io(err),
330 DecryptEncryptError::Tls(err) => Error::Tls(err),
331 }
332 }
333}