1use proc_macro2::{Ident, Span};
2use quote::{quote, ToTokens};
3use std::collections::{BTreeMap, HashMap};
4use std::fs::File;
5use std::io::Read;
6use std::path::PathBuf;
7use syn::{ItemFn, LitStr};
8use url::Url;
9
10use crate::codegen::loader_html::{load_html, Element};
11use crate::codegen::proc_macro::TokenStream;
12use crate::codegen::string_util::find_convert_string;
13use crate::codegen::ParseArgs;
14use crate::error::Error;
15
16const SQL_TAG: &str = "sql";
18const INCLUDE_TAG: &str = "include";
19const MAPPER_TAG: &str = "mapper";
20const IF_TAG: &str = "if";
21const TRIM_TAG: &str = "trim";
22const BIND_TAG: &str = "bind";
23const WHERE_TAG: &str = "where";
24const CHOOSE_TAG: &str = "choose";
25const WHEN_TAG: &str = "when";
26const OTHERWISE_TAG: &str = "otherwise";
27const FOREACH_TAG: &str = "foreach";
28const SET_TAG: &str = "set";
29const CONTINUE_TAG: &str = "continue";
30const BREAK_TAG: &str = "break";
31const SELECT_TAG: &str = "select";
32const UPDATE_TAG: &str = "update";
33const INSERT_TAG: &str = "insert";
34const DELETE_TAG: &str = "delete";
35
36pub fn load_mapper_map(html: &str) -> Result<BTreeMap<String, Element>, Error> {
38 let elements = load_mapper_vec(html)?;
39 let mut sql_map = BTreeMap::new();
40 let processed_elements = include_replace(elements, &mut sql_map);
41
42 let mut m = BTreeMap::new();
43 for x in processed_elements {
44 if let Some(v) = x.attrs.get("id") {
45 m.insert(v.to_string(), x);
46 }
47 }
48 Ok(m)
49}
50
51pub fn load_mapper_vec(html: &str) -> Result<Vec<Element>, Error> {
53 let elements = load_html(html).map_err(|e| Error::from(e.to_string()))?;
54
55 let mut mappers = Vec::new();
56 for element in elements {
57 if element.tag == MAPPER_TAG {
58 mappers.extend(element.childs);
59 } else {
60 mappers.push(element);
61 }
62 }
63
64 Ok(mappers)
65}
66
67pub fn parse_html(html: &str, fn_name: &str, ignore: &mut Vec<String>) -> proc_macro2::TokenStream {
69 let processed_html = html
70 .replace("\\\"", "\"")
71 .replace("\\n", "\n")
72 .trim_matches('"')
73 .to_string();
74
75 let elements = load_mapper_map(&processed_html)
76 .unwrap_or_else(|_| panic!("Failed to load html: {}", processed_html));
77
78 let (_, element) = elements.into_iter().next()
79 .unwrap_or_else(|| panic!("HTML not found for function: {}", fn_name));
80
81 parse_html_node(vec![element], ignore, fn_name)
82}
83
84fn include_replace(elements: Vec<Element>, sql_map: &mut BTreeMap<String, Element>) -> Vec<Element> {
86 elements.into_iter().map(|mut element| {
87 match element.tag.as_str() {
88 SQL_TAG => {
89 let id = element.attrs.get("id")
90 .expect("[rbatis-codegen] <sql> element must have id!");
91 sql_map.insert(id.clone(), element.clone());
92 }
93 INCLUDE_TAG => {
94 element = handle_include_element(&element, sql_map);
95 }
96 _ => {
97 if let Some(id) = element.attrs.get("id").filter(|id| !id.is_empty()) {
98 sql_map.insert(id.clone(), element.clone());
99 }
100 }
101 }
102
103 if !element.childs.is_empty() {
104 element.childs = include_replace(element.childs, sql_map);
105 }
106
107 element
108 }).collect()
109}
110
111fn handle_include_element(element: &Element, sql_map: &BTreeMap<String, Element>) -> Element {
113 let ref_id = element.attrs.get("refid")
114 .expect("[rbatis-codegen] <include> element must have attr <include refid=\"\">!");
115
116 let url = if ref_id.contains("://") {
117 Url::parse(ref_id).unwrap_or_else(|_| panic!(
118 "[rbatis-codegen] parse <include refid=\"{}\"> fail!", ref_id
119 ))
120 } else {
121 Url::parse(&format!("current://current?refid={}", ref_id)).unwrap_or_else(|_| panic!(
122 "[rbatis-codegen] parse <include refid=\"{}\"> fail!", ref_id
123 ))
124 };
125
126 match url.scheme() {
127 "file" => handle_file_include(&url, ref_id),
128 "current" => handle_current_include(&url, ref_id, sql_map),
129 _ => panic!("Unimplemented scheme <include refid=\"{}\">", ref_id),
130 }
131}
132
133fn handle_file_include(url: &Url, ref_id: &str) -> Element {
135 let mut manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
136 .expect("Failed to read CARGO_MANIFEST_DIR");
137 manifest_dir.push('/');
138
139 let path = url.host_str().unwrap_or_default().to_string() +
140 url.path().trim_end_matches(&['/', '\\'][..]);
141 let mut file_path = PathBuf::from(&path);
142
143 if file_path.is_relative() {
144 file_path = PathBuf::from(format!("{}{}", manifest_dir, path));
145 }
146
147 let ref_id = url.query_pairs()
148 .find(|(k, _)| k == "refid")
149 .map(|(_, v)| v.to_string())
150 .unwrap_or_else(|| {
151 panic!("No ref_id found in URL {}", ref_id);
152 });
153
154 let mut file = File::open(&file_path).unwrap_or_else(|_| panic!(
155 "[rbatis-codegen] can't find file='{}', url='{}'",
156 file_path.to_str().unwrap_or_default(),
157 url
158 ));
159
160 let mut html = String::new();
161 file.read_to_string(&mut html).expect("Failed to read file");
162
163 load_mapper_vec(&html).expect("Failed to parse HTML")
164 .into_iter()
165 .find(|e| e.tag == SQL_TAG && e.attrs.get("id") == Some(&ref_id))
166 .unwrap_or_else(|| panic!(
167 "No ref_id={} found in file={}",
168 ref_id,
169 file_path.to_str().unwrap_or_default()
170 ))
171}
172
173fn handle_current_include(url: &Url, ref_id: &str, sql_map: &BTreeMap<String, Element>) -> Element {
175 let ref_id = url.query_pairs()
176 .find(|(k, _)| k == "refid")
177 .map(|(_, v)| v.to_string())
178 .unwrap_or(ref_id.to_string());
179
180 sql_map.get(&ref_id).unwrap_or_else(|| panic!(
181 "[rbatis-codegen] cannot find element <include refid=\"{}\">!",
182 ref_id
183 )).clone()
184}
185
186fn parse_html_node(
188 elements: Vec<Element>,
189 ignore: &mut Vec<String>,
190 fn_name: &str,
191) -> proc_macro2::TokenStream {
192 let mut methods = quote!();
193 let fn_impl = parse_elements(&elements, &mut methods, ignore, fn_name);
194 quote! { #methods #fn_impl }
195}
196
197fn parse_elements(
199 elements: &[Element],
200 methods: &mut proc_macro2::TokenStream,
201 ignore: &mut Vec<String>,
202 fn_name: &str,
203) -> proc_macro2::TokenStream {
204 let mut body = quote! {};
205
206 for element in elements {
207 match element.tag.as_str() {
208 MAPPER_TAG => {
209 return parse_elements(&element.childs, methods, ignore, fn_name);
210 }
211 SQL_TAG | INCLUDE_TAG => {
212 let code = parse_elements(&element.childs, methods, ignore, fn_name);
213 body = quote! { #body #code };
214 }
215 CONTINUE_TAG => impl_continue(&mut body),
216 BREAK_TAG => impl_break(&mut body),
217 "" => handle_text_element(element, &mut body, ignore),
218 IF_TAG => handle_if_element(element, &mut body, methods, ignore, fn_name),
219 TRIM_TAG => handle_trim_element(element, &mut body, methods, ignore, fn_name),
220 BIND_TAG => handle_bind_element(element, &mut body, ignore),
221 WHERE_TAG => handle_where_element(element, &mut body, methods, ignore, fn_name),
222 CHOOSE_TAG => handle_choose_element(element, &mut body, methods, ignore, fn_name),
223 FOREACH_TAG => handle_foreach_element(element, &mut body, methods, ignore, fn_name),
224 SET_TAG => handle_set_element(element, &mut body, methods, ignore, fn_name),
225 SELECT_TAG | UPDATE_TAG | INSERT_TAG | DELETE_TAG => {
226 handle_crud_element(element, &mut body, methods, ignore, fn_name)
227 }
228 _ => {}
229 }
230 }
231
232 body
233}
234
235fn handle_text_element(
237 element: &Element,
238 body: &mut proc_macro2::TokenStream,
239 ignore: &mut Vec<String>,
240) {
241 let mut string_data = remove_extra(&element.data);
242 let convert_list = find_convert_string(&string_data);
243
244 let mut formats_value = quote! {};
245 let mut replace_num = 0;
246
247 for (k, v) in convert_list {
248 let method_impl = crate::codegen::func::impl_fn(
249 &body.to_string(),
250 "",
251 &format!("\"{}\"", k),
252 false,
253 ignore,
254 );
255
256 if v.starts_with('#') {
257 string_data = string_data.replacen(&v, "?", 1);
258 *body = quote! {
259 #body
260 args.push(rbs::value(#method_impl).unwrap_or_default());
261 };
262 } else {
263 string_data = string_data.replacen(&v, "{}", 1);
264 if !formats_value.to_string().trim().ends_with(',') {
265 formats_value = quote!(#formats_value,);
266 }
267 formats_value = quote!(#formats_value &#method_impl.string());
268 replace_num += 1;
269 }
270 }
271
272 if !string_data.is_empty() {
273 *body = if replace_num == 0 {
274 quote! { #body sql.push_str(#string_data); }
275 } else {
276 quote! { #body sql.push_str(&format!(#string_data #formats_value)); }
277 };
278 }
279}
280
281fn handle_if_element(
283 element: &Element,
284 body: &mut proc_macro2::TokenStream,
285 methods: &mut proc_macro2::TokenStream,
286 ignore: &mut Vec<String>,
287 fn_name: &str,
288) {
289 let test_value = element.attrs.get("test")
290 .unwrap_or_else(|| panic!("{} element must have test field!", element.tag));
291
292 let if_tag_body = if !element.childs.is_empty() {
293 parse_elements(&element.childs, methods, ignore, fn_name)
294 } else {
295 quote! {}
296 };
297
298 impl_condition(test_value, if_tag_body, body, methods, quote! {}, ignore);
299}
300
301fn handle_trim_element(
303 element: &Element,
304 body: &mut proc_macro2::TokenStream,
305 methods: &mut proc_macro2::TokenStream,
306 ignore: &mut Vec<String>,
307 fn_name: &str,
308) {
309 let empty = String::new();
310 let prefix = element.attrs.get("prefix").unwrap_or(&empty);
311 let suffix = element.attrs.get("suffix").unwrap_or(&empty);
312 let prefix_overrides = element.attrs.get("start")
313 .or_else(|| element.attrs.get("prefixOverrides"))
314 .unwrap_or(&empty);
315 let suffix_overrides = element.attrs.get("end")
316 .or_else(|| element.attrs.get("suffixOverrides"))
317 .unwrap_or(&empty);
318
319 impl_trim(
320 prefix,
321 suffix,
322 prefix_overrides,
323 suffix_overrides,
324 element,
325 body,
326 methods,
327 ignore,
328 fn_name,
329 );
330}
331
332fn handle_bind_element(
334 element: &Element,
335 body: &mut proc_macro2::TokenStream,
336 ignore: &mut Vec<String>,
337) {
338 let name = element.attrs.get("name")
339 .expect("<bind> must have name!");
340 let value = element.attrs.get("value")
341 .expect("<bind> element must have value!");
342
343 let method_impl = crate::codegen::func::impl_fn(
344 &body.to_string(),
345 "",
346 &format!("\"{}\"", value),
347 false,
348 ignore,
349 );
350
351 let lit_str = LitStr::new(name, Span::call_site());
352
353 *body = quote! {
354 #body
355 if arg[#lit_str] == rbs::Value::Null {
356 arg.insert(rbs::Value::String(#lit_str.to_string()), rbs::Value::Null);
357 }
358 arg[#lit_str] = rbs::value(#method_impl).unwrap_or_default();
359 };
360}
361
362fn handle_where_element(
364 element: &Element,
365 body: &mut proc_macro2::TokenStream,
366 methods: &mut proc_macro2::TokenStream,
367 ignore: &mut Vec<String>,
368 fn_name: &str,
369) {
370 impl_trim(
371 " where ",
372 " ",
373 " |and |or ",
374 " | and| or",
375 element,
376 body,
377 methods,
378 ignore,
379 fn_name,
380 );
381
382 *body = quote! {
383 #body
384 sql = sql.trim_end_matches(" where ").to_string();
385 };
386}
387
388fn handle_choose_element(
390 element: &Element,
391 body: &mut proc_macro2::TokenStream,
392 methods: &mut proc_macro2::TokenStream,
393 ignore: &mut Vec<String>,
394 fn_name: &str,
395) {
396 let mut inner_body = quote! {};
397
398 for child in &element.childs {
399 match child.tag.as_str() {
400 WHEN_TAG => {
401 let test_value = child.attrs.get("test")
402 .unwrap_or_else(|| panic!("{} element must have test field!", child.tag));
403
404 let if_tag_body = if !child.childs.is_empty() {
405 parse_elements(&child.childs, methods, ignore, fn_name)
406 } else {
407 quote! {}
408 };
409
410 impl_condition(
411 test_value,
412 if_tag_body,
413 &mut inner_body,
414 methods,
415 quote! { return sql; },
416 ignore,
417 );
418 }
419 OTHERWISE_TAG => {
420 let child_body = parse_elements(&child.childs, methods, ignore, fn_name);
421 impl_otherwise(child_body, &mut inner_body);
422 }
423 _ => panic!("choose node's children must be when or otherwise nodes!"),
424 }
425 }
426
427 let capacity = element.child_string_cup() + 1000;
428 *body = quote! {
429 #body
430 sql.push_str(&|| -> String {
431 let mut sql = String::with_capacity(#capacity);
432 #inner_body
433 return sql;
434 }());
435 };
436}
437
438fn handle_foreach_element(
440 element: &Element,
441 body: &mut proc_macro2::TokenStream,
442 methods: &mut proc_macro2::TokenStream,
443 ignore: &mut Vec<String>,
444 fn_name: &str,
445) {
446 let empty = String::new();
447 let def_item = "item".to_string();
448 let def_index = "index".to_string();
449
450 let collection = element.attrs.get("collection").unwrap_or(&empty);
451 let mut item = element.attrs.get("item").unwrap_or(&def_item);
452 let mut index = element.attrs.get("index").unwrap_or(&def_index);
453 let open = element.attrs.get("open").unwrap_or(&empty);
454 let close = element.attrs.get("close").unwrap_or(&empty);
455 let separator = element.attrs.get("separator").unwrap_or(&empty);
456
457 if item.is_empty() || item == "_" {
458 item = &def_item;
459 }
460 if index.is_empty() || index == "_" {
461 index = &def_index;
462 }
463
464 let mut ignores = ignore.clone();
465 ignores.push(index.to_string());
466 ignores.push(item.to_string());
467
468 let impl_body = parse_elements(&element.childs, methods, &mut ignores, fn_name);
469 let method_impl = crate::codegen::func::impl_fn(
470 &body.to_string(),
471 "",
472 &format!("\"{}\"", collection),
473 false,
474 ignore,
475 );
476
477 let open_impl = if !open.is_empty() {
478 quote! { sql.push_str(#open); }
479 } else {
480 quote! {}
481 };
482
483 let close_impl = if !close.is_empty() {
484 quote! { sql.push_str(#close); }
485 } else {
486 quote! {}
487 };
488
489 let item_ident = Ident::new(item, Span::call_site());
490 let index_ident = Ident::new(index, Span::call_site());
491
492 let (split_code, split_code_trim) = if !separator.is_empty() {
493 (
494 quote! { sql.push_str(#separator); },
495 quote! { sql = sql.trim_end_matches(#separator).to_string(); }
496 )
497 } else {
498 (quote! {}, quote! {})
499 };
500
501 *body = quote! {
502 #body
503 #open_impl
504 for (ref #index_ident, #item_ident) in #method_impl {
505 #impl_body
506 #split_code
507 }
508 #split_code_trim
509 #close_impl
510 };
511}
512
513fn handle_set_element(
515 element: &Element,
516 body: &mut proc_macro2::TokenStream,
517 methods: &mut proc_macro2::TokenStream,
518 ignore: &mut Vec<String>,
519 fn_name: &str,
520) {
521 if let Some(collection) = element.attrs.get("collection") {
522 let skip_null = element.attrs.get("skip_null");
523 let skips = element.attrs.get("skips").unwrap_or(&"id".to_string()).to_string();
524 let elements = make_sets(collection, skip_null, &skips);
525 let code = parse_elements(&elements, methods, ignore, fn_name);
526 *body = quote! { #body #code };
527 } else {
528 impl_trim(
529 " set ", " ", " |,", " |,", element, body, methods, ignore, fn_name,
530 );
531 }
532}
533
534fn handle_crud_element(
536 element: &Element,
537 body: &mut proc_macro2::TokenStream,
538 methods: &mut proc_macro2::TokenStream,
539 ignore: &mut Vec<String>,
540 fn_name: &str,
541) {
542 let method_name = Ident::new(fn_name, Span::call_site());
543 let child_body = parse_elements(&element.childs, methods, ignore, fn_name);
544 let capacity = element.child_string_cup() + 1000;
545 let push_count = child_body.to_string().matches("args.push").count();
546
547 let function = quote! {
548 pub fn #method_name(mut arg: rbs::Value, _tag: char) -> (String, Vec<rbs::Value>) {
549 use rbatis_codegen::ops::*;
550 let mut sql = String::with_capacity(#capacity);
551 let mut args = Vec::with_capacity(#push_count);
552 #child_body
553 (sql, args)
554 }
555 };
556
557 *body = quote! { #body #function };
558}
559
560fn make_sets(collection: &str, skip_null: Option<&String>, skips: &str) -> Vec<Element> {
562 let is_skip_null = skip_null.map_or(true, |v| v != "false");
563 let skip_strs: Vec<&str> = skips.split(',').collect();
564
565 let skip_elements = skip_strs.iter().map(|x| Element {
566 tag: IF_TAG.to_string(),
567 data: String::new(),
568 attrs: {
569 let mut attr = HashMap::new();
570 attr.insert("test".to_string(), format!("k == '{}'", x));
571 attr
572 },
573 childs: vec![Element {
574 tag: CONTINUE_TAG.to_string(),
575 data: String::new(),
576 attrs: HashMap::new(),
577 childs: vec![],
578 }],
579 }).collect::<Vec<_>>();
580
581 let mut for_each_body = skip_elements;
582
583 if is_skip_null {
584 for_each_body.push(Element {
585 tag: IF_TAG.to_string(),
586 data: String::new(),
587 attrs: {
588 let mut attr = HashMap::new();
589 attr.insert("test".to_string(), "v == null".to_string());
590 attr
591 },
592 childs: vec![Element {
593 tag: CONTINUE_TAG.to_string(),
594 data: String::new(),
595 attrs: HashMap::new(),
596 childs: vec![],
597 }],
598 });
599 }
600
601 for_each_body.push(Element {
602 tag: "".to_string(),
603 data: "${k}=#{v},".to_string(),
604 attrs: HashMap::new(),
605 childs: vec![],
606 });
607
608 vec![Element {
609 tag: TRIM_TAG.to_string(),
610 data: String::new(),
611 attrs: {
612 let mut attr = HashMap::new();
613 attr.insert("prefix".to_string(), " set ".to_string());
614 attr.insert("suffix".to_string(), " ".to_string());
615 attr.insert("start".to_string(), " ".to_string());
616 attr.insert("end".to_string(), " ".to_string());
617 attr
618 },
619 childs: vec![Element {
620 tag: TRIM_TAG.to_string(),
621 data: String::new(),
622 attrs: {
623 let mut attr = HashMap::new();
624 attr.insert("prefix".to_string(), "".to_string());
625 attr.insert("suffix".to_string(), "".to_string());
626 attr.insert("start".to_string(), ",".to_string());
627 attr.insert("end".to_string(), ",".to_string());
628 attr
629 },
630 childs: vec![Element {
631 tag: FOREACH_TAG.to_string(),
632 data: String::new(),
633 attrs: {
634 let mut attr = HashMap::new();
635 attr.insert("collection".to_string(), collection.to_string());
636 attr.insert("index".to_string(), "k".to_string());
637 attr.insert("item".to_string(), "v".to_string());
638 attr
639 },
640 childs: for_each_body,
641 }],
642 }],
643 }]
644}
645
646fn remove_extra(text: &str) -> String {
648 let text = text.trim().replace("\\r", "");
649 let lines: Vec<&str> = text.split('\n').collect();
650
651 let mut data = String::with_capacity(text.len());
652 for (i, line) in lines.iter().enumerate() {
653 let mut line = line.trim();
654 line = line.trim_start_matches('`').trim_end_matches('`');
655 data.push_str(line);
656 if i + 1 < lines.len() {
657 data.push('\n');
658 }
659 }
660
661 data.trim_matches('`').replace("``", "")
662}
663
664fn impl_continue(body: &mut proc_macro2::TokenStream) {
666 *body = quote! { #body continue; };
667}
668
669fn impl_break(body: &mut proc_macro2::TokenStream) {
671 *body = quote! { #body break; };
672}
673
674fn impl_condition(
676 test_value: &str,
677 condition_body: proc_macro2::TokenStream,
678 body: &mut proc_macro2::TokenStream,
679 _methods: &mut proc_macro2::TokenStream,
680 appends: proc_macro2::TokenStream,
681 ignore: &mut Vec<String>,
682) {
683 let method_impl = crate::codegen::func::impl_fn(
684 &body.to_string(),
685 "",
686 &format!("\"{}\"", test_value),
687 false,
688 ignore,
689 );
690
691 *body = quote! {
692 #body
693 if #method_impl.to_owned().into() {
694 #condition_body
695 #appends
696 }
697 };
698}
699
700fn impl_otherwise(
702 child_body: proc_macro2::TokenStream,
703 body: &mut proc_macro2::TokenStream,
704) {
705 *body = quote! { #body #child_body };
706}
707
708fn impl_trim(
710 prefix: &str,
711 suffix: &str,
712 start: &str,
713 end: &str,
714 element: &Element,
715 body: &mut proc_macro2::TokenStream,
716 methods: &mut proc_macro2::TokenStream,
717 ignore: &mut Vec<String>,
718 fn_name: &str,
719) {
720 let trim_body = parse_elements(&element.childs, methods, ignore, fn_name);
721 let prefixes: Vec<&str> = start.split('|').collect();
722 let suffixes: Vec<&str> = end.split('|').collect();
723 let has_trim = !prefixes.is_empty() && !suffixes.is_empty();
724 let capacity = element.child_string_cup();
725
726 let mut trims = quote! {
727 let mut sql = String::with_capacity(#capacity);
728 #trim_body
729 sql = sql
730 };
731
732 for prefix in prefixes {
733 trims = quote! { #trims .trim_start_matches(#prefix) };
734 }
735
736 for suffix in suffixes {
737 trims = quote! { #trims .trim_end_matches(#suffix) };
738 }
739
740 if !prefix.is_empty() {
741 *body = quote! { #body sql.push_str(#prefix); };
742 }
743
744 if has_trim {
745 *body = quote! { #body sql.push_str(&{#trims.to_string(); sql}); };
746 }
747
748 if !suffix.is_empty() {
749 *body = quote! { #body sql.push_str(#suffix); };
750 }
751}
752
753pub fn impl_fn_html(m: &ItemFn, args: &ParseArgs) -> TokenStream {
755 let fn_name = m.sig.ident.to_string();
756
757 if args.sqls.is_empty() {
758 panic!(
759 "[rbatis-codegen] #[html_sql()] must have html_data, for example: {}",
760 stringify!(#[html_sql(r#"<select id="select_by_condition">`select * from biz_activity</select>"#)])
761 );
762 }
763
764 let html_data = args.sqls[0].to_token_stream().to_string();
765 parse_html(&html_data, &fn_name, &mut vec![]).into()
766}