Skip to main content

rbatis_codegen/codegen/
parser_pysql.rs

1use crate::codegen::parser_html::parse_html;
2use crate::codegen::proc_macro::TokenStream;
3use crate::codegen::syntax_tree_pysql::{
4    bind_node::BindNode, break_node::BreakNode, choose_node::ChooseNode,
5    continue_node::ContinueNode, error::Error, foreach_node::ForEachNode, if_node::IfNode,
6    otherwise_node::OtherwiseNode, set_node::SetNode, sql_node::SqlNode, string_node::StringNode,
7    trim_node::TrimNode, when_node::WhenNode, where_node::WhereNode, DefaultName, Name, NodeType,
8};
9use crate::codegen::ParseArgs;
10use quote::ToTokens;
11use std::collections::HashMap;
12use syn::ItemFn;
13
14///A handwritten recursive descent algorithm for parsing PySQL
15pub trait ParsePySql {
16    fn parse_pysql(arg: &str) -> Result<Vec<NodeType>, Error>;
17}
18
19pub fn impl_fn_py(m: &ItemFn, args: &ParseArgs) -> TokenStream {
20    let fn_name = m.sig.ident.to_string();
21
22    let mut data = args
23        .sqls
24        .iter()
25        .map(|x| x.to_token_stream().to_string())
26        .collect::<String>();
27
28    if data.ne("\"\"") && data.starts_with('"') && data.ends_with('"') {
29        data = data[1..data.len() - 1].to_string();
30    }
31
32    data = data.replace("\\n", "\n");
33
34    let nodes = NodeType::parse_pysql(&data).expect("[rbatis-codegen] parse py_sql fail!");
35
36    let is_select = data.starts_with("select") || data.starts_with(" select");
37    let htmls =
38        crate::codegen::syntax_tree_pysql::to_html::to_html_mapper(&nodes, is_select, &fn_name);
39
40    parse_html(&htmls, &fn_name, &mut vec![]).into()
41}
42
43impl ParsePySql for NodeType {
44    fn parse_pysql(arg: &str) -> Result<Vec<NodeType>, Error> {
45        let line_space_map = Self::create_line_space_map(arg);
46        let mut main_node = Vec::new();
47        let mut space = -1;
48        let mut line = -1;
49        let mut skip = -1;
50
51        for x in arg.lines() {
52            line += 1;
53
54            if x.is_empty() || (skip != -1 && line <= skip) {
55                continue;
56            }
57
58            let count_index = *line_space_map
59                .get(&line)
60                .ok_or_else(|| Error::from(format!("line_space_map not have line:{}", line)))?;
61
62            if space == -1 {
63                space = count_index;
64            }
65
66            let (child_str, do_skip) =
67                Self::find_child_str(line, count_index, arg, &line_space_map);
68            if do_skip != -1 && do_skip >= skip {
69                skip = do_skip;
70            }
71
72            let parsed = if !child_str.is_empty() {
73                Self::parse_pysql(&child_str)?
74            } else {
75                vec![]
76            };
77
78            let current_space = *line_space_map
79                .get(&line)
80                .ok_or_else(|| Error::from(format!("line:{} not exist!", line)))?;
81
82            Self::parse(&mut main_node, x, current_space as usize, parsed)?;
83        }
84
85        Ok(main_node)
86    }
87}
88
89impl NodeType {
90    fn parse(
91        main_node: &mut Vec<NodeType>,
92        line: &str,
93        space: usize,
94        mut childs: Vec<NodeType>,
95    ) -> Result<(), Error> {
96        let mut trim_line = line.trim();
97
98        if trim_line.starts_with("//") {
99            return Ok(());
100        }
101
102        if trim_line.ends_with(':') {
103            trim_line = trim_line[..trim_line.len() - 1].trim();
104
105            if trim_line.contains(": ") {
106                let parts: Vec<&str> = trim_line.split(": ").collect();
107                if parts.len() > 1 {
108                    for index in (0..parts.len()).rev() {
109                        let item = parts[index];
110                        childs = vec![Self::parse_node(item, line, childs)?];
111
112                        if index == 0 {
113                            main_node.extend(childs);
114                            return Ok(());
115                        }
116                    }
117                }
118            }
119
120            let node = Self::parse_node(trim_line, line, childs)?;
121            main_node.push(node);
122        } else {
123            let data = if space <= 1 {
124                line.to_string()
125            } else {
126                line[(space - 1)..].to_string()
127            };
128
129            main_node.push(NodeType::NString(StringNode {
130                value: data.trim().to_string(),
131            }));
132            main_node.extend(childs);
133        }
134
135        Ok(())
136    }
137
138    fn count_space(arg: &str) -> i32 {
139        arg.chars().take_while(|&c| c == ' ').count() as i32
140    }
141
142    fn find_child_str(
143        line_index: i32,
144        space_index: i32,
145        arg: &str,
146        line_space_map: &HashMap<i32, i32>,
147    ) -> (String, i32) {
148        let mut result = String::new();
149        let mut skip_line = -1;
150        let mut current_line = -1;
151
152        for line in arg.lines() {
153            current_line += 1;
154
155            if current_line > line_index {
156                let cached_space = *line_space_map.get(&current_line).expect("line not exists");
157
158                if cached_space > space_index {
159                    result.push_str(line);
160                    result.push('\n');
161                    skip_line = current_line;
162                } else {
163                    break;
164                }
165            }
166        }
167
168        (result, skip_line)
169    }
170
171    fn create_line_space_map(arg: &str) -> HashMap<i32, i32> {
172        arg.lines()
173            .enumerate()
174            .map(|(i, line)| (i as i32, Self::count_space(line)))
175            .collect()
176    }
177
178    fn parse_node(
179        trim_express: &str,
180        source_str: &str,
181        childs: Vec<NodeType>,
182    ) -> Result<NodeType, Error> {
183        match trim_express {
184            s if s.starts_with(IfNode::name()) => Ok(NodeType::NIf(IfNode {
185                childs,
186                test: s.trim_start_matches("if ").to_string(),
187            })),
188
189            s if s.starts_with(ForEachNode::name()) => {
190                Self::parse_for_each_node(s, source_str, childs)
191            }
192
193            s if s.starts_with(TrimNode::name()) => {
194                Self::parse_trim_tag_node(s, source_str, childs)
195            }
196
197            s if s.starts_with(ChooseNode::name()) => Self::parse_choose_node(childs),
198
199            s if s.starts_with(OtherwiseNode::default_name())
200                || s.starts_with(OtherwiseNode::name()) =>
201            {
202                Ok(NodeType::NOtherwise(OtherwiseNode { childs }))
203            }
204
205            s if s.starts_with(WhenNode::name()) => Ok(NodeType::NWhen(WhenNode {
206                childs,
207                test: s[WhenNode::name().len()..].trim().to_string(),
208            })),
209
210            s if s.starts_with(BindNode::default_name()) || s.starts_with(BindNode::name()) => {
211                Self::parse_bind_node(s)
212            }
213
214            s if s.starts_with(SetNode::name()) => Self::parse_set_node(s, source_str, childs),
215
216            s if s.starts_with(WhereNode::name()) => Ok(NodeType::NWhere(WhereNode { childs })),
217
218            s if s.starts_with(ContinueNode::name()) => Ok(NodeType::NContinue(ContinueNode {})),
219
220            s if s.starts_with(BreakNode::name()) => Ok(NodeType::NBreak(BreakNode {})),
221
222            s if s.starts_with(SqlNode::name()) => Self::parse_sql_node(s, childs),
223
224            _ => Err(Error::from(
225                "[rbatis-codegen] unknown tag: ".to_string() + source_str,
226            )),
227        }
228    }
229
230    fn parse_for_each_node(
231        express: &str,
232        source_str: &str,
233        childs: Vec<NodeType>,
234    ) -> Result<NodeType, Error> {
235        const FOR_TAG: &str = "for";
236        const IN_TAG: &str = " in ";
237
238        if !express.starts_with(FOR_TAG) {
239            return Err(Error::from(
240                "[rbatis-codegen] parser express fail:".to_string() + source_str,
241            ));
242        }
243
244        if !express.contains(IN_TAG) {
245            return Err(Error::from(
246                "[rbatis-codegen] parser express fail:".to_string() + source_str,
247            ));
248        }
249
250        let in_index = express
251            .find(IN_TAG)
252            .ok_or_else(|| Error::from(format!("{} not have {}", express, IN_TAG)))?;
253
254        let col = express[in_index + IN_TAG.len()..].trim();
255        let mut item = express[FOR_TAG.len()..in_index].trim();
256        let mut index = "";
257
258        if item.contains(',') {
259            let splits: Vec<&str> = item.split(',').collect();
260            if splits.len() != 2 {
261                panic!("[rbatis-codegen_codegen] for node must be 'for key,item in col:'");
262            }
263            index = splits[0].trim();
264            item = splits[1].trim();
265        }
266
267        Ok(NodeType::NForEach(ForEachNode {
268            childs,
269            collection: col.to_string(),
270            index: index.to_string(),
271            item: item.to_string(),
272        }))
273    }
274
275    fn parse_trim_tag_node(
276        express: &str,
277        _source_str: &str,
278        childs: Vec<NodeType>,
279    ) -> Result<NodeType, Error> {
280        let trim_express = express.trim().trim_start_matches("trim ").trim();
281
282        if (trim_express.starts_with('\'') && trim_express.ends_with('\''))
283            || (trim_express.starts_with('`') && trim_express.ends_with('`'))
284        {
285            let trimmed = if trim_express.starts_with('`') {
286                trim_express.trim_matches('`')
287            } else {
288                trim_express.trim_matches('\'')
289            };
290
291            Ok(NodeType::NTrim(TrimNode {
292                childs,
293                start: trimmed.to_string(),
294                end: trimmed.to_string(),
295            }))
296        } else if trim_express.contains('=') || trim_express.contains(',') {
297            let mut prefix = "";
298            let mut suffix = "";
299
300            for expr in trim_express.split(',') {
301                let expr = expr.trim();
302                if expr.starts_with("start") {
303                    prefix = expr
304                        .trim_start_matches("start")
305                        .trim()
306                        .trim_start_matches('=')
307                        .trim()
308                        .trim_matches(|c| c == '\'' || c == '`');
309                } else if expr.starts_with("end") {
310                    suffix = expr
311                        .trim_start_matches("end")
312                        .trim()
313                        .trim_start_matches('=')
314                        .trim()
315                        .trim_matches(|c| c == '\'' || c == '`');
316                } else {
317                    return Err(Error::from(format!(
318                        "[rbatis-codegen] express trim node error, for example  trim 'value':  \
319                        trim start='value': trim start='value',end='value':   express = {}",
320                        trim_express
321                    )));
322                }
323            }
324
325            Ok(NodeType::NTrim(TrimNode {
326                childs,
327                start: prefix.to_string(),
328                end: suffix.to_string(),
329            }))
330        } else {
331            Err(Error::from(format!(
332                "[rbatis-codegen] express trim node error, for example  trim 'value':  \
333                trim start='value': trim start='value',end='value':   error express = {}",
334                trim_express
335            )))
336        }
337    }
338
339    fn parse_choose_node(childs: Vec<NodeType>) -> Result<NodeType, Error> {
340        let mut node = ChooseNode {
341            when_nodes: vec![],
342            otherwise_node: None,
343        };
344
345        for child in childs {
346            match child {
347                NodeType::NWhen(_) => node.when_nodes.push(child),
348                NodeType::NOtherwise(_) => node.otherwise_node = Some(Box::new(child)),
349                _ => return Err(Error::from(
350                    "[rbatis-codegen] parser node fail,choose node' child must be when and otherwise nodes!".to_string()
351                )),
352            }
353        }
354
355        Ok(NodeType::NChoose(node))
356    }
357
358    fn parse_bind_node(express: &str) -> Result<NodeType, Error> {
359        let expr = if express.starts_with(BindNode::default_name()) {
360            express[BindNode::default_name().len()..].trim()
361        } else {
362            express[BindNode::name().len()..].trim()
363        };
364
365        let parts: Vec<&str> = expr.split('=').collect();
366        if parts.len() != 2 {
367            return Err(Error::from(
368                "[rbatis-codegen] parser bind express fail:".to_string() + express,
369            ));
370        }
371
372        Ok(NodeType::NBind(BindNode {
373            name: parts[0].trim().to_string(),
374            value: parts[1].trim().to_string(),
375        }))
376    }
377
378    fn parse_sql_node(express: &str, childs: Vec<NodeType>) -> Result<NodeType, Error> {
379        let expr = express[SqlNode::name().len()..].trim();
380
381        if !expr.starts_with("id=") {
382            return Err(Error::from(
383                "[rbatis-codegen] parser sql express fail, need id param:".to_string() + express,
384            ));
385        }
386
387        let id_value = expr.trim_start_matches("id=").trim();
388
389        let id = if (id_value.starts_with('\'') && id_value.ends_with('\''))
390            || (id_value.starts_with('"') && id_value.ends_with('"'))
391        {
392            id_value[1..id_value.len() - 1].to_string()
393        } else {
394            return Err(Error::from(
395                "[rbatis-codegen] parser sql id value need quotes:".to_string() + express,
396            ));
397        };
398
399        Ok(NodeType::NSql(SqlNode { childs, id }))
400    }
401
402    fn strip_quotes_for_attr(s: &str) -> String {
403        let val = s.trim(); // Trim whitespace around the value first
404        if val.starts_with('\'') && val.ends_with('\'')
405            || (val.starts_with('"') && val.ends_with('"'))
406        {
407            if val.len() >= 2 {
408                return val[1..val.len() - 1].to_string();
409            }
410        }
411        val.to_string() // Return the trimmed string if no quotes or malformed quotes
412    }
413
414    fn parse_set_node(
415        express: &str,
416        source_str: &str,
417        childs: Vec<NodeType>,
418    ) -> Result<NodeType, Error> {
419        let actual_attrs_str = if express.starts_with(SetNode::name()) {
420            express[SetNode::name().len()..].trim()
421        } else {
422            // This case should ideally not happen if called correctly from the match arm
423            return Err(Error::from(format!(
424                "[rbatis-codegen] SetNode expression '{}' does not start with '{}'",
425                express,
426                SetNode::name()
427            )));
428        };
429        let mut collection_opt: Option<String> = None;
430        let mut skip_null_val = None; // Default
431        let mut skips_val: String = String::new(); // Default is now an empty String
432        for part_str in actual_attrs_str.split(',') {
433            let clean_part = part_str.trim();
434            if clean_part.is_empty() {
435                continue;
436            }
437
438            let kv: Vec<&str> = clean_part.splitn(2, '=').collect();
439            if kv.len() != 2 {
440                return Err(Error::from(format!(
441                    "[rbatis-codegen] Malformed attribute in set node near '{}' in '{}'",
442                    clean_part, source_str
443                )));
444            }
445
446            let key = kv[0].trim();
447            let value_str_raw = kv[1].trim();
448
449            match key {
450                "collection" => {
451                    collection_opt = Some(Self::strip_quotes_for_attr(value_str_raw));
452                }
453                "skip_null" => {
454                    let val_bool_str = Self::strip_quotes_for_attr(value_str_raw);
455                    if val_bool_str.is_empty() {
456                        skip_null_val = None;
457                    } else if val_bool_str.eq_ignore_ascii_case("true") {
458                        skip_null_val = Some(true);
459                    } else if val_bool_str.eq_ignore_ascii_case("false") {
460                        skip_null_val = Some(false);
461                    } else {
462                        return Err(Error::from(format!(
463                            "[rbatis-codegen] Invalid boolean value for skip_null: '{}' in '{}'",
464                            value_str_raw, source_str
465                        )));
466                    }
467                }
468                "skips" => {
469                    let inner_skips_str = Self::strip_quotes_for_attr(value_str_raw);
470                    skips_val = inner_skips_str;
471                }
472                _ => {
473                    return Err(Error::from(format!(
474                        "[rbatis-codegen] Unknown attribute '{}' for set node in '{}'",
475                        key, source_str
476                    )));
477                }
478            }
479        }
480        let collection_val = collection_opt.unwrap_or_default();
481        Ok(NodeType::NSet(SetNode {
482            childs,
483            collection: collection_val,
484            skip_null: skip_null_val,
485            skips: skips_val,
486        }))
487    }
488}