1use crate::myc::constants::{CapabilityFlags, Command as CommandByte};
2
3#[derive(Debug)]
4#[allow(dead_code)] pub struct ClientHandshake<'a> {
6 pub capabilities: CapabilityFlags,
7 maxps: u32,
8 collation: u16,
9 pub(crate) username: Option<&'a [u8]>,
10}
11
12pub fn client_handshake(i: &[u8], after_tls: bool) -> nom::IResult<&[u8], ClientHandshake<'_>> {
13 let (i, cap) = nom::number::complete::le_u16(i)?;
17
18 if CapabilityFlags::from_bits_truncate(cap as u32).contains(CapabilityFlags::CLIENT_PROTOCOL_41)
19 {
20 let (i, cap2) = nom::number::complete::le_u16(i)?;
22 let cap = (cap2 as u32) << 16 | cap as u32;
23
24 let capabilities = CapabilityFlags::from_bits_truncate(cap);
25
26 let (i, maxps) = nom::number::complete::le_u32(i)?;
27 let (i, collation) = nom::bytes::complete::take(1u8)(i)?;
28 let (i, _) = nom::bytes::complete::take(23u8)(i)?;
29
30 let (i, username) = if after_tls || !capabilities.contains(CapabilityFlags::CLIENT_SSL) {
31 let (i, user) = nom::bytes::complete::take_until(&b"\0"[..])(i)?;
32 let (i, _) = nom::bytes::complete::tag(b"\0")(i)?;
33 (i, Some(user))
34 } else {
35 (i, None)
36 };
37
38 Ok((
39 i,
40 ClientHandshake {
41 capabilities,
42 maxps,
43 collation: u16::from(collation[0]),
44 username,
45 },
46 ))
47 } else {
48 let (i, maxps1) = nom::number::complete::le_u16(i)?;
50 let (i, maxps2) = nom::number::complete::le_u8(i)?;
51 let maxps = (maxps2 as u32) << 16 | maxps1 as u32;
52 let (i, username) = nom::bytes::complete::take_until(&b"\0"[..])(i)?;
53
54 Ok((
55 i,
56 ClientHandshake {
57 capabilities: CapabilityFlags::from_bits_truncate(cap as u32),
58 maxps,
59 collation: 0,
60 username: Some(username),
61 },
62 ))
63 }
64}
65
66#[derive(Debug, PartialEq, Eq)]
67pub enum Command<'a> {
68 Query(&'a [u8]),
69 ListFields(&'a [u8]),
70 Close(u32),
71 Prepare(&'a [u8]),
72 Init(&'a [u8]),
73 Execute {
74 stmt: u32,
75 params: &'a [u8],
76 },
77 SendLongData {
78 stmt: u32,
79 param: u16,
80 data: &'a [u8],
81 },
82 Ping,
83 Quit,
84}
85
86pub fn execute(i: &[u8]) -> nom::IResult<&[u8], Command<'_>> {
87 let (i, stmt) = nom::number::complete::le_u32(i)?;
88 let (i, _flags) = nom::bytes::complete::take(1u8)(i)?;
89 let (i, _iterations) = nom::number::complete::le_u32(i)?;
90 Ok((&[], Command::Execute { stmt, params: i }))
91}
92
93pub fn send_long_data(i: &[u8]) -> nom::IResult<&[u8], Command<'_>> {
94 let (i, stmt) = nom::number::complete::le_u32(i)?;
95 let (i, param) = nom::number::complete::le_u16(i)?;
96 Ok((
97 &[],
98 Command::SendLongData {
99 stmt,
100 param,
101 data: i,
102 },
103 ))
104}
105
106pub fn parse(i: &[u8]) -> nom::IResult<&[u8], Command<'_>> {
107 use nom::bytes::complete::tag;
108 use nom::combinator::{map, rest};
109 use nom::sequence::preceded;
110 nom::branch::alt((
111 map(
112 preceded(tag(&[CommandByte::COM_QUERY as u8]), rest),
113 Command::Query,
114 ),
115 map(
116 preceded(tag(&[CommandByte::COM_FIELD_LIST as u8]), rest),
117 Command::ListFields,
118 ),
119 map(
120 preceded(tag(&[CommandByte::COM_INIT_DB as u8]), rest),
121 Command::Init,
122 ),
123 map(
124 preceded(tag(&[CommandByte::COM_STMT_PREPARE as u8]), rest),
125 Command::Prepare,
126 ),
127 preceded(tag(&[CommandByte::COM_STMT_EXECUTE as u8]), execute),
128 preceded(
129 tag(&[CommandByte::COM_STMT_SEND_LONG_DATA as u8]),
130 send_long_data,
131 ),
132 map(
133 preceded(
134 tag(&[CommandByte::COM_STMT_CLOSE as u8]),
135 nom::number::complete::le_u32,
136 ),
137 Command::Close,
138 ),
139 map(tag(&[CommandByte::COM_QUIT as u8]), |_| Command::Quit),
140 map(tag(&[CommandByte::COM_PING as u8]), |_| Command::Ping),
141 ))(i)
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::myc::constants::{CapabilityFlags, UTF8_GENERAL_CI};
148 use crate::packet::PacketConn;
149 use std::io::Cursor;
150
151 #[test]
152 fn it_parses_handshake() {
153 let data = [
154 0x25, 0x00, 0x00, 0x01, 0x85, 0xa6, 0x3f, 0x20, 0x00, 0x00, 0x00, 0x01, 0x21, 0x00,
155 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
156 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6a, 0x6f, 0x6e, 0x00, 0x00,
157 ]
158 .to_vec();
159 let r = Cursor::new(data);
160 let mut pr = PacketConn::new(r);
161 let (_, p) = pr.next().unwrap().unwrap();
162 let (_, handshake) = client_handshake(&p, false).unwrap();
163 println!("{:?}", handshake);
164 assert!(handshake
165 .capabilities
166 .contains(CapabilityFlags::CLIENT_LONG_PASSWORD));
167 assert!(handshake
168 .capabilities
169 .contains(CapabilityFlags::CLIENT_MULTI_RESULTS));
170 assert!(!handshake
171 .capabilities
172 .contains(CapabilityFlags::CLIENT_CONNECT_WITH_DB));
173 assert!(!handshake
174 .capabilities
175 .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF));
176 assert_eq!(handshake.collation, UTF8_GENERAL_CI);
177 assert_eq!(handshake.username.unwrap(), &b"jon"[..]);
178 assert_eq!(handshake.maxps, 16777216);
179 }
180
181 #[test]
182 fn it_parses_handshake_with_ssl_enabled() {
183 let data = [
184 0x25, 0x00, 0x00, 0x01, 0x85, 0xae, 0x3f, 0x20, 0x00, 0x00, 0x00, 0x01, 0x21, 0x00,
185 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
186 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6a, 0x6f, 0x6e, 0x00, 0x00, 0x05,
187 ]
188 .to_vec();
189 let r = Cursor::new(data);
190 let mut pr = PacketConn::new(r);
191 let (_, p) = pr.next().unwrap().unwrap();
192 let (_, handshake) = client_handshake(&p, false).unwrap();
193 println!("{:?}", handshake);
194 assert!(handshake
195 .capabilities
196 .contains(CapabilityFlags::CLIENT_LONG_PASSWORD));
197 assert!(handshake
198 .capabilities
199 .contains(CapabilityFlags::CLIENT_MULTI_RESULTS));
200 assert!(!handshake
201 .capabilities
202 .contains(CapabilityFlags::CLIENT_CONNECT_WITH_DB));
203 assert!(!handshake
204 .capabilities
205 .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF));
206 assert!(handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL));
207 assert_eq!(handshake.collation, UTF8_GENERAL_CI);
208 assert_eq!(handshake.username, None);
209 assert_eq!(handshake.maxps, 16777216);
210 }
211
212 #[test]
213 fn it_parses_handshake_after_ssl() {
214 let data = [
215 0x25, 0x00, 0x00, 0x01, 0x85, 0xae, 0x3f, 0x20, 0x00, 0x00, 0x00, 0x01, 0x21, 0x00,
216 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
217 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6a, 0x6f, 0x6e, 0x00, 0x00, 0x05,
218 ]
219 .to_vec();
220 let r = Cursor::new(data);
221 let mut pr = PacketConn::new(r);
222 let (_, p) = pr.next().unwrap().unwrap();
223 let (_, handshake) = client_handshake(&p, true).unwrap();
224 println!("{:?}", handshake);
225 assert!(handshake
226 .capabilities
227 .contains(CapabilityFlags::CLIENT_LONG_PASSWORD));
228 assert!(handshake
229 .capabilities
230 .contains(CapabilityFlags::CLIENT_MULTI_RESULTS));
231 assert!(!handshake
232 .capabilities
233 .contains(CapabilityFlags::CLIENT_CONNECT_WITH_DB));
234 assert!(!handshake
235 .capabilities
236 .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF));
237 assert!(handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL));
238 assert_eq!(handshake.collation, UTF8_GENERAL_CI);
239 assert_eq!(handshake.username.unwrap(), &b"jon"[..]);
240 assert_eq!(handshake.maxps, 16777216);
241 }
242
243 #[test]
244 fn it_parses_request() {
245 let data = [
246 0x21, 0x00, 0x00, 0x00, 0x03, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40,
247 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e,
248 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31,
249 ]
250 .to_vec();
251 let r = Cursor::new(data);
252 let mut pr = PacketConn::new(r);
253 let (_, p) = pr.next().unwrap().unwrap();
254 let (_, cmd) = parse(&p).unwrap();
255 assert_eq!(
256 cmd,
257 Command::Query(&b"select @@version_comment limit 1"[..])
258 );
259 }
260
261 #[test]
262 fn it_handles_list_fields() {
263 let data = [
267 0x21, 0x00, 0x00, 0x00, 0x04, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40,
268 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e,
269 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31,
270 ]
271 .to_vec();
272 let r = Cursor::new(data);
273 let mut pr = PacketConn::new(r);
274 let (_, p) = pr.next().unwrap().unwrap();
275 let (_, cmd) = parse(&p).unwrap();
276 assert_eq!(
277 cmd,
278 Command::ListFields(&b"select @@version_comment limit 1"[..])
279 );
280 }
281}