1use alloc::vec;
14use alloc::vec::Vec;
15
16use crate::{
17 Identifier, SelectExpr, Span, Spanned,
18 expression::{
19 Expression, PRIORITY_MAX, parse_expression_or_default, parse_expression_unreserved,
20 },
21 keywords::{Keyword, Restrict},
22 lexer::Token,
23 parser::{ParseError, Parser},
24 select::{TableReference, parse_select_expr, parse_table_reference},
25 span::OptSpanned,
26};
27
28#[derive(Clone, Debug)]
30pub enum UpdateFlag {
31 LowPriority(Span),
32 Ignore(Span),
33}
34
35impl Spanned for UpdateFlag {
36 fn span(&self) -> Span {
37 match &self {
38 UpdateFlag::LowPriority(v) => v.span(),
39 UpdateFlag::Ignore(v) => v.span(),
40 }
41 }
42}
43
44#[derive(Clone, Debug)]
63pub struct Update<'a> {
64 pub update_span: Span,
66 pub flags: Vec<UpdateFlag>,
68 pub tables: Vec<TableReference<'a>>,
70 pub set_span: Span,
72 pub set: Vec<(Vec<Identifier<'a>>, Expression<'a>)>,
74 pub where_: Option<(Expression<'a>, Span)>,
76 pub from: Option<(Span, Vec<TableReference<'a>>)>,
78 pub returning: Option<(Span, Vec<SelectExpr<'a>>)>,
80}
81
82impl<'a> Spanned for Update<'a> {
83 fn span(&self) -> Span {
84 let mut set_span = None;
85 for (a, b) in &self.set {
86 set_span = set_span.opt_join_span(a).opt_join_span(b)
87 }
88
89 self.update_span
90 .join_span(&self.flags)
91 .join_span(&self.tables)
92 .join_span(&self.set_span)
93 .join_span(&set_span)
94 .join_span(&self.from)
95 .join_span(&self.where_)
96 .join_span(&self.returning)
97 }
98}
99
100pub(crate) fn parse_update<'a>(parser: &mut Parser<'a, '_>) -> Result<Update<'a>, ParseError> {
101 let update_span = parser.consume_keyword(Keyword::UPDATE)?;
102 let mut flags = Vec::new();
103
104 loop {
105 match &parser.token {
106 Token::Ident(_, Keyword::LOW_PRIORITY) => flags.push(UpdateFlag::LowPriority(
107 parser.consume_keyword(Keyword::LOW_PRIORITY)?,
108 )),
109 Token::Ident(_, Keyword::IGNORE) => {
110 flags.push(UpdateFlag::Ignore(parser.consume_keyword(Keyword::IGNORE)?))
111 }
112 _ => break,
113 }
114 }
115
116 let mut tables = Vec::new();
117 loop {
118 tables.push(parse_table_reference(parser, Restrict::UPDATE_TABLE)?);
119 if matches!(parser.token, Token::Ident(_, Keyword::SET)) {
121 break;
122 }
123 if parser.skip_token(Token::Comma).is_none() {
124 break;
125 }
126 }
127 let set_span = parser.consume_keyword(Keyword::SET)?;
128 let mut set = Vec::new();
129 loop {
130 let mut col = vec![parser.consume_plain_identifier_unreserved()?];
131 while parser.skip_token(Token::Period).is_some() {
132 col.push(parser.consume_plain_identifier_unreserved()?);
133 }
134 parser.consume_token(Token::Eq)?;
135 let val = parse_expression_or_default(parser, PRIORITY_MAX)?;
136 set.push((col, val));
137 if parser.skip_token(Token::Comma).is_none() {
138 break;
139 }
140 }
141
142 let where_ = if let Some(span) = parser.skip_keyword(Keyword::WHERE) {
143 Some((parse_expression_unreserved(parser, PRIORITY_MAX)?, span))
144 } else {
145 None
146 };
147
148 let from = if where_.is_none() {
150 if let Some(from_span) = parser.skip_keyword(Keyword::FROM) {
151 parser.postgres_only(&from_span);
152 let mut from_tables = Vec::new();
153 loop {
154 from_tables.push(parse_table_reference(parser, Restrict::EMPTY)?);
155 if parser.skip_token(Token::Comma).is_none() {
156 break;
157 }
158 }
159 let where_inner = if let Some(span) = parser.skip_keyword(Keyword::WHERE) {
160 Some((parse_expression_unreserved(parser, PRIORITY_MAX)?, span))
161 } else {
162 None
163 };
164 Some((from_span, from_tables, where_inner))
166 } else {
167 None
168 }
169 } else {
170 None
171 };
172 let (from, where_) = if let Some((from_span, from_tables, where_inner)) = from {
173 (Some((from_span, from_tables)), where_inner)
174 } else {
175 (None, where_)
176 };
177
178 let returning = if let Some(returning_span) = parser.skip_keyword(Keyword::RETURNING) {
179 let mut returning_exprs = Vec::new();
180 loop {
181 returning_exprs.push(parse_select_expr(parser)?);
182 if parser.skip_token(Token::Comma).is_none() {
183 break;
184 }
185 }
186 parser.postgres_only(&returning_span);
187 Some((returning_span, returning_exprs))
188 } else {
189 None
190 };
191
192 Ok(Update {
193 flags,
194 update_span,
195 tables,
196 set_span,
197 set,
198 from,
199 where_,
200 returning,
201 })
202}
203
204