mysql_connector/connection/packets/
err.rs

1use {
2    crate::{
3        bitflags::CapabilityFlags,
4        connection::{Deserialize, ParseBuf},
5        error::ProtocolError,
6    },
7    std::fmt,
8};
9
10#[derive(Debug)]
11pub enum ErrPacket {
12    Error(ErrorPacket),
13    Progress(ProgressPacket),
14}
15
16impl Deserialize<'_> for ErrPacket {
17    const SIZE: Option<usize> = None;
18    type Ctx = CapabilityFlags;
19
20    fn deserialize(buf: &mut ParseBuf<'_>, capabilities: Self::Ctx) -> Result<Self, ProtocolError> {
21        buf.check_len(3)?;
22        if buf.eat_u8() != 0xFF {
23            return Err(ProtocolError::unexpected_packet(
24                buf.0.to_vec(),
25                Some("Err Packet"),
26            ));
27        }
28        let code = buf.eat_u16();
29
30        if code == 0xFFFF && capabilities.contains(CapabilityFlags::PROGRESS_OBSOLETE) {
31            buf.parse(()).map(ErrPacket::Progress)
32        } else {
33            buf.parse((code, capabilities.contains(CapabilityFlags::PROTOCOL_41)))
34                .map(ErrPacket::Error)
35        }
36    }
37}
38
39pub struct ErrorPacket {
40    code: u16,
41    state: Option<[u8; 5]>,
42    message: String,
43}
44
45impl ErrorPacket {
46    pub fn code(&self) -> u16 {
47        self.code
48    }
49
50    pub fn state(&self) -> Option<&[u8; 5]> {
51        self.state.as_ref()
52    }
53
54    pub fn state_str(&self) -> Option<&str> {
55        self.state.as_ref().map(|x| unsafe {
56            // Safety: state is validated during parsing
57            std::str::from_utf8_unchecked(x)
58        })
59    }
60
61    pub fn message(&self) -> &str {
62        &self.message
63    }
64}
65
66impl fmt::Debug for ErrorPacket {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        let mut format = f.debug_struct("ErrorPacket");
69        format.field("code", &self.code);
70        if let Some(state) = self.state_str() {
71            format.field("state", &state);
72        }
73        format.field("message", &self.message);
74        format.finish()
75    }
76}
77
78impl Deserialize<'_> for ErrorPacket {
79    const SIZE: Option<usize> = None;
80    type Ctx = (u16, bool);
81
82    fn deserialize(
83        buf: &mut crate::connection::ParseBuf<'_>,
84        (code, protocol_41): Self::Ctx,
85    ) -> Result<Self, ProtocolError> {
86        let state = if protocol_41 {
87            if buf.checked_eat_u8()? != b'#' {
88                return Err(ProtocolError::invalid_packet(
89                    buf.0.to_vec(),
90                    "Err",
91                    "missing state",
92                ));
93            }
94            let state = unsafe { *(buf.checked_eat(5)? as *const _ as *const [u8; 5]) };
95            std::str::from_utf8(&state)?;
96            Some(state)
97        } else {
98            None
99        };
100        Ok(ErrorPacket {
101            code,
102            state,
103            message: String::from_utf8(buf.eat_all().to_owned())?,
104        })
105    }
106}
107
108#[derive(Debug)]
109pub struct ProgressPacket {
110    stage: u8,
111    max_stage: u8,
112    progress: u32,
113    stage_info: Vec<u8>,
114}
115
116impl ProgressPacket {
117    pub fn stage(&self) -> u8 {
118        self.stage
119    }
120
121    pub fn max_stage(&self) -> u8 {
122        self.max_stage
123    }
124
125    pub fn progress(&self) -> u32 {
126        self.progress
127    }
128
129    pub fn stage_info(&self) -> &Vec<u8> {
130        &self.stage_info
131    }
132}
133
134impl Deserialize<'_> for ProgressPacket {
135    const SIZE: Option<usize> = None;
136    type Ctx = ();
137
138    fn deserialize(
139        buf: &mut crate::connection::ParseBuf<'_>,
140        _ctx: Self::Ctx,
141    ) -> Result<Self, ProtocolError> {
142        buf.check_len(6)?;
143        buf.skip(1);
144        Ok(ProgressPacket {
145            stage: buf.eat_u8(),
146            max_stage: buf.eat_u8(),
147            progress: buf.eat_u24(),
148            stage_info: buf.checked_eat_lenenc_slice()?.to_vec(),
149        })
150    }
151}