Skip to main content

qusql_parse/
insert_replace.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12use alloc::vec::Vec;
13
14use crate::{
15    Identifier, OptSpanned, QualifiedName, Span, Spanned, Statement,
16    expression::{
17        Expression, PRIORITY_MAX, parse_expression_or_default, parse_expression_unreserved,
18    },
19    keywords::Keyword,
20    lexer::Token,
21    parser::{ParseError, Parser},
22    qualified_name::parse_qualified_name_unreserved,
23    select::{SelectExpr, parse_select_expr},
24    statement::parse_compound_query,
25};
26
27/// Flags for insert
28#[derive(Clone, Debug)]
29pub enum InsertReplaceFlag {
30    LowPriority(Span),
31    HighPriority(Span),
32    Delayed(Span),
33    Ignore(Span),
34}
35
36impl Spanned for InsertReplaceFlag {
37    fn span(&self) -> Span {
38        match &self {
39            InsertReplaceFlag::LowPriority(v) => v.span(),
40            InsertReplaceFlag::HighPriority(v) => v.span(),
41            InsertReplaceFlag::Delayed(v) => v.span(),
42            InsertReplaceFlag::Ignore(v) => v.span(),
43        }
44    }
45}
46
47#[derive(Clone, Debug)]
48pub enum InsertReplaceType {
49    Insert(Span),
50    Replace(Span),
51}
52
53impl Spanned for InsertReplaceType {
54    fn span(&self) -> Span {
55        match self {
56            InsertReplaceType::Insert(a) => a.clone(),
57            InsertReplaceType::Replace(a) => a.clone(),
58        }
59    }
60}
61
62#[derive(Clone, Debug)]
63pub enum OnConflictTarget<'a> {
64    Columns {
65        names: Vec<Identifier<'a>>,
66    },
67    OnConstraint {
68        on_constraint_span: Span,
69        name: Identifier<'a>,
70    },
71    None,
72}
73
74impl<'a> OptSpanned for OnConflictTarget<'a> {
75    fn opt_span(&self) -> Option<Span> {
76        match self {
77            OnConflictTarget::Columns { names } => names.opt_span(),
78            OnConflictTarget::OnConstraint {
79                on_constraint_span: token,
80                name,
81            } => Some(token.join_span(name)),
82            OnConflictTarget::None => None,
83        }
84    }
85}
86
87#[derive(Clone, Debug)]
88pub enum OnConflictAction<'a> {
89    DoNothing(Span),
90    DoUpdateSet {
91        do_update_set_span: Span,
92        sets: Vec<(Identifier<'a>, Expression<'a>)>,
93        where_: Option<(Span, Expression<'a>)>,
94    },
95}
96
97impl<'a> Spanned for OnConflictAction<'a> {
98    fn span(&self) -> Span {
99        match self {
100            OnConflictAction::DoNothing(span) => span.span(),
101            OnConflictAction::DoUpdateSet {
102                do_update_set_span,
103                sets,
104                where_,
105            } => do_update_set_span.join_span(sets).join_span(where_),
106        }
107    }
108}
109
110#[derive(Clone, Debug)]
111pub struct OnConflict<'a> {
112    pub on_conflict_span: Span,
113    pub target: OnConflictTarget<'a>,
114    pub action: OnConflictAction<'a>,
115}
116
117impl<'a> Spanned for OnConflict<'a> {
118    fn span(&self) -> Span {
119        self.on_conflict_span
120            .join_span(&self.target)
121            .join_span(&self.action)
122    }
123}
124
125#[derive(Clone, Debug)]
126pub struct InsertReplaceSetPair<'a> {
127    pub column: Identifier<'a>,
128    pub equal_span: Span,
129    pub value: Expression<'a>,
130}
131
132impl<'a> Spanned for InsertReplaceSetPair<'a> {
133    fn span(&self) -> Span {
134        self.column
135            .join_span(&self.equal_span)
136            .join_span(&self.value)
137    }
138}
139
140#[derive(Clone, Debug)]
141pub struct InsertReplaceSet<'a> {
142    pub set_span: Span,
143    pub pairs: Vec<InsertReplaceSetPair<'a>>,
144}
145
146impl<'a> Spanned for InsertReplaceSet<'a> {
147    fn span(&self) -> Span {
148        self.set_span.join_span(&self.pairs)
149    }
150}
151
152#[derive(Clone, Debug)]
153pub struct InsertReplaceOnDuplicateKeyUpdate<'a> {
154    pub on_duplicate_key_update_span: Span,
155    pub pairs: Vec<InsertReplaceSetPair<'a>>,
156}
157
158impl<'a> Spanned for InsertReplaceOnDuplicateKeyUpdate<'a> {
159    fn span(&self) -> Span {
160        self.on_duplicate_key_update_span.join_span(&self.pairs)
161    }
162}
163
164/// Representation of Insert or Replace Statement
165///
166/// ```
167/// # use qusql_parse::{SQLDialect, SQLArguments, ParseOptions, parse_statement, InsertReplace, InsertReplaceType, Statement, Issues};
168/// # let options = ParseOptions::new().dialect(SQLDialect::MariaDB);
169/// #
170/// let sql1 = "INSERT INTO person (first_name, last_name) VALUES ('John', 'Doe')";
171/// # let mut issues = Issues::new(sql1);
172/// let stmt1 = parse_statement(sql1, &mut issues, &options);
173/// # assert!(issues.is_ok());/// #
174/// let sql2 = "INSERT INTO contractor SELECT * FROM person WHERE status = 'c'";
175/// # let mut issues = Issues::new(sql2);
176/// let stmt2 = parse_statement(sql2, &mut issues, &options);
177/// # assert!(issues.is_ok());/// #
178/// let sql3 = "INSERT INTO account (`key`, `value`) VALUES ('foo', 42)
179///             ON DUPLICATE KEY UPDATE `value`=`value`+42";
180/// # let mut issues = Issues::new(sql3);
181/// let stmt3 = parse_statement(sql3, &mut issues, &options);
182/// # assert!(issues.is_ok());
183///
184/// let i: InsertReplace = match stmt1 {
185///     Some(Statement::InsertReplace(i)) if matches!(i.type_, InsertReplaceType::Insert(_)) => *i,
186///     _ => panic!("We should get an insert statement")
187/// };
188///
189/// assert!(i.table.identifier.as_str() == "person");
190/// println!("{:#?}", i.values.unwrap());
191///
192///
193/// let sql = "REPLACE INTO t2 VALUES (1,'Leopard'),(2,'Dog')";
194/// # let mut issues = Issues::new(sql);
195/// let stmt = parse_statement(sql, &mut issues, &options);
196/// # assert!(issues.is_ok());
197/// #
198/// let r: InsertReplace = match stmt {
199///     Some(Statement::InsertReplace(r)) if matches!(r.type_, InsertReplaceType::Replace(_)) => *r,
200///     _ => panic!("We should get an replace statement")
201/// };
202///
203/// assert!(r.table.identifier.as_str() == "t2");
204/// println!("{:#?}", r.values.unwrap());
205/// ```
206///
207/// PostgreSQL
208/// ```
209/// # use qusql_parse::{SQLDialect, SQLArguments, ParseOptions, parse_statement, InsertReplace, InsertReplaceType, Statement, Issues};
210/// # let options = ParseOptions::new().dialect(SQLDialect::PostgreSQL).arguments(SQLArguments::Dollar);
211/// #
212///
213/// let sql4 = "INSERT INTO contractor SELECT * FROM person WHERE status = $1 ON CONFLICT (name) DO NOTHING";
214/// # let mut issues = Issues::new(sql4);
215/// let stmt4 = parse_statement(sql4, &mut issues, &options);
216///
217/// println!("{}", issues);
218/// # assert!(issues.is_ok());
219/// ```
220#[derive(Clone, Debug)]
221pub struct InsertReplace<'a> {
222    /// Span of "INSERT" or "REPLACE"
223    pub type_: InsertReplaceType,
224    /// Flags specified after "INSERT"
225    pub flags: Vec<InsertReplaceFlag>,
226    /// Span of "INTO" if specified
227    pub into_span: Option<Span>,
228    /// Table to insert into
229    pub table: QualifiedName<'a>,
230    /// List of columns to set
231    pub columns: Vec<Identifier<'a>>,
232    /// Span of values "VALUES" and list of tuples to insert if specified
233    pub values: Option<(Span, Vec<Vec<Expression<'a>>>)>,
234    /// Select statement (possibly compound with UNION/INTERSECT/EXCEPT) to insert if specified
235    pub select: Option<Statement<'a>>,
236    /// Span of "SET" and list of key, value pairs to set if specified
237    pub set: Option<InsertReplaceSet<'a>>,
238    /// Updates to execute on duplicate key (mysql)
239    pub on_duplicate_key_update: Option<InsertReplaceOnDuplicateKeyUpdate<'a>>,
240    /// Action to take on duplicate keys (postgresql)
241    pub on_conflict: Option<OnConflict<'a>>,
242    /// AS alias with optional column list (MySQL/MariaDB): AS alias [(col1, col2, ...)]
243    pub as_alias: Option<(Span, Identifier<'a>, Option<Vec<Identifier<'a>>>)>,
244    /// Span of "RETURNING" and select expressions after "RETURNING", if "RETURNING" is present
245    pub returning: Option<(Span, Vec<SelectExpr<'a>>)>,
246}
247
248impl<'a> Spanned for InsertReplace<'a> {
249    fn span(&self) -> Span {
250        self.type_
251            .join_span(&self.flags)
252            .join_span(&self.into_span)
253            .join_span(&self.table)
254            .join_span(&self.values)
255            .join_span(&self.select)
256            .join_span(&self.set)
257            .join_span(&self.as_alias)
258            .join_span(&self.on_duplicate_key_update)
259            .join_span(&self.on_conflict)
260            .join_span(&self.returning)
261    }
262}
263
264pub(crate) fn parse_insert_replace<'a>(
265    parser: &mut Parser<'a, '_>,
266) -> Result<InsertReplace<'a>, ParseError> {
267    let type_ = match &parser.token {
268        Token::Ident(_, Keyword::INSERT) => InsertReplaceType::Insert(parser.consume()),
269        Token::Ident(_, Keyword::REPLACE) => InsertReplaceType::Replace(parser.consume()),
270        _ => parser.expected_failure("INSERT or REPLACE")?,
271    };
272
273    let insert = matches!(type_, InsertReplaceType::Insert(_));
274
275    let mut flags = Vec::new();
276    loop {
277        match &parser.token {
278            Token::Ident(_, Keyword::LOW_PRIORITY) => flags.push(InsertReplaceFlag::LowPriority(
279                parser.consume_keyword(Keyword::LOW_PRIORITY)?,
280            )),
281            Token::Ident(_, Keyword::HIGH_PRIORITY) => flags.push(InsertReplaceFlag::HighPriority(
282                parser.consume_keyword(Keyword::HIGH_PRIORITY)?,
283            )),
284            Token::Ident(_, Keyword::DELAYED) => flags.push(InsertReplaceFlag::Delayed(
285                parser.consume_keyword(Keyword::DELAYED)?,
286            )),
287            Token::Ident(_, Keyword::IGNORE) => flags.push(InsertReplaceFlag::Ignore(
288                parser.consume_keyword(Keyword::IGNORE)?,
289            )),
290            _ => break,
291        }
292    }
293
294    for flag in &flags {
295        match flag {
296            InsertReplaceFlag::LowPriority(_) => {}
297            InsertReplaceFlag::HighPriority(s) => {
298                if !insert {
299                    parser.err("Not supported for replace", s);
300                }
301            }
302            InsertReplaceFlag::Delayed(_) => {}
303            InsertReplaceFlag::Ignore(s) => {
304                if !insert {
305                    parser.err("Not supported for replace", s);
306                }
307            }
308        }
309    }
310
311    let into_span = parser.skip_keyword(Keyword::INTO);
312    let table = parse_qualified_name_unreserved(parser)?;
313    // [PARTITION (partition_list)]
314
315    let mut columns = Vec::new();
316    if parser.skip_token(Token::LParen).is_some() {
317        // Check for empty column list ()
318        if !matches!(parser.token, Token::RParen) {
319            parser.recovered(")", &|t| t == &Token::RParen, |parser| {
320                loop {
321                    columns.push(parser.consume_plain_identifier_unreserved()?);
322                    if parser.skip_token(Token::Comma).is_none() {
323                        break;
324                    }
325                }
326                Ok(())
327            })?;
328        }
329        parser.consume_token(Token::RParen)?;
330    }
331
332    // Parse AS alias before VALUES/SELECT/SET (PostgreSQL style)
333    let as_alias_before = if let Some(as_span) = parser.skip_keyword(Keyword::AS) {
334        let alias = parser.consume_plain_identifier_unreserved()?;
335        let columns = if parser.skip_token(Token::LParen).is_some() {
336            let mut cols = Vec::new();
337            // Check for empty column list ()
338            if !matches!(parser.token, Token::RParen) {
339                loop {
340                    cols.push(parser.consume_plain_identifier_unreserved()?);
341                    if parser.skip_token(Token::Comma).is_none() {
342                        break;
343                    }
344                }
345            }
346            parser.consume_token(Token::RParen)?;
347            Some(cols)
348        } else {
349            None
350        };
351        Some((as_span, alias, columns))
352    } else {
353        None
354    };
355
356    let mut select = None;
357    let mut values = None;
358    let mut set = None;
359    match &parser.token {
360        Token::Ident(_, Keyword::SELECT) | Token::LParen => {
361            select = Some(parse_compound_query(parser)?);
362        }
363        Token::Ident(_, Keyword::WITH) => {
364            // INSERT ... WITH [RECURSIVE] cte AS (...) SELECT ...
365            use crate::with_query::parse_with_query;
366            let wq = parse_with_query(parser)?;
367            select = Some(Statement::WithQuery(alloc::boxed::Box::new(wq)));
368        }
369        Token::Ident(_, Keyword::VALUE | Keyword::VALUES) => {
370            let values_span = parser.consume();
371            let mut values_items = Vec::new();
372            loop {
373                let mut vals = Vec::new();
374                parser.consume_token(Token::LParen)?;
375                // Check for empty VALUES ()
376                if !matches!(parser.token, Token::RParen) {
377                    parser.recovered(")", &|t| t == &Token::RParen, |parser| {
378                        loop {
379                            vals.push(parse_expression_or_default(parser, PRIORITY_MAX)?);
380                            if parser.skip_token(Token::Comma).is_none() {
381                                break;
382                            }
383                        }
384                        Ok(())
385                    })?;
386                }
387                parser.consume_token(Token::RParen)?;
388                values_items.push(vals);
389                if parser.skip_token(Token::Comma).is_none() {
390                    break;
391                }
392            }
393            values = Some((values_span, values_items));
394        }
395        Token::Ident(_, Keyword::SET) => {
396            let set_span = parser.consume_keyword(Keyword::SET)?;
397            let mut pairs = Vec::new();
398            loop {
399                let column = parser.consume_plain_identifier_unreserved()?;
400                let equal_span = parser.consume_token(Token::Eq)?;
401                let value: Expression<'_> = parse_expression_or_default(parser, PRIORITY_MAX)?;
402                pairs.push(InsertReplaceSetPair {
403                    column,
404                    equal_span,
405                    value,
406                });
407                if parser.skip_token(Token::Comma).is_none() {
408                    break;
409                }
410            }
411            if let Some(cs) = columns.opt_span() {
412                parser
413                    .err("Columns may not be used here", &cs)
414                    .frag("Together with SET", &set_span);
415            }
416            set = Some(InsertReplaceSet { set_span, pairs });
417        }
418        _ => {
419            parser.expected_error("VALUE, VALUES, SELECT or SET");
420        }
421    }
422
423    let (on_duplicate_key_update, on_conflict) =
424        if matches!(parser.token, Token::Ident(_, Keyword::ON)) {
425            let on = parser.consume_keyword(Keyword::ON)?;
426            match &parser.token {
427                Token::Ident(_, Keyword::DUPLICATE) => {
428                    let on_duplicate_key_update_span =
429                        on.join_span(&parser.consume_keywords(&[
430                            Keyword::DUPLICATE,
431                            Keyword::KEY,
432                            Keyword::UPDATE,
433                        ])?);
434                    let mut pairs = Vec::new();
435                    loop {
436                        let column = parser.consume_plain_identifier_unreserved()?;
437                        let equal_span = parser.consume_token(Token::Eq)?;
438                        let value = parse_expression_or_default(parser, PRIORITY_MAX)?;
439                        pairs.push(InsertReplaceSetPair {
440                            column,
441                            equal_span,
442                            value,
443                        });
444                        if parser.skip_token(Token::Comma).is_none() {
445                            break;
446                        }
447                    }
448                    parser.maria_only(&on_duplicate_key_update_span.join_span(&pairs));
449                    (
450                        Some(InsertReplaceOnDuplicateKeyUpdate {
451                            on_duplicate_key_update_span,
452                            pairs,
453                        }),
454                        None,
455                    )
456                }
457                Token::Ident(_, Keyword::CONFLICT) => {
458                    let on_conflict_span =
459                        on.join_span(&parser.consume_keyword(Keyword::CONFLICT)?);
460
461                    let target = match &parser.token {
462                        Token::LParen => {
463                            parser.consume_token(Token::LParen)?;
464                            let mut names = Vec::new();
465                            names.push(parser.consume_plain_identifier_unreserved()?);
466                            while parser.skip_token(Token::Comma).is_some() {
467                                names.push(parser.consume_plain_identifier_unreserved()?);
468                            }
469                            parser.consume_token(Token::RParen)?;
470                            OnConflictTarget::Columns { names }
471                        }
472                        Token::Ident(_, Keyword::ON) => {
473                            let on_constraint =
474                                parser.consume_keywords(&[Keyword::ON, Keyword::CONSTRAINT])?;
475                            let name = parser.consume_plain_identifier_unreserved()?;
476                            OnConflictTarget::OnConstraint {
477                                on_constraint_span: on_constraint,
478                                name,
479                            }
480                        }
481                        _ => OnConflictTarget::None,
482                    };
483
484                    let do_ = parser.consume_keyword(Keyword::DO)?;
485                    let action = match &parser.token {
486                        Token::Ident(_, Keyword::NOTHING) => OnConflictAction::DoNothing(
487                            do_.join_span(&parser.consume_keyword(Keyword::NOTHING)?),
488                        ),
489                        Token::Ident(_, Keyword::UPDATE) => {
490                            let do_update_set_span = do_.join_span(
491                                &parser.consume_keywords(&[Keyword::UPDATE, Keyword::SET])?,
492                            );
493                            let mut sets = Vec::new();
494                            loop {
495                                let name = parser.consume_plain_identifier_unreserved()?;
496                                parser.consume_token(Token::Eq)?;
497                                let expr = parse_expression_or_default(parser, PRIORITY_MAX)?;
498                                sets.push((name, expr));
499                                if parser.skip_token(Token::Comma).is_none() {
500                                    break;
501                                }
502                            }
503                            let where_ = if matches!(parser.token, Token::Ident(_, Keyword::WHERE))
504                            {
505                                let where_span = parser.consume_keyword(Keyword::WHERE)?;
506                                let where_expr = parse_expression_unreserved(parser, PRIORITY_MAX)?;
507                                Some((where_span, where_expr))
508                            } else {
509                                None
510                            };
511                            OnConflictAction::DoUpdateSet {
512                                do_update_set_span,
513                                sets,
514                                where_,
515                            }
516                        }
517                        _ => parser.expected_failure("'NOTHING' or 'UPDATE'")?,
518                    };
519
520                    let on_conflict = OnConflict {
521                        on_conflict_span,
522                        target,
523                        action,
524                    };
525
526                    parser.postgres_only(&on_conflict);
527
528                    (None, Some(on_conflict))
529                }
530                _ => parser.expected_failure("'DUPLICATE' OR 'CONFLICT'")?,
531            }
532        } else {
533            (None, None)
534        };
535
536    // Parse AS alias after VALUES/SELECT/SET (MySQL/MariaDB style) if not already parsed
537    let as_alias = if as_alias_before.is_none() {
538        if let Some(as_span) = parser.skip_keyword(Keyword::AS) {
539            let alias = parser.consume_plain_identifier_unreserved()?;
540            let columns = if parser.skip_token(Token::LParen).is_some() {
541                let mut cols = Vec::new();
542                // Check for empty column list ()
543                if !matches!(parser.token, Token::RParen) {
544                    loop {
545                        cols.push(parser.consume_plain_identifier_unreserved()?);
546                        if parser.skip_token(Token::Comma).is_none() {
547                            break;
548                        }
549                    }
550                }
551                parser.consume_token(Token::RParen)?;
552                Some(cols)
553            } else {
554                None
555            };
556            Some((as_span, alias, columns))
557        } else {
558            None
559        }
560    } else {
561        as_alias_before
562    };
563
564    let returning = if let Some(returning_span) = parser.skip_keyword(Keyword::RETURNING) {
565        let mut returning_exprs = Vec::new();
566        loop {
567            returning_exprs.push(parse_select_expr(parser)?);
568            if parser.skip_token(Token::Comma).is_none() {
569                break;
570            }
571        }
572        Some((returning_span, returning_exprs))
573    } else {
574        None
575    };
576
577    Ok(InsertReplace {
578        type_,
579        flags,
580        table,
581        columns,
582        into_span,
583        values,
584        select,
585        set,
586        as_alias,
587        on_duplicate_key_update,
588        on_conflict,
589        returning,
590    })
591}