Skip to main content

qusql_parse/
update.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::vec;
14use alloc::vec::Vec;
15
16use crate::{
17    Identifier, SelectExpr, Span, Spanned,
18    expression::{
19        Expression, PRIORITY_MAX, parse_expression_or_default, parse_expression_unreserved,
20    },
21    keywords::{Keyword, Restrict},
22    lexer::Token,
23    parser::{ParseError, Parser},
24    select::{TableReference, parse_select_expr, parse_table_reference},
25    span::OptSpanned,
26};
27
28/// Flags specified after "UPDATE"
29#[derive(Clone, Debug)]
30pub enum UpdateFlag {
31    LowPriority(Span),
32    Ignore(Span),
33}
34
35impl Spanned for UpdateFlag {
36    fn span(&self) -> Span {
37        match &self {
38            UpdateFlag::LowPriority(v) => v.span(),
39            UpdateFlag::Ignore(v) => v.span(),
40        }
41    }
42}
43
44/// Representation of replace Statement
45///
46/// ```
47/// # use qusql_parse::{SQLDialect, SQLArguments, ParseOptions, parse_statement, Update, Statement, Issues};
48/// # let options = ParseOptions::new().dialect(SQLDialect::MariaDB);
49/// #
50/// let sql = "UPDATE tab1, tab2 SET tab1.column1 = value1, tab1.column2 = value2 WHERE tab1.id = tab2.id";
51/// let mut issues = Issues::new(sql);
52/// let stmt = parse_statement(sql, &mut issues, &options);
53///
54/// # assert!(issues.is_ok());
55/// let u: Update = match stmt {
56///     Some(Statement::Update(u)) => *u,
57///     _ => panic!("We should get an update statement")
58/// };
59///
60/// println!("{:#?}", u.where_.unwrap())
61/// ```
62#[derive(Clone, Debug)]
63pub struct Update<'a> {
64    /// Span of "UPDATE"
65    pub update_span: Span,
66    /// Flags specified after "UPDATE"
67    pub flags: Vec<UpdateFlag>,
68    /// List of tables to update
69    pub tables: Vec<TableReference<'a>>,
70    /// Span of "SET"
71    pub set_span: Span,
72    /// List of key,value pairs to set
73    pub set: Vec<(Vec<Identifier<'a>>, Expression<'a>)>,
74    /// Where expression and span of "WHERE" if specified
75    pub where_: Option<(Expression<'a>, Span)>,
76    /// PostgreSQL: FROM clause (additional tables)
77    pub from: Option<(Span, Vec<TableReference<'a>>)>,
78    /// Span of "RETURNING" and select expressions after "RETURNING", if "RETURNING" is present
79    pub returning: Option<(Span, Vec<SelectExpr<'a>>)>,
80}
81
82impl<'a> Spanned for Update<'a> {
83    fn span(&self) -> Span {
84        let mut set_span = None;
85        for (a, b) in &self.set {
86            set_span = set_span.opt_join_span(a).opt_join_span(b)
87        }
88
89        self.update_span
90            .join_span(&self.flags)
91            .join_span(&self.tables)
92            .join_span(&self.set_span)
93            .join_span(&set_span)
94            .join_span(&self.from)
95            .join_span(&self.where_)
96            .join_span(&self.returning)
97    }
98}
99
100pub(crate) fn parse_update<'a>(parser: &mut Parser<'a, '_>) -> Result<Update<'a>, ParseError> {
101    let update_span = parser.consume_keyword(Keyword::UPDATE)?;
102    let mut flags = Vec::new();
103
104    loop {
105        match &parser.token {
106            Token::Ident(_, Keyword::LOW_PRIORITY) => flags.push(UpdateFlag::LowPriority(
107                parser.consume_keyword(Keyword::LOW_PRIORITY)?,
108            )),
109            Token::Ident(_, Keyword::IGNORE) => {
110                flags.push(UpdateFlag::Ignore(parser.consume_keyword(Keyword::IGNORE)?))
111            }
112            _ => break,
113        }
114    }
115
116    let mut tables = Vec::new();
117    loop {
118        tables.push(parse_table_reference(parser, Restrict::UPDATE_TABLE)?);
119        // In PostgreSQL UPDATE, SET is not fully reserved, so stop here before comma
120        if matches!(parser.token, Token::Ident(_, Keyword::SET)) {
121            break;
122        }
123        if parser.skip_token(Token::Comma).is_none() {
124            break;
125        }
126    }
127    let set_span = parser.consume_keyword(Keyword::SET)?;
128    let mut set = Vec::new();
129    loop {
130        let mut col = vec![parser.consume_plain_identifier_unreserved()?];
131        while parser.skip_token(Token::Period).is_some() {
132            col.push(parser.consume_plain_identifier_unreserved()?);
133        }
134        parser.consume_token(Token::Eq)?;
135        let val = parse_expression_or_default(parser, PRIORITY_MAX)?;
136        set.push((col, val));
137        if parser.skip_token(Token::Comma).is_none() {
138            break;
139        }
140    }
141
142    let where_ = if let Some(span) = parser.skip_keyword(Keyword::WHERE) {
143        Some((parse_expression_unreserved(parser, PRIORITY_MAX)?, span))
144    } else {
145        None
146    };
147
148    // PostgreSQL: FROM clause after SET (before WHERE)
149    let from = if where_.is_none() {
150        if let Some(from_span) = parser.skip_keyword(Keyword::FROM) {
151            parser.postgres_only(&from_span);
152            let mut from_tables = Vec::new();
153            loop {
154                from_tables.push(parse_table_reference(parser, Restrict::EMPTY)?);
155                if parser.skip_token(Token::Comma).is_none() {
156                    break;
157                }
158            }
159            let where_inner = if let Some(span) = parser.skip_keyword(Keyword::WHERE) {
160                Some((parse_expression_unreserved(parser, PRIORITY_MAX)?, span))
161            } else {
162                None
163            };
164            // Re-assign where_ by returning it from the block — handled below
165            Some((from_span, from_tables, where_inner))
166        } else {
167            None
168        }
169    } else {
170        None
171    };
172    let (from, where_) = if let Some((from_span, from_tables, where_inner)) = from {
173        (Some((from_span, from_tables)), where_inner)
174    } else {
175        (None, where_)
176    };
177
178    let returning = if let Some(returning_span) = parser.skip_keyword(Keyword::RETURNING) {
179        let mut returning_exprs = Vec::new();
180        loop {
181            returning_exprs.push(parse_select_expr(parser)?);
182            if parser.skip_token(Token::Comma).is_none() {
183                break;
184            }
185        }
186        parser.postgres_only(&returning_span);
187        Some((returning_span, returning_exprs))
188    } else {
189        None
190    };
191
192    Ok(Update {
193        flags,
194        update_span,
195        tables,
196        set_span,
197        set,
198        from,
199        where_,
200        returning,
201    })
202}
203
204// UPDATE [LOW_PRIORITY] [IGNORE] table_references
205// SET col1={expr1|DEFAULT} [, col2={expr2|DEFAULT}] ...
206// [WHERE where_condition]