py_sql/
py_sql.rs

1use dashmap::DashMap;
2use std::collections::HashMap;
3
4use rexpr::runtime::RExprRuntime;
5
6use serde_json::Value;
7
8use crate::ast::RbatisAST;
9use crate::error::Error;
10use crate::node::bind_node::BindNode;
11use crate::node::choose_node::ChooseNode;
12use crate::node::foreach_node::ForEachNode;
13use crate::node::if_node::IfNode;
14use crate::node::node::do_child_nodes;
15use crate::node::node_type::NodeType;
16use crate::node::otherwise_node::OtherwiseNode;
17use crate::node::print_node::PrintNode;
18use crate::node::proxy_node::NodeFactory;
19use crate::node::set_node::SetNode;
20use crate::node::string_node::StringNode;
21use crate::node::trim_node::TrimNode;
22use crate::node::when_node::WhenNode;
23use crate::node::where_node::WhereNode;
24
25/// Py lang,make sure Send+Sync
26#[derive(Debug)]
27pub struct PyRuntime {
28    pub cache: DashMap<String, Vec<NodeType>>,
29    pub generate: Vec<Box<dyn NodeFactory>>,
30}
31
32impl PyRuntime {
33    pub fn new(generate: Vec<Box<dyn NodeFactory>>) -> Self {
34        Self {
35            cache: Default::default(),
36            generate: generate,
37        }
38    }
39    ///eval with cache
40    pub fn eval(
41        &self,
42        driver_type: &dyn crate::StringConvert,
43        py_sql: &str,
44        env: &mut Value,
45        engine: &RExprRuntime,
46    ) -> Result<(String, Vec<serde_json::Value>), Error> {
47        if !env.is_object() {
48            return Result::Err(Error::from(
49                "[rbatis] py_sql Requires that the parameter be an json object!",
50            ));
51        }
52        let mut sql = String::new();
53        let mut arg_array = vec![];
54        let cache_value = self.cache.get(py_sql);
55        match cache_value {
56            Some(v) => {
57                do_child_nodes(driver_type, &v, env, engine, &mut arg_array, &mut sql)?;
58            }
59            _ => {
60                let nodes = Self::parse(engine, py_sql, &self.generate)?;
61                do_child_nodes(driver_type, &nodes, env, engine, &mut arg_array, &mut sql)?;
62                self.cache.insert(py_sql.to_string(), nodes);
63            }
64        }
65        sql = sql.trim().to_string();
66        return Ok((sql, arg_array));
67    }
68
69    pub fn add_gen<T>(&mut self, arg: T)
70    where
71        T: NodeFactory + 'static,
72    {
73        self.generate.push(Box::new(arg));
74    }
75
76    /// parser py string data
77    pub fn parse(
78        runtime: &RExprRuntime,
79        arg: &str,
80        generates: &Vec<Box<dyn NodeFactory>>,
81    ) -> Result<Vec<NodeType>, crate::error::Error> {
82        let line_space_map = PyRuntime::create_line_space_map(arg);
83        let mut main_node = vec![];
84        let ls = arg.lines();
85        let mut space = -1;
86        let mut line = -1;
87        let mut skip = -1;
88        for x in ls {
89            line += 1;
90            if x.is_empty() || (skip != -1 && line <= skip) {
91                continue;
92            }
93            let count_index = *line_space_map.get(&line).unwrap();
94            if space == -1 {
95                space = count_index;
96            }
97            let (child_str, do_skip) =
98                PyRuntime::find_child_str(line, count_index, arg, &line_space_map);
99            if do_skip != -1 && do_skip >= skip {
100                skip = do_skip;
101            }
102            let parserd;
103            if !child_str.is_empty() {
104                parserd = PyRuntime::parse(runtime, child_str.as_str(), generates)?;
105            } else {
106                parserd = vec![];
107            }
108            PyRuntime::parse_node(
109                runtime,
110                generates,
111                &mut main_node,
112                x,
113                *line_space_map.get(&line).unwrap() as usize,
114                parserd,
115            )?;
116        }
117        return Ok(main_node);
118    }
119
120    fn parse_trim_node(
121        runtime: &RExprRuntime,
122        factorys: &Vec<Box<dyn NodeFactory>>,
123        trim_express: &str,
124        source_str: &str,
125        childs: Vec<NodeType>,
126    ) -> Result<NodeType, crate::error::Error> {
127        if trim_express.starts_with(IfNode::name()) {
128            return Ok(NodeType::NIf(IfNode::from(runtime, trim_express, childs)?));
129        } else if trim_express.starts_with(ForEachNode::name()) {
130            return Ok(NodeType::NForEach(ForEachNode::from(
131                runtime,
132                source_str,
133                &trim_express,
134                childs,
135            )?));
136        } else if trim_express.starts_with(TrimNode::name()) {
137            return Ok(NodeType::NTrim(TrimNode::from(
138                source_str,
139                trim_express,
140                childs,
141            )?));
142        } else if trim_express.starts_with(ChooseNode::name()) {
143            return Ok(NodeType::NChoose(ChooseNode::from(
144                source_str,
145                trim_express,
146                childs,
147            )?));
148        } else if trim_express.starts_with(OtherwiseNode::def_name())
149            || trim_express.starts_with(OtherwiseNode::name())
150        {
151            return Ok(NodeType::NOtherwise(OtherwiseNode::from(
152                source_str,
153                trim_express,
154                childs,
155            )?));
156        } else if trim_express.starts_with(WhenNode::name()) {
157            return Ok(NodeType::NWhen(WhenNode::from(
158                runtime,
159                source_str,
160                trim_express,
161                childs,
162            )?));
163        } else if trim_express.starts_with(BindNode::def_name())
164            || trim_express.starts_with(BindNode::name())
165        {
166            return Ok(NodeType::NBind(BindNode::from(
167                runtime,
168                source_str,
169                trim_express,
170                childs,
171            )?));
172        } else if trim_express.starts_with(SetNode::name()) {
173            return Ok(NodeType::NSet(SetNode::from(
174                source_str,
175                trim_express,
176                childs,
177            )?));
178        } else if trim_express.starts_with(WhereNode::name()) {
179            return Ok(NodeType::NWhere(WhereNode::from(
180                source_str,
181                trim_express,
182                childs,
183            )?));
184        } else if trim_express.starts_with(PrintNode::name()) {
185            return Ok(NodeType::NPrint(PrintNode::from(
186                runtime,
187                source_str,
188                trim_express,
189                childs,
190            )?));
191        } else {
192            for f in factorys {
193                let gen = f.try_new(trim_express, childs.clone())?;
194                if gen.is_some() {
195                    return Ok(NodeType::NCustom(gen.unwrap()));
196                }
197            }
198            // unkonw tag
199            return Err(crate::error::Error::from(
200                "[rbatis] unknow tag: ".to_string() + source_str,
201            ));
202        }
203    }
204
205    fn parse_node(
206        runtime: &RExprRuntime,
207        generates: &Vec<Box<dyn NodeFactory>>,
208        main_node: &mut Vec<NodeType>,
209        x: &str,
210        space: usize,
211        mut childs: Vec<NodeType>,
212    ) -> Result<(), crate::error::Error> {
213        let mut trim_x = x.trim();
214        if trim_x.starts_with("//") {
215            return Ok(());
216        }
217        if trim_x.ends_with(":") {
218            trim_x = trim_x[0..trim_x.len() - 1].trim();
219            if trim_x.contains(": ") {
220                let vecs: Vec<&str> = trim_x.split(": ").collect();
221                if vecs.len() > 1 {
222                    let len = vecs.len();
223                    for index in 0..len {
224                        let index = len - 1 - index;
225                        let item = vecs[index];
226                        childs = vec![Self::parse_trim_node(runtime, generates, item, x, childs)?];
227                        if index == 0 {
228                            for x in &childs {
229                                main_node.push(x.clone());
230                            }
231                            return Ok(());
232                        }
233                    }
234                }
235            }
236            let node = Self::parse_trim_node(runtime, generates, trim_x, x, childs)?;
237            main_node.push(node);
238            return Ok(());
239        } else {
240            //string,replace space to only one
241            let mut data = x.to_owned();
242            if space <= 1 {
243                data = x.to_string();
244            } else {
245                data = x[(space - 1)..].to_string();
246            }
247            main_node.push(NodeType::NString(StringNode::new(runtime, &data)?));
248            for x in childs {
249                main_node.push(x);
250            }
251            return Ok(());
252        }
253    }
254
255    fn count_space(arg: &str) -> i32 {
256        let cs = arg.chars();
257        let mut index = 0;
258        for x in cs {
259            match x {
260                ' ' => {
261                    index += 1;
262                }
263                _ => {
264                    break;
265                }
266            }
267        }
268        return index;
269    }
270
271    ///find_child_str
272    fn find_child_str(
273        line_index: i32,
274        space_index: i32,
275        arg: &str,
276        m: &HashMap<i32, i32>,
277    ) -> (String, i32) {
278        let mut result = String::new();
279        let mut skip_line = -1;
280        let mut line = -1;
281        let lines = arg.lines();
282        for x in lines {
283            line += 1;
284            if line > line_index {
285                let cached_space = *m.get(&line).unwrap();
286                if cached_space > space_index {
287                    result = result + x + "\n";
288                    skip_line = line;
289                } else {
290                    break;
291                }
292            }
293        }
294        return (result, skip_line);
295    }
296
297    ///Map<line,space>
298    fn create_line_space_map(arg: &str) -> HashMap<i32, i32> {
299        let mut m = HashMap::new();
300        let lines = arg.lines();
301        let mut line = -1;
302        for x in lines {
303            line += 1;
304            let space = PyRuntime::count_space(x);
305            //dothing
306            m.insert(line, space);
307        }
308        return m;
309    }
310}