1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
use {
    super::{
        packets::{ErrPacket, OkPacket},
        Connection, ParseBuf, Serialize, Socket, BUFFER_POOL, MAX_PAYLOAD_LEN,
    },
    crate::{
        error::ProtocolError,
        packets::StmtSendLongData,
        pool::PoolItem,
        types::{SimpleValue, Value},
        utils::read_u32,
        Deserialize, Error,
    },
    bytes::Buf,
    tokio::io::{AsyncReadExt, AsyncWriteExt},
};

impl<T: Socket> Connection<T> {
    pub(super) async fn send_long_data<'a, V, I>(
        &mut self,
        statement_id: u32,
        params: I,
    ) -> Result<(), Error>
    where
        V: SimpleValue + 'a,
        I: Iterator<Item = &'a V>,
    {
        for (i, value) in params.enumerate() {
            if let Value::Bytes(bytes) = value.value() {
                if bytes.is_empty() {
                    self.write_command(&StmtSendLongData::new(statement_id, i as u16, &[]))
                        .await?;
                } else {
                    let chunks = bytes.chunks(MAX_PAYLOAD_LEN - 6);
                    for chunk in chunks {
                        self.write_command(&StmtSendLongData::new(statement_id, i as u16, chunk))
                            .await?;
                    }
                }
            }
        }
        Ok(())
    }

    async fn read_chunk_to_buf(socket: &mut T, dst: &mut Vec<u8>) -> Result<(u8, bool), Error> {
        let mut metadata_buf = [0u8; 4];
        socket.read_exact(&mut metadata_buf).await?;
        let chunk_len = read_u32(&metadata_buf[..3]) as usize;
        let seq_id = metadata_buf[3];

        if chunk_len == 0 {
            return Ok((seq_id, true));
        }

        let start = dst.len();
        dst.resize(start + chunk_len, 0);
        socket.read_exact(&mut dst[start..]).await?;

        if dst.len() % MAX_PAYLOAD_LEN == 0 {
            Ok((seq_id, false))
        } else {
            Ok((seq_id, true))
        }
    }

    pub(super) async fn read_packet_to_buf(
        socket: &mut T,
        seq_id: &mut u8,
        dst: &mut Vec<u8>,
    ) -> Result<(), Error> {
        loop {
            let (read_seq_id, last_chunk) = Self::read_chunk_to_buf(socket, dst).await?;
            if *seq_id != read_seq_id {
                return Err(Error::Protocol(ProtocolError::OutOfSync));
            }

            *seq_id = seq_id.wrapping_add(1);

            if last_chunk {
                return Ok(());
            }
        }
    }

    pub(super) async fn read_packet<'b>(&mut self) -> Result<PoolItem<'b, Vec<u8>>, Error> {
        let mut decode_buf = BUFFER_POOL.get();
        Self::read_packet_to_buf(
            &mut self.socket,
            &mut self.seq_id,
            decode_buf.as_mut(),
        )
        .await?;
        Ok(decode_buf)
    }

    pub(super) async fn write_packet(&mut self, mut bytes: &[u8]) -> Result<(), Error> {
        let extra_packet = bytes.remaining() % MAX_PAYLOAD_LEN == 0;

        while bytes.has_remaining() {
            let chunk_len = usize::min(bytes.remaining(), MAX_PAYLOAD_LEN);
            self.socket
                .write_u32_le(chunk_len as u32 | (u32::from(self.seq_id) << 24))
                .await?;
            self.socket.write_all(&bytes[..chunk_len]).await?;
            bytes = &bytes[chunk_len..];
            self.seq_id = self.seq_id.wrapping_add(1);
        }

        if extra_packet {
            self.socket
                .write_u32_le(u32::from(self.seq_id) << 24)
                .await?;
            self.seq_id = self.seq_id.wrapping_add(1);
        }
        Ok(())
    }

    pub(super) async fn write_struct<S: Serialize>(&mut self, x: &S) -> Result<(), Error> {
        let mut buf = BUFFER_POOL.get();
        x.serialize(buf.as_mut());
        self.write_packet(&buf).await
    }

    pub(super) async fn write_command<S: Serialize>(&mut self, cmd: &S) -> Result<(), Error> {
        self.cleanup().await?;
        self.seq_id = 0;
        self.write_struct(cmd).await
    }

    pub(crate) async fn decode_response(
        &mut self,
        packet: &[u8],
    ) -> Result<Result<OkPacket, ErrPacket>, Error> {
        let capabilities = self.data().capabilities();
        if packet.is_empty() {
            return Err(ProtocolError::eof().into());
        }
        match packet[0] {
            0x00 => Ok(Ok(OkPacket::read_ok(packet, capabilities)?)),
            0xFF => Ok(Err(ErrPacket::deserialize(
                &mut ParseBuf(packet),
                capabilities,
            )?)),
            _ => Err(
                ProtocolError::unexpected_packet(packet.to_vec(), Some("Ok or Err Packet")).into(),
            ),
        }
    }

    pub(crate) async fn read_response(&mut self) -> Result<Result<OkPacket, ErrPacket>, Error> {
        let packet = self.read_packet().await?;
        self.decode_response(&packet).await
    }
}