Skip to main content

qusql_parse/
insert_replace.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12use alloc::vec::Vec;
13
14use crate::{
15    Identifier, OptSpanned, QualifiedName, Span, Spanned, Statement,
16    expression::{
17        Expression, PRIORITY_MAX, parse_expression_or_default, parse_expression_unreserved,
18    },
19    keywords::Keyword,
20    lexer::Token,
21    parser::{ParseError, Parser},
22    qualified_name::parse_qualified_name_unreserved,
23    select::{Select, SelectExpr, parse_select, parse_select_expr},
24};
25
26/// Flags for insert
27#[derive(Clone, Debug)]
28pub enum InsertReplaceFlag {
29    LowPriority(Span),
30    HighPriority(Span),
31    Delayed(Span),
32    Ignore(Span),
33}
34
35impl Spanned for InsertReplaceFlag {
36    fn span(&self) -> Span {
37        match &self {
38            InsertReplaceFlag::LowPriority(v) => v.span(),
39            InsertReplaceFlag::HighPriority(v) => v.span(),
40            InsertReplaceFlag::Delayed(v) => v.span(),
41            InsertReplaceFlag::Ignore(v) => v.span(),
42        }
43    }
44}
45
46#[derive(Clone, Debug)]
47pub enum InsertReplaceType {
48    Insert(Span),
49    Replace(Span),
50}
51
52impl Spanned for InsertReplaceType {
53    fn span(&self) -> Span {
54        match self {
55            InsertReplaceType::Insert(a) => a.clone(),
56            InsertReplaceType::Replace(a) => a.clone(),
57        }
58    }
59}
60
61#[derive(Clone, Debug)]
62pub enum OnConflictTarget<'a> {
63    Columns {
64        names: Vec<Identifier<'a>>,
65    },
66    OnConstraint {
67        on_constraint_span: Span,
68        name: Identifier<'a>,
69    },
70    None,
71}
72
73impl<'a> OptSpanned for OnConflictTarget<'a> {
74    fn opt_span(&self) -> Option<Span> {
75        match self {
76            OnConflictTarget::Columns { names } => names.opt_span(),
77            OnConflictTarget::OnConstraint {
78                on_constraint_span: token,
79                name,
80            } => Some(token.join_span(name)),
81            OnConflictTarget::None => None,
82        }
83    }
84}
85
86#[derive(Clone, Debug)]
87pub enum OnConflictAction<'a> {
88    DoNothing(Span),
89    DoUpdateSet {
90        do_update_set_span: Span,
91        sets: Vec<(Identifier<'a>, Expression<'a>)>,
92        where_: Option<(Span, Expression<'a>)>,
93    },
94}
95
96impl<'a> Spanned for OnConflictAction<'a> {
97    fn span(&self) -> Span {
98        match self {
99            OnConflictAction::DoNothing(span) => span.span(),
100            OnConflictAction::DoUpdateSet {
101                do_update_set_span,
102                sets,
103                where_,
104            } => do_update_set_span.join_span(sets).join_span(where_),
105        }
106    }
107}
108
109#[derive(Clone, Debug)]
110pub struct OnConflict<'a> {
111    pub on_conflict_span: Span,
112    pub target: OnConflictTarget<'a>,
113    pub action: OnConflictAction<'a>,
114}
115
116impl<'a> Spanned for OnConflict<'a> {
117    fn span(&self) -> Span {
118        self.on_conflict_span
119            .join_span(&self.target)
120            .join_span(&self.action)
121    }
122}
123
124#[derive(Clone, Debug)]
125pub struct InsertReplaceSetPair<'a> {
126    pub column: Identifier<'a>,
127    pub equal_span: Span,
128    pub value: Expression<'a>,
129}
130
131impl<'a> Spanned for InsertReplaceSetPair<'a> {
132    fn span(&self) -> Span {
133        self.column
134            .join_span(&self.equal_span)
135            .join_span(&self.value)
136    }
137}
138
139#[derive(Clone, Debug)]
140pub struct InsertReplaceSet<'a> {
141    pub set_span: Span,
142    pub pairs: Vec<InsertReplaceSetPair<'a>>,
143}
144
145impl<'a> Spanned for InsertReplaceSet<'a> {
146    fn span(&self) -> Span {
147        self.set_span.join_span(&self.pairs)
148    }
149}
150
151#[derive(Clone, Debug)]
152pub struct InsertReplaceOnDuplicateKeyUpdate<'a> {
153    pub on_duplicate_key_update_span: Span,
154    pub pairs: Vec<InsertReplaceSetPair<'a>>,
155}
156
157impl<'a> Spanned for InsertReplaceOnDuplicateKeyUpdate<'a> {
158    fn span(&self) -> Span {
159        self.on_duplicate_key_update_span.join_span(&self.pairs)
160    }
161}
162
163/// Representation of Insert or Replace Statement
164///
165/// ```
166/// # use qusql_parse::{SQLDialect, SQLArguments, ParseOptions, parse_statement, InsertReplace, InsertReplaceType, Statement, Issues};
167/// # let options = ParseOptions::new().dialect(SQLDialect::MariaDB);
168/// #
169/// let sql1 = "INSERT INTO person (first_name, last_name) VALUES ('John', 'Doe')";
170/// # let mut issues = Issues::new(sql1);
171/// let stmt1 = parse_statement(sql1, &mut issues, &options);
172/// # assert!(issues.is_ok());/// #
173/// let sql2 = "INSERT INTO contractor SELECT * FROM person WHERE status = 'c'";
174/// # let mut issues = Issues::new(sql2);
175/// let stmt2 = parse_statement(sql2, &mut issues, &options);
176/// # assert!(issues.is_ok());/// #
177/// let sql3 = "INSERT INTO account (`key`, `value`) VALUES ('foo', 42)
178///             ON DUPLICATE KEY UPDATE `value`=`value`+42";
179/// # let mut issues = Issues::new(sql3);
180/// let stmt3 = parse_statement(sql3, &mut issues, &options);
181/// # assert!(issues.is_ok());
182///
183/// let i: InsertReplace = match stmt1 {
184///     Some(Statement::InsertReplace(i)) if matches!(i.type_, InsertReplaceType::Insert(_)) => *i,
185///     _ => panic!("We should get an insert statement")
186/// };
187///
188/// assert!(i.table.identifier.as_str() == "person");
189/// println!("{:#?}", i.values.unwrap());
190///
191///
192/// let sql = "REPLACE INTO t2 VALUES (1,'Leopard'),(2,'Dog')";
193/// # let mut issues = Issues::new(sql);
194/// let stmt = parse_statement(sql, &mut issues, &options);
195/// # assert!(issues.is_ok());
196/// #
197/// let r: InsertReplace = match stmt {
198///     Some(Statement::InsertReplace(r)) if matches!(r.type_, InsertReplaceType::Replace(_)) => *r,
199///     _ => panic!("We should get an replace statement")
200/// };
201///
202/// assert!(r.table.identifier.as_str() == "t2");
203/// println!("{:#?}", r.values.unwrap());
204/// ```
205///
206/// PostgreSQL
207/// ```
208/// # use qusql_parse::{SQLDialect, SQLArguments, ParseOptions, parse_statement, InsertReplace, InsertReplaceType, Statement, Issues};
209/// # let options = ParseOptions::new().dialect(SQLDialect::PostgreSQL).arguments(SQLArguments::Dollar);
210/// #
211///
212/// let sql4 = "INSERT INTO contractor SELECT * FROM person WHERE status = $1 ON CONFLICT (name) DO NOTHING";
213/// # let mut issues = Issues::new(sql4);
214/// let stmt4 = parse_statement(sql4, &mut issues, &options);
215///
216/// println!("{}", issues);
217/// # assert!(issues.is_ok());
218/// ```
219#[derive(Clone, Debug)]
220pub struct InsertReplace<'a> {
221    /// Span of "INSERT" or "REPLACE"
222    pub type_: InsertReplaceType,
223    /// Flags specified after "INSERT"
224    pub flags: Vec<InsertReplaceFlag>,
225    /// Span of "INTO" if specified
226    pub into_span: Option<Span>,
227    /// Table to insert into
228    pub table: QualifiedName<'a>,
229    /// List of columns to set
230    pub columns: Vec<Identifier<'a>>,
231    /// Span of values "VALUES" and list of tuples to insert if specified
232    pub values: Option<(Span, Vec<Vec<Expression<'a>>>)>,
233    /// Select statement to insert if specified
234    pub select: Option<Select<'a>>,
235    /// Span of "SET" and list of key, value pairs to set if specified
236    pub set: Option<InsertReplaceSet<'a>>,
237    /// Updates to execute on duplicate key (mysql)
238    pub on_duplicate_key_update: Option<InsertReplaceOnDuplicateKeyUpdate<'a>>,
239    /// Action to take on duplicate keys (postgresql)
240    pub on_conflict: Option<OnConflict<'a>>,
241    /// AS alias with optional column list (MySQL/MariaDB): AS alias [(col1, col2, ...)]
242    pub as_alias: Option<(Span, Identifier<'a>, Option<Vec<Identifier<'a>>>)>,
243    /// Span of "RETURNING" and select expressions after "RETURNING", if "RETURNING" is present
244    pub returning: Option<(Span, Vec<SelectExpr<'a>>)>,
245}
246
247impl<'a> Spanned for InsertReplace<'a> {
248    fn span(&self) -> Span {
249        self.type_
250            .join_span(&self.flags)
251            .join_span(&self.into_span)
252            .join_span(&self.table)
253            .join_span(&self.values)
254            .join_span(&self.select)
255            .join_span(&self.set)
256            .join_span(&self.as_alias)
257            .join_span(&self.on_duplicate_key_update)
258            .join_span(&self.on_conflict)
259            .join_span(&self.returning)
260    }
261}
262
263pub(crate) fn parse_insert_replace<'a>(
264    parser: &mut Parser<'a, '_>,
265) -> Result<InsertReplace<'a>, ParseError> {
266    let type_ = match &parser.token {
267        Token::Ident(_, Keyword::INSERT) => InsertReplaceType::Insert(parser.consume()),
268        Token::Ident(_, Keyword::REPLACE) => InsertReplaceType::Replace(parser.consume()),
269        _ => parser.expected_failure("INSERT or REPLACE")?,
270    };
271
272    let insert = matches!(type_, InsertReplaceType::Insert(_));
273
274    let mut flags = Vec::new();
275    loop {
276        match &parser.token {
277            Token::Ident(_, Keyword::LOW_PRIORITY) => flags.push(InsertReplaceFlag::LowPriority(
278                parser.consume_keyword(Keyword::LOW_PRIORITY)?,
279            )),
280            Token::Ident(_, Keyword::HIGH_PRIORITY) => flags.push(InsertReplaceFlag::HighPriority(
281                parser.consume_keyword(Keyword::HIGH_PRIORITY)?,
282            )),
283            Token::Ident(_, Keyword::DELAYED) => flags.push(InsertReplaceFlag::Delayed(
284                parser.consume_keyword(Keyword::DELAYED)?,
285            )),
286            Token::Ident(_, Keyword::IGNORE) => flags.push(InsertReplaceFlag::Ignore(
287                parser.consume_keyword(Keyword::IGNORE)?,
288            )),
289            _ => break,
290        }
291    }
292
293    for flag in &flags {
294        match flag {
295            InsertReplaceFlag::LowPriority(_) => {}
296            InsertReplaceFlag::HighPriority(s) => {
297                if !insert {
298                    parser.err("Not supported for replace", s);
299                }
300            }
301            InsertReplaceFlag::Delayed(_) => {}
302            InsertReplaceFlag::Ignore(s) => {
303                if !insert {
304                    parser.err("Not supported for replace", s);
305                }
306            }
307        }
308    }
309
310    let into_span = parser.skip_keyword(Keyword::INTO);
311    let table = parse_qualified_name_unreserved(parser)?;
312    // [PARTITION (partition_list)]
313
314    let mut columns = Vec::new();
315    if parser.skip_token(Token::LParen).is_some() {
316        // Check for empty column list ()
317        if !matches!(parser.token, Token::RParen) {
318            parser.recovered(")", &|t| t == &Token::RParen, |parser| {
319                loop {
320                    columns.push(parser.consume_plain_identifier_unreserved()?);
321                    if parser.skip_token(Token::Comma).is_none() {
322                        break;
323                    }
324                }
325                Ok(())
326            })?;
327        }
328        parser.consume_token(Token::RParen)?;
329    }
330
331    // Parse AS alias before VALUES/SELECT/SET (PostgreSQL style)
332    let as_alias_before = if let Some(as_span) = parser.skip_keyword(Keyword::AS) {
333        let alias = parser.consume_plain_identifier_unreserved()?;
334        let columns = if parser.skip_token(Token::LParen).is_some() {
335            let mut cols = Vec::new();
336            // Check for empty column list ()
337            if !matches!(parser.token, Token::RParen) {
338                loop {
339                    cols.push(parser.consume_plain_identifier_unreserved()?);
340                    if parser.skip_token(Token::Comma).is_none() {
341                        break;
342                    }
343                }
344            }
345            parser.consume_token(Token::RParen)?;
346            Some(cols)
347        } else {
348            None
349        };
350        Some((as_span, alias, columns))
351    } else {
352        None
353    };
354
355    let mut select = None;
356    let mut values = None;
357    let mut set = None;
358    match &parser.token {
359        Token::Ident(_, Keyword::SELECT) => {
360            select = Some(parse_select(parser)?);
361        }
362        Token::Ident(_, Keyword::WITH) => {
363            // INSERT ... WITH [RECURSIVE] cte AS (...) SELECT ...
364            // Parse as a WithQuery and extract the inner SELECT.
365            use crate::with_query::parse_with_query;
366            let wq = parse_with_query(parser)?;
367            if let Statement::Select(s) = *wq.statement {
368                select = Some(*s);
369            } else {
370                parser.err("Expected SELECT after WITH", &wq.with_span);
371            }
372        }
373        Token::Ident(_, Keyword::VALUE | Keyword::VALUES) => {
374            let values_span = parser.consume();
375            let mut values_items = Vec::new();
376            loop {
377                let mut vals = Vec::new();
378                parser.consume_token(Token::LParen)?;
379                // Check for empty VALUES ()
380                if !matches!(parser.token, Token::RParen) {
381                    parser.recovered(")", &|t| t == &Token::RParen, |parser| {
382                        loop {
383                            vals.push(parse_expression_or_default(parser, PRIORITY_MAX)?);
384                            if parser.skip_token(Token::Comma).is_none() {
385                                break;
386                            }
387                        }
388                        Ok(())
389                    })?;
390                }
391                parser.consume_token(Token::RParen)?;
392                values_items.push(vals);
393                if parser.skip_token(Token::Comma).is_none() {
394                    break;
395                }
396            }
397            values = Some((values_span, values_items));
398        }
399        Token::Ident(_, Keyword::SET) => {
400            let set_span = parser.consume_keyword(Keyword::SET)?;
401            let mut pairs = Vec::new();
402            loop {
403                let column = parser.consume_plain_identifier_unreserved()?;
404                let equal_span = parser.consume_token(Token::Eq)?;
405                let value: Expression<'_> = parse_expression_or_default(parser, PRIORITY_MAX)?;
406                pairs.push(InsertReplaceSetPair {
407                    column,
408                    equal_span,
409                    value,
410                });
411                if parser.skip_token(Token::Comma).is_none() {
412                    break;
413                }
414            }
415            if let Some(cs) = columns.opt_span() {
416                parser
417                    .err("Columns may not be used here", &cs)
418                    .frag("Together with SET", &set_span);
419            }
420            set = Some(InsertReplaceSet { set_span, pairs });
421        }
422        _ => {
423            parser.expected_error("VALUE, VALUES, SELECT or SET");
424        }
425    }
426
427    let (on_duplicate_key_update, on_conflict) =
428        if matches!(parser.token, Token::Ident(_, Keyword::ON)) {
429            let on = parser.consume_keyword(Keyword::ON)?;
430            match &parser.token {
431                Token::Ident(_, Keyword::DUPLICATE) => {
432                    let on_duplicate_key_update_span =
433                        on.join_span(&parser.consume_keywords(&[
434                            Keyword::DUPLICATE,
435                            Keyword::KEY,
436                            Keyword::UPDATE,
437                        ])?);
438                    let mut pairs = Vec::new();
439                    loop {
440                        let column = parser.consume_plain_identifier_unreserved()?;
441                        let equal_span = parser.consume_token(Token::Eq)?;
442                        let value = parse_expression_or_default(parser, PRIORITY_MAX)?;
443                        pairs.push(InsertReplaceSetPair {
444                            column,
445                            equal_span,
446                            value,
447                        });
448                        if parser.skip_token(Token::Comma).is_none() {
449                            break;
450                        }
451                    }
452                    parser.maria_only(&on_duplicate_key_update_span.join_span(&pairs));
453                    (
454                        Some(InsertReplaceOnDuplicateKeyUpdate {
455                            on_duplicate_key_update_span,
456                            pairs,
457                        }),
458                        None,
459                    )
460                }
461                Token::Ident(_, Keyword::CONFLICT) => {
462                    let on_conflict_span =
463                        on.join_span(&parser.consume_keyword(Keyword::CONFLICT)?);
464
465                    let target = match &parser.token {
466                        Token::LParen => {
467                            parser.consume_token(Token::LParen)?;
468                            let mut names = Vec::new();
469                            names.push(parser.consume_plain_identifier_unreserved()?);
470                            while parser.skip_token(Token::Comma).is_some() {
471                                names.push(parser.consume_plain_identifier_unreserved()?);
472                            }
473                            parser.consume_token(Token::RParen)?;
474                            OnConflictTarget::Columns { names }
475                        }
476                        Token::Ident(_, Keyword::ON) => {
477                            let on_constraint =
478                                parser.consume_keywords(&[Keyword::ON, Keyword::CONSTRAINT])?;
479                            let name = parser.consume_plain_identifier_unreserved()?;
480                            OnConflictTarget::OnConstraint {
481                                on_constraint_span: on_constraint,
482                                name,
483                            }
484                        }
485                        _ => OnConflictTarget::None,
486                    };
487
488                    let do_ = parser.consume_keyword(Keyword::DO)?;
489                    let action = match &parser.token {
490                        Token::Ident(_, Keyword::NOTHING) => OnConflictAction::DoNothing(
491                            do_.join_span(&parser.consume_keyword(Keyword::NOTHING)?),
492                        ),
493                        Token::Ident(_, Keyword::UPDATE) => {
494                            let do_update_set_span = do_.join_span(
495                                &parser.consume_keywords(&[Keyword::UPDATE, Keyword::SET])?,
496                            );
497                            let mut sets = Vec::new();
498                            loop {
499                                let name = parser.consume_plain_identifier_unreserved()?;
500                                parser.consume_token(Token::Eq)?;
501                                let expr = parse_expression_or_default(parser, PRIORITY_MAX)?;
502                                sets.push((name, expr));
503                                if parser.skip_token(Token::Comma).is_none() {
504                                    break;
505                                }
506                            }
507                            let where_ = if matches!(parser.token, Token::Ident(_, Keyword::WHERE))
508                            {
509                                let where_span = parser.consume_keyword(Keyword::WHERE)?;
510                                let where_expr = parse_expression_unreserved(parser, PRIORITY_MAX)?;
511                                Some((where_span, where_expr))
512                            } else {
513                                None
514                            };
515                            OnConflictAction::DoUpdateSet {
516                                do_update_set_span,
517                                sets,
518                                where_,
519                            }
520                        }
521                        _ => parser.expected_failure("'NOTHING' or 'UPDATE'")?,
522                    };
523
524                    let on_conflict = OnConflict {
525                        on_conflict_span,
526                        target,
527                        action,
528                    };
529
530                    parser.postgres_only(&on_conflict);
531
532                    (None, Some(on_conflict))
533                }
534                _ => parser.expected_failure("'DUPLICATE' OR 'CONFLICT'")?,
535            }
536        } else {
537            (None, None)
538        };
539
540    // Parse AS alias after VALUES/SELECT/SET (MySQL/MariaDB style) if not already parsed
541    let as_alias = if as_alias_before.is_none() {
542        if let Some(as_span) = parser.skip_keyword(Keyword::AS) {
543            let alias = parser.consume_plain_identifier_unreserved()?;
544            let columns = if parser.skip_token(Token::LParen).is_some() {
545                let mut cols = Vec::new();
546                // Check for empty column list ()
547                if !matches!(parser.token, Token::RParen) {
548                    loop {
549                        cols.push(parser.consume_plain_identifier_unreserved()?);
550                        if parser.skip_token(Token::Comma).is_none() {
551                            break;
552                        }
553                    }
554                }
555                parser.consume_token(Token::RParen)?;
556                Some(cols)
557            } else {
558                None
559            };
560            Some((as_span, alias, columns))
561        } else {
562            None
563        }
564    } else {
565        as_alias_before
566    };
567
568    let returning = if let Some(returning_span) = parser.skip_keyword(Keyword::RETURNING) {
569        let mut returning_exprs = Vec::new();
570        loop {
571            returning_exprs.push(parse_select_expr(parser)?);
572            if parser.skip_token(Token::Comma).is_none() {
573                break;
574            }
575        }
576        Some((returning_span, returning_exprs))
577    } else {
578        None
579    };
580
581    Ok(InsertReplace {
582        type_,
583        flags,
584        table,
585        columns,
586        into_span,
587        values,
588        select,
589        set,
590        as_alias,
591        on_duplicate_key_update,
592        on_conflict,
593        returning,
594    })
595}