1use tokio::io::AsyncWriteExt;
2
3#[cfg(feature = "logs")]
4use tracing::trace;
5
6use crate::packets::ConnAck;
7use crate::packets::{ConnAckReasonCode, Packet};
8use crate::{connect_options::ConnectOptions, error::ConnectionError};
9
10pub(crate) trait StreamExt {
11 fn connect(&mut self, options: &ConnectOptions) -> impl std::future::Future<Output = Result<ConnAck, ConnectionError>>;
12 fn read_packet(&mut self) -> impl std::future::Future<Output = Result<Packet, ConnectionError>>;
13 fn write_packet(&mut self, packet: &Packet) -> impl std::future::Future<Output = Result<(), ConnectionError>>;
14 fn write_packets(&mut self, packets: &[Packet]) -> impl std::future::Future<Output = Result<(), ConnectionError>>;
15 fn flush_packets(&mut self) -> impl std::future::Future<Output = std::io::Result<()>>;
16}
17
18impl<S> StreamExt for S
19where
20 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Sized + Unpin,
21{
22 async fn connect(&mut self, options: &ConnectOptions) -> Result<ConnAck, ConnectionError> {
23 let connect = options.create_connect_from_options();
24
25 self.write_packet(&connect).await?;
26
27 let packet = Packet::async_read(self).await?;
28 if let Packet::ConnAck(con) = packet {
29 if con.reason_code == ConnAckReasonCode::Success {
30 #[cfg(feature = "logs")]
31 trace!("Connected to server");
32 Ok(con)
33 } else {
34 Err(ConnectionError::ConnectionRefused(con.reason_code))
35 }
36 } else {
37 Err(ConnectionError::NotConnAck(packet))
38 }
39 }
40
41 async fn read_packet(&mut self) -> Result<Packet, ConnectionError> {
42 Ok(Packet::async_read(self).await?)
43 }
44
45 async fn write_packet(&mut self, packet: &Packet) -> Result<(), ConnectionError> {
46 match packet.async_write(self).await {
47 Ok(_) => (),
48 Err(err) => {
49 return match err {
50 crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)),
51 crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)),
52 }
53 }
54 }
55
56 #[cfg(feature = "logs")]
57 trace!("Sending packet {}", packet);
58
59 self.flush().await?;
60 Ok(())
63 }
64
65 async fn write_packets(&mut self, packets: &[Packet]) -> Result<(), ConnectionError> {
66 for packet in packets {
67 let _ = packet.async_write(self).await;
68 #[cfg(feature = "logs")]
69 trace!("Sending packet {}", packet);
70 }
71 self.flush_packets().await?;
72 Ok(())
73 }
74
75 fn flush_packets(&mut self) -> impl std::future::Future<Output = std::io::Result<()>> {
76 tokio::io::AsyncWriteExt::flush(self)
77 }
78}