mqrstt/tokio/
stream.rs

1use tokio::io::AsyncWriteExt;
2
3#[cfg(feature = "logs")]
4use tracing::trace;
5
6use crate::packets::ConnAck;
7use crate::packets::{ConnAckReasonCode, Packet};
8use crate::{connect_options::ConnectOptions, error::ConnectionError};
9
10pub(crate) trait StreamExt {
11    fn connect(&mut self, options: &ConnectOptions) -> impl std::future::Future<Output = Result<ConnAck, ConnectionError>>;
12    fn read_packet(&mut self) -> impl std::future::Future<Output = Result<Packet, ConnectionError>>;
13    fn write_packet(&mut self, packet: &Packet) -> impl std::future::Future<Output = Result<(), ConnectionError>>;
14    fn write_packets(&mut self, packets: &[Packet]) -> impl std::future::Future<Output = Result<(), ConnectionError>>;
15    fn flush_packets(&mut self) -> impl std::future::Future<Output = std::io::Result<()>>;
16}
17
18impl<S> StreamExt for S
19where
20    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Sized + Unpin,
21{
22    async fn connect(&mut self, options: &ConnectOptions) -> Result<ConnAck, ConnectionError> {
23        let connect = options.create_connect_from_options();
24
25        self.write_packet(&connect).await?;
26
27        let packet = Packet::async_read(self).await?;
28        if let Packet::ConnAck(con) = packet {
29            if con.reason_code == ConnAckReasonCode::Success {
30                #[cfg(feature = "logs")]
31                trace!("Connected to server");
32                Ok(con)
33            } else {
34                Err(ConnectionError::ConnectionRefused(con.reason_code))
35            }
36        } else {
37            Err(ConnectionError::NotConnAck(packet))
38        }
39    }
40
41    async fn read_packet(&mut self) -> Result<Packet, ConnectionError> {
42        Ok(Packet::async_read(self).await?)
43    }
44
45    async fn write_packet(&mut self, packet: &Packet) -> Result<(), ConnectionError> {
46        match packet.async_write(self).await {
47            Ok(_) => (),
48            Err(err) => {
49                return match err {
50                    crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)),
51                    crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)),
52                }
53            }
54        }
55
56        #[cfg(feature = "logs")]
57        trace!("Sending packet {}", packet);
58
59        self.flush().await?;
60        // self.flush_packets().await?;
61
62        Ok(())
63    }
64
65    async fn write_packets(&mut self, packets: &[Packet]) -> Result<(), ConnectionError> {
66        for packet in packets {
67            let _ = packet.async_write(self).await;
68            #[cfg(feature = "logs")]
69            trace!("Sending packet {}", packet);
70        }
71        self.flush_packets().await?;
72        Ok(())
73    }
74
75    fn flush_packets(&mut self) -> impl std::future::Future<Output = std::io::Result<()>> {
76        tokio::io::AsyncWriteExt::flush(self)
77    }
78}