nom_sql/
update.rs

1use nom::character::complete::{multispace0, multispace1};
2use std::{fmt, str};
3
4use column::Column;
5use common::{assignment_expr_list, statement_terminator, table_reference, FieldValueExpression};
6use condition::ConditionExpression;
7use keywords::escape_if_keyword;
8use nom::bytes::complete::tag_no_case;
9use nom::combinator::opt;
10use nom::sequence::tuple;
11use nom::IResult;
12use select::where_clause;
13use table::Table;
14
15#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
16pub struct UpdateStatement {
17    pub table: Table,
18    pub fields: Vec<(Column, FieldValueExpression)>,
19    pub where_clause: Option<ConditionExpression>,
20}
21
22impl fmt::Display for UpdateStatement {
23    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
24        write!(f, "UPDATE {} ", escape_if_keyword(&self.table.name))?;
25        assert!(self.fields.len() > 0);
26        write!(
27            f,
28            "SET {}",
29            self.fields
30                .iter()
31                .map(|&(ref col, ref literal)| format!("{} = {}", col, literal.to_string()))
32                .collect::<Vec<_>>()
33                .join(", ")
34        )?;
35        if let Some(ref where_clause) = self.where_clause {
36            write!(f, " WHERE ")?;
37            write!(f, "{}", where_clause)?;
38        }
39        Ok(())
40    }
41}
42
43pub fn updating(i: &[u8]) -> IResult<&[u8], UpdateStatement> {
44    let (remaining_input, (_, _, table, _, _, _, fields, _, where_clause, _)) = tuple((
45        tag_no_case("update"),
46        multispace1,
47        table_reference,
48        multispace1,
49        tag_no_case("set"),
50        multispace1,
51        assignment_expr_list,
52        multispace0,
53        opt(where_clause),
54        statement_terminator,
55    ))(i)?;
56    Ok((
57        remaining_input,
58        UpdateStatement {
59            table,
60            fields,
61            where_clause,
62        },
63    ))
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69    use arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator};
70    use column::Column;
71    use common::{Literal, LiteralExpression, Operator, Real};
72    use condition::ConditionBase::*;
73    use condition::ConditionExpression::*;
74    use condition::ConditionTree;
75    use table::Table;
76
77    #[test]
78    fn simple_update() {
79        let qstring = "UPDATE users SET id = 42, name = 'test'";
80
81        let res = updating(qstring.as_bytes());
82        assert_eq!(
83            res.unwrap().1,
84            UpdateStatement {
85                table: Table::from("users"),
86                fields: vec![
87                    (
88                        Column::from("id"),
89                        FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))),
90                    ),
91                    (
92                        Column::from("name"),
93                        FieldValueExpression::Literal(LiteralExpression::from(Literal::from(
94                            "test",
95                        ))),
96                    ),
97                ],
98                ..Default::default()
99            }
100        );
101    }
102
103    #[test]
104    fn update_with_where_clause() {
105        let qstring = "UPDATE users SET id = 42, name = 'test' WHERE id = 1";
106
107        let res = updating(qstring.as_bytes());
108        let expected_left = Base(Field(Column::from("id")));
109        let expected_where_cond = Some(ComparisonOp(ConditionTree {
110            left: Box::new(expected_left),
111            right: Box::new(Base(Literal(Literal::Integer(1)))),
112            operator: Operator::Equal,
113        }));
114        assert_eq!(
115            res.unwrap().1,
116            UpdateStatement {
117                table: Table::from("users"),
118                fields: vec![
119                    (
120                        Column::from("id"),
121                        FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))),
122                    ),
123                    (
124                        Column::from("name"),
125                        FieldValueExpression::Literal(LiteralExpression::from(Literal::from(
126                            "test",
127                        ))),
128                    ),
129                ],
130                where_clause: expected_where_cond,
131                ..Default::default()
132            }
133        );
134    }
135
136    #[test]
137    fn format_update_with_where_clause() {
138        let qstring = "UPDATE users SET id = 42, name = 'test' WHERE id = 1";
139        let expected = "UPDATE users SET id = 42, name = 'test' WHERE id = 1";
140        let res = updating(qstring.as_bytes());
141        assert_eq!(format!("{}", res.unwrap().1), expected);
142    }
143
144    #[test]
145    fn updated_with_neg_float() {
146        let qstring = "UPDATE `stories` SET `hotness` = -19216.5479744 WHERE `stories`.`id` = ?";
147
148        let res = updating(qstring.as_bytes());
149        let expected_left = Base(Field(Column::from("stories.id")));
150        let expected_where_cond = Some(ComparisonOp(ConditionTree {
151            left: Box::new(expected_left),
152            right: Box::new(Base(Literal(Literal::Placeholder))),
153            operator: Operator::Equal,
154        }));
155        assert_eq!(
156            res.unwrap().1,
157            UpdateStatement {
158                table: Table::from("stories"),
159                fields: vec![(
160                    Column::from("hotness"),
161                    FieldValueExpression::Literal(LiteralExpression::from(Literal::FixedPoint(
162                        Real {
163                            integral: -19216,
164                            fractional: 5479744,
165                        }
166                    ),)),
167                ),],
168                where_clause: expected_where_cond,
169                ..Default::default()
170            }
171        );
172    }
173
174    #[test]
175    fn update_with_arithmetic_and_where() {
176        let qstring = "UPDATE users SET karma = karma + 1 WHERE users.id = ?;";
177
178        let res = updating(qstring.as_bytes());
179        let expected_where_cond = Some(ComparisonOp(ConditionTree {
180            left: Box::new(Base(Field(Column::from("users.id")))),
181            right: Box::new(Base(Literal(Literal::Placeholder))),
182            operator: Operator::Equal,
183        }));
184        let expected_ae = ArithmeticExpression {
185            op: ArithmeticOperator::Add,
186            left: ArithmeticBase::Column(Column::from("karma")),
187            right: ArithmeticBase::Scalar(1.into()),
188            alias: None,
189        };
190        assert_eq!(
191            res.unwrap().1,
192            UpdateStatement {
193                table: Table::from("users"),
194                fields: vec![(
195                    Column::from("karma"),
196                    FieldValueExpression::Arithmetic(expected_ae),
197                ),],
198                where_clause: expected_where_cond,
199                ..Default::default()
200            }
201        );
202    }
203
204    #[test]
205    fn update_with_arithmetic() {
206        let qstring = "UPDATE users SET karma = karma + 1;";
207
208        let res = updating(qstring.as_bytes());
209        let expected_ae = ArithmeticExpression {
210            op: ArithmeticOperator::Add,
211            left: ArithmeticBase::Column(Column::from("karma")),
212            right: ArithmeticBase::Scalar(1.into()),
213            alias: None,
214        };
215        assert_eq!(
216            res.unwrap().1,
217            UpdateStatement {
218                table: Table::from("users"),
219                fields: vec![(
220                    Column::from("karma"),
221                    FieldValueExpression::Arithmetic(expected_ae),
222                ),],
223                ..Default::default()
224            }
225        );
226    }
227}