nom_sql/
parser.rs

1use std::fmt;
2use std::str;
3
4use compound_select::{compound_selection, CompoundSelectStatement};
5use create::{creation, view_creation, CreateTableStatement, CreateViewStatement};
6use delete::{deletion, DeleteStatement};
7use drop::{drop_table, DropTableStatement};
8use insert::{insertion, InsertStatement};
9use nom::branch::alt;
10use nom::combinator::map;
11use nom::IResult;
12use select::{selection, SelectStatement};
13use set::{set, SetStatement};
14use update::{updating, UpdateStatement};
15
16#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
17pub enum SqlQuery {
18    CreateTable(CreateTableStatement),
19    CreateView(CreateViewStatement),
20    Insert(InsertStatement),
21    CompoundSelect(CompoundSelectStatement),
22    Select(SelectStatement),
23    Delete(DeleteStatement),
24    DropTable(DropTableStatement),
25    Update(UpdateStatement),
26    Set(SetStatement),
27}
28
29impl fmt::Display for SqlQuery {
30    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
31        match *self {
32            SqlQuery::Select(ref select) => write!(f, "{}", select),
33            SqlQuery::Insert(ref insert) => write!(f, "{}", insert),
34            SqlQuery::CreateTable(ref create) => write!(f, "{}", create),
35            SqlQuery::CreateView(ref create) => write!(f, "{}", create),
36            SqlQuery::Delete(ref delete) => write!(f, "{}", delete),
37            SqlQuery::DropTable(ref drop) => write!(f, "{}", drop),
38            SqlQuery::Update(ref update) => write!(f, "{}", update),
39            SqlQuery::Set(ref set) => write!(f, "{}", set),
40            _ => unimplemented!(),
41        }
42    }
43}
44
45pub fn sql_query(i: &[u8]) -> IResult<&[u8], SqlQuery> {
46    alt((
47        map(creation, |c| SqlQuery::CreateTable(c)),
48        map(insertion, |i| SqlQuery::Insert(i)),
49        map(compound_selection, |cs| SqlQuery::CompoundSelect(cs)),
50        map(selection, |s| SqlQuery::Select(s)),
51        map(deletion, |d| SqlQuery::Delete(d)),
52        map(drop_table, |dt| SqlQuery::DropTable(dt)),
53        map(updating, |u| SqlQuery::Update(u)),
54        map(set, |s| SqlQuery::Set(s)),
55        map(view_creation, |vc| SqlQuery::CreateView(vc)),
56    ))(i)
57}
58
59pub fn parse_query_bytes<T>(input: T) -> Result<SqlQuery, &'static str>
60where
61    T: AsRef<[u8]>,
62{
63    match sql_query(input.as_ref()) {
64        Ok((_, o)) => Ok(o),
65        Err(_) => Err("failed to parse query"),
66    }
67}
68
69pub fn parse_query<T>(input: T) -> Result<SqlQuery, &'static str>
70where
71    T: AsRef<str>,
72{
73    parse_query_bytes(input.as_ref().trim().as_bytes())
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use std::collections::hash_map::DefaultHasher;
80    use std::hash::{Hash, Hasher};
81
82    use table::Table;
83
84    #[test]
85    fn hash_query() {
86        let qstring = "INSERT INTO users VALUES (42, \"test\");";
87        let res = parse_query(qstring);
88        assert!(res.is_ok());
89
90        let expected = SqlQuery::Insert(InsertStatement {
91            table: Table::from("users"),
92            fields: None,
93            data: vec![vec![42.into(), "test".into()]],
94            ..Default::default()
95        });
96        let mut h0 = DefaultHasher::new();
97        let mut h1 = DefaultHasher::new();
98        res.unwrap().hash(&mut h0);
99        expected.hash(&mut h1);
100        assert_eq!(h0.finish(), h1.finish());
101    }
102
103    #[test]
104    fn trim_query() {
105        let qstring = "   INSERT INTO users VALUES (42, \"test\");     ";
106        let res = parse_query(qstring);
107        assert!(res.is_ok());
108    }
109
110    #[test]
111    fn parse_byte_slice() {
112        let qstring: &[u8] = b"INSERT INTO users VALUES (42, \"test\");";
113        let res = parse_query_bytes(qstring);
114        assert!(res.is_ok());
115    }
116
117    #[test]
118    fn parse_byte_vector() {
119        let qstring: Vec<u8> = b"INSERT INTO users VALUES (42, \"test\");".to_vec();
120        let res = parse_query_bytes(&qstring);
121        assert!(res.is_ok());
122    }
123
124    #[test]
125    fn display_select_query() {
126        let qstring0 = "SELECT * FROM users";
127        let qstring1 = "SELECT * FROM users AS u";
128        let qstring2 = "SELECT name, password FROM users AS u";
129        let qstring3 = "SELECT name, password FROM users AS u WHERE user_id = '1'";
130        let qstring4 =
131            "SELECT name, password FROM users AS u WHERE user = 'aaa' AND password = 'xxx'";
132        let qstring5 = "SELECT name * 2 AS double_name FROM users";
133
134        let res0 = parse_query(qstring0);
135        let res1 = parse_query(qstring1);
136        let res2 = parse_query(qstring2);
137        let res3 = parse_query(qstring3);
138        let res4 = parse_query(qstring4);
139        let res5 = parse_query(qstring5);
140
141        assert!(res0.is_ok());
142        assert!(res1.is_ok());
143        assert!(res2.is_ok());
144        assert!(res3.is_ok());
145        assert!(res4.is_ok());
146        assert!(res5.is_ok());
147
148        assert_eq!(qstring0, format!("{}", res0.unwrap()));
149        assert_eq!(qstring1, format!("{}", res1.unwrap()));
150        assert_eq!(qstring2, format!("{}", res2.unwrap()));
151        assert_eq!(qstring3, format!("{}", res3.unwrap()));
152        assert_eq!(qstring4, format!("{}", res4.unwrap()));
153        assert_eq!(qstring5, format!("{}", res5.unwrap()));
154    }
155
156    #[test]
157    fn format_select_query() {
158        let qstring1 = "select * from users u";
159        let qstring2 = "select name,password from users u;";
160        let qstring3 = "select name,password from users u WHERE user_id='1'";
161
162        let expected1 = "SELECT * FROM users AS u";
163        let expected2 = "SELECT name, password FROM users AS u";
164        let expected3 = "SELECT name, password FROM users AS u WHERE user_id = '1'";
165
166        let res1 = parse_query(qstring1);
167        let res2 = parse_query(qstring2);
168        let res3 = parse_query(qstring3);
169
170        assert!(res1.is_ok());
171        assert!(res2.is_ok());
172        assert!(res3.is_ok());
173
174        assert_eq!(expected1, format!("{}", res1.unwrap()));
175        assert_eq!(expected2, format!("{}", res2.unwrap()));
176        assert_eq!(expected3, format!("{}", res3.unwrap()));
177    }
178
179    #[test]
180    fn format_select_query_with_where_clause() {
181        let qstring0 = "select name, password from users as u where user='aaa' and password= 'xxx'";
182        let qstring1 = "select name, password from users as u where user=? and password =?";
183
184        let expected0 =
185            "SELECT name, password FROM users AS u WHERE user = 'aaa' AND password = 'xxx'";
186        let expected1 = "SELECT name, password FROM users AS u WHERE user = ? AND password = ?";
187
188        let res0 = parse_query(qstring0);
189        let res1 = parse_query(qstring1);
190        assert!(res0.is_ok());
191        assert!(res1.is_ok());
192        assert_eq!(expected0, format!("{}", res0.unwrap()));
193        assert_eq!(expected1, format!("{}", res1.unwrap()));
194    }
195
196    #[test]
197    fn format_select_query_with_function() {
198        let qstring1 = "select count(*) from users";
199        let expected1 = "SELECT count(*) FROM users";
200
201        let res1 = parse_query(qstring1);
202        assert!(res1.is_ok());
203        assert_eq!(expected1, format!("{}", res1.unwrap()));
204    }
205
206    #[test]
207    fn display_insert_query() {
208        let qstring = "INSERT INTO users (name, password) VALUES ('aaa', 'xxx')";
209        let res = parse_query(qstring);
210        assert!(res.is_ok());
211        assert_eq!(qstring, format!("{}", res.unwrap()));
212    }
213
214    #[test]
215    fn display_insert_query_no_columns() {
216        let qstring = "INSERT INTO users VALUES ('aaa', 'xxx')";
217        let expected = "INSERT INTO users VALUES ('aaa', 'xxx')";
218        let res = parse_query(qstring);
219        assert!(res.is_ok());
220        assert_eq!(expected, format!("{}", res.unwrap()));
221    }
222
223    #[test]
224    fn format_insert_query() {
225        let qstring = "insert into users (name, password) values ('aaa', 'xxx')";
226        let expected = "INSERT INTO users (name, password) VALUES ('aaa', 'xxx')";
227        let res = parse_query(qstring);
228        assert!(res.is_ok());
229        assert_eq!(expected, format!("{}", res.unwrap()));
230    }
231
232    #[test]
233    fn format_update_query() {
234        let qstring = "update users set name=42, password='xxx' where id=1";
235        let expected = "UPDATE users SET name = 42, password = 'xxx' WHERE id = 1";
236        let res = parse_query(qstring);
237        assert!(res.is_ok());
238        assert_eq!(expected, format!("{}", res.unwrap()));
239    }
240
241    #[test]
242    fn format_delete_query_with_where_clause() {
243        let qstring0 = "delete from users where user='aaa' and password= 'xxx'";
244        let qstring1 = "delete from users where user=? and password =?";
245
246        let expected0 = "DELETE FROM users WHERE user = 'aaa' AND password = 'xxx'";
247        let expected1 = "DELETE FROM users WHERE user = ? AND password = ?";
248
249        let res0 = parse_query(qstring0);
250        let res1 = parse_query(qstring1);
251        assert!(res0.is_ok());
252        assert!(res1.is_ok());
253        assert_eq!(expected0, format!("{}", res0.unwrap()));
254        assert_eq!(expected1, format!("{}", res1.unwrap()));
255    }
256
257    #[test]
258    fn format_query_with_escaped_keyword() {
259        let qstring0 = "delete from articles where `key`='aaa'";
260        let qstring1 = "delete from `where` where user=?";
261
262        let expected0 = "DELETE FROM articles WHERE `key` = 'aaa'";
263        let expected1 = "DELETE FROM `where` WHERE user = ?";
264
265        let res0 = parse_query(qstring0);
266        let res1 = parse_query(qstring1);
267        assert!(res0.is_ok());
268        assert!(res1.is_ok());
269        assert_eq!(expected0, format!("{}", res0.unwrap()));
270        assert_eq!(expected1, format!("{}", res1.unwrap()));
271    }
272}