1use crate::codegen::parser_html::parse_html;
2use crate::codegen::proc_macro::TokenStream;
3use crate::codegen::syntax_tree_pysql::{
4 bind_node::BindNode, break_node::BreakNode, choose_node::ChooseNode, continue_node::ContinueNode,
5 error::Error, foreach_node::ForEachNode, if_node::IfNode, otherwise_node::OtherwiseNode,
6 set_node::SetNode, sql_node::SqlNode, string_node::StringNode, trim_node::TrimNode,
7 when_node::WhenNode, where_node::WhereNode, DefaultName, Name, NodeType,
8};
9use crate::codegen::ParseArgs;
10use quote::ToTokens;
11use std::collections::HashMap;
12use syn::ItemFn;
13
14pub trait ParsePySql {
16 fn parse_pysql(arg: &str) -> Result<Vec<NodeType>, Error>;
17}
18
19pub fn impl_fn_py(m: &ItemFn, args: &ParseArgs) -> TokenStream {
20 let fn_name = m.sig.ident.to_string();
21
22 let mut data = args.sqls.iter()
23 .map(|x| x.to_token_stream().to_string())
24 .collect::<String>();
25
26 if data.ne("\"\"") && data.starts_with('"') && data.ends_with('"') {
27 data = data[1..data.len() - 1].to_string();
28 }
29
30 data = data.replace("\\n", "\n");
31
32 let nodes = NodeType::parse_pysql(&data)
33 .expect("[rbatis-codegen] parse py_sql fail!");
34
35 let is_select = data.starts_with("select") || data.starts_with(" select");
36 let htmls = crate::codegen::syntax_tree_pysql::to_html::to_html_mapper(&nodes, is_select, &fn_name);
37
38 parse_html(&htmls, &fn_name, &mut vec![]).into()
39}
40
41impl ParsePySql for NodeType {
42 fn parse_pysql(arg: &str) -> Result<Vec<NodeType>, Error> {
43 let line_space_map = Self::create_line_space_map(arg);
44 let mut main_node = Vec::new();
45 let mut space = -1;
46 let mut line = -1;
47 let mut skip = -1;
48
49 for x in arg.lines() {
50 line += 1;
51
52 if x.is_empty() || (skip != -1 && line <= skip) {
53 continue;
54 }
55
56 let count_index = *line_space_map
57 .get(&line)
58 .ok_or_else(|| Error::from(format!("line_space_map not have line:{}", line)))?;
59
60 if space == -1 {
61 space = count_index;
62 }
63
64 let (child_str, do_skip) = Self::find_child_str(line, count_index, arg, &line_space_map);
65 if do_skip != -1 && do_skip >= skip {
66 skip = do_skip;
67 }
68
69 let parsed = if !child_str.is_empty() {
70 Self::parse_pysql(&child_str)?
71 } else {
72 vec![]
73 };
74
75 let current_space = *line_space_map
76 .get(&line)
77 .ok_or_else(|| Error::from(format!("line:{} not exist!", line)))?;
78
79 Self::parse(&mut main_node, x, current_space as usize, parsed)?;
80 }
81
82 Ok(main_node)
83 }
84}
85
86impl NodeType {
87 fn parse(
88 main_node: &mut Vec<NodeType>,
89 line: &str,
90 space: usize,
91 mut childs: Vec<NodeType>,
92 ) -> Result<(), Error> {
93 let mut trim_line = line.trim();
94
95 if trim_line.starts_with("//") {
96 return Ok(());
97 }
98
99 if trim_line.ends_with(':') {
100 trim_line = trim_line[..trim_line.len() - 1].trim();
101
102 if trim_line.contains(": ") {
103 let parts: Vec<&str> = trim_line.split(": ").collect();
104 if parts.len() > 1 {
105 for index in (0..parts.len()).rev() {
106 let item = parts[index];
107 childs = vec![Self::parse_node(item, line, childs)?];
108
109 if index == 0 {
110 main_node.extend(childs);
111 return Ok(());
112 }
113 }
114 }
115 }
116
117 let node = Self::parse_node(trim_line, line, childs)?;
118 main_node.push(node);
119 } else {
120 let data = if space <= 1 {
121 line.to_string()
122 } else {
123 line[(space - 1)..].to_string()
124 };
125
126 main_node.push(NodeType::NString(StringNode {
127 value: data.trim().to_string(),
128 }));
129 main_node.extend(childs);
130 }
131
132 Ok(())
133 }
134
135 fn count_space(arg: &str) -> i32 {
136 arg.chars()
137 .take_while(|&c| c == ' ')
138 .count() as i32
139 }
140
141 fn find_child_str(
142 line_index: i32,
143 space_index: i32,
144 arg: &str,
145 line_space_map: &HashMap<i32, i32>,
146 ) -> (String, i32) {
147 let mut result = String::new();
148 let mut skip_line = -1;
149 let mut current_line = -1;
150
151 for line in arg.lines() {
152 current_line += 1;
153
154 if current_line > line_index {
155 let cached_space = *line_space_map.get(¤t_line).expect("line not exists");
156
157 if cached_space > space_index {
158 result.push_str(line);
159 result.push('\n');
160 skip_line = current_line;
161 } else {
162 break;
163 }
164 }
165 }
166
167 (result, skip_line)
168 }
169
170 fn create_line_space_map(arg: &str) -> HashMap<i32, i32> {
171 arg.lines()
172 .enumerate()
173 .map(|(i, line)| (i as i32, Self::count_space(line)))
174 .collect()
175 }
176
177 fn parse_node(
178 trim_express: &str,
179 source_str: &str,
180 childs: Vec<NodeType>,
181 ) -> Result<NodeType, Error> {
182 match trim_express {
183 s if s.starts_with(IfNode::name()) => Ok(NodeType::NIf(IfNode {
184 childs,
185 test: s.trim_start_matches("if ").to_string(),
186 })),
187
188 s if s.starts_with(ForEachNode::name()) => Self::parse_for_each_node(s, source_str, childs),
189
190 s if s.starts_with(TrimNode::name()) => Self::parse_trim_tag_node(s, source_str, childs),
191
192 s if s.starts_with(ChooseNode::name()) => Self::parse_choose_node(childs),
193
194 s if s.starts_with(OtherwiseNode::default_name()) || s.starts_with(OtherwiseNode::name()) => {
195 Ok(NodeType::NOtherwise(OtherwiseNode { childs }))
196 }
197
198 s if s.starts_with(WhenNode::name()) => Ok(NodeType::NWhen(WhenNode {
199 childs,
200 test: s[WhenNode::name().len()..].trim().to_string(),
201 })),
202
203 s if s.starts_with(BindNode::default_name()) || s.starts_with(BindNode::name()) => {
204 Self::parse_bind_node(s)
205 }
206
207 s if s.starts_with(SetNode::name()) => Self::parse_set_node(s,source_str,childs),
208
209 s if s.starts_with(WhereNode::name()) => Ok(NodeType::NWhere(WhereNode { childs })),
210
211 s if s.starts_with(ContinueNode::name()) => Ok(NodeType::NContinue(ContinueNode {})),
212
213 s if s.starts_with(BreakNode::name()) => Ok(NodeType::NBreak(BreakNode {})),
214
215 s if s.starts_with(SqlNode::name()) => Self::parse_sql_node(s, childs),
216
217 _ => Err(Error::from("[rbatis-codegen] unknown tag: ".to_string() + source_str)),
218 }
219 }
220
221 fn parse_for_each_node(express: &str, source_str: &str, childs: Vec<NodeType>) -> Result<NodeType, Error> {
222 const FOR_TAG: &str = "for";
223 const IN_TAG: &str = " in ";
224
225 if !express.starts_with(FOR_TAG) {
226 return Err(Error::from("[rbatis-codegen] parser express fail:".to_string() + source_str));
227 }
228
229 if !express.contains(IN_TAG) {
230 return Err(Error::from("[rbatis-codegen] parser express fail:".to_string() + source_str));
231 }
232
233 let in_index = express.find(IN_TAG)
234 .ok_or_else(|| Error::from(format!("{} not have {}", express, IN_TAG)))?;
235
236 let col = express[in_index + IN_TAG.len()..].trim();
237 let mut item = express[FOR_TAG.len()..in_index].trim();
238 let mut index = "";
239
240 if item.contains(',') {
241 let splits: Vec<&str> = item.split(',').collect();
242 if splits.len() != 2 {
243 panic!("[rbatis-codegen_codegen] for node must be 'for key,item in col:'");
244 }
245 index = splits[0].trim();
246 item = splits[1].trim();
247 }
248
249 Ok(NodeType::NForEach(ForEachNode {
250 childs,
251 collection: col.to_string(),
252 index: index.to_string(),
253 item: item.to_string(),
254 }))
255 }
256
257 fn parse_trim_tag_node(express: &str, _source_str: &str, childs: Vec<NodeType>) -> Result<NodeType, Error> {
258 let trim_express = express.trim().trim_start_matches("trim ").trim();
259
260 if (trim_express.starts_with('\'') && trim_express.ends_with('\'')) ||
261 (trim_express.starts_with('`') && trim_express.ends_with('`'))
262 {
263 let trimmed = if trim_express.starts_with('`') {
264 trim_express.trim_matches('`')
265 } else {
266 trim_express.trim_matches('\'')
267 };
268
269 Ok(NodeType::NTrim(TrimNode {
270 childs,
271 start: trimmed.to_string(),
272 end: trimmed.to_string(),
273 }))
274 } else if trim_express.contains('=') || trim_express.contains(',') {
275 let mut prefix = "";
276 let mut suffix = "";
277
278 for expr in trim_express.split(',') {
279 let expr = expr.trim();
280 if expr.starts_with("start") {
281 prefix = expr.trim_start_matches("start")
282 .trim()
283 .trim_start_matches('=')
284 .trim()
285 .trim_matches(|c| c == '\'' || c == '`');
286 } else if expr.starts_with("end") {
287 suffix = expr.trim_start_matches("end")
288 .trim()
289 .trim_start_matches('=')
290 .trim()
291 .trim_matches(|c| c == '\'' || c == '`');
292 } else {
293 return Err(Error::from(format!(
294 "[rbatis-codegen] express trim node error, for example trim 'value': \
295 trim start='value': trim start='value',end='value': express = {}",
296 trim_express
297 )));
298 }
299 }
300
301 Ok(NodeType::NTrim(TrimNode {
302 childs,
303 start: prefix.to_string(),
304 end: suffix.to_string(),
305 }))
306 } else {
307 Err(Error::from(format!(
308 "[rbatis-codegen] express trim node error, for example trim 'value': \
309 trim start='value': trim start='value',end='value': error express = {}",
310 trim_express
311 )))
312 }
313 }
314
315 fn parse_choose_node(childs: Vec<NodeType>) -> Result<NodeType, Error> {
316 let mut node = ChooseNode {
317 when_nodes: vec![],
318 otherwise_node: None,
319 };
320
321 for child in childs {
322 match child {
323 NodeType::NWhen(_) => node.when_nodes.push(child),
324 NodeType::NOtherwise(_) => node.otherwise_node = Some(Box::new(child)),
325 _ => return Err(Error::from(
326 "[rbatis-codegen] parser node fail,choose node' child must be when and otherwise nodes!".to_string()
327 )),
328 }
329 }
330
331 Ok(NodeType::NChoose(node))
332 }
333
334 fn parse_bind_node(express: &str) -> Result<NodeType, Error> {
335 let expr = if express.starts_with(BindNode::default_name()) {
336 express[BindNode::default_name().len()..].trim()
337 } else {
338 express[BindNode::name().len()..].trim()
339 };
340
341 let parts: Vec<&str> = expr.split('=').collect();
342 if parts.len() != 2 {
343 return Err(Error::from(
344 "[rbatis-codegen] parser bind express fail:".to_string() + express,
345 ));
346 }
347
348 Ok(NodeType::NBind(BindNode {
349 name: parts[0].trim().to_string(),
350 value: parts[1].trim().to_string(),
351 }))
352 }
353
354 fn parse_sql_node(express: &str, childs: Vec<NodeType>) -> Result<NodeType, Error> {
355 let expr = express[SqlNode::name().len()..].trim();
356
357 if !expr.starts_with("id=") {
358 return Err(Error::from(
359 "[rbatis-codegen] parser sql express fail, need id param:".to_string() + express,
360 ));
361 }
362
363 let id_value = expr.trim_start_matches("id=").trim();
364
365 let id = if (id_value.starts_with('\'') && id_value.ends_with('\'')) ||
366 (id_value.starts_with('"') && id_value.ends_with('"'))
367 {
368 id_value[1..id_value.len() - 1].to_string()
369 } else {
370 return Err(Error::from(
371 "[rbatis-codegen] parser sql id value need quotes:".to_string() + express,
372 ));
373 };
374
375 Ok(NodeType::NSql(SqlNode { childs, id }))
376 }
377
378 fn strip_quotes_for_attr(s: &str) -> String {
379 let val = s.trim(); if val.starts_with('\'') && val.ends_with('\'') ||
381 (val.starts_with('"') && val.ends_with('"')) {
382 if val.len() >= 2 {
383 return val[1..val.len()-1].to_string();
384 }
385 }
386 val.to_string() }
388
389 fn parse_set_node(express: &str, source_str: &str, childs: Vec<NodeType>) -> Result<NodeType, Error> {
390 let actual_attrs_str = if express.starts_with(SetNode::name()) {
391 express[SetNode::name().len()..].trim()
392 } else {
393 return Err(Error::from(format!("[rbatis-codegen] SetNode expression '{}' does not start with '{}'", express, SetNode::name())));
395 };
396 let mut collection_opt: Option<String> = None;
397 let mut skip_null_val = false; let mut skips_val: String = String::new(); for part_str in actual_attrs_str.split(',') {
400 let clean_part = part_str.trim();
401 if clean_part.is_empty() {
402 continue;
403 }
404
405 let kv: Vec<&str> = clean_part.splitn(2, '=').collect();
406 if kv.len() != 2 {
407 return Err(Error::from(format!("[rbatis-codegen] Malformed attribute in set node near '{}' in '{}'", clean_part, source_str)));
408 }
409
410 let key = kv[0].trim();
411 let value_str_raw = kv[1].trim();
412
413 match key {
414 "collection" => {
415 collection_opt = Some(Self::strip_quotes_for_attr(value_str_raw));
416 }
417 "skip_null" => {
418 let val_bool_str = Self::strip_quotes_for_attr(value_str_raw);
419 if val_bool_str.eq_ignore_ascii_case("true") {
420 skip_null_val = true;
421 } else if val_bool_str.eq_ignore_ascii_case("false") {
422 skip_null_val = false;
423 } else {
424 return Err(Error::from(format!("[rbatis-codegen] Invalid boolean value for skip_null: '{}' in '{}'", value_str_raw, source_str)));
425 }
426 }
427 "skips" => {
428 let inner_skips_str = Self::strip_quotes_for_attr(value_str_raw);
429 skips_val = inner_skips_str;
430 }
431 _ => {
432 return Err(Error::from(format!("[rbatis-codegen] Unknown attribute '{}' for set node in '{}'", key, source_str)));
433 }
434 }
435 }
436 let collection_val = collection_opt.unwrap_or_default();
437 Ok(NodeType::NSet(SetNode {
438 childs,
439 collection: collection_val,
440 skip_null: skip_null_val,
441 skips: skips_val,
442 }))
443 }
444}