Skip to main content

qusql_parse/
create_trigger.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12use crate::{
13    Expression, Identifier, Span, Spanned, Statement,
14    create_option::CreateOption,
15    expression::{PRIORITY_MAX, parse_expression_unreserved},
16    keywords::Keyword,
17    lexer::Token,
18    parser::{ParseError, Parser},
19    statement::{Block, parse_statement},
20};
21use alloc::{boxed::Box, vec::Vec};
22
23/// When to fire the trigger
24#[derive(Clone, Debug)]
25pub enum TriggerTime {
26    Before(Span),
27    After(Span),
28    InsteadOf(Span),
29}
30
31impl Spanned for TriggerTime {
32    fn span(&self) -> Span {
33        match &self {
34            TriggerTime::Before(v) => v.span(),
35            TriggerTime::After(v) => v.span(),
36            TriggerTime::InsteadOf(v) => v.span(),
37        }
38    }
39}
40
41/// On what event to fire the trigger
42#[derive(Clone, Debug)]
43pub enum TriggerEvent {
44    Update(Span),
45    Insert(Span),
46    Delete(Span),
47}
48
49impl Spanned for TriggerEvent {
50    fn span(&self) -> Span {
51        match &self {
52            TriggerEvent::Update(v) => v.span(),
53            TriggerEvent::Insert(v) => v.span(),
54            TriggerEvent::Delete(v) => v.span(),
55        }
56    }
57}
58
59#[derive(Clone, Debug)]
60pub enum TriggerReferenceDirection {
61    New(Span),
62    Old(Span),
63}
64
65impl Spanned for TriggerReferenceDirection {
66    fn span(&self) -> Span {
67        match &self {
68            TriggerReferenceDirection::New(v) => v.span(),
69            TriggerReferenceDirection::Old(v) => v.span(),
70        }
71    }
72}
73
74#[derive(Clone, Debug)]
75pub struct TriggerReference<'a> {
76    direction: TriggerReferenceDirection,
77    table_as_span: Span,
78    alias: Identifier<'a>,
79}
80
81impl Spanned for TriggerReference<'_> {
82    fn span(&self) -> Span {
83        self.direction
84            .join_span(&self.table_as_span)
85            .join_span(&self.alias)
86    }
87}
88
89/// Represent a create trigger statement
90/// ```
91/// # use qusql_parse::{SQLDialect, SQLArguments, ParseOptions, parse_statements, CreateTrigger, Statement, Issues};
92/// # let options = ParseOptions::new().dialect(SQLDialect::MariaDB);
93/// #
94/// let sql = "DROP TRIGGER IF EXISTS `my_trigger`;
95/// DELIMITER $$
96/// CREATE TRIGGER `my_trigger` AFTER DELETE ON `things` FOR EACH ROW BEGIN
97///     IF OLD.`value` IS NOT NULL THEN
98///         UPDATE `t2` AS `j`
99///             SET
100///             `j`.`total_items` = `total_items` - 1
101///             WHERE `j`.`id`=OLD.`value` AND NOT `j`.`frozen`;
102///         END IF;
103///     INSERT INTO `updated_things` (`thing`) VALUES (OLD.`id`);
104/// END
105/// $$
106/// DELIMITER ;";
107/// let mut issues = Issues::new(sql);
108/// let mut stmts = parse_statements(sql, &mut issues, &options);
109///
110/// # assert_eq!(issues.get(), &[]);
111/// #
112/// let create: CreateTrigger = match stmts.pop() {
113///     Some(Statement::CreateTrigger(c)) => *c,
114///     _ => panic!("We should get an create trigger statement")
115/// };
116///
117/// assert!(create.name.as_str() == "my_trigger");
118/// println!("{:#?}", create.statement)
119/// ```
120#[derive(Clone, Debug)]
121pub struct CreateTrigger<'a> {
122    /// Span of "CREATE"
123    pub create_span: Span,
124    /// Options after "CREATE"
125    pub create_options: Vec<CreateOption<'a>>,
126    /// Span of "TRIGGER"
127    pub trigger_span: Span,
128    /// Span of "IF NOT EXISTS" if specified
129    pub if_not_exists: Option<Span>,
130    /// Name of the created trigger
131    pub name: Identifier<'a>,
132    /// Should the trigger be fired before or after the event
133    pub trigger_time: TriggerTime,
134    /// What event should the trigger be fired on
135    pub trigger_event: TriggerEvent,
136    /// Span of "ON"
137    pub on_span: Span,
138    /// Name of table to create the trigger on
139    pub table: Identifier<'a>,
140    /// Span of "FOR EACH ROW"
141    pub for_each_row_span: Span,
142    /// Optional REFERENCING NEW TABLE AS alias / OLD TABLE AS alias clauses
143    pub referencing: Vec<TriggerReference<'a>>,
144    /// Optional WHEN (condition)
145    pub when_condition: Option<(Span, Expression<'a>)>,
146    /// Statement to execute
147    pub statement: Statement<'a>,
148}
149
150impl<'a> Spanned for CreateTrigger<'a> {
151    fn span(&self) -> Span {
152        self.create_span
153            .join_span(&self.create_options)
154            .join_span(&self.trigger_span)
155            .join_span(&self.if_not_exists)
156            .join_span(&self.name)
157            .join_span(&self.trigger_time)
158            .join_span(&self.trigger_event)
159            .join_span(&self.on_span)
160            .join_span(&self.table)
161            .join_span(&self.for_each_row_span)
162            .join_span(&self.referencing)
163            .join_span(&self.when_condition.as_ref().map(|(s, e)| s.join_span(e)))
164            .join_span(&self.statement)
165    }
166}
167
168pub(crate) fn parse_create_trigger<'a>(
169    parser: &mut Parser<'a, '_>,
170    create_span: Span,
171    create_options: Vec<CreateOption<'a>>,
172) -> Result<CreateTrigger<'a>, ParseError> {
173    let trigger_span = parser.consume_keyword(Keyword::TRIGGER)?;
174
175    let if_not_exists = if let Some(if_) = parser.skip_keyword(Keyword::IF) {
176        Some(
177            parser
178                .consume_keywords(&[Keyword::NOT, Keyword::EXISTS])?
179                .join_span(&if_),
180        )
181    } else {
182        None
183    };
184
185    let name = parser.consume_plain_identifier_unreserved()?;
186
187    let trigger_time = match &parser.token {
188        Token::Ident(_, Keyword::AFTER) => {
189            TriggerTime::After(parser.consume_keyword(Keyword::AFTER)?)
190        }
191        Token::Ident(_, Keyword::BEFORE) => {
192            TriggerTime::Before(parser.consume_keyword(Keyword::BEFORE)?)
193        }
194        Token::Ident(_, Keyword::INSTEAD) => {
195            TriggerTime::InsteadOf(parser.consume_keywords(&[Keyword::INSTEAD, Keyword::OF])?)
196        }
197        _ => parser.expected_failure("'BEFORE', 'AFTER', or 'INSTEAD OF'")?,
198    };
199
200    let trigger_event = match &parser.token {
201        Token::Ident(_, Keyword::UPDATE) => {
202            TriggerEvent::Update(parser.consume_keyword(Keyword::UPDATE)?)
203        }
204        Token::Ident(_, Keyword::INSERT) => {
205            TriggerEvent::Insert(parser.consume_keyword(Keyword::INSERT)?)
206        }
207        Token::Ident(_, Keyword::DELETE) => {
208            TriggerEvent::Delete(parser.consume_keyword(Keyword::DELETE)?)
209        }
210        _ => parser.expected_failure("'UPDATE' or 'INSERT' or 'DELETE'")?,
211    };
212
213    let on_span = parser.consume_keyword(Keyword::ON)?;
214
215    let table = parser.consume_plain_identifier_unreserved()?;
216
217    let for_each_row_span =
218        parser.consume_keywords(&[Keyword::FOR, Keyword::EACH, Keyword::ROW])?;
219
220    // Parse optional REFERENCING clause (PostgreSQL transition table aliases)
221    let mut referencing = Vec::new();
222    if parser.skip_keyword(Keyword::REFERENCING).is_some() {
223        // Each REFERENCING item: { NEW | OLD } TABLE AS alias
224        loop {
225            let direction = match &parser.token {
226                Token::Ident(_, Keyword::NEW) => {
227                    TriggerReferenceDirection::New(parser.consume_keyword(Keyword::NEW)?)
228                }
229                Token::Ident(_, Keyword::OLD) => {
230                    TriggerReferenceDirection::Old(parser.consume_keyword(Keyword::OLD)?)
231                }
232                _ => break,
233            };
234            let table_as_span = parser.consume_keywords(&[Keyword::TABLE, Keyword::AS])?;
235            let alias = parser.consume_plain_identifier_unreserved()?;
236            referencing.push(TriggerReference {
237                direction,
238                table_as_span,
239                alias,
240            });
241        }
242    }
243
244    // Parse optional WHEN (condition)
245    let when_condition = if let Some(when_span) = parser.skip_keyword(Keyword::WHEN) {
246        parser.consume_token(Token::LParen)?;
247        let expr = parser.recovered(")", &|t| t == &Token::RParen, |parser| {
248            Ok(Some(parse_expression_unreserved(parser, PRIORITY_MAX)?))
249        })?;
250        parser.consume_token(Token::RParen)?;
251        expr.map(|e| (when_span, e))
252    } else {
253        None
254    };
255
256    // TODO [{ FOLLOWS | PRECEDES } other_trigger_name ]
257
258    // PostgreSQL allows EXECUTE FUNCTION func_name() instead of a statement block
259    let statement = if matches!(parser.token, Token::Ident(_, Keyword::EXECUTE)) {
260        // Parse EXECUTE FUNCTION func_name()
261        let _execute_span = parser.consume_keyword(Keyword::EXECUTE)?;
262        parser.consume_keyword(Keyword::FUNCTION)?;
263        parser.consume_plain_identifier_unreserved()?;
264        let begin_span = parser.consume_token(Token::LParen)?;
265        // TODO: parse function arguments if needed
266        let end_span = parser.consume_token(Token::RParen)?;
267
268        // Use an empty block as a placeholder for EXECUTE FUNCTION
269        Statement::Block(Box::new(Block {
270            begin_span,
271            end_span,
272            statements: Vec::new(),
273        }))
274    } else {
275        let old = core::mem::replace(&mut parser.permit_compound_statements, true);
276        let statement = match parse_statement(parser)? {
277            Some(v) => v,
278            None => parser.expected_failure("statement")?,
279        };
280        parser.permit_compound_statements = old;
281        statement
282    };
283
284    Ok(CreateTrigger {
285        create_span,
286        create_options,
287        trigger_span,
288        if_not_exists,
289        name,
290        trigger_time,
291        trigger_event,
292        on_span,
293        table,
294        for_each_row_span,
295        referencing,
296        when_condition,
297        statement,
298    })
299}