Skip to main content

dbx_core/automation/
function_parser.rs

1//! Function Parser
2//!
3//! SQL CREATE FUNCTION, DROP FUNCTION 파싱
4
5use crate::automation::UdfMetadata;
6use crate::error::{DbxError, DbxResult};
7use crate::sql::StringCaseExt;
8
9/// CREATE FUNCTION 파싱
10///
11/// # 예제
12/// ```sql
13/// CREATE FUNCTION add_numbers (a INT, b INT) RETURNS INT
14/// BEGIN
15///     RETURN a + b;
16/// END;
17/// ```
18pub fn parse_create_function(sql: &str) -> DbxResult<UdfMetadata> {
19    let sql = sql.trim();
20
21    // CREATE FUNCTION 확인
22    if !sql.starts_with_ignore_ascii_case("CREATE FUNCTION") {
23        return Err(DbxError::InvalidOperation {
24            message: format!("Expected CREATE FUNCTION, got: {}", sql),
25            context: "SQL parsing".to_string(),
26        });
27    }
28
29    // 함수 이름 추출
30    let tokens: Vec<&str> = sql.split_whitespace().collect();
31    if tokens.len() < 3 {
32        return Err(DbxError::InvalidOperation {
33            message: "Invalid CREATE FUNCTION syntax".to_string(),
34            context: "SQL parsing".to_string(),
35        });
36    }
37
38    let name = tokens[2].trim_end_matches('(').to_string();
39
40    // 파라미터 추출
41    let params_start = sql.find('(').ok_or_else(|| DbxError::InvalidOperation {
42        message: "Missing opening parenthesis".to_string(),
43        context: "CREATE FUNCTION parsing".to_string(),
44    })?;
45
46    let params_end = sql.find(')').ok_or_else(|| DbxError::InvalidOperation {
47        message: "Missing closing parenthesis".to_string(),
48        context: "CREATE FUNCTION parsing".to_string(),
49    })?;
50
51    let params_str = &sql[params_start + 1..params_end];
52    let mut param_types = Vec::new();
53
54    if !params_str.trim().is_empty() {
55        for param in params_str.split(',') {
56            let parts: Vec<&str> = param.split_whitespace().collect();
57            if parts.len() >= 2 {
58                param_types.push(parts[1].to_string());
59            }
60        }
61    }
62
63    // RETURNS 타입 추출
64    let returns_idx =
65        sql.to_uppercase()
66            .find("RETURNS")
67            .ok_or_else(|| DbxError::InvalidOperation {
68                message: "Missing RETURNS clause".to_string(),
69                context: "CREATE FUNCTION parsing".to_string(),
70            })?;
71
72    let after_returns = &sql[returns_idx + 7..].trim();
73    let return_type = after_returns
74        .split_whitespace()
75        .next()
76        .ok_or_else(|| DbxError::InvalidOperation {
77            message: "Missing return type".to_string(),
78            context: "CREATE FUNCTION parsing".to_string(),
79        })?
80        .to_string();
81
82    // UdfMetadata 생성
83    Ok(UdfMetadata::new(
84        &name,
85        crate::automation::UdfType::Scalar, // 기본값으로 Scalar 사용
86        param_types,
87        return_type,
88        false,
89    ))
90}
91
92/// DROP FUNCTION 파싱
93///
94/// # 예제
95/// ```sql
96/// DROP FUNCTION add_numbers;
97/// ```
98pub fn parse_drop_function(sql: &str) -> DbxResult<String> {
99    let sql = sql.trim();
100
101    // DROP FUNCTION 확인
102    if !sql.starts_with_ignore_ascii_case("DROP FUNCTION") {
103        return Err(DbxError::InvalidOperation {
104            message: format!("Expected DROP FUNCTION, got: {}", sql),
105            context: "SQL parsing".to_string(),
106        });
107    }
108
109    // 함수 이름 추출
110    let tokens: Vec<&str> = sql.split_whitespace().collect();
111    if tokens.len() < 3 {
112        return Err(DbxError::InvalidOperation {
113            message: "Invalid DROP FUNCTION syntax".to_string(),
114            context: "SQL parsing".to_string(),
115        });
116    }
117
118    let name = tokens[2].trim_end_matches(';').to_string();
119    Ok(name)
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn test_parse_create_function() {
128        let sql = r#"
129            CREATE FUNCTION add_numbers (a INT, b INT) RETURNS INT
130            BEGIN
131                RETURN a + b;
132            END;
133        "#;
134
135        let metadata = parse_create_function(sql).unwrap();
136        assert_eq!(metadata.name, "add_numbers");
137        assert_eq!(metadata.param_types.len(), 2);
138        assert_eq!(metadata.return_type, "INT");
139    }
140
141    #[test]
142    fn test_parse_create_function_no_params() {
143        let sql = r#"
144            CREATE FUNCTION get_version () RETURNS VARCHAR
145            BEGIN
146                RETURN '1.0.0';
147            END;
148        "#;
149
150        let metadata = parse_create_function(sql).unwrap();
151        assert_eq!(metadata.name, "get_version");
152        assert_eq!(metadata.param_types.len(), 0);
153        assert_eq!(metadata.return_type, "VARCHAR");
154    }
155
156    #[test]
157    fn test_parse_drop_function() {
158        let sql = "DROP FUNCTION add_numbers;";
159        let name = parse_drop_function(sql).unwrap();
160        assert_eq!(name, "add_numbers");
161    }
162
163    #[test]
164    fn test_parse_create_function_multiple_params() {
165        let sql = r#"
166            CREATE FUNCTION calculate (x DECIMAL, y DECIMAL, z INT) RETURNS DECIMAL
167            BEGIN
168                RETURN x + y * z;
169            END;
170        "#;
171
172        let metadata = parse_create_function(sql).unwrap();
173        assert_eq!(metadata.name, "calculate");
174        assert_eq!(metadata.param_types.len(), 3);
175        assert_eq!(metadata.param_types[0], "DECIMAL");
176        assert_eq!(metadata.param_types[1], "DECIMAL");
177        assert_eq!(metadata.param_types[2], "INT");
178        assert_eq!(metadata.return_type, "DECIMAL");
179    }
180}