1use crate::codegen::parser_html::parse_html;
2use crate::codegen::proc_macro::TokenStream;
3use crate::codegen::syntax_tree_pysql::{
4 bind_node::BindNode, break_node::BreakNode, choose_node::ChooseNode,
5 continue_node::ContinueNode, error::Error, foreach_node::ForEachNode, if_node::IfNode,
6 otherwise_node::OtherwiseNode, set_node::SetNode, sql_node::SqlNode, string_node::StringNode,
7 trim_node::TrimNode, when_node::WhenNode, where_node::WhereNode, DefaultName, Name, NodeType,
8};
9use crate::codegen::ParseArgs;
10use quote::ToTokens;
11use std::collections::HashMap;
12use syn::ItemFn;
13
14pub trait ParsePySql {
16 fn parse_pysql(arg: &str) -> Result<Vec<NodeType>, Error>;
17}
18
19pub fn impl_fn_py(m: &ItemFn, args: &ParseArgs) -> TokenStream {
20 let fn_name = m.sig.ident.to_string();
21
22 let mut data = args
23 .sqls
24 .iter()
25 .map(|x| x.to_token_stream().to_string())
26 .collect::<String>();
27
28 if data.ne("\"\"") && data.starts_with('"') && data.ends_with('"') {
29 data = data[1..data.len() - 1].to_string();
30 }
31
32 data = data.replace("\\n", "\n");
33
34 let nodes = NodeType::parse_pysql(&data).expect("[rbatis-codegen] parse py_sql fail!");
35
36 let is_select = data.starts_with("select") || data.starts_with(" select");
37 let htmls =
38 crate::codegen::syntax_tree_pysql::to_html::to_html_mapper(&nodes, is_select, &fn_name);
39
40 parse_html(&htmls, &fn_name, &mut vec![]).into()
41}
42
43impl ParsePySql for NodeType {
44 fn parse_pysql(arg: &str) -> Result<Vec<NodeType>, Error> {
45 let line_space_map = Self::create_line_space_map(arg);
46 let mut main_node = Vec::new();
47 let mut space = -1;
48 let mut line = -1;
49 let mut skip = -1;
50
51 for x in arg.lines() {
52 line += 1;
53
54 if x.is_empty() || (skip != -1 && line <= skip) {
55 continue;
56 }
57
58 let count_index = *line_space_map
59 .get(&line)
60 .ok_or_else(|| Error::from(format!("line_space_map not have line:{}", line)))?;
61
62 if space == -1 {
63 space = count_index;
64 }
65
66 let (child_str, do_skip) =
67 Self::find_child_str(line, count_index, arg, &line_space_map);
68 if do_skip != -1 && do_skip >= skip {
69 skip = do_skip;
70 }
71
72 let parsed = if !child_str.is_empty() {
73 Self::parse_pysql(&child_str)?
74 } else {
75 vec![]
76 };
77
78 let current_space = *line_space_map
79 .get(&line)
80 .ok_or_else(|| Error::from(format!("line:{} not exist!", line)))?;
81
82 Self::parse(&mut main_node, x, current_space as usize, parsed)?;
83 }
84
85 Ok(main_node)
86 }
87}
88
89impl NodeType {
90 fn parse(
91 main_node: &mut Vec<NodeType>,
92 line: &str,
93 space: usize,
94 mut childs: Vec<NodeType>,
95 ) -> Result<(), Error> {
96 let mut trim_line = line.trim();
97
98 if trim_line.starts_with("//") {
99 return Ok(());
100 }
101
102 if trim_line.ends_with(':') {
103 trim_line = trim_line[..trim_line.len() - 1].trim();
104
105 if trim_line.contains(": ") {
106 let parts: Vec<&str> = trim_line.split(": ").collect();
107 if parts.len() > 1 {
108 for index in (0..parts.len()).rev() {
109 let item = parts[index];
110 childs = vec![Self::parse_node(item, line, childs)?];
111
112 if index == 0 {
113 main_node.extend(childs);
114 return Ok(());
115 }
116 }
117 }
118 }
119
120 let node = Self::parse_node(trim_line, line, childs)?;
121 main_node.push(node);
122 } else {
123 let data = if space <= 1 {
124 line.to_string()
125 } else {
126 line[(space - 1)..].to_string()
127 };
128
129 main_node.push(NodeType::NString(StringNode {
130 value: data.trim().to_string(),
131 }));
132 main_node.extend(childs);
133 }
134
135 Ok(())
136 }
137
138 fn count_space(arg: &str) -> i32 {
139 arg.chars().take_while(|&c| c == ' ').count() as i32
140 }
141
142 fn find_child_str(
143 line_index: i32,
144 space_index: i32,
145 arg: &str,
146 line_space_map: &HashMap<i32, i32>,
147 ) -> (String, i32) {
148 let mut result = String::new();
149 let mut skip_line = -1;
150 let mut current_line = -1;
151
152 for line in arg.lines() {
153 current_line += 1;
154
155 if current_line > line_index {
156 let cached_space = *line_space_map.get(¤t_line).expect("line not exists");
157
158 if cached_space > space_index {
159 result.push_str(line);
160 result.push('\n');
161 skip_line = current_line;
162 } else {
163 break;
164 }
165 }
166 }
167
168 (result, skip_line)
169 }
170
171 fn create_line_space_map(arg: &str) -> HashMap<i32, i32> {
172 arg.lines()
173 .enumerate()
174 .map(|(i, line)| (i as i32, Self::count_space(line)))
175 .collect()
176 }
177
178 fn parse_node(
179 trim_express: &str,
180 source_str: &str,
181 childs: Vec<NodeType>,
182 ) -> Result<NodeType, Error> {
183 match trim_express {
184 s if s.starts_with(IfNode::name()) => Ok(NodeType::NIf(IfNode {
185 childs,
186 test: s.trim_start_matches("if ").to_string(),
187 })),
188
189 s if s.starts_with(ForEachNode::name()) => {
190 Self::parse_for_each_node(s, source_str, childs)
191 }
192
193 s if s.starts_with(TrimNode::name()) => {
194 Self::parse_trim_tag_node(s, source_str, childs)
195 }
196
197 s if s.starts_with(ChooseNode::name()) => Self::parse_choose_node(childs),
198
199 s if s.starts_with(OtherwiseNode::default_name())
200 || s.starts_with(OtherwiseNode::name()) =>
201 {
202 Ok(NodeType::NOtherwise(OtherwiseNode { childs }))
203 }
204
205 s if s.starts_with(WhenNode::name()) => Ok(NodeType::NWhen(WhenNode {
206 childs,
207 test: s[WhenNode::name().len()..].trim().to_string(),
208 })),
209
210 s if s.starts_with(BindNode::default_name()) || s.starts_with(BindNode::name()) => {
211 Self::parse_bind_node(s)
212 }
213
214 s if s.starts_with(SetNode::name()) => Self::parse_set_node(s, source_str, childs),
215
216 s if s.starts_with(WhereNode::name()) => Ok(NodeType::NWhere(WhereNode { childs })),
217
218 s if s.starts_with(ContinueNode::name()) => Ok(NodeType::NContinue(ContinueNode {})),
219
220 s if s.starts_with(BreakNode::name()) => Ok(NodeType::NBreak(BreakNode {})),
221
222 s if s.starts_with(SqlNode::name()) => Self::parse_sql_node(s, childs),
223
224 _ => Err(Error::from(
225 "[rbatis-codegen] unknown tag: ".to_string() + source_str,
226 )),
227 }
228 }
229
230 fn parse_for_each_node(
231 express: &str,
232 source_str: &str,
233 childs: Vec<NodeType>,
234 ) -> Result<NodeType, Error> {
235 const FOR_TAG: &str = "for";
236 const IN_TAG: &str = " in ";
237
238 if !express.starts_with(FOR_TAG) {
239 return Err(Error::from(
240 "[rbatis-codegen] parser express fail:".to_string() + source_str,
241 ));
242 }
243
244 if !express.contains(IN_TAG) {
245 return Err(Error::from(
246 "[rbatis-codegen] parser express fail:".to_string() + source_str,
247 ));
248 }
249
250 let in_index = express
251 .find(IN_TAG)
252 .ok_or_else(|| Error::from(format!("{} not have {}", express, IN_TAG)))?;
253
254 let col = express[in_index + IN_TAG.len()..].trim();
255 let mut item = express[FOR_TAG.len()..in_index].trim();
256 let mut index = "";
257
258 if item.contains(',') {
259 let splits: Vec<&str> = item.split(',').collect();
260 if splits.len() != 2 {
261 panic!("[rbatis-codegen_codegen] for node must be 'for key,item in col:'");
262 }
263 index = splits[0].trim();
264 item = splits[1].trim();
265 }
266
267 Ok(NodeType::NForEach(ForEachNode {
268 childs,
269 collection: col.to_string(),
270 index: index.to_string(),
271 item: item.to_string(),
272 }))
273 }
274
275 fn parse_trim_tag_node(
276 express: &str,
277 _source_str: &str,
278 childs: Vec<NodeType>,
279 ) -> Result<NodeType, Error> {
280 let trim_express = express.trim().trim_start_matches("trim ").trim();
281
282 if (trim_express.starts_with('\'') && trim_express.ends_with('\''))
283 || (trim_express.starts_with('`') && trim_express.ends_with('`'))
284 {
285 let trimmed = if trim_express.starts_with('`') {
286 trim_express.trim_matches('`')
287 } else {
288 trim_express.trim_matches('\'')
289 };
290
291 Ok(NodeType::NTrim(TrimNode {
292 childs,
293 start: trimmed.to_string(),
294 end: trimmed.to_string(),
295 }))
296 } else if trim_express.contains('=') || trim_express.contains(',') {
297 let mut prefix = "";
298 let mut suffix = "";
299
300 for expr in trim_express.split(',') {
301 let expr = expr.trim();
302 if expr.starts_with("start") {
303 prefix = expr
304 .trim_start_matches("start")
305 .trim()
306 .trim_start_matches('=')
307 .trim()
308 .trim_matches(|c| c == '\'' || c == '`');
309 } else if expr.starts_with("end") {
310 suffix = expr
311 .trim_start_matches("end")
312 .trim()
313 .trim_start_matches('=')
314 .trim()
315 .trim_matches(|c| c == '\'' || c == '`');
316 } else {
317 return Err(Error::from(format!(
318 "[rbatis-codegen] express trim node error, for example trim 'value': \
319 trim start='value': trim start='value',end='value': express = {}",
320 trim_express
321 )));
322 }
323 }
324
325 Ok(NodeType::NTrim(TrimNode {
326 childs,
327 start: prefix.to_string(),
328 end: suffix.to_string(),
329 }))
330 } else {
331 Err(Error::from(format!(
332 "[rbatis-codegen] express trim node error, for example trim 'value': \
333 trim start='value': trim start='value',end='value': error express = {}",
334 trim_express
335 )))
336 }
337 }
338
339 fn parse_choose_node(childs: Vec<NodeType>) -> Result<NodeType, Error> {
340 let mut node = ChooseNode {
341 when_nodes: vec![],
342 otherwise_node: None,
343 };
344
345 for child in childs {
346 match child {
347 NodeType::NWhen(_) => node.when_nodes.push(child),
348 NodeType::NOtherwise(_) => node.otherwise_node = Some(Box::new(child)),
349 _ => return Err(Error::from(
350 "[rbatis-codegen] parser node fail,choose node' child must be when and otherwise nodes!".to_string()
351 )),
352 }
353 }
354
355 Ok(NodeType::NChoose(node))
356 }
357
358 fn parse_bind_node(express: &str) -> Result<NodeType, Error> {
359 let expr = if express.starts_with(BindNode::default_name()) {
360 express[BindNode::default_name().len()..].trim()
361 } else {
362 express[BindNode::name().len()..].trim()
363 };
364
365 let parts: Vec<&str> = expr.split('=').collect();
366 if parts.len() != 2 {
367 return Err(Error::from(
368 "[rbatis-codegen] parser bind express fail:".to_string() + express,
369 ));
370 }
371
372 Ok(NodeType::NBind(BindNode {
373 name: parts[0].trim().to_string(),
374 value: parts[1].trim().to_string(),
375 }))
376 }
377
378 fn parse_sql_node(express: &str, childs: Vec<NodeType>) -> Result<NodeType, Error> {
379 let expr = express[SqlNode::name().len()..].trim();
380
381 if !expr.starts_with("id=") {
382 return Err(Error::from(
383 "[rbatis-codegen] parser sql express fail, need id param:".to_string() + express,
384 ));
385 }
386
387 let id_value = expr.trim_start_matches("id=").trim();
388
389 let id = if (id_value.starts_with('\'') && id_value.ends_with('\''))
390 || (id_value.starts_with('"') && id_value.ends_with('"'))
391 {
392 id_value[1..id_value.len() - 1].to_string()
393 } else {
394 return Err(Error::from(
395 "[rbatis-codegen] parser sql id value need quotes:".to_string() + express,
396 ));
397 };
398
399 Ok(NodeType::NSql(SqlNode { childs, id }))
400 }
401
402 fn strip_quotes_for_attr(s: &str) -> String {
403 let val = s.trim(); if val.starts_with('\'') && val.ends_with('\'')
405 || (val.starts_with('"') && val.ends_with('"'))
406 {
407 if val.len() >= 2 {
408 return val[1..val.len() - 1].to_string();
409 }
410 }
411 val.to_string() }
413
414 fn parse_set_node(
415 express: &str,
416 source_str: &str,
417 childs: Vec<NodeType>,
418 ) -> Result<NodeType, Error> {
419 let actual_attrs_str = if express.starts_with(SetNode::name()) {
420 express[SetNode::name().len()..].trim()
421 } else {
422 return Err(Error::from(format!(
424 "[rbatis-codegen] SetNode expression '{}' does not start with '{}'",
425 express,
426 SetNode::name()
427 )));
428 };
429 let mut collection_opt: Option<String> = None;
430 let mut skip_null_val = None; let mut skips_val: String = String::new(); for part_str in actual_attrs_str.split(',') {
433 let clean_part = part_str.trim();
434 if clean_part.is_empty() {
435 continue;
436 }
437
438 let kv: Vec<&str> = clean_part.splitn(2, '=').collect();
439 if kv.len() != 2 {
440 return Err(Error::from(format!(
441 "[rbatis-codegen] Malformed attribute in set node near '{}' in '{}'",
442 clean_part, source_str
443 )));
444 }
445
446 let key = kv[0].trim();
447 let value_str_raw = kv[1].trim();
448
449 match key {
450 "collection" => {
451 collection_opt = Some(Self::strip_quotes_for_attr(value_str_raw));
452 }
453 "skip_null" => {
454 let val_bool_str = Self::strip_quotes_for_attr(value_str_raw);
455 if val_bool_str.is_empty() {
456 skip_null_val = None;
457 } else if val_bool_str.eq_ignore_ascii_case("true") {
458 skip_null_val = Some(true);
459 } else if val_bool_str.eq_ignore_ascii_case("false") {
460 skip_null_val = Some(false);
461 } else {
462 return Err(Error::from(format!(
463 "[rbatis-codegen] Invalid boolean value for skip_null: '{}' in '{}'",
464 value_str_raw, source_str
465 )));
466 }
467 }
468 "skips" => {
469 let inner_skips_str = Self::strip_quotes_for_attr(value_str_raw);
470 skips_val = inner_skips_str;
471 }
472 _ => {
473 return Err(Error::from(format!(
474 "[rbatis-codegen] Unknown attribute '{}' for set node in '{}'",
475 key, source_str
476 )));
477 }
478 }
479 }
480 let collection_val = collection_opt.unwrap_or_default();
481 Ok(NodeType::NSet(SetNode {
482 childs,
483 collection: collection_val,
484 skip_null: skip_null_val,
485 skips: skips_val,
486 }))
487 }
488}