Skip to main content

dbx_core/automation/
trigger_parser.rs

1//! SQL Trigger Parser
2//!
3//! CREATE TRIGGER, DROP TRIGGER SQL 파싱
4
5use crate::automation::{ForEachType, Trigger, TriggerOperation, TriggerTiming};
6use crate::error::{DbxError, DbxResult};
7
8/// CREATE TRIGGER 파싱
9///
10/// # 문법
11/// ```sql
12/// CREATE TRIGGER trigger_name
13/// { BEFORE | AFTER } { INSERT | UPDATE | DELETE }
14/// ON table_name
15/// [ FOR EACH { ROW | STATEMENT } ]
16/// [ WHEN ( condition ) ]
17/// BEGIN
18///   sql_statement;
19///   ...
20/// END;
21/// ```
22pub fn parse_create_trigger(sql: &str) -> DbxResult<Trigger> {
23    // 간단한 수동 파싱 (sqlparser는 CREATE TRIGGER를 완전히 지원하지 않음)
24    let sql = sql.trim();
25
26    // CREATE TRIGGER 확인
27    if !sql.to_uppercase().starts_with("CREATE TRIGGER") {
28        return Err(DbxError::InvalidOperation {
29            message: format!("Expected CREATE TRIGGER, got: {}", sql),
30            context: "SQL parsing".to_string(),
31        });
32    }
33
34    // 토큰 분리
35    let tokens: Vec<&str> = sql.split_whitespace().collect();
36    if tokens.len() < 7 {
37        return Err(DbxError::InvalidOperation {
38            message: "Invalid CREATE TRIGGER syntax".to_string(),
39            context: "SQL parsing".to_string(),
40        });
41    }
42
43    // Trigger 이름 추출 (CREATE TRIGGER name)
44    let name = tokens[2].to_string();
45
46    // Timing 추출 (BEFORE/AFTER)
47    let timing = match tokens[3].to_uppercase().as_str() {
48        "BEFORE" => TriggerTiming::Before,
49        "AFTER" => TriggerTiming::After,
50        _ => {
51            return Err(DbxError::InvalidOperation {
52                message: format!("Invalid timing: {}", tokens[3]),
53                context: "CREATE TRIGGER parsing".to_string(),
54            });
55        }
56    };
57
58    // Operation 추출 (INSERT/UPDATE/DELETE)
59    let operation = match tokens[4].to_uppercase().as_str() {
60        "INSERT" => TriggerOperation::Insert,
61        "UPDATE" => TriggerOperation::Update,
62        "DELETE" => TriggerOperation::Delete,
63        _ => {
64            return Err(DbxError::InvalidOperation {
65                message: format!("Invalid operation: {}", tokens[4]),
66                context: "CREATE TRIGGER parsing".to_string(),
67            });
68        }
69    };
70
71    // ON 확인
72    if tokens[5].to_uppercase() != "ON" {
73        return Err(DbxError::InvalidOperation {
74            message: format!("Expected ON, got: {}", tokens[5]),
75            context: "CREATE TRIGGER parsing".to_string(),
76        });
77    }
78
79    // 테이블 이름 추출
80    let table = tokens[6].to_string();
81
82    // FOR EACH 추출 (선택적)
83    let mut for_each = ForEachType::Row; // 기본값
84    let mut body_start_idx = 7;
85
86    if tokens.len() > 9 && tokens[7].to_uppercase() == "FOR" && tokens[8].to_uppercase() == "EACH" {
87        for_each = match tokens[9].to_uppercase().as_str() {
88            "ROW" => ForEachType::Row,
89            "STATEMENT" => ForEachType::Statement,
90            _ => {
91                return Err(DbxError::InvalidOperation {
92                    message: format!("Invalid FOR EACH type: {}", tokens[9]),
93                    context: "CREATE TRIGGER parsing".to_string(),
94                });
95            }
96        };
97        body_start_idx = 10;
98    }
99
100    // WHEN 조건 추출 (선택적)
101    let mut condition = None;
102    if tokens.len() > body_start_idx && tokens[body_start_idx].to_uppercase() == "WHEN" {
103        // WHEN (condition) 형태 파싱
104        let when_start = sql.find("WHEN").unwrap();
105        let begin_pos = sql
106            .find("BEGIN")
107            .ok_or_else(|| DbxError::InvalidOperation {
108                message: "Missing BEGIN".to_string(),
109                context: "CREATE TRIGGER parsing".to_string(),
110            })?;
111        let condition_str = sql[when_start + 4..begin_pos].trim();
112
113        // 괄호 제거
114        let condition_str = condition_str
115            .trim_start_matches('(')
116            .trim_end_matches(')')
117            .trim();
118        condition = Some(condition_str.to_string());
119    }
120
121    // BEGIN ... END 사이의 SQL 문장들 추출
122    let begin_pos = sql
123        .find("BEGIN")
124        .ok_or_else(|| DbxError::InvalidOperation {
125            message: "Missing BEGIN".to_string(),
126            context: "CREATE TRIGGER parsing".to_string(),
127        })?;
128    let end_pos = sql.rfind("END").ok_or_else(|| DbxError::InvalidOperation {
129        message: "Missing END".to_string(),
130        context: "CREATE TRIGGER parsing".to_string(),
131    })?;
132
133    let body_sql = sql[begin_pos + 5..end_pos].trim();
134
135    // 세미콜론으로 분리
136    let body: Vec<String> = body_sql
137        .split(';')
138        .map(|s| s.trim().to_string())
139        .filter(|s| !s.is_empty())
140        .collect();
141
142    if body.is_empty() {
143        return Err(DbxError::InvalidOperation {
144            message: "Trigger body is empty".to_string(),
145            context: "CREATE TRIGGER parsing".to_string(),
146        });
147    }
148
149    Ok(Trigger::new(
150        name, timing, operation, table, for_each, condition, body,
151    ))
152}
153
154/// DROP TRIGGER 파싱
155///
156/// # 문법
157/// ```sql
158/// DROP TRIGGER trigger_name;
159/// ```
160pub fn parse_drop_trigger(sql: &str) -> DbxResult<String> {
161    let sql = sql.trim();
162
163    if !sql.to_uppercase().starts_with("DROP TRIGGER") {
164        return Err(DbxError::InvalidOperation {
165            message: format!("Expected DROP TRIGGER, got: {}", sql),
166            context: "SQL parsing".to_string(),
167        });
168    }
169
170    let tokens: Vec<&str> = sql.split_whitespace().collect();
171    if tokens.len() < 3 {
172        return Err(DbxError::InvalidOperation {
173            message: "Invalid DROP TRIGGER syntax".to_string(),
174            context: "SQL parsing".to_string(),
175        });
176    }
177
178    let name = tokens[2].trim_end_matches(';').to_string();
179    Ok(name)
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn test_parse_create_trigger_basic() {
188        let sql = r#"
189            CREATE TRIGGER audit_trigger
190            AFTER INSERT ON users
191            BEGIN
192                INSERT INTO audit_logs VALUES (NEW.id, 'INSERT');
193            END;
194        "#;
195
196        let trigger = parse_create_trigger(sql).unwrap();
197        assert_eq!(trigger.name, "audit_trigger");
198        assert_eq!(trigger.timing, TriggerTiming::After);
199        assert_eq!(trigger.operation, TriggerOperation::Insert);
200        assert_eq!(trigger.table, "users");
201        assert_eq!(trigger.for_each, ForEachType::Row);
202        assert!(trigger.condition.is_none());
203        assert_eq!(trigger.body.len(), 1);
204    }
205
206    #[test]
207    fn test_parse_create_trigger_with_condition() {
208        let sql = r#"
209            CREATE TRIGGER check_age
210            BEFORE INSERT ON users
211            FOR EACH ROW
212            WHEN (NEW.age > 18)
213            BEGIN
214                INSERT INTO adult_users VALUES (NEW.id);
215            END;
216        "#;
217
218        let trigger = parse_create_trigger(sql).unwrap();
219        assert_eq!(trigger.name, "check_age");
220        assert_eq!(trigger.timing, TriggerTiming::Before);
221        assert_eq!(trigger.for_each, ForEachType::Row);
222        assert!(trigger.condition.is_some());
223        assert_eq!(trigger.condition.unwrap(), "NEW.age > 18");
224    }
225
226    #[test]
227    fn test_parse_create_trigger_multiple_statements() {
228        let sql = r#"
229            CREATE TRIGGER multi_action
230            AFTER UPDATE ON products
231            BEGIN
232                UPDATE logs SET count = count + 1;
233                INSERT INTO history VALUES (OLD.id, NEW.id);
234            END;
235        "#;
236
237        let trigger = parse_create_trigger(sql).unwrap();
238        assert_eq!(trigger.body.len(), 2);
239    }
240
241    #[test]
242    fn test_parse_drop_trigger() {
243        let sql = "DROP TRIGGER audit_trigger;";
244        let name = parse_drop_trigger(sql).unwrap();
245        assert_eq!(name, "audit_trigger");
246    }
247
248    #[test]
249    fn test_parse_create_trigger_invalid() {
250        let sql = "CREATE TABLE users (id INT);";
251        assert!(parse_create_trigger(sql).is_err());
252    }
253}