Skip to main content

dbx_core/automation/
function_parser.rs

1//! Function Parser
2//!
3//! SQL CREATE FUNCTION, DROP FUNCTION 파싱
4
5use crate::automation::UdfMetadata;
6use crate::error::{DbxError, DbxResult};
7
8/// CREATE FUNCTION 파싱
9///
10/// # 예제
11/// ```sql
12/// CREATE FUNCTION add_numbers (a INT, b INT) RETURNS INT
13/// BEGIN
14///     RETURN a + b;
15/// END;
16/// ```
17pub fn parse_create_function(sql: &str) -> DbxResult<UdfMetadata> {
18    let sql = sql.trim();
19
20    // CREATE FUNCTION 확인
21    if !sql.to_uppercase().starts_with("CREATE FUNCTION") {
22        return Err(DbxError::InvalidOperation {
23            message: format!("Expected CREATE FUNCTION, got: {}", sql),
24            context: "SQL parsing".to_string(),
25        });
26    }
27
28    // 함수 이름 추출
29    let tokens: Vec<&str> = sql.split_whitespace().collect();
30    if tokens.len() < 3 {
31        return Err(DbxError::InvalidOperation {
32            message: "Invalid CREATE FUNCTION syntax".to_string(),
33            context: "SQL parsing".to_string(),
34        });
35    }
36
37    let name = tokens[2].trim_end_matches('(').to_string();
38
39    // 파라미터 추출
40    let params_start = sql.find('(').ok_or_else(|| DbxError::InvalidOperation {
41        message: "Missing opening parenthesis".to_string(),
42        context: "CREATE FUNCTION parsing".to_string(),
43    })?;
44
45    let params_end = sql.find(')').ok_or_else(|| DbxError::InvalidOperation {
46        message: "Missing closing parenthesis".to_string(),
47        context: "CREATE FUNCTION parsing".to_string(),
48    })?;
49
50    let params_str = &sql[params_start + 1..params_end];
51    let mut param_types = Vec::new();
52
53    if !params_str.trim().is_empty() {
54        for param in params_str.split(',') {
55            let parts: Vec<&str> = param.split_whitespace().collect();
56            if parts.len() >= 2 {
57                param_types.push(parts[1].to_string());
58            }
59        }
60    }
61
62    // RETURNS 타입 추출
63    let returns_idx =
64        sql.to_uppercase()
65            .find("RETURNS")
66            .ok_or_else(|| DbxError::InvalidOperation {
67                message: "Missing RETURNS clause".to_string(),
68                context: "CREATE FUNCTION parsing".to_string(),
69            })?;
70
71    let after_returns = &sql[returns_idx + 7..].trim();
72    let return_type = after_returns
73        .split_whitespace()
74        .next()
75        .ok_or_else(|| DbxError::InvalidOperation {
76            message: "Missing return type".to_string(),
77            context: "CREATE FUNCTION parsing".to_string(),
78        })?
79        .to_string();
80
81    // UdfMetadata 생성
82    Ok(UdfMetadata::new(
83        &name,
84        crate::automation::UdfType::Scalar, // 기본값으로 Scalar 사용
85        param_types,
86        return_type,
87        false,
88    ))
89}
90
91/// DROP FUNCTION 파싱
92///
93/// # 예제
94/// ```sql
95/// DROP FUNCTION add_numbers;
96/// ```
97pub fn parse_drop_function(sql: &str) -> DbxResult<String> {
98    let sql = sql.trim();
99
100    // DROP FUNCTION 확인
101    if !sql.to_uppercase().starts_with("DROP FUNCTION") {
102        return Err(DbxError::InvalidOperation {
103            message: format!("Expected DROP FUNCTION, got: {}", sql),
104            context: "SQL parsing".to_string(),
105        });
106    }
107
108    // 함수 이름 추출
109    let tokens: Vec<&str> = sql.split_whitespace().collect();
110    if tokens.len() < 3 {
111        return Err(DbxError::InvalidOperation {
112            message: "Invalid DROP FUNCTION syntax".to_string(),
113            context: "SQL parsing".to_string(),
114        });
115    }
116
117    let name = tokens[2].trim_end_matches(';').to_string();
118    Ok(name)
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[test]
126    fn test_parse_create_function() {
127        let sql = r#"
128            CREATE FUNCTION add_numbers (a INT, b INT) RETURNS INT
129            BEGIN
130                RETURN a + b;
131            END;
132        "#;
133
134        let metadata = parse_create_function(sql).unwrap();
135        assert_eq!(metadata.name, "add_numbers");
136        assert_eq!(metadata.param_types.len(), 2);
137        assert_eq!(metadata.return_type, "INT");
138    }
139
140    #[test]
141    fn test_parse_create_function_no_params() {
142        let sql = r#"
143            CREATE FUNCTION get_version () RETURNS VARCHAR
144            BEGIN
145                RETURN '1.0.0';
146            END;
147        "#;
148
149        let metadata = parse_create_function(sql).unwrap();
150        assert_eq!(metadata.name, "get_version");
151        assert_eq!(metadata.param_types.len(), 0);
152        assert_eq!(metadata.return_type, "VARCHAR");
153    }
154
155    #[test]
156    fn test_parse_drop_function() {
157        let sql = "DROP FUNCTION add_numbers;";
158        let name = parse_drop_function(sql).unwrap();
159        assert_eq!(name, "add_numbers");
160    }
161
162    #[test]
163    fn test_parse_create_function_multiple_params() {
164        let sql = r#"
165            CREATE FUNCTION calculate (x DECIMAL, y DECIMAL, z INT) RETURNS DECIMAL
166            BEGIN
167                RETURN x + y * z;
168            END;
169        "#;
170
171        let metadata = parse_create_function(sql).unwrap();
172        assert_eq!(metadata.name, "calculate");
173        assert_eq!(metadata.param_types.len(), 3);
174        assert_eq!(metadata.param_types[0], "DECIMAL");
175        assert_eq!(metadata.param_types[1], "DECIMAL");
176        assert_eq!(metadata.param_types[2], "INT");
177        assert_eq!(metadata.return_type, "DECIMAL");
178    }
179}