Skip to main content

qusql_parse/
create_trigger.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12use crate::{
13    Expression, Identifier, QualifiedName, Span, Spanned, Statement,
14    create_option::CreateOption,
15    expression::{PRIORITY_MAX, parse_expression_unreserved},
16    keywords::Keyword,
17    lexer::Token,
18    parser::{ParseError, Parser},
19    qualified_name::parse_qualified_name_unreserved,
20    statement::parse_statement,
21};
22use alloc::{boxed::Box, vec::Vec};
23
24/// PostgreSQL trigger EXECUTE FUNCTION func_name(args...) body
25#[derive(Clone, Debug)]
26pub struct ExecuteFunction<'a> {
27    /// Span of "EXECUTE FUNCTION" or "EXECUTE PROCEDURE"
28    pub execute_span: Span,
29    /// Name of the function to execute
30    pub func_name: QualifiedName<'a>,
31    /// Arguments passed to the function
32    pub args: Vec<Expression<'a>>,
33}
34
35impl<'a> Spanned for ExecuteFunction<'a> {
36    fn span(&self) -> Span {
37        self.execute_span
38            .join_span(&self.func_name)
39            .join_span(&self.args)
40    }
41}
42
43/// Whether the trigger fires once per row or once per statement
44#[derive(Clone, Debug)]
45pub enum TriggerForEach {
46    Row(Span),
47    Statement(Span),
48}
49
50impl Spanned for TriggerForEach {
51    fn span(&self) -> Span {
52        match self {
53            TriggerForEach::Row(s) => s.clone(),
54            TriggerForEach::Statement(s) => s.clone(),
55        }
56    }
57}
58
59/// When to fire the trigger
60#[derive(Clone, Debug)]
61pub enum TriggerTime {
62    Before(Span),
63    After(Span),
64    InsteadOf(Span),
65}
66
67impl Spanned for TriggerTime {
68    fn span(&self) -> Span {
69        match &self {
70            TriggerTime::Before(v) => v.span(),
71            TriggerTime::After(v) => v.span(),
72            TriggerTime::InsteadOf(v) => v.span(),
73        }
74    }
75}
76
77/// On what event to fire the trigger
78#[derive(Clone, Debug)]
79pub enum TriggerEvent {
80    Update(Span),
81    Insert(Span),
82    Delete(Span),
83    Truncate(Span),
84}
85
86impl Spanned for TriggerEvent {
87    fn span(&self) -> Span {
88        match &self {
89            TriggerEvent::Update(v) => v.span(),
90            TriggerEvent::Insert(v) => v.span(),
91            TriggerEvent::Delete(v) => v.span(),
92            TriggerEvent::Truncate(v) => v.span(),
93        }
94    }
95}
96
97#[derive(Clone, Debug)]
98pub enum TriggerReferenceDirection {
99    New(Span),
100    Old(Span),
101}
102
103impl Spanned for TriggerReferenceDirection {
104    fn span(&self) -> Span {
105        match &self {
106            TriggerReferenceDirection::New(v) => v.span(),
107            TriggerReferenceDirection::Old(v) => v.span(),
108        }
109    }
110}
111
112#[derive(Clone, Debug)]
113pub struct TriggerReference<'a> {
114    direction: TriggerReferenceDirection,
115    table_as_span: Span,
116    alias: Identifier<'a>,
117}
118
119impl Spanned for TriggerReference<'_> {
120    fn span(&self) -> Span {
121        self.direction
122            .join_span(&self.table_as_span)
123            .join_span(&self.alias)
124    }
125}
126
127/// Represent a create trigger statement
128/// ```
129/// # use qusql_parse::{SQLDialect, SQLArguments, ParseOptions, parse_statements, CreateTrigger, Statement, Issues};
130/// # let options = ParseOptions::new().dialect(SQLDialect::MariaDB);
131/// #
132/// let sql = "DROP TRIGGER IF EXISTS `my_trigger`;
133/// DELIMITER $$
134/// CREATE TRIGGER `my_trigger` AFTER DELETE ON `things` FOR EACH ROW BEGIN
135///     IF OLD.`value` IS NOT NULL THEN
136///         UPDATE `t2` AS `j`
137///             SET
138///             `j`.`total_items` = `total_items` - 1
139///             WHERE `j`.`id`=OLD.`value` AND NOT `j`.`frozen`;
140///         END IF;
141///     INSERT INTO `updated_things` (`thing`) VALUES (OLD.`id`);
142/// END
143/// $$
144/// DELIMITER ;";
145/// let mut issues = Issues::new(sql);
146/// let mut stmts = parse_statements(sql, &mut issues, &options);
147///
148/// # assert_eq!(issues.get(), &[]);
149/// #
150/// let create: CreateTrigger = match stmts.pop() {
151///     Some(Statement::CreateTrigger(c)) => *c,
152///     _ => panic!("We should get an create trigger statement")
153/// };
154///
155/// assert!(create.name.as_str() == "my_trigger");
156/// println!("{:#?}", create.statement)
157/// ```
158#[derive(Clone, Debug)]
159pub struct CreateTrigger<'a> {
160    /// Span of "CREATE"
161    pub create_span: Span,
162    /// Options after "CREATE"
163    pub create_options: Vec<CreateOption<'a>>,
164    /// Span of "TRIGGER"
165    pub trigger_span: Span,
166    /// Span of "IF NOT EXISTS" if specified
167    pub if_not_exists: Option<Span>,
168    /// Name of the created trigger
169    pub name: Identifier<'a>,
170    /// Should the trigger be fired before or after the event
171    pub trigger_time: TriggerTime,
172    /// What events should the trigger be fired on (multiple events joined by OR)
173    pub trigger_events: Vec<TriggerEvent>,
174    /// Span of "ON"
175    pub on_span: Span,
176    /// Name of table to create the trigger on
177    pub table: Identifier<'a>,
178    /// Whether the trigger fires once per row or once per statement (None if omitted, PostgreSQL only)
179    pub for_each: Option<TriggerForEach>,
180    /// Optional REFERENCING NEW TABLE AS alias / OLD TABLE AS alias clauses
181    pub referencing: Vec<TriggerReference<'a>>,
182    /// Optional WHEN (condition)
183    pub when_condition: Option<(Span, Expression<'a>)>,
184    /// Statement to execute
185    pub statement: Statement<'a>,
186}
187
188impl<'a> Spanned for CreateTrigger<'a> {
189    fn span(&self) -> Span {
190        self.create_span
191            .join_span(&self.create_options)
192            .join_span(&self.trigger_span)
193            .join_span(&self.if_not_exists)
194            .join_span(&self.name)
195            .join_span(&self.trigger_time)
196            .join_span(&self.trigger_events)
197            .join_span(&self.on_span)
198            .join_span(&self.table)
199            .join_span(&self.for_each)
200            .join_span(&self.referencing)
201            .join_span(&self.when_condition.as_ref().map(|(s, e)| s.join_span(e)))
202            .join_span(&self.statement)
203    }
204}
205
206pub(crate) fn parse_create_trigger<'a>(
207    parser: &mut Parser<'a, '_>,
208    create_span: Span,
209    create_options: Vec<CreateOption<'a>>,
210) -> Result<CreateTrigger<'a>, ParseError> {
211    let trigger_span = parser.consume_keyword(Keyword::TRIGGER)?;
212
213    let if_not_exists = if let Some(if_) = parser.skip_keyword(Keyword::IF) {
214        Some(
215            parser
216                .consume_keywords(&[Keyword::NOT, Keyword::EXISTS])?
217                .join_span(&if_),
218        )
219    } else {
220        None
221    };
222
223    let name = parser.consume_plain_identifier_unreserved()?;
224
225    let trigger_time = match &parser.token {
226        Token::Ident(_, Keyword::AFTER) => {
227            TriggerTime::After(parser.consume_keyword(Keyword::AFTER)?)
228        }
229        Token::Ident(_, Keyword::BEFORE) => {
230            TriggerTime::Before(parser.consume_keyword(Keyword::BEFORE)?)
231        }
232        Token::Ident(_, Keyword::INSTEAD) => {
233            TriggerTime::InsteadOf(parser.consume_keywords(&[Keyword::INSTEAD, Keyword::OF])?)
234        }
235        _ => parser.expected_failure("'BEFORE', 'AFTER', or 'INSTEAD OF'")?,
236    };
237
238    let mut trigger_events = Vec::new();
239    loop {
240        let event = match &parser.token {
241            Token::Ident(_, Keyword::UPDATE) => {
242                TriggerEvent::Update(parser.consume_keyword(Keyword::UPDATE)?)
243            }
244            Token::Ident(_, Keyword::INSERT) => {
245                TriggerEvent::Insert(parser.consume_keyword(Keyword::INSERT)?)
246            }
247            Token::Ident(_, Keyword::DELETE) => {
248                TriggerEvent::Delete(parser.consume_keyword(Keyword::DELETE)?)
249            }
250            Token::Ident(_, Keyword::TRUNCATE) => {
251                TriggerEvent::Truncate(parser.consume_keyword(Keyword::TRUNCATE)?)
252            }
253            _ => parser.expected_failure("'UPDATE', 'INSERT', 'DELETE', or 'TRUNCATE'")?,
254        };
255        trigger_events.push(event);
256        if parser.skip_keyword(Keyword::OR).is_none() {
257            break;
258        }
259    }
260
261    let on_span = parser.consume_keyword(Keyword::ON)?;
262
263    let table = parser.consume_plain_identifier_unreserved()?;
264
265    let for_each = if parser.options.dialect.is_postgresql() {
266        if let Some(for_span) = parser.skip_keyword(Keyword::FOR) {
267            let each_span = parser.skip_keyword(Keyword::EACH);
268            match &parser.token {
269                Token::Ident(_, Keyword::ROW) => Some(TriggerForEach::Row(
270                    for_span
271                        .join_span(&each_span)
272                        .join_span(&parser.consume_keyword(Keyword::ROW)?),
273                )),
274                Token::Ident(_, Keyword::STATEMENT) => Some(TriggerForEach::Statement(
275                    for_span
276                        .join_span(&each_span)
277                        .join_span(&parser.consume_keyword(Keyword::STATEMENT)?),
278                )),
279                _ => Some(TriggerForEach::Row(for_span.join_span(&each_span))),
280            }
281        } else {
282            None
283        }
284    } else {
285        Some(TriggerForEach::Row(parser.consume_keywords(&[
286            Keyword::FOR,
287            Keyword::EACH,
288            Keyword::ROW,
289        ])?))
290    };
291
292    // Parse optional REFERENCING clause (PostgreSQL transition table aliases)
293    let mut referencing = Vec::new();
294    if parser.skip_keyword(Keyword::REFERENCING).is_some() {
295        // Each REFERENCING item: { NEW | OLD } TABLE AS alias
296        loop {
297            let direction = match &parser.token {
298                Token::Ident(_, Keyword::NEW) => {
299                    TriggerReferenceDirection::New(parser.consume_keyword(Keyword::NEW)?)
300                }
301                Token::Ident(_, Keyword::OLD) => {
302                    TriggerReferenceDirection::Old(parser.consume_keyword(Keyword::OLD)?)
303                }
304                _ => break,
305            };
306            let table_as_span = parser.consume_keywords(&[Keyword::TABLE, Keyword::AS])?;
307            let alias = parser.consume_plain_identifier_unreserved()?;
308            referencing.push(TriggerReference {
309                direction,
310                table_as_span,
311                alias,
312            });
313        }
314    }
315
316    // Parse optional WHEN (condition)
317    let when_condition = if let Some(when_span) = parser.skip_keyword(Keyword::WHEN) {
318        parser.consume_token(Token::LParen)?;
319        let expr = parser.recovered(")", &|t| t == &Token::RParen, |parser| {
320            Ok(Some(parse_expression_unreserved(parser, PRIORITY_MAX)?))
321        })?;
322        parser.consume_token(Token::RParen)?;
323        expr.map(|e| (when_span, e))
324    } else {
325        None
326    };
327
328    // TODO [{ FOLLOWS | PRECEDES } other_trigger_name ]
329
330    // PostgreSQL allows EXECUTE FUNCTION func_name(...) instead of a statement block
331    let statement = if matches!(parser.token, Token::Ident(_, Keyword::EXECUTE)) {
332        let execute_span = parser.consume_keyword(Keyword::EXECUTE)?;
333        // Accept both FUNCTION and PROCEDURE (synonyms in this context)
334        let execute_span = if let Some(s) = parser.skip_keyword(Keyword::FUNCTION) {
335            execute_span.join_span(&s)
336        } else {
337            execute_span.join_span(&parser.consume_keyword(Keyword::PROCEDURE)?)
338        };
339        let func_name = parse_qualified_name_unreserved(parser)?;
340        parser.consume_token(Token::LParen)?;
341        let mut args = Vec::new();
342        parser.recovered("')'", &|t| t == &Token::RParen, |parser| {
343            loop {
344                if matches!(parser.token, Token::RParen) {
345                    break;
346                }
347                args.push(parse_expression_unreserved(parser, PRIORITY_MAX)?);
348                if parser.skip_token(Token::Comma).is_none() {
349                    break;
350                }
351            }
352            Ok(())
353        })?;
354        parser.consume_token(Token::RParen)?;
355        Statement::ExecuteFunction(Box::new(ExecuteFunction {
356            execute_span,
357            func_name,
358            args,
359        }))
360    } else {
361        let old = core::mem::replace(&mut parser.permit_compound_statements, true);
362        let statement = match parse_statement(parser)? {
363            Some(v) => v,
364            None => parser.expected_failure("statement")?,
365        };
366        parser.permit_compound_statements = old;
367        statement
368    };
369
370    Ok(CreateTrigger {
371        create_span,
372        create_options,
373        trigger_span,
374        if_not_exists,
375        name,
376        trigger_time,
377        trigger_events,
378        on_span,
379        table,
380        for_each,
381        referencing,
382        when_condition,
383        statement,
384    })
385}