1use std::fmt;
2use std::str;
3
4use compound_select::{compound_selection, CompoundSelectStatement};
5use create::{creation, view_creation, CreateTableStatement, CreateViewStatement};
6use delete::{deletion, DeleteStatement};
7use drop::{drop_table, DropTableStatement};
8use insert::{insertion, InsertStatement};
9use nom::branch::alt;
10use nom::combinator::map;
11use nom::IResult;
12use select::{selection, SelectStatement};
13use set::{set, SetStatement};
14use update::{updating, UpdateStatement};
15
16#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
17pub enum SqlQuery {
18 CreateTable(CreateTableStatement),
19 CreateView(CreateViewStatement),
20 Insert(InsertStatement),
21 CompoundSelect(CompoundSelectStatement),
22 Select(SelectStatement),
23 Delete(DeleteStatement),
24 DropTable(DropTableStatement),
25 Update(UpdateStatement),
26 Set(SetStatement),
27}
28
29impl fmt::Display for SqlQuery {
30 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
31 match *self {
32 SqlQuery::Select(ref select) => write!(f, "{}", select),
33 SqlQuery::Insert(ref insert) => write!(f, "{}", insert),
34 SqlQuery::CreateTable(ref create) => write!(f, "{}", create),
35 SqlQuery::CreateView(ref create) => write!(f, "{}", create),
36 SqlQuery::Delete(ref delete) => write!(f, "{}", delete),
37 SqlQuery::DropTable(ref drop) => write!(f, "{}", drop),
38 SqlQuery::Update(ref update) => write!(f, "{}", update),
39 SqlQuery::Set(ref set) => write!(f, "{}", set),
40 _ => unimplemented!(),
41 }
42 }
43}
44
45pub fn sql_query(i: &[u8]) -> IResult<&[u8], SqlQuery> {
46 alt((
47 map(creation, |c| SqlQuery::CreateTable(c)),
48 map(insertion, |i| SqlQuery::Insert(i)),
49 map(compound_selection, |cs| SqlQuery::CompoundSelect(cs)),
50 map(selection, |s| SqlQuery::Select(s)),
51 map(deletion, |d| SqlQuery::Delete(d)),
52 map(drop_table, |dt| SqlQuery::DropTable(dt)),
53 map(updating, |u| SqlQuery::Update(u)),
54 map(set, |s| SqlQuery::Set(s)),
55 map(view_creation, |vc| SqlQuery::CreateView(vc)),
56 ))(i)
57}
58
59pub fn parse_query_bytes<T>(input: T) -> Result<SqlQuery, &'static str>
60where
61 T: AsRef<[u8]>,
62{
63 match sql_query(input.as_ref()) {
64 Ok((_, o)) => Ok(o),
65 Err(_) => Err("failed to parse query"),
66 }
67}
68
69pub fn parse_query<T>(input: T) -> Result<SqlQuery, &'static str>
70where
71 T: AsRef<str>,
72{
73 parse_query_bytes(input.as_ref().trim().as_bytes())
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use std::collections::hash_map::DefaultHasher;
80 use std::hash::{Hash, Hasher};
81
82 use table::Table;
83
84 #[test]
85 fn hash_query() {
86 let qstring = "INSERT INTO users VALUES (42, \"test\");";
87 let res = parse_query(qstring);
88 assert!(res.is_ok());
89
90 let expected = SqlQuery::Insert(InsertStatement {
91 table: Table::from("users"),
92 fields: None,
93 data: vec![vec![42.into(), "test".into()]],
94 ..Default::default()
95 });
96 let mut h0 = DefaultHasher::new();
97 let mut h1 = DefaultHasher::new();
98 res.unwrap().hash(&mut h0);
99 expected.hash(&mut h1);
100 assert_eq!(h0.finish(), h1.finish());
101 }
102
103 #[test]
104 fn trim_query() {
105 let qstring = " INSERT INTO users VALUES (42, \"test\"); ";
106 let res = parse_query(qstring);
107 assert!(res.is_ok());
108 }
109
110 #[test]
111 fn parse_byte_slice() {
112 let qstring: &[u8] = b"INSERT INTO users VALUES (42, \"test\");";
113 let res = parse_query_bytes(qstring);
114 assert!(res.is_ok());
115 }
116
117 #[test]
118 fn parse_byte_vector() {
119 let qstring: Vec<u8> = b"INSERT INTO users VALUES (42, \"test\");".to_vec();
120 let res = parse_query_bytes(&qstring);
121 assert!(res.is_ok());
122 }
123
124 #[test]
125 fn display_select_query() {
126 let qstring0 = "SELECT * FROM users";
127 let qstring1 = "SELECT * FROM users AS u";
128 let qstring2 = "SELECT name, password FROM users AS u";
129 let qstring3 = "SELECT name, password FROM users AS u WHERE user_id = '1'";
130 let qstring4 =
131 "SELECT name, password FROM users AS u WHERE user = 'aaa' AND password = 'xxx'";
132 let qstring5 = "SELECT name * 2 AS double_name FROM users";
133
134 let res0 = parse_query(qstring0);
135 let res1 = parse_query(qstring1);
136 let res2 = parse_query(qstring2);
137 let res3 = parse_query(qstring3);
138 let res4 = parse_query(qstring4);
139 let res5 = parse_query(qstring5);
140
141 assert!(res0.is_ok());
142 assert!(res1.is_ok());
143 assert!(res2.is_ok());
144 assert!(res3.is_ok());
145 assert!(res4.is_ok());
146 assert!(res5.is_ok());
147
148 assert_eq!(qstring0, format!("{}", res0.unwrap()));
149 assert_eq!(qstring1, format!("{}", res1.unwrap()));
150 assert_eq!(qstring2, format!("{}", res2.unwrap()));
151 assert_eq!(qstring3, format!("{}", res3.unwrap()));
152 assert_eq!(qstring4, format!("{}", res4.unwrap()));
153 assert_eq!(qstring5, format!("{}", res5.unwrap()));
154 }
155
156 #[test]
157 fn format_select_query() {
158 let qstring1 = "select * from users u";
159 let qstring2 = "select name,password from users u;";
160 let qstring3 = "select name,password from users u WHERE user_id='1'";
161
162 let expected1 = "SELECT * FROM users AS u";
163 let expected2 = "SELECT name, password FROM users AS u";
164 let expected3 = "SELECT name, password FROM users AS u WHERE user_id = '1'";
165
166 let res1 = parse_query(qstring1);
167 let res2 = parse_query(qstring2);
168 let res3 = parse_query(qstring3);
169
170 assert!(res1.is_ok());
171 assert!(res2.is_ok());
172 assert!(res3.is_ok());
173
174 assert_eq!(expected1, format!("{}", res1.unwrap()));
175 assert_eq!(expected2, format!("{}", res2.unwrap()));
176 assert_eq!(expected3, format!("{}", res3.unwrap()));
177 }
178
179 #[test]
180 fn format_select_query_with_where_clause() {
181 let qstring0 = "select name, password from users as u where user='aaa' and password= 'xxx'";
182 let qstring1 = "select name, password from users as u where user=? and password =?";
183
184 let expected0 =
185 "SELECT name, password FROM users AS u WHERE user = 'aaa' AND password = 'xxx'";
186 let expected1 = "SELECT name, password FROM users AS u WHERE user = ? AND password = ?";
187
188 let res0 = parse_query(qstring0);
189 let res1 = parse_query(qstring1);
190 assert!(res0.is_ok());
191 assert!(res1.is_ok());
192 assert_eq!(expected0, format!("{}", res0.unwrap()));
193 assert_eq!(expected1, format!("{}", res1.unwrap()));
194 }
195
196 #[test]
197 fn format_select_query_with_function() {
198 let qstring1 = "select count(*) from users";
199 let expected1 = "SELECT count(*) FROM users";
200
201 let res1 = parse_query(qstring1);
202 assert!(res1.is_ok());
203 assert_eq!(expected1, format!("{}", res1.unwrap()));
204 }
205
206 #[test]
207 fn display_insert_query() {
208 let qstring = "INSERT INTO users (name, password) VALUES ('aaa', 'xxx')";
209 let res = parse_query(qstring);
210 assert!(res.is_ok());
211 assert_eq!(qstring, format!("{}", res.unwrap()));
212 }
213
214 #[test]
215 fn display_insert_query_no_columns() {
216 let qstring = "INSERT INTO users VALUES ('aaa', 'xxx')";
217 let expected = "INSERT INTO users VALUES ('aaa', 'xxx')";
218 let res = parse_query(qstring);
219 assert!(res.is_ok());
220 assert_eq!(expected, format!("{}", res.unwrap()));
221 }
222
223 #[test]
224 fn format_insert_query() {
225 let qstring = "insert into users (name, password) values ('aaa', 'xxx')";
226 let expected = "INSERT INTO users (name, password) VALUES ('aaa', 'xxx')";
227 let res = parse_query(qstring);
228 assert!(res.is_ok());
229 assert_eq!(expected, format!("{}", res.unwrap()));
230 }
231
232 #[test]
233 fn format_update_query() {
234 let qstring = "update users set name=42, password='xxx' where id=1";
235 let expected = "UPDATE users SET name = 42, password = 'xxx' WHERE id = 1";
236 let res = parse_query(qstring);
237 assert!(res.is_ok());
238 assert_eq!(expected, format!("{}", res.unwrap()));
239 }
240
241 #[test]
242 fn format_delete_query_with_where_clause() {
243 let qstring0 = "delete from users where user='aaa' and password= 'xxx'";
244 let qstring1 = "delete from users where user=? and password =?";
245
246 let expected0 = "DELETE FROM users WHERE user = 'aaa' AND password = 'xxx'";
247 let expected1 = "DELETE FROM users WHERE user = ? AND password = ?";
248
249 let res0 = parse_query(qstring0);
250 let res1 = parse_query(qstring1);
251 assert!(res0.is_ok());
252 assert!(res1.is_ok());
253 assert_eq!(expected0, format!("{}", res0.unwrap()));
254 assert_eq!(expected1, format!("{}", res1.unwrap()));
255 }
256
257 #[test]
258 fn format_query_with_escaped_keyword() {
259 let qstring0 = "delete from articles where `key`='aaa'";
260 let qstring1 = "delete from `where` where user=?";
261
262 let expected0 = "DELETE FROM articles WHERE `key` = 'aaa'";
263 let expected1 = "DELETE FROM `where` WHERE user = ?";
264
265 let res0 = parse_query(qstring0);
266 let res1 = parse_query(qstring1);
267 assert!(res0.is_ok());
268 assert!(res1.is_ok());
269 assert_eq!(expected0, format!("{}", res0.unwrap()));
270 assert_eq!(expected1, format!("{}", res1.unwrap()));
271 }
272}