mysql_connector/connection/
io.rs1use {
2 super::{
3 packets::{ErrPacket, OkPacket},
4 Connection, ParseBuf, Serialize, BUFFER_POOL, MAX_PAYLOAD_LEN,
5 },
6 crate::{
7 error::ProtocolError,
8 packets::StmtSendLongData,
9 pool::PoolItem,
10 types::{SimpleValue, Value},
11 utils::read_u32,
12 Deserialize, Error, StreamRequirements, Timeout, TimeoutFuture,
13 },
14 bytes::Buf,
15 std::time::Duration,
16 tokio::io::{AsyncReadExt, AsyncWriteExt},
17};
18
19impl Connection {
20 pub(super) async fn send_long_data<'a, V, I>(
21 &mut self,
22 statement_id: u32,
23 params: I,
24 ) -> Result<(), Error>
25 where
26 V: SimpleValue + 'a,
27 I: Iterator<Item = &'a V>,
28 {
29 for (i, value) in params.enumerate() {
30 if let Value::Bytes(bytes) = value.value() {
31 if bytes.is_empty() {
32 self.write_command(&StmtSendLongData::new(statement_id, i as u16, &[]))
33 .await?;
34 } else {
35 let chunks = bytes.chunks(MAX_PAYLOAD_LEN - 6);
36 for chunk in chunks {
37 self.write_command(&StmtSendLongData::new(statement_id, i as u16, chunk))
38 .await?;
39 }
40 }
41 }
42 }
43 Ok(())
44 }
45
46 async fn read_chunk_to_buf(
47 stream: &mut dyn StreamRequirements,
48 dst: &mut Vec<u8>,
49 sleep: &dyn Fn(std::time::Duration) -> TimeoutFuture,
50 timeout: Duration,
51 ) -> Result<(u8, bool), Error> {
52 let mut metadata_buf = [0u8; 4];
53 Timeout::new(stream.read_exact(&mut metadata_buf), sleep, timeout).await??;
54 let chunk_len = read_u32(&metadata_buf[..3]) as usize;
55 let seq_id = metadata_buf[3];
56
57 if chunk_len == 0 {
58 return Ok((seq_id, true));
59 }
60
61 let start = dst.len();
62 dst.resize(start + chunk_len, 0);
63 Timeout::new(stream.read_exact(&mut dst[start..]), sleep, timeout).await??;
64
65 if dst.len() % MAX_PAYLOAD_LEN == 0 {
66 Ok((seq_id, false))
67 } else {
68 Ok((seq_id, true))
69 }
70 }
71
72 pub(super) async fn read_packet_to_buf(
73 stream: &mut dyn StreamRequirements,
74 seq_id: &mut u8,
75 dst: &mut Vec<u8>,
76 sleep: &dyn Fn(std::time::Duration) -> TimeoutFuture,
77 timeout: Duration,
78 ) -> Result<(), Error> {
79 loop {
80 let (read_seq_id, last_chunk) =
81 Self::read_chunk_to_buf(stream, dst, sleep, timeout).await?;
82 if *seq_id != read_seq_id {
83 return Err(Error::Protocol(ProtocolError::OutOfSync));
84 }
85
86 *seq_id = seq_id.wrapping_add(1);
87
88 if last_chunk {
89 return Ok(());
90 }
91 }
92 }
93
94 pub(super) async fn read_packet<'b>(&mut self) -> Result<PoolItem<'b, Vec<u8>>, Error> {
95 let mut decode_buf = BUFFER_POOL.get();
96 Self::read_packet_to_buf(
97 &mut self.stream,
98 &mut self.seq_id,
99 decode_buf.as_mut(),
100 self.data.sleep,
101 self.options.timeout(),
102 )
103 .await?;
104 Ok(decode_buf)
105 }
106
107 pub(super) async fn write_packet(&mut self, mut bytes: &[u8]) -> Result<(), Error> {
108 let extra_packet = bytes.remaining() % MAX_PAYLOAD_LEN == 0;
109
110 while bytes.has_remaining() {
111 let chunk_len = usize::min(bytes.remaining(), MAX_PAYLOAD_LEN);
112 Timeout::new(
113 self.stream
114 .write_u32_le(chunk_len as u32 | (u32::from(self.seq_id) << 24)),
115 self.data.sleep,
116 self.options.timeout(),
117 )
118 .await??;
119 Timeout::new(
120 self.stream.write_all(&bytes[..chunk_len]),
121 self.data.sleep,
122 self.options.timeout(),
123 )
124 .await??;
125 bytes = &bytes[chunk_len..];
126 self.seq_id = self.seq_id.wrapping_add(1);
127 }
128
129 if extra_packet {
130 Timeout::new(
131 self.stream.write_u32_le(u32::from(self.seq_id) << 24),
132 self.data.sleep,
133 self.options.timeout(),
134 )
135 .await??;
136 self.seq_id = self.seq_id.wrapping_add(1);
137 }
138 Ok(())
139 }
140
141 pub(super) async fn write_struct<S: Serialize>(&mut self, x: &S) -> Result<(), Error> {
142 let mut buf = BUFFER_POOL.get();
143 x.serialize(buf.as_mut());
144 self.write_packet(&buf).await
145 }
146
147 pub(super) async fn write_command<S: Serialize>(&mut self, cmd: &S) -> Result<(), Error> {
148 self.cleanup().await?;
149 self.seq_id = 0;
150 self.write_struct(cmd).await
151 }
152
153 pub(crate) async fn decode_response(
154 &mut self,
155 packet: &[u8],
156 ) -> Result<Result<OkPacket, ErrPacket>, Error> {
157 let capabilities = self.data().capabilities();
158 if packet.is_empty() {
159 return Err(ProtocolError::eof().into());
160 }
161 match packet[0] {
162 0x00 => Ok(Ok(OkPacket::read_ok(packet, capabilities)?)),
163 0xFF => Ok(Err(ErrPacket::deserialize(
164 &mut ParseBuf(packet),
165 capabilities,
166 )?)),
167 _ => Err(
168 ProtocolError::unexpected_packet(packet.to_vec(), Some("Ok or Err Packet")).into(),
169 ),
170 }
171 }
172
173 pub(crate) async fn read_response(&mut self) -> Result<Result<OkPacket, ErrPacket>, Error> {
174 let packet = self.read_packet().await?;
175 self.decode_response(&packet).await
176 }
177}