1use std::{
2 convert::Infallible,
3 io::{ErrorKind, Read, Write},
4};
5
6use bytes::{Buf, BufMut, BytesMut};
7#[cfg(debug_assertions)]
8use imap_codec::imap_types::utils::escape_byte_string;
9use thiserror::Error;
10use tokio::{
11 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
12 net::TcpStream,
13 select,
14};
15use tokio_rustls::{rustls, TlsStream};
16#[cfg(debug_assertions)]
17use tracing::trace;
18
19use crate::{Interrupt, Io, State};
20
21pub struct Stream {
22 stream: TcpStream,
23 tls: Option<rustls::Connection>,
24 read_buffer: BytesMut,
25 write_buffer: BytesMut,
26}
27
28impl Stream {
29 pub fn insecure(stream: TcpStream) -> Self {
30 Self {
31 stream,
32 tls: None,
33 read_buffer: BytesMut::default(),
34 write_buffer: BytesMut::default(),
35 }
36 }
37
38 pub fn tls(stream: TlsStream<TcpStream>) -> Self {
39 let (stream, tls) = match stream {
56 TlsStream::Client(stream) => {
57 let (stream, tls) = stream.into_inner();
58 (stream, rustls::Connection::Client(tls))
59 }
60 TlsStream::Server(stream) => {
61 let (stream, tls) = stream.into_inner();
62 (stream, rustls::Connection::Server(tls))
63 }
64 };
65
66 Self {
67 stream,
68 tls: Some(tls),
69 read_buffer: BytesMut::default(),
70 write_buffer: BytesMut::default(),
71 }
72 }
73
74 pub async fn flush(&mut self) -> Result<(), Error<Infallible>> {
75 if let Some(tls) = &mut self.tls {
77 tls.writer().flush()?;
78 encrypt(tls, &mut self.write_buffer, Vec::new())?;
79 }
80
81 write(&mut self.stream, &mut self.write_buffer).await?;
83 self.stream.flush().await?;
84
85 Ok(())
86 }
87
88 pub async fn next<F: State>(&mut self, mut state: F) -> Result<F::Event, Error<F::Error>> {
89 let event = loop {
90 match &mut self.tls {
91 None => {
92 if !self.read_buffer.is_empty() {
94 state.enqueue_input(&self.read_buffer);
95 self.read_buffer.clear();
96 }
97 }
98 Some(tls) => {
99 let plain_bytes = decrypt(tls, &mut self.read_buffer)?;
101
102 if !plain_bytes.is_empty() {
104 state.enqueue_input(&plain_bytes);
105 }
106 }
107 }
108
109 let result = state.next();
111
112 let interrupt = match result {
114 Err(interrupt) => interrupt,
115 Ok(event) => break event,
116 };
117
118 let io = match interrupt {
120 Interrupt::Io(io) => io,
121 Interrupt::Error(err) => return Err(Error::State(err)),
122 };
123
124 match &mut self.tls {
125 None => {
126 if let Io::Output(bytes) = io {
128 self.write_buffer.extend(bytes);
129 }
130 }
131 Some(tls) => {
132 let plain_bytes = if let Io::Output(bytes) = io {
134 bytes
135 } else {
136 Vec::new()
137 };
138
139 encrypt(tls, &mut self.write_buffer, plain_bytes)?;
141 }
142 }
143
144 if self.write_buffer.is_empty() {
146 read(&mut self.stream, &mut self.read_buffer).await?;
147 } else {
148 let (read_stream, write_stream) = self.stream.split();
152 select! {
153 result = read(read_stream, &mut self.read_buffer) => result,
154 result = write(write_stream, &mut self.write_buffer) => result,
155 }?;
156 };
157 };
158
159 Ok(event)
160 }
161
162 #[cfg(feature = "expose_stream")]
163 pub fn stream_mut(&mut self) -> &mut TcpStream {
168 &mut self.stream
169 }
170}
171
172#[cfg(feature = "expose_stream")]
176impl From<Stream> for TcpStream {
177 fn from(stream: Stream) -> Self {
178 stream.stream
179 }
180}
181
182#[derive(Debug, Error)]
184pub enum Error<E> {
185 #[error("Stream was closed")]
190 Closed,
191 #[error(transparent)]
193 Io(#[from] tokio::io::Error),
194 #[error(transparent)]
196 Tls(#[from] rustls::Error),
197 #[error(transparent)]
199 State(E),
200}
201
202async fn read<S: AsyncRead + Unpin>(
203 mut stream: S,
204 read_buffer: &mut BytesMut,
205) -> Result<(), ReadWriteError> {
206 #[cfg(debug_assertions)]
207 let old_len = read_buffer.len();
208 let byte_count = stream.read_buf(read_buffer).await?;
209 #[cfg(debug_assertions)]
210 trace!(
211 data = escape_byte_string(&read_buffer[old_len..]),
212 "io/read/raw"
213 );
214
215 if byte_count == 0 {
216 return Err(ReadWriteError::Closed);
220 }
221
222 Ok(())
223}
224
225async fn write<S: AsyncWrite + Unpin>(
226 mut stream: S,
227 write_buffer: &mut BytesMut,
228) -> Result<(), ReadWriteError> {
229 while !write_buffer.is_empty() {
230 let byte_count = stream.write(write_buffer).await?;
231 #[cfg(debug_assertions)]
232 trace!(
233 data = escape_byte_string(&write_buffer[..byte_count]),
234 "io/write/raw"
235 );
236 write_buffer.advance(byte_count);
237
238 if byte_count == 0 {
239 return Err(ReadWriteError::Closed);
243 }
244 }
245
246 Ok(())
247}
248
249#[derive(Debug, Error)]
250enum ReadWriteError {
251 #[error("Stream was closed")]
252 Closed,
253 #[error(transparent)]
254 Io(#[from] tokio::io::Error),
255}
256
257impl<E> From<ReadWriteError> for Error<E> {
258 fn from(value: ReadWriteError) -> Self {
259 match value {
260 ReadWriteError::Closed => Error::Closed,
261 ReadWriteError::Io(err) => Error::Io(err),
262 }
263 }
264}
265
266fn decrypt(
267 tls: &mut rustls::Connection,
268 read_buffer: &mut BytesMut,
269) -> Result<Vec<u8>, DecryptEncryptError> {
270 let mut plain_bytes = Vec::new();
271
272 while tls.wants_read() && !read_buffer.is_empty() {
273 let mut encrypted_bytes = read_buffer.reader();
274 tls.read_tls(&mut encrypted_bytes)?;
275 tls.process_new_packets()?;
276 }
277
278 loop {
279 let mut plain_bytes_chunk = [0; 128];
280 match tls.reader().read(&mut plain_bytes_chunk) {
283 Err(err) if err.kind() == ErrorKind::WouldBlock => break,
285 Err(err) if err.kind() == ErrorKind::UnexpectedEof => {
287 return Err(DecryptEncryptError::Closed)
288 }
289 Err(err) => return Err(DecryptEncryptError::Io(err)),
291 Ok(0) => return Err(DecryptEncryptError::Closed),
293 Ok(n) => plain_bytes.extend(&plain_bytes_chunk[0..n]),
295 };
296 }
297
298 Ok(plain_bytes)
299}
300
301fn encrypt(
302 tls: &mut rustls::Connection,
303 write_buffer: &mut BytesMut,
304 plain_bytes: Vec<u8>,
305) -> Result<(), DecryptEncryptError> {
306 if !plain_bytes.is_empty() {
307 tls.writer().write_all(&plain_bytes)?;
308 }
309
310 while tls.wants_write() {
311 let mut encrypted_bytes = write_buffer.writer();
312 tls.write_tls(&mut encrypted_bytes)?;
313 }
314
315 Ok(())
316}
317
318#[derive(Debug, Error)]
319enum DecryptEncryptError {
320 #[error("Session was closed")]
321 Closed,
322 #[error(transparent)]
323 Io(#[from] std::io::Error),
324 #[error(transparent)]
325 Tls(#[from] rustls::Error),
326}
327
328impl<E> From<DecryptEncryptError> for Error<E> {
329 fn from(value: DecryptEncryptError) -> Self {
330 match value {
331 DecryptEncryptError::Closed => Error::Closed,
332 DecryptEncryptError::Io(err) => Error::Io(err),
333 DecryptEncryptError::Tls(err) => Error::Tls(err),
334 }
335 }
336}