mysql_connector/connection/
io.rs

1use {
2    super::{
3        packets::{ErrPacket, OkPacket},
4        Connection, ParseBuf, Serialize, BUFFER_POOL, MAX_PAYLOAD_LEN,
5    },
6    crate::{
7        error::ProtocolError,
8        packets::StmtSendLongData,
9        pool::PoolItem,
10        types::{SimpleValue, Value},
11        utils::read_u32,
12        Deserialize, Error, StreamRequirements, Timeout, TimeoutFuture,
13    },
14    bytes::Buf,
15    std::time::Duration,
16    tokio::io::{AsyncReadExt, AsyncWriteExt},
17};
18
19impl Connection {
20    pub(super) async fn send_long_data<'a, V, I>(
21        &mut self,
22        statement_id: u32,
23        params: I,
24    ) -> Result<(), Error>
25    where
26        V: SimpleValue + 'a,
27        I: Iterator<Item = &'a V>,
28    {
29        for (i, value) in params.enumerate() {
30            if let Value::Bytes(bytes) = value.value() {
31                if bytes.is_empty() {
32                    self.write_command(&StmtSendLongData::new(statement_id, i as u16, &[]))
33                        .await?;
34                } else {
35                    let chunks = bytes.chunks(MAX_PAYLOAD_LEN - 6);
36                    for chunk in chunks {
37                        self.write_command(&StmtSendLongData::new(statement_id, i as u16, chunk))
38                            .await?;
39                    }
40                }
41            }
42        }
43        Ok(())
44    }
45
46    async fn read_chunk_to_buf(
47        stream: &mut dyn StreamRequirements,
48        dst: &mut Vec<u8>,
49        sleep: &dyn Fn(std::time::Duration) -> TimeoutFuture,
50        timeout: Duration,
51    ) -> Result<(u8, bool), Error> {
52        let mut metadata_buf = [0u8; 4];
53        Timeout::new(stream.read_exact(&mut metadata_buf), sleep, timeout).await??;
54        let chunk_len = read_u32(&metadata_buf[..3]) as usize;
55        let seq_id = metadata_buf[3];
56
57        if chunk_len == 0 {
58            return Ok((seq_id, true));
59        }
60
61        let start = dst.len();
62        dst.resize(start + chunk_len, 0);
63        Timeout::new(stream.read_exact(&mut dst[start..]), sleep, timeout).await??;
64
65        if dst.len() % MAX_PAYLOAD_LEN == 0 {
66            Ok((seq_id, false))
67        } else {
68            Ok((seq_id, true))
69        }
70    }
71
72    pub(super) async fn read_packet_to_buf(
73        stream: &mut dyn StreamRequirements,
74        seq_id: &mut u8,
75        dst: &mut Vec<u8>,
76        sleep: &dyn Fn(std::time::Duration) -> TimeoutFuture,
77        timeout: Duration,
78    ) -> Result<(), Error> {
79        loop {
80            let (read_seq_id, last_chunk) =
81                Self::read_chunk_to_buf(stream, dst, sleep, timeout).await?;
82            if *seq_id != read_seq_id {
83                return Err(Error::Protocol(ProtocolError::OutOfSync));
84            }
85
86            *seq_id = seq_id.wrapping_add(1);
87
88            if last_chunk {
89                return Ok(());
90            }
91        }
92    }
93
94    pub(super) async fn read_packet<'b>(&mut self) -> Result<PoolItem<'b, Vec<u8>>, Error> {
95        let mut decode_buf = BUFFER_POOL.get();
96        Self::read_packet_to_buf(
97            &mut self.stream,
98            &mut self.seq_id,
99            decode_buf.as_mut(),
100            self.data.sleep,
101            self.options.timeout(),
102        )
103        .await?;
104        Ok(decode_buf)
105    }
106
107    pub(super) async fn write_packet(&mut self, mut bytes: &[u8]) -> Result<(), Error> {
108        let extra_packet = bytes.remaining() % MAX_PAYLOAD_LEN == 0;
109
110        while bytes.has_remaining() {
111            let chunk_len = usize::min(bytes.remaining(), MAX_PAYLOAD_LEN);
112            Timeout::new(
113                self.stream
114                    .write_u32_le(chunk_len as u32 | (u32::from(self.seq_id) << 24)),
115                self.data.sleep,
116                self.options.timeout(),
117            )
118            .await??;
119            Timeout::new(
120                self.stream.write_all(&bytes[..chunk_len]),
121                self.data.sleep,
122                self.options.timeout(),
123            )
124            .await??;
125            bytes = &bytes[chunk_len..];
126            self.seq_id = self.seq_id.wrapping_add(1);
127        }
128
129        if extra_packet {
130            Timeout::new(
131                self.stream.write_u32_le(u32::from(self.seq_id) << 24),
132                self.data.sleep,
133                self.options.timeout(),
134            )
135            .await??;
136            self.seq_id = self.seq_id.wrapping_add(1);
137        }
138        Ok(())
139    }
140
141    pub(super) async fn write_struct<S: Serialize>(&mut self, x: &S) -> Result<(), Error> {
142        let mut buf = BUFFER_POOL.get();
143        x.serialize(buf.as_mut());
144        self.write_packet(&buf).await
145    }
146
147    pub(super) async fn write_command<S: Serialize>(&mut self, cmd: &S) -> Result<(), Error> {
148        self.cleanup().await?;
149        self.seq_id = 0;
150        self.write_struct(cmd).await
151    }
152
153    pub(crate) async fn decode_response(
154        &mut self,
155        packet: &[u8],
156    ) -> Result<Result<OkPacket, ErrPacket>, Error> {
157        let capabilities = self.data().capabilities();
158        if packet.is_empty() {
159            return Err(ProtocolError::eof().into());
160        }
161        match packet[0] {
162            0x00 => Ok(Ok(OkPacket::read_ok(packet, capabilities)?)),
163            0xFF => Ok(Err(ErrPacket::deserialize(
164                &mut ParseBuf(packet),
165                capabilities,
166            )?)),
167            _ => Err(
168                ProtocolError::unexpected_packet(packet.to_vec(), Some("Ok or Err Packet")).into(),
169            ),
170        }
171    }
172
173    pub(crate) async fn read_response(&mut self) -> Result<Result<OkPacket, ErrPacket>, Error> {
174        let packet = self.read_packet().await?;
175        self.decode_response(&packet).await
176    }
177}