msql_srv/
commands.rs

1use crate::myc::constants::{CapabilityFlags, Command as CommandByte};
2
3#[derive(Debug)]
4#[allow(dead_code)] // The fields here are read, but only in tests. This keeps clippy quiet.
5pub struct ClientHandshake<'a> {
6    pub capabilities: CapabilityFlags,
7    maxps: u32,
8    collation: u16,
9    pub(crate) username: Option<&'a [u8]>,
10}
11
12pub fn client_handshake(i: &[u8], after_tls: bool) -> nom::IResult<&[u8], ClientHandshake<'_>> {
13    // mysql handshake protocol documentation
14    // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html
15
16    let (i, cap) = nom::number::complete::le_u16(i)?;
17
18    if CapabilityFlags::from_bits_truncate(cap as u32).contains(CapabilityFlags::CLIENT_PROTOCOL_41)
19    {
20        // HandshakeResponse41
21        let (i, cap2) = nom::number::complete::le_u16(i)?;
22        let cap = (cap2 as u32) << 16 | cap as u32;
23
24        let capabilities = CapabilityFlags::from_bits_truncate(cap);
25
26        let (i, maxps) = nom::number::complete::le_u32(i)?;
27        let (i, collation) = nom::bytes::complete::take(1u8)(i)?;
28        let (i, _) = nom::bytes::complete::take(23u8)(i)?;
29
30        let (i, username) = if after_tls || !capabilities.contains(CapabilityFlags::CLIENT_SSL) {
31            let (i, user) = nom::bytes::complete::take_until(&b"\0"[..])(i)?;
32            let (i, _) = nom::bytes::complete::tag(b"\0")(i)?;
33            (i, Some(user))
34        } else {
35            (i, None)
36        };
37
38        Ok((
39            i,
40            ClientHandshake {
41                capabilities,
42                maxps,
43                collation: u16::from(collation[0]),
44                username,
45            },
46        ))
47    } else {
48        // HandshakeResponse320
49        let (i, maxps1) = nom::number::complete::le_u16(i)?;
50        let (i, maxps2) = nom::number::complete::le_u8(i)?;
51        let maxps = (maxps2 as u32) << 16 | maxps1 as u32;
52        let (i, username) = nom::bytes::complete::take_until(&b"\0"[..])(i)?;
53
54        Ok((
55            i,
56            ClientHandshake {
57                capabilities: CapabilityFlags::from_bits_truncate(cap as u32),
58                maxps,
59                collation: 0,
60                username: Some(username),
61            },
62        ))
63    }
64}
65
66#[derive(Debug, PartialEq, Eq)]
67pub enum Command<'a> {
68    Query(&'a [u8]),
69    ListFields(&'a [u8]),
70    Close(u32),
71    Prepare(&'a [u8]),
72    Init(&'a [u8]),
73    Execute {
74        stmt: u32,
75        params: &'a [u8],
76    },
77    SendLongData {
78        stmt: u32,
79        param: u16,
80        data: &'a [u8],
81    },
82    Ping,
83    Quit,
84}
85
86pub fn execute(i: &[u8]) -> nom::IResult<&[u8], Command<'_>> {
87    let (i, stmt) = nom::number::complete::le_u32(i)?;
88    let (i, _flags) = nom::bytes::complete::take(1u8)(i)?;
89    let (i, _iterations) = nom::number::complete::le_u32(i)?;
90    Ok((&[], Command::Execute { stmt, params: i }))
91}
92
93pub fn send_long_data(i: &[u8]) -> nom::IResult<&[u8], Command<'_>> {
94    let (i, stmt) = nom::number::complete::le_u32(i)?;
95    let (i, param) = nom::number::complete::le_u16(i)?;
96    Ok((
97        &[],
98        Command::SendLongData {
99            stmt,
100            param,
101            data: i,
102        },
103    ))
104}
105
106pub fn parse(i: &[u8]) -> nom::IResult<&[u8], Command<'_>> {
107    use nom::bytes::complete::tag;
108    use nom::combinator::{map, rest};
109    use nom::sequence::preceded;
110    nom::branch::alt((
111        map(
112            preceded(tag(&[CommandByte::COM_QUERY as u8]), rest),
113            Command::Query,
114        ),
115        map(
116            preceded(tag(&[CommandByte::COM_FIELD_LIST as u8]), rest),
117            Command::ListFields,
118        ),
119        map(
120            preceded(tag(&[CommandByte::COM_INIT_DB as u8]), rest),
121            Command::Init,
122        ),
123        map(
124            preceded(tag(&[CommandByte::COM_STMT_PREPARE as u8]), rest),
125            Command::Prepare,
126        ),
127        preceded(tag(&[CommandByte::COM_STMT_EXECUTE as u8]), execute),
128        preceded(
129            tag(&[CommandByte::COM_STMT_SEND_LONG_DATA as u8]),
130            send_long_data,
131        ),
132        map(
133            preceded(
134                tag(&[CommandByte::COM_STMT_CLOSE as u8]),
135                nom::number::complete::le_u32,
136            ),
137            Command::Close,
138        ),
139        map(tag(&[CommandByte::COM_QUIT as u8]), |_| Command::Quit),
140        map(tag(&[CommandByte::COM_PING as u8]), |_| Command::Ping),
141    ))(i)
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use crate::myc::constants::{CapabilityFlags, UTF8_GENERAL_CI};
148    use crate::packet::PacketConn;
149    use std::io::Cursor;
150
151    #[test]
152    fn it_parses_handshake() {
153        let data = [
154            0x25, 0x00, 0x00, 0x01, 0x85, 0xa6, 0x3f, 0x20, 0x00, 0x00, 0x00, 0x01, 0x21, 0x00,
155            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
156            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6a, 0x6f, 0x6e, 0x00, 0x00,
157        ]
158        .to_vec();
159        let r = Cursor::new(data);
160        let mut pr = PacketConn::new(r);
161        let (_, p) = pr.next().unwrap().unwrap();
162        let (_, handshake) = client_handshake(&p, false).unwrap();
163        println!("{:?}", handshake);
164        assert!(handshake
165            .capabilities
166            .contains(CapabilityFlags::CLIENT_LONG_PASSWORD));
167        assert!(handshake
168            .capabilities
169            .contains(CapabilityFlags::CLIENT_MULTI_RESULTS));
170        assert!(!handshake
171            .capabilities
172            .contains(CapabilityFlags::CLIENT_CONNECT_WITH_DB));
173        assert!(!handshake
174            .capabilities
175            .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF));
176        assert_eq!(handshake.collation, UTF8_GENERAL_CI);
177        assert_eq!(handshake.username.unwrap(), &b"jon"[..]);
178        assert_eq!(handshake.maxps, 16777216);
179    }
180
181    #[test]
182    fn it_parses_handshake_with_ssl_enabled() {
183        let data = [
184            0x25, 0x00, 0x00, 0x01, 0x85, 0xae, 0x3f, 0x20, 0x00, 0x00, 0x00, 0x01, 0x21, 0x00,
185            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
186            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6a, 0x6f, 0x6e, 0x00, 0x00, 0x05,
187        ]
188        .to_vec();
189        let r = Cursor::new(data);
190        let mut pr = PacketConn::new(r);
191        let (_, p) = pr.next().unwrap().unwrap();
192        let (_, handshake) = client_handshake(&p, false).unwrap();
193        println!("{:?}", handshake);
194        assert!(handshake
195            .capabilities
196            .contains(CapabilityFlags::CLIENT_LONG_PASSWORD));
197        assert!(handshake
198            .capabilities
199            .contains(CapabilityFlags::CLIENT_MULTI_RESULTS));
200        assert!(!handshake
201            .capabilities
202            .contains(CapabilityFlags::CLIENT_CONNECT_WITH_DB));
203        assert!(!handshake
204            .capabilities
205            .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF));
206        assert!(handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL));
207        assert_eq!(handshake.collation, UTF8_GENERAL_CI);
208        assert_eq!(handshake.username, None);
209        assert_eq!(handshake.maxps, 16777216);
210    }
211
212    #[test]
213    fn it_parses_handshake_after_ssl() {
214        let data = [
215            0x25, 0x00, 0x00, 0x01, 0x85, 0xae, 0x3f, 0x20, 0x00, 0x00, 0x00, 0x01, 0x21, 0x00,
216            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
217            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6a, 0x6f, 0x6e, 0x00, 0x00, 0x05,
218        ]
219        .to_vec();
220        let r = Cursor::new(data);
221        let mut pr = PacketConn::new(r);
222        let (_, p) = pr.next().unwrap().unwrap();
223        let (_, handshake) = client_handshake(&p, true).unwrap();
224        println!("{:?}", handshake);
225        assert!(handshake
226            .capabilities
227            .contains(CapabilityFlags::CLIENT_LONG_PASSWORD));
228        assert!(handshake
229            .capabilities
230            .contains(CapabilityFlags::CLIENT_MULTI_RESULTS));
231        assert!(!handshake
232            .capabilities
233            .contains(CapabilityFlags::CLIENT_CONNECT_WITH_DB));
234        assert!(!handshake
235            .capabilities
236            .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF));
237        assert!(handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL));
238        assert_eq!(handshake.collation, UTF8_GENERAL_CI);
239        assert_eq!(handshake.username.unwrap(), &b"jon"[..]);
240        assert_eq!(handshake.maxps, 16777216);
241    }
242
243    #[test]
244    fn it_parses_request() {
245        let data = [
246            0x21, 0x00, 0x00, 0x00, 0x03, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40,
247            0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e,
248            0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31,
249        ]
250        .to_vec();
251        let r = Cursor::new(data);
252        let mut pr = PacketConn::new(r);
253        let (_, p) = pr.next().unwrap().unwrap();
254        let (_, cmd) = parse(&p).unwrap();
255        assert_eq!(
256            cmd,
257            Command::Query(&b"select @@version_comment limit 1"[..])
258        );
259    }
260
261    #[test]
262    fn it_handles_list_fields() {
263        // mysql_list_fields (CommandByte::COM_FIELD_LIST / 0x04) has been deprecated in mysql 5.7 and will be removed
264        // in a future version. The mysql command line tool issues one of these commands after
265        // switching databases with USE <DB>.
266        let data = [
267            0x21, 0x00, 0x00, 0x00, 0x04, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40,
268            0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e,
269            0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31,
270        ]
271        .to_vec();
272        let r = Cursor::new(data);
273        let mut pr = PacketConn::new(r);
274        let (_, p) = pr.next().unwrap().unwrap();
275        let (_, cmd) = parse(&p).unwrap();
276        assert_eq!(
277            cmd,
278            Command::ListFields(&b"select @@version_comment limit 1"[..])
279        );
280    }
281}