Skip to main content

qusql_parse/
create_constraint_trigger.rs

1// Licensed under the Apache License, Version 2.0
2// CREATE CONSTRAINT TRIGGER parser for PostgreSQL
3use crate::{
4    Identifier, Span, Spanned,
5    create_option::CreateOption,
6    expression::{Expression, PRIORITY_MAX, parse_expression_unreserved},
7    keywords::Keyword,
8    lexer::Token,
9    parser::{ParseError, Parser},
10};
11use alloc::vec::Vec;
12
13/// Enum for constraint trigger events
14#[derive(Clone, Debug)]
15pub enum AfterEvent {
16    Insert(Span),
17    Update(Span),
18    Delete(Span),
19}
20
21impl Spanned for AfterEvent {
22    fn span(&self) -> Span {
23        match self {
24            AfterEvent::Insert(s) => s.clone(),
25            AfterEvent::Update(s) => s.clone(),
26            AfterEvent::Delete(s) => s.clone(),
27        }
28    }
29}
30
31/// Whether the trigger is deferrable
32#[derive(Clone, Debug)]
33pub enum Deferrable {
34    /// DEFERRABLE
35    Deferrable(Span),
36    /// NOT DEFERRABLE
37    NotDeferrable(Span),
38}
39
40impl Spanned for Deferrable {
41    fn span(&self) -> Span {
42        match self {
43            Deferrable::Deferrable(s) => s.clone(),
44            Deferrable::NotDeferrable(s) => s.clone(),
45        }
46    }
47}
48
49/// Initial timing of the constraint trigger
50#[derive(Clone, Debug)]
51pub enum Initially {
52    /// INITIALLY IMMEDIATE
53    Immediate(Span),
54    /// INITIALLY DEFERRED
55    Deferred(Span),
56}
57
58impl Spanned for Initially {
59    fn span(&self) -> Span {
60        match self {
61            Initially::Immediate(s) => s.clone(),
62            Initially::Deferred(s) => s.clone(),
63        }
64    }
65}
66
67/// Represent a create constraint trigger statement
68#[derive(Clone, Debug)]
69pub struct CreateConstraintTrigger<'a> {
70    /// The span of the entire CREATE keyword
71    pub create_span: Span,
72    /// The span of the CONSTRAINT TRIGGER keywords
73    pub constraint_trigger_span: Span,
74    /// The name of the constraint trigger
75    pub name: Identifier<'a>,
76    /// The events that fire the trigger (AFTER INSERT, AFTER UPDATE, AFTER DELETE)
77    pub after_span: Span,
78    pub after_events: Vec<AfterEvent>,
79    /// The table the trigger is on
80    pub on_span: Span,
81    pub table_name: Identifier<'a>,
82    /// The referenced table for the trigger (optional, used for referencing foreign keys)
83    pub referenced_table_name: Option<Identifier<'a>>,
84    /// Whether the trigger is deferrable or not (optional, PostgreSQL specific)
85    pub deferrable: Option<Deferrable>,
86    /// The initial timing of the trigger (optional, PostgreSQL specific)
87    pub initially: Option<Initially>,
88    /// The span of the FOR EACH ROW keywords
89    pub for_each_row_span: Span,
90    /// The WHEN condition for the trigger (optional, PostgreSQL specific)
91    pub when_condition: Option<(Span, Expression<'a>)>,
92    /// The span of the EXECUTE PROCEDURE keywords
93    pub execute_procedure_span: Span,
94    /// The name of the function to execute when the trigger fires
95    pub function_name: Identifier<'a>,
96    /// The arguments to the function (optional)
97    pub function_args: Vec<Expression<'a>>,
98}
99
100impl<'a> Spanned for CreateConstraintTrigger<'a> {
101    fn span(&self) -> Span {
102        self.create_span
103            .join_span(&self.constraint_trigger_span)
104            .join_span(&self.name)
105            .join_span(&self.after_span)
106            .join_span(&self.after_events)
107            .join_span(&self.on_span)
108            .join_span(&self.table_name)
109            .join_span(&self.referenced_table_name)
110            .join_span(&self.deferrable)
111            .join_span(&self.initially)
112            .join_span(&self.for_each_row_span)
113            .join_span(&self.when_condition)
114            .join_span(&self.execute_procedure_span)
115            .join_span(&self.function_name)
116            .join_span(&self.function_args)
117    }
118}
119
120pub(crate) fn parse_create_constraint_trigger<'a>(
121    parser: &mut Parser<'a, '_>,
122    create_span: Span,
123    create_options: Vec<CreateOption<'a>>,
124) -> Result<CreateConstraintTrigger<'a>, ParseError> {
125    let constraint_span = parser.consume_keywords(&[Keyword::CONSTRAINT, Keyword::TRIGGER])?;
126    parser.postgres_only(&constraint_span);
127
128    for option in create_options {
129        parser.err(
130            "Not supported for CREATE CONSTRAINT TRIGGER",
131            &option.span(),
132        );
133    }
134    let name = parser.consume_plain_identifier_unreserved()?;
135
136    // Parse AFTER event(s)
137    let mut after_events = Vec::new();
138    let after_span = parser.consume_keyword(Keyword::AFTER)?;
139    loop {
140        match &parser.token {
141            Token::Ident(_, Keyword::INSERT) => {
142                after_events.push(AfterEvent::Insert(parser.consume_keyword(Keyword::INSERT)?))
143            }
144            Token::Ident(_, Keyword::UPDATE) => {
145                after_events.push(AfterEvent::Update(parser.consume_keyword(Keyword::UPDATE)?))
146            }
147            Token::Ident(_, Keyword::DELETE) => {
148                after_events.push(AfterEvent::Delete(parser.consume_keyword(Keyword::DELETE)?))
149            }
150            Token::Ident(_, Keyword::OR) => {
151                parser.consume_keyword(Keyword::OR)?;
152            }
153            _ => break,
154        }
155    }
156
157    let on_span = parser.consume_keyword(Keyword::ON)?;
158    let table_name = parser.consume_plain_identifier_unreserved()?;
159
160    let referenced_table_name = if parser.skip_keyword(Keyword::FROM).is_some() {
161        Some(parser.consume_plain_identifier_unreserved()?)
162    } else {
163        None
164    };
165
166    let deferrable = if let Some(span) = parser.skip_keyword(Keyword::DEFERRABLE) {
167        Some(Deferrable::Deferrable(span))
168    } else if let Some(not_span) = parser.skip_keyword(Keyword::NOT) {
169        let deferrable_span = parser.consume_keyword(Keyword::DEFERRABLE)?;
170        Some(Deferrable::NotDeferrable(
171            not_span.join_span(&deferrable_span),
172        ))
173    } else {
174        None
175    };
176
177    #[allow(clippy::manual_map)]
178    let initially = if let Some(initially_span) = parser.skip_keyword(Keyword::INITIALLY) {
179        if let Some(s) = parser.skip_keyword(Keyword::IMMEDIATE) {
180            Some(Initially::Immediate(initially_span.join_span(&s)))
181        } else if let Some(s) = parser.skip_keyword(Keyword::DEFERRED) {
182            Some(Initially::Deferred(initially_span.join_span(&s)))
183        } else {
184            None
185        }
186    } else {
187        None
188    };
189
190    let for_each_row_span =
191        parser.consume_keywords(&[Keyword::FOR, Keyword::EACH, Keyword::ROW])?;
192
193    let when_condition = if let Some(when_span) = parser.skip_keyword(Keyword::WHEN) {
194        parser.consume_token(Token::LParen)?;
195        let cond = parse_expression_unreserved(parser, PRIORITY_MAX)?;
196        parser.consume_token(Token::RParen)?;
197        Some((when_span, cond))
198    } else {
199        None
200    };
201
202    parser.consume_keyword(Keyword::EXECUTE)?;
203    let execute_procedure_span = parser.consume_keyword(Keyword::PROCEDURE)?;
204    let function_name = parser.consume_plain_identifier_unreserved()?;
205    let mut function_args = Vec::new();
206    if parser.skip_token(Token::LParen).is_some() {
207        // Parse arguments as expressions
208        loop {
209            function_args.push(parse_expression_unreserved(parser, PRIORITY_MAX)?);
210            if parser.skip_token(Token::Comma).is_none() {
211                break;
212            }
213        }
214        parser.consume_token(Token::RParen)?;
215    }
216
217    Ok(CreateConstraintTrigger {
218        create_span,
219        constraint_trigger_span: constraint_span,
220        name,
221        after_span,
222        after_events,
223        on_span,
224        table_name,
225        referenced_table_name,
226        deferrable,
227        initially,
228        for_each_row_span,
229        when_condition,
230        execute_procedure_span,
231        function_name,
232        function_args,
233    })
234}