rbdc_mysql/protocol/
packet.rs

1use std::ops::{Deref, DerefMut};
2
3use crate::protocol::response::{EofPacket, OkPacket};
4use crate::protocol::Capabilities;
5use bytes::Bytes;
6use rbdc::io::{Decode, Encode};
7use rbdc::Error;
8use std::cmp::min;
9
10#[derive(Debug)]
11pub struct Packet<T>(pub T);
12
13impl<'en, 'stream, T> Encode<'stream, (Capabilities, &'stream mut u8)> for Packet<T>
14where
15    T: Encode<'en, Capabilities>,
16{
17    fn encode_with(
18        &self,
19        buf: &mut Vec<u8>,
20        (capabilities, sequence_id): (Capabilities, &'stream mut u8),
21    ) {
22        let mut next_header = |len: u32| {
23            let mut buf = len.to_le_bytes();
24            buf[3] = *sequence_id;
25            *sequence_id = sequence_id.wrapping_add(1);
26
27            buf
28        };
29
30        // reserve space to write the prefixed length
31        let offset = buf.len();
32        buf.extend(&[0_u8; 4]);
33
34        // encode the payload
35        self.0.encode_with(buf, capabilities);
36
37        // determine the length of the encoded payload
38        // and write to our reserved space
39        let len = buf.len() - offset - 4;
40        let header = &mut buf[offset..];
41
42        header[..4].copy_from_slice(&next_header(min(len, 0xFF_FF_FF) as u32));
43
44        // add more packets if we need to split the data
45        if len >= 0xFF_FF_FF {
46            let rest = buf.split_off(offset + 4 + 0xFF_FF_FF);
47            let mut chunks = rest.chunks_exact(0xFF_FF_FF);
48
49            for chunk in chunks.by_ref() {
50                buf.reserve(chunk.len() + 4);
51                buf.extend(&next_header(chunk.len() as u32));
52                buf.extend(chunk);
53            }
54
55            // this will also handle adding a zero sized packet if the data size is a multiple of 0xFF_FF_FF
56            let remainder = chunks.remainder();
57            buf.reserve(remainder.len() + 4);
58            buf.extend(&next_header(remainder.len() as u32));
59            buf.extend(remainder);
60        }
61    }
62}
63
64impl Packet<Bytes> {
65    pub fn decode<'de, T>(self) -> Result<T, Error>
66    where
67        T: Decode<'de, ()>,
68    {
69        self.decode_with(())
70    }
71
72    pub fn decode_with<'de, T, C>(self, context: C) -> Result<T, Error>
73    where
74        T: Decode<'de, C>,
75    {
76        T::decode_with(self.0, context)
77    }
78
79    pub fn ok(self) -> Result<OkPacket, Error> {
80        self.decode()
81    }
82
83    pub fn eof(self, capabilities: Capabilities) -> Result<EofPacket, Error> {
84        if capabilities.contains(Capabilities::DEPRECATE_EOF) {
85            let ok = self.ok()?;
86
87            Ok(EofPacket {
88                warnings: ok.warnings,
89                status: ok.status,
90            })
91        } else {
92            self.decode_with(capabilities)
93        }
94    }
95}
96
97impl Deref for Packet<Bytes> {
98    type Target = Bytes;
99
100    fn deref(&self) -> &Bytes {
101        &self.0
102    }
103}
104
105impl DerefMut for Packet<Bytes> {
106    fn deref_mut(&mut self) -> &mut Bytes {
107        &mut self.0
108    }
109}