nom_sql/
insert.rs

1use nom::character::complete::{multispace0, multispace1};
2use std::fmt;
3use std::str;
4
5use column::Column;
6use common::{
7    assignment_expr_list, field_list, statement_terminator, table_reference, value_list,
8    ws_sep_comma, FieldValueExpression, Literal,
9};
10use keywords::escape_if_keyword;
11use nom::bytes::complete::{tag, tag_no_case};
12use nom::combinator::opt;
13use nom::multi::many1;
14use nom::sequence::{delimited, preceded, tuple};
15use nom::IResult;
16use table::Table;
17
18#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
19pub struct InsertStatement {
20    pub table: Table,
21    pub fields: Option<Vec<Column>>,
22    pub data: Vec<Vec<Literal>>,
23    pub ignore: bool,
24    pub on_duplicate: Option<Vec<(Column, FieldValueExpression)>>,
25}
26
27impl fmt::Display for InsertStatement {
28    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
29        write!(f, "INSERT INTO {}", escape_if_keyword(&self.table.name))?;
30        if let Some(ref fields) = self.fields {
31            write!(
32                f,
33                " ({})",
34                fields
35                    .iter()
36                    .map(|ref col| col.name.to_owned())
37                    .collect::<Vec<_>>()
38                    .join(", ")
39            )?;
40        }
41        write!(
42            f,
43            " VALUES {}",
44            self.data
45                .iter()
46                .map(|datas| format!(
47                    "({})",
48                    datas
49                        .into_iter()
50                        .map(|l| l.to_string())
51                        .collect::<Vec<_>>()
52                        .join(", ")
53                ))
54                .collect::<Vec<_>>()
55                .join(", ")
56        )
57    }
58}
59
60fn fields(i: &[u8]) -> IResult<&[u8], Vec<Column>> {
61    delimited(
62        preceded(tag("("), multispace0),
63        field_list,
64        delimited(multispace0, tag(")"), multispace1),
65    )(i)
66}
67
68fn data(i: &[u8]) -> IResult<&[u8], Vec<Literal>> {
69    delimited(tag("("), value_list, preceded(tag(")"), opt(ws_sep_comma)))(i)
70}
71
72fn on_duplicate(i: &[u8]) -> IResult<&[u8], Vec<(Column, FieldValueExpression)>> {
73    preceded(
74        multispace0,
75        preceded(
76            tag_no_case("on duplicate key update"),
77            preceded(multispace1, assignment_expr_list),
78        ),
79    )(i)
80}
81
82// Parse rule for a SQL insert query.
83// TODO(malte): support REPLACE, nested selection, DEFAULT VALUES
84pub fn insertion(i: &[u8]) -> IResult<&[u8], InsertStatement> {
85    let (remaining_input, (_, ignore_res, _, _, _, table, _, fields, _, _, data, on_duplicate, _)) =
86        tuple((
87            tag_no_case("insert"),
88            opt(preceded(multispace1, tag_no_case("ignore"))),
89            multispace1,
90            tag_no_case("into"),
91            multispace1,
92            table_reference,
93            multispace0,
94            opt(fields),
95            tag_no_case("values"),
96            multispace0,
97            many1(data),
98            opt(on_duplicate),
99            statement_terminator,
100        ))(i)?;
101    assert!(table.alias.is_none());
102    let ignore = ignore_res.is_some();
103
104    Ok((
105        remaining_input,
106        InsertStatement {
107            table,
108            fields,
109            data,
110            ignore,
111            on_duplicate,
112        },
113    ))
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator};
120    use column::Column;
121    use table::Table;
122
123    #[test]
124    fn simple_insert() {
125        let qstring = "INSERT INTO users VALUES (42, \"test\");";
126
127        let res = insertion(qstring.as_bytes());
128        assert_eq!(
129            res.unwrap().1,
130            InsertStatement {
131                table: Table::from("users"),
132                fields: None,
133                data: vec![vec![42.into(), "test".into()]],
134                ..Default::default()
135            }
136        );
137    }
138
139    #[test]
140    fn complex_insert() {
141        let qstring = "INSERT INTO users VALUES (42, 'test', \"test\", CURRENT_TIMESTAMP);";
142
143        let res = insertion(qstring.as_bytes());
144        assert_eq!(
145            res.unwrap().1,
146            InsertStatement {
147                table: Table::from("users"),
148                fields: None,
149                data: vec![vec![
150                    42.into(),
151                    "test".into(),
152                    "test".into(),
153                    Literal::CurrentTimestamp,
154                ],],
155                ..Default::default()
156            }
157        );
158    }
159
160    #[test]
161    fn insert_with_field_names() {
162        let qstring = "INSERT INTO users (id, name) VALUES (42, \"test\");";
163
164        let res = insertion(qstring.as_bytes());
165        assert_eq!(
166            res.unwrap().1,
167            InsertStatement {
168                table: Table::from("users"),
169                fields: Some(vec![Column::from("id"), Column::from("name")]),
170                data: vec![vec![42.into(), "test".into()]],
171                ..Default::default()
172            }
173        );
174    }
175
176    // Issue #3
177    #[test]
178    fn insert_without_spaces() {
179        let qstring = "INSERT INTO users(id, name) VALUES(42, \"test\");";
180
181        let res = insertion(qstring.as_bytes());
182        assert_eq!(
183            res.unwrap().1,
184            InsertStatement {
185                table: Table::from("users"),
186                fields: Some(vec![Column::from("id"), Column::from("name")]),
187                data: vec![vec![42.into(), "test".into()]],
188                ..Default::default()
189            }
190        );
191    }
192
193    #[test]
194    fn multi_insert() {
195        let qstring = "INSERT INTO users (id, name) VALUES (42, \"test\"),(21, \"test2\");";
196
197        let res = insertion(qstring.as_bytes());
198        assert_eq!(
199            res.unwrap().1,
200            InsertStatement {
201                table: Table::from("users"),
202                fields: Some(vec![Column::from("id"), Column::from("name")]),
203                data: vec![
204                    vec![42.into(), "test".into()],
205                    vec![21.into(), "test2".into()],
206                ],
207                ..Default::default()
208            }
209        );
210    }
211
212    #[test]
213    fn insert_with_parameters() {
214        let qstring = "INSERT INTO users (id, name) VALUES (?, ?);";
215
216        let res = insertion(qstring.as_bytes());
217        assert_eq!(
218            res.unwrap().1,
219            InsertStatement {
220                table: Table::from("users"),
221                fields: Some(vec![Column::from("id"), Column::from("name")]),
222                data: vec![vec![Literal::Placeholder, Literal::Placeholder]],
223                ..Default::default()
224            }
225        );
226    }
227
228    #[test]
229    fn insert_with_on_dup_update() {
230        let qstring = "INSERT INTO keystores (`key`, `value`) VALUES (?, ?) \
231                       ON DUPLICATE KEY UPDATE `value` = `value` + 1";
232
233        let res = insertion(qstring.as_bytes());
234        let expected_ae = ArithmeticExpression {
235            op: ArithmeticOperator::Add,
236            left: ArithmeticBase::Column(Column::from("value")),
237            right: ArithmeticBase::Scalar(1.into()),
238            alias: None,
239        };
240        assert_eq!(
241            res.unwrap().1,
242            InsertStatement {
243                table: Table::from("keystores"),
244                fields: Some(vec![Column::from("key"), Column::from("value")]),
245                data: vec![vec![Literal::Placeholder, Literal::Placeholder]],
246                on_duplicate: Some(vec![(
247                    Column::from("value"),
248                    FieldValueExpression::Arithmetic(expected_ae),
249                ),]),
250                ..Default::default()
251            }
252        );
253    }
254
255    #[test]
256    fn insert_with_leading_value_whitespace() {
257        let qstring = "INSERT INTO users (id, name) VALUES ( 42, \"test\");";
258
259        let res = insertion(qstring.as_bytes());
260        assert_eq!(
261            res.unwrap().1,
262            InsertStatement {
263                table: Table::from("users"),
264                fields: Some(vec![Column::from("id"), Column::from("name")]),
265                data: vec![vec![42.into(), "test".into()]],
266                ..Default::default()
267            }
268        );
269    }
270}