rbatis_codegen/codegen/
parser_pysql.rs

1use crate::codegen::parser_html::parse_html;
2use crate::codegen::proc_macro::TokenStream;
3use crate::codegen::syntax_tree_pysql::{
4    bind_node::BindNode, break_node::BreakNode, choose_node::ChooseNode, continue_node::ContinueNode,
5    error::Error, foreach_node::ForEachNode, if_node::IfNode, otherwise_node::OtherwiseNode,
6    set_node::SetNode, sql_node::SqlNode, string_node::StringNode, trim_node::TrimNode,
7    when_node::WhenNode, where_node::WhereNode, DefaultName, Name, NodeType,
8};
9use crate::codegen::ParseArgs;
10use quote::ToTokens;
11use std::collections::HashMap;
12use syn::ItemFn;
13
14///A handwritten recursive descent algorithm for parsing PySQL
15pub trait ParsePySql {
16    fn parse_pysql(arg: &str) -> Result<Vec<NodeType>, Error>;
17}
18
19pub fn impl_fn_py(m: &ItemFn, args: &ParseArgs) -> TokenStream {
20    let fn_name = m.sig.ident.to_string();
21
22    let mut data = args.sqls.iter()
23        .map(|x| x.to_token_stream().to_string())
24        .collect::<String>();
25
26    if data.ne("\"\"") && data.starts_with('"') && data.ends_with('"') {
27        data = data[1..data.len() - 1].to_string();
28    }
29
30    data = data.replace("\\n", "\n");
31
32    let nodes = NodeType::parse_pysql(&data)
33        .expect("[rbatis-codegen] parse py_sql fail!");
34
35    let is_select = data.starts_with("select") || data.starts_with(" select");
36    let htmls = crate::codegen::syntax_tree_pysql::to_html::to_html_mapper(&nodes, is_select, &fn_name);
37
38    parse_html(&htmls, &fn_name, &mut vec![]).into()
39}
40
41impl ParsePySql for NodeType {
42    fn parse_pysql(arg: &str) -> Result<Vec<NodeType>, Error> {
43        let line_space_map = Self::create_line_space_map(arg);
44        let mut main_node = Vec::new();
45        let mut space = -1;
46        let mut line = -1;
47        let mut skip = -1;
48
49        for x in arg.lines() {
50            line += 1;
51
52            if x.is_empty() || (skip != -1 && line <= skip) {
53                continue;
54            }
55
56            let count_index = *line_space_map
57                .get(&line)
58                .ok_or_else(|| Error::from(format!("line_space_map not have line:{}", line)))?;
59
60            if space == -1 {
61                space = count_index;
62            }
63
64            let (child_str, do_skip) = Self::find_child_str(line, count_index, arg, &line_space_map);
65            if do_skip != -1 && do_skip >= skip {
66                skip = do_skip;
67            }
68
69            let parsed = if !child_str.is_empty() {
70                Self::parse_pysql(&child_str)?
71            } else {
72                vec![]
73            };
74
75            let current_space = *line_space_map
76                .get(&line)
77                .ok_or_else(|| Error::from(format!("line:{} not exist!", line)))?;
78
79            Self::parse(&mut main_node, x, current_space as usize, parsed)?;
80        }
81
82        Ok(main_node)
83    }
84}
85
86impl NodeType {
87    fn parse(
88        main_node: &mut Vec<NodeType>,
89        line: &str,
90        space: usize,
91        mut childs: Vec<NodeType>,
92    ) -> Result<(), Error> {
93        let mut trim_line = line.trim();
94
95        if trim_line.starts_with("//") {
96            return Ok(());
97        }
98
99        if trim_line.ends_with(':') {
100            trim_line = trim_line[..trim_line.len() - 1].trim();
101
102            if trim_line.contains(": ") {
103                let parts: Vec<&str> = trim_line.split(": ").collect();
104                if parts.len() > 1 {
105                    for index in (0..parts.len()).rev() {
106                        let item = parts[index];
107                        childs = vec![Self::parse_node(item, line, childs)?];
108
109                        if index == 0 {
110                            main_node.extend(childs);
111                            return Ok(());
112                        }
113                    }
114                }
115            }
116
117            let node = Self::parse_node(trim_line, line, childs)?;
118            main_node.push(node);
119        } else {
120            let data = if space <= 1 {
121                line.to_string()
122            } else {
123                line[(space - 1)..].to_string()
124            };
125
126            main_node.push(NodeType::NString(StringNode {
127                value: data.trim().to_string(),
128            }));
129            main_node.extend(childs);
130        }
131
132        Ok(())
133    }
134
135    fn count_space(arg: &str) -> i32 {
136        arg.chars()
137            .take_while(|&c| c == ' ')
138            .count() as i32
139    }
140
141    fn find_child_str(
142        line_index: i32,
143        space_index: i32,
144        arg: &str,
145        line_space_map: &HashMap<i32, i32>,
146    ) -> (String, i32) {
147        let mut result = String::new();
148        let mut skip_line = -1;
149        let mut current_line = -1;
150
151        for line in arg.lines() {
152            current_line += 1;
153
154            if current_line > line_index {
155                let cached_space = *line_space_map.get(&current_line).expect("line not exists");
156
157                if cached_space > space_index {
158                    result.push_str(line);
159                    result.push('\n');
160                    skip_line = current_line;
161                } else {
162                    break;
163                }
164            }
165        }
166
167        (result, skip_line)
168    }
169
170    fn create_line_space_map(arg: &str) -> HashMap<i32, i32> {
171        arg.lines()
172            .enumerate()
173            .map(|(i, line)| (i as i32, Self::count_space(line)))
174            .collect()
175    }
176
177    fn parse_node(
178        trim_express: &str,
179        source_str: &str,
180        childs: Vec<NodeType>,
181    ) -> Result<NodeType, Error> {
182        match trim_express {
183            s if s.starts_with(IfNode::name()) => Ok(NodeType::NIf(IfNode {
184                childs,
185                test: s.trim_start_matches("if ").to_string(),
186            })),
187
188            s if s.starts_with(ForEachNode::name()) => Self::parse_for_each_node(s, source_str, childs),
189
190            s if s.starts_with(TrimNode::name()) => Self::parse_trim_tag_node(s, source_str, childs),
191
192            s if s.starts_with(ChooseNode::name()) => Self::parse_choose_node(childs),
193
194            s if s.starts_with(OtherwiseNode::default_name()) || s.starts_with(OtherwiseNode::name()) => {
195                Ok(NodeType::NOtherwise(OtherwiseNode { childs }))
196            }
197
198            s if s.starts_with(WhenNode::name()) => Ok(NodeType::NWhen(WhenNode {
199                childs,
200                test: s[WhenNode::name().len()..].trim().to_string(),
201            })),
202
203            s if s.starts_with(BindNode::default_name()) || s.starts_with(BindNode::name()) => {
204                Self::parse_bind_node(s)
205            }
206
207            s if s.starts_with(SetNode::name()) => Self::parse_set_node(s,source_str,childs),
208
209            s if s.starts_with(WhereNode::name()) => Ok(NodeType::NWhere(WhereNode { childs })),
210
211            s if s.starts_with(ContinueNode::name()) => Ok(NodeType::NContinue(ContinueNode {})),
212
213            s if s.starts_with(BreakNode::name()) => Ok(NodeType::NBreak(BreakNode {})),
214
215            s if s.starts_with(SqlNode::name()) => Self::parse_sql_node(s, childs),
216
217            _ => Err(Error::from("[rbatis-codegen] unknown tag: ".to_string() + source_str)),
218        }
219    }
220
221    fn parse_for_each_node(express: &str, source_str: &str, childs: Vec<NodeType>) -> Result<NodeType, Error> {
222        const FOR_TAG: &str = "for";
223        const IN_TAG: &str = " in ";
224
225        if !express.starts_with(FOR_TAG) {
226            return Err(Error::from("[rbatis-codegen] parser express fail:".to_string() + source_str));
227        }
228
229        if !express.contains(IN_TAG) {
230            return Err(Error::from("[rbatis-codegen] parser express fail:".to_string() + source_str));
231        }
232
233        let in_index = express.find(IN_TAG)
234            .ok_or_else(|| Error::from(format!("{} not have {}", express, IN_TAG)))?;
235
236        let col = express[in_index + IN_TAG.len()..].trim();
237        let mut item = express[FOR_TAG.len()..in_index].trim();
238        let mut index = "";
239
240        if item.contains(',') {
241            let splits: Vec<&str> = item.split(',').collect();
242            if splits.len() != 2 {
243                panic!("[rbatis-codegen_codegen] for node must be 'for key,item in col:'");
244            }
245            index = splits[0].trim();
246            item = splits[1].trim();
247        }
248
249        Ok(NodeType::NForEach(ForEachNode {
250            childs,
251            collection: col.to_string(),
252            index: index.to_string(),
253            item: item.to_string(),
254        }))
255    }
256
257    fn parse_trim_tag_node(express: &str, _source_str: &str, childs: Vec<NodeType>) -> Result<NodeType, Error> {
258        let trim_express = express.trim().trim_start_matches("trim ").trim();
259
260        if (trim_express.starts_with('\'') && trim_express.ends_with('\'')) ||
261            (trim_express.starts_with('`') && trim_express.ends_with('`'))
262        {
263            let trimmed = if trim_express.starts_with('`') {
264                trim_express.trim_matches('`')
265            } else {
266                trim_express.trim_matches('\'')
267            };
268
269            Ok(NodeType::NTrim(TrimNode {
270                childs,
271                start: trimmed.to_string(),
272                end: trimmed.to_string(),
273            }))
274        } else if trim_express.contains('=') || trim_express.contains(',') {
275            let mut prefix = "";
276            let mut suffix = "";
277
278            for expr in trim_express.split(',') {
279                let expr = expr.trim();
280                if expr.starts_with("start") {
281                    prefix = expr.trim_start_matches("start")
282                        .trim()
283                        .trim_start_matches('=')
284                        .trim()
285                        .trim_matches(|c| c == '\'' || c == '`');
286                } else if expr.starts_with("end") {
287                    suffix = expr.trim_start_matches("end")
288                        .trim()
289                        .trim_start_matches('=')
290                        .trim()
291                        .trim_matches(|c| c == '\'' || c == '`');
292                } else {
293                    return Err(Error::from(format!(
294                        "[rbatis-codegen] express trim node error, for example  trim 'value':  \
295                        trim start='value': trim start='value',end='value':   express = {}",
296                        trim_express
297                    )));
298                }
299            }
300
301            Ok(NodeType::NTrim(TrimNode {
302                childs,
303                start: prefix.to_string(),
304                end: suffix.to_string(),
305            }))
306        } else {
307            Err(Error::from(format!(
308                "[rbatis-codegen] express trim node error, for example  trim 'value':  \
309                trim start='value': trim start='value',end='value':   error express = {}",
310                trim_express
311            )))
312        }
313    }
314
315    fn parse_choose_node(childs: Vec<NodeType>) -> Result<NodeType, Error> {
316        let mut node = ChooseNode {
317            when_nodes: vec![],
318            otherwise_node: None,
319        };
320
321        for child in childs {
322            match child {
323                NodeType::NWhen(_) => node.when_nodes.push(child),
324                NodeType::NOtherwise(_) => node.otherwise_node = Some(Box::new(child)),
325                _ => return Err(Error::from(
326                    "[rbatis-codegen] parser node fail,choose node' child must be when and otherwise nodes!".to_string()
327                )),
328            }
329        }
330
331        Ok(NodeType::NChoose(node))
332    }
333
334    fn parse_bind_node(express: &str) -> Result<NodeType, Error> {
335        let expr = if express.starts_with(BindNode::default_name()) {
336            express[BindNode::default_name().len()..].trim()
337        } else {
338            express[BindNode::name().len()..].trim()
339        };
340
341        let parts: Vec<&str> = expr.split('=').collect();
342        if parts.len() != 2 {
343            return Err(Error::from(
344                "[rbatis-codegen] parser bind express fail:".to_string() + express,
345            ));
346        }
347
348        Ok(NodeType::NBind(BindNode {
349            name: parts[0].trim().to_string(),
350            value: parts[1].trim().to_string(),
351        }))
352    }
353
354    fn parse_sql_node(express: &str, childs: Vec<NodeType>) -> Result<NodeType, Error> {
355        let expr = express[SqlNode::name().len()..].trim();
356
357        if !expr.starts_with("id=") {
358            return Err(Error::from(
359                "[rbatis-codegen] parser sql express fail, need id param:".to_string() + express,
360            ));
361        }
362
363        let id_value = expr.trim_start_matches("id=").trim();
364
365        let id = if (id_value.starts_with('\'') && id_value.ends_with('\'')) ||
366            (id_value.starts_with('"') && id_value.ends_with('"'))
367        {
368            id_value[1..id_value.len() - 1].to_string()
369        } else {
370            return Err(Error::from(
371                "[rbatis-codegen] parser sql id value need quotes:".to_string() + express,
372            ));
373        };
374
375        Ok(NodeType::NSql(SqlNode { childs, id }))
376    }
377
378    fn strip_quotes_for_attr(s: &str) -> String {
379        let val = s.trim(); // Trim whitespace around the value first
380        if val.starts_with('\'') && val.ends_with('\'') ||
381           (val.starts_with('"') && val.ends_with('"')) {
382            if val.len() >= 2 {
383                return val[1..val.len()-1].to_string();
384            }
385        }
386        val.to_string() // Return the trimmed string if no quotes or malformed quotes
387    }
388    
389    fn parse_set_node(express: &str, source_str: &str,  childs: Vec<NodeType>) -> Result<NodeType, Error>  {
390        let actual_attrs_str = if express.starts_with(SetNode::name()) {
391            express[SetNode::name().len()..].trim()
392        } else {
393            // This case should ideally not happen if called correctly from the match arm
394            return Err(Error::from(format!("[rbatis-codegen] SetNode expression '{}' does not start with '{}'", express, SetNode::name())));
395        };
396        let mut collection_opt: Option<String> = None;
397        let mut skip_null_val = false; // Default
398        let mut skips_val: String = String::new(); // Default is now an empty String
399        for part_str in actual_attrs_str.split(',') {
400            let clean_part = part_str.trim();
401            if clean_part.is_empty() {
402                continue;
403            }
404
405            let kv: Vec<&str> = clean_part.splitn(2, '=').collect();
406            if kv.len() != 2 {
407                return Err(Error::from(format!("[rbatis-codegen] Malformed attribute in set node near '{}' in '{}'", clean_part, source_str)));
408            }
409
410            let key = kv[0].trim();
411            let value_str_raw = kv[1].trim();
412
413            match key {
414                "collection" => {
415                    collection_opt = Some(Self::strip_quotes_for_attr(value_str_raw));
416                }
417                "skip_null" => {
418                    let val_bool_str = Self::strip_quotes_for_attr(value_str_raw);
419                    if val_bool_str.eq_ignore_ascii_case("true") {
420                        skip_null_val = true;
421                    } else if val_bool_str.eq_ignore_ascii_case("false") {
422                        skip_null_val = false;
423                    } else {
424                        return Err(Error::from(format!("[rbatis-codegen] Invalid boolean value for skip_null: '{}' in '{}'", value_str_raw, source_str)));
425                    }
426                }
427                "skips" => {
428                    let inner_skips_str = Self::strip_quotes_for_attr(value_str_raw);
429                    skips_val = inner_skips_str;
430                }
431                _ => {
432                    return Err(Error::from(format!("[rbatis-codegen] Unknown attribute '{}' for set node in '{}'", key, source_str)));
433                }
434            }
435        }
436        let collection_val = collection_opt.unwrap_or_default();
437        Ok(NodeType::NSet(SetNode {
438            childs,
439            collection: collection_val,
440            skip_null: skip_null_val,
441            skips: skips_val,
442        }))
443    }
444}