1use nom::character::complete::{multispace0, multispace1};
2use std::{fmt, str};
3
4use column::Column;
5use common::{assignment_expr_list, statement_terminator, table_reference, FieldValueExpression};
6use condition::ConditionExpression;
7use keywords::escape_if_keyword;
8use nom::bytes::complete::tag_no_case;
9use nom::combinator::opt;
10use nom::sequence::tuple;
11use nom::IResult;
12use select::where_clause;
13use table::Table;
14
15#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
16pub struct UpdateStatement {
17 pub table: Table,
18 pub fields: Vec<(Column, FieldValueExpression)>,
19 pub where_clause: Option<ConditionExpression>,
20}
21
22impl fmt::Display for UpdateStatement {
23 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
24 write!(f, "UPDATE {} ", escape_if_keyword(&self.table.name))?;
25 assert!(self.fields.len() > 0);
26 write!(
27 f,
28 "SET {}",
29 self.fields
30 .iter()
31 .map(|&(ref col, ref literal)| format!("{} = {}", col, literal.to_string()))
32 .collect::<Vec<_>>()
33 .join(", ")
34 )?;
35 if let Some(ref where_clause) = self.where_clause {
36 write!(f, " WHERE ")?;
37 write!(f, "{}", where_clause)?;
38 }
39 Ok(())
40 }
41}
42
43pub fn updating(i: &[u8]) -> IResult<&[u8], UpdateStatement> {
44 let (remaining_input, (_, _, table, _, _, _, fields, _, where_clause, _)) = tuple((
45 tag_no_case("update"),
46 multispace1,
47 table_reference,
48 multispace1,
49 tag_no_case("set"),
50 multispace1,
51 assignment_expr_list,
52 multispace0,
53 opt(where_clause),
54 statement_terminator,
55 ))(i)?;
56 Ok((
57 remaining_input,
58 UpdateStatement {
59 table,
60 fields,
61 where_clause,
62 },
63 ))
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69 use arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator};
70 use column::Column;
71 use common::{Literal, LiteralExpression, Operator, Real};
72 use condition::ConditionBase::*;
73 use condition::ConditionExpression::*;
74 use condition::ConditionTree;
75 use table::Table;
76
77 #[test]
78 fn simple_update() {
79 let qstring = "UPDATE users SET id = 42, name = 'test'";
80
81 let res = updating(qstring.as_bytes());
82 assert_eq!(
83 res.unwrap().1,
84 UpdateStatement {
85 table: Table::from("users"),
86 fields: vec![
87 (
88 Column::from("id"),
89 FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))),
90 ),
91 (
92 Column::from("name"),
93 FieldValueExpression::Literal(LiteralExpression::from(Literal::from(
94 "test",
95 ))),
96 ),
97 ],
98 ..Default::default()
99 }
100 );
101 }
102
103 #[test]
104 fn update_with_where_clause() {
105 let qstring = "UPDATE users SET id = 42, name = 'test' WHERE id = 1";
106
107 let res = updating(qstring.as_bytes());
108 let expected_left = Base(Field(Column::from("id")));
109 let expected_where_cond = Some(ComparisonOp(ConditionTree {
110 left: Box::new(expected_left),
111 right: Box::new(Base(Literal(Literal::Integer(1)))),
112 operator: Operator::Equal,
113 }));
114 assert_eq!(
115 res.unwrap().1,
116 UpdateStatement {
117 table: Table::from("users"),
118 fields: vec![
119 (
120 Column::from("id"),
121 FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))),
122 ),
123 (
124 Column::from("name"),
125 FieldValueExpression::Literal(LiteralExpression::from(Literal::from(
126 "test",
127 ))),
128 ),
129 ],
130 where_clause: expected_where_cond,
131 ..Default::default()
132 }
133 );
134 }
135
136 #[test]
137 fn format_update_with_where_clause() {
138 let qstring = "UPDATE users SET id = 42, name = 'test' WHERE id = 1";
139 let expected = "UPDATE users SET id = 42, name = 'test' WHERE id = 1";
140 let res = updating(qstring.as_bytes());
141 assert_eq!(format!("{}", res.unwrap().1), expected);
142 }
143
144 #[test]
145 fn updated_with_neg_float() {
146 let qstring = "UPDATE `stories` SET `hotness` = -19216.5479744 WHERE `stories`.`id` = ?";
147
148 let res = updating(qstring.as_bytes());
149 let expected_left = Base(Field(Column::from("stories.id")));
150 let expected_where_cond = Some(ComparisonOp(ConditionTree {
151 left: Box::new(expected_left),
152 right: Box::new(Base(Literal(Literal::Placeholder))),
153 operator: Operator::Equal,
154 }));
155 assert_eq!(
156 res.unwrap().1,
157 UpdateStatement {
158 table: Table::from("stories"),
159 fields: vec![(
160 Column::from("hotness"),
161 FieldValueExpression::Literal(LiteralExpression::from(Literal::FixedPoint(
162 Real {
163 integral: -19216,
164 fractional: 5479744,
165 }
166 ),)),
167 ),],
168 where_clause: expected_where_cond,
169 ..Default::default()
170 }
171 );
172 }
173
174 #[test]
175 fn update_with_arithmetic_and_where() {
176 let qstring = "UPDATE users SET karma = karma + 1 WHERE users.id = ?;";
177
178 let res = updating(qstring.as_bytes());
179 let expected_where_cond = Some(ComparisonOp(ConditionTree {
180 left: Box::new(Base(Field(Column::from("users.id")))),
181 right: Box::new(Base(Literal(Literal::Placeholder))),
182 operator: Operator::Equal,
183 }));
184 let expected_ae = ArithmeticExpression {
185 op: ArithmeticOperator::Add,
186 left: ArithmeticBase::Column(Column::from("karma")),
187 right: ArithmeticBase::Scalar(1.into()),
188 alias: None,
189 };
190 assert_eq!(
191 res.unwrap().1,
192 UpdateStatement {
193 table: Table::from("users"),
194 fields: vec![(
195 Column::from("karma"),
196 FieldValueExpression::Arithmetic(expected_ae),
197 ),],
198 where_clause: expected_where_cond,
199 ..Default::default()
200 }
201 );
202 }
203
204 #[test]
205 fn update_with_arithmetic() {
206 let qstring = "UPDATE users SET karma = karma + 1;";
207
208 let res = updating(qstring.as_bytes());
209 let expected_ae = ArithmeticExpression {
210 op: ArithmeticOperator::Add,
211 left: ArithmeticBase::Column(Column::from("karma")),
212 right: ArithmeticBase::Scalar(1.into()),
213 alias: None,
214 };
215 assert_eq!(
216 res.unwrap().1,
217 UpdateStatement {
218 table: Table::from("users"),
219 fields: vec![(
220 Column::from("karma"),
221 FieldValueExpression::Arithmetic(expected_ae),
222 ),],
223 ..Default::default()
224 }
225 );
226 }
227}