Skip to main content

makepad_gen_plugin/visitor/fn/
replacer.rs

1use gen_analyzer::Binds;
2use gen_plugin::MacroContext;
3use gen_utils::{
4    common::string::FixedString,
5    error::{CompilerError, Error},
6};
7use quote::ToTokens;
8use ra_ap_syntax::{
9    ast::{self, HasArgList, MethodCallExpr},
10    AstNode, Edition, SourceFile, TextRange,
11};
12use std::collections::HashMap;
13use syn::{parse_str, ImplItemFn};
14
15use crate::compiler::{Context, WidgetPoll};
16
17/// 访问双向绑定访问器结构体
18#[allow(unused)]
19#[derive(Debug)]
20struct BindingReplacer {
21    replacements: HashMap<TextRange, String>,
22    fields: Vec<String>,
23}
24
25impl BindingReplacer {
26    fn new(fields: Vec<String>) -> Self {
27        Self {
28            replacements: HashMap::new(),
29            fields,
30        }
31    }
32
33    fn add_replacement(&mut self, range: TextRange, new_text: String) {
34        self.replacements.insert(range, new_text);
35    }
36
37    fn apply_replacements(&self, input: &str) -> String {
38        let mut result = input.to_string();
39        let mut offset = 0_i32;
40        // 按照范围排序,确保替换的正确性
41        let mut ranges: Vec<_> = self.replacements.iter().collect();
42        ranges.sort_by_key(|(range, _)| range.start());
43
44        for (range, new_text) in ranges {
45            let range_start: u32 = range.start().into();
46            let range_end: u32 = range.end().into();
47            let start = (range_start as i32 + offset) as usize;
48            let end = (range_end as i32 + offset) as usize;
49            result.replace_range(start..end, new_text);
50
51            offset += new_text.len() as i32 - (range_end - range_start) as i32;
52        }
53
54        result
55    }
56}
57
58/// ## 访问fuction并进行替换
59/// 以下内容需要进行处理:
60/// 1. c_ref!宏 (转为self.#widget(id!(#id)))
61/// 2. active!宏 (转为self.active_event(cx, |cx, uid, path| {cx.widget_action(uid, path, #param);}))
62/// 3. get_和set_方法 (转为self.#field_name()和self.#field_name(#param))
63/// 4. signal_fns中的方法 (在参数列表最后添加cx)
64/// 5. 当方法中含有set_方法时, 最终需要增加一行重新绘制的代码 (self.redraw(cx);) 来触发重绘
65///
66/// is_special: 标记当前方法是否是特殊的访问器,例如生命周期就无需进行redraw
67pub fn visit_fns(
68    input: &mut ImplItemFn,
69    fields: &Vec<String>,
70    computeds: &Vec<String>,
71    widgets: &WidgetPoll,
72    prop_binds: Option<&Binds>,
73    signal_fns: &Vec<String>,
74    ctx: &Context,
75    is_special: bool,
76) -> Result<(), Error> {
77    let processor = ctx.dyn_processor.as_ref();
78    let router = ctx.router.as_ref();
79    let input_str = input.to_token_stream().to_string();
80    let source_file = SourceFile::parse(&input_str, Edition::Edition2021);
81    let syntax = source_file.tree();
82    // 记录需要检查并调用get|set的组件,当使用者调用c_ref!时需要将组件id记录到这里,然后在get|set访问时进行替换
83    let mut addition_widgets = HashMap::new();
84    // 记录是否需要增加重绘
85    let mut redraw = false;
86    // 创建替换器
87    let mut replacer = BindingReplacer::new(fields.clone());
88    // 记录方法中访问到的路由组件,路由组件需要在其nav_to和nav_back方法中添加cx参数
89    let mut router_widget = None;
90    // [visit_two_way_binding] -------------------------------------------------------------------------------
91    // 遍历语法树
92    for node in syntax.syntax().descendants() {
93        // [c_ref!, active!] ---------------------------------------------------------------------------------------------------
94        // c_ref宏一定是let语句中的MacroCall
95        if let Some(let_stmt) = ast::LetStmt::cast(node.clone()) {
96            for node in let_stmt.syntax().descendants() {
97                if let Some(macro_call) = ast::MacroCall::cast(node) {
98                    if let Some(path) = macro_call.path() {
99                        let ident = path.syntax().text().to_string();
100                        if ident == "c_ref" {
101                            // [replace c_ref!() => self.#widget(id!(#id))] ------------------------------------------------------------
102                            if let Some(tt) = macro_call.token_tree() {
103                                // remove `{}` or `()`
104                                let id = inner_tt(tt);
105                                // 记录id
106                                let let_ident = let_stmt
107                                    .pat()
108                                    .and_then(|pat| {
109                                        ast::IdentPat::cast(pat.syntax().clone()).map(|ident_pat| {
110                                            ident_pat
111                                                .syntax()
112                                                .last_token()
113                                                .unwrap()
114                                                .text()
115                                                .to_string()
116                                        })
117                                    })
118                                    .ok_or(CompilerError::runtime(
119                                        "Makepad Compiler - Script",
120                                        "c_ref! macro should has let statement",
121                                    ))?;
122                                // 这里需要将id记录到addition_widgets中
123                                addition_widgets.insert(let_ident.to_string(), id.to_string());
124
125                                let widget = widgets.get(&id).map_or_else(
126                                    || {
127                                        Err(Error::from(CompilerError::runtime(
128                                            "Makepad Compiler - Script",
129                                            "can not find id in template, please check!",
130                                        )))
131                                    },
132                                    |widget| Ok(widget.snake_name()),
133                                )?;
134
135                                let new_expr = format!("self.{}(id!({}))", &widget, id);
136                                let full_range = macro_call.syntax().text_range();
137                                replacer.add_replacement(full_range, new_expr);
138
139                                // 尝试获取路由组件
140                                if let Some(router) = router {
141                                    // 这里需要根据组件名字的缩写来判断是否是路由组件
142                                    if widget == router.name.camel_to_snake() {
143                                        router_widget.replace(UsedRouter {
144                                            id: id.to_string(),
145                                            name: widget.to_string(),
146                                            ident: let_ident.to_string(),
147                                        });
148                                    }
149                                }
150                            } else {
151                                return Err(CompilerError::runtime(
152                                    "Makepad Compiler - Script",
153                                    "c_ref! macro should has widget id as token",
154                                )
155                                .into());
156                            }
157                        }
158                    }
159                }
160            }
161        }
162
163        if let Some(macro_call) = ast::MacroCall::cast(node.clone()) {
164            if let Some(path) = macro_call.path() {
165                let ident = path.syntax().text().to_string();
166                if ident == "active" {
167                    // [replace active!() => self.active_event(cx, |cx, uid, path| {cx.widget_action(uid, path, #param);})] -------
168                    if let Some(tt) = macro_call.token_tree() {
169                        let param = inner_tt(tt);
170                        let new_expr = format!(
171                            "self.active_event(cx, |cx, uid, path| {{cx.widget_action(uid, path, {});}})",
172                            param
173                        );
174                        let full_range = macro_call.syntax().text_range();
175                        replacer.add_replacement(full_range, new_expr);
176                    } else {
177                        return Err(CompilerError::runtime(
178                            "Makepad Compiler - Script",
179                            "active! macro should has param as token",
180                        )
181                        .into());
182                    }
183                } else if ident == "nav_to" {
184                    if let Some(tt) = macro_call.token_tree() {
185                        let tt = inner_tt(tt);
186                        if !tt.is_empty() {
187                            // add cx, self.widget_uid(), &mut Scope::empty() as param
188                            let new_expr = format!(
189                                "nav_to!({}, cx, self.widget_uid(), &mut Scope::empty());",
190                                tt
191                            );
192                            let full_range = macro_call.syntax().text_range();
193                            replacer.add_replacement(full_range, new_expr);
194                        } else {
195                            return Err(CompilerError::runtime(
196                                "Makepad Compiler - Script",
197                                "nav_to! macro should has param, param is the id of the page you registered in router toml",
198                            )
199                            .into());
200                        }
201                    }
202                } else if ident == "nav_back" {
203                    if let Some(tt) = macro_call.token_tree() {
204                        let tt = inner_tt(tt);
205                        // nav_back should have no tt, so tt should be empty
206                        if tt.is_empty() {
207                            // add cx, self.widget_uid(), &mut Scope::empty() as param
208                            let new_expr =
209                                format!("nav_back!(cx, self.widget_uid(), &mut Scope::empty());");
210                            let full_range = macro_call.syntax().text_range();
211                            replacer.add_replacement(full_range, new_expr);
212                        } else {
213                            return Err(CompilerError::runtime(
214                                "Makepad Compiler - Script",
215                                "nav_back! macro should has no param",
216                            )
217                            .into());
218                        }
219                    }
220                } else {
221                    if let Some(processor) = processor {
222                        let tokens = if let Some(tt) = macro_call.token_tree() {
223                            inner_tt(tt)
224                        } else {
225                            String::new()
226                        };
227                        let mut mac_context = MacroContext { ident, tokens };
228
229                        let is_replace = unsafe {
230                            processor.process_macro(&mut mac_context).map_err(|e| {
231                                CompilerError::runtime("Makepad Compiler - Script", &e.to_string())
232                            })?
233                        };
234
235                        if is_replace {
236                            let new_expr =
237                                format!("{}!({})", mac_context.ident, mac_context.tokens);
238                            let full_range = macro_call.syntax().text_range();
239                            replacer.add_replacement(full_range, new_expr);
240                        }
241                    }
242                }
243            }
244        }
245
246        // get and set method call
247        if let Some(method_call) = ast::MethodCallExpr::cast(node.clone()) {
248            if let Some(receiver) = method_call.receiver() {
249                let receiver_text = receiver.syntax().text().to_string();
250
251                let from_widget = method_call
252                    .syntax()
253                    .first_child()
254                    .and_then(|first| addition_widgets.get_key_value(&first.text().to_string()));
255
256                // 检查是否是目标属性访问
257                if receiver_text == "self" || from_widget.is_some() {
258                    // dbg!(method_call.syntax().text());
259                    if let Some(name_ref) = method_call.name_ref() {
260                        let method_name = name_ref.syntax().text().to_string();
261                        if method_name.starts_with("get_") || method_name.starts_with("set_") {
262                            let field_name = method_name
263                                .trim_start_matches("get_")
264                                .trim_start_matches("set_")
265                                .to_string();
266                            // dbg!(&fields, &field_name);
267                            let is_computed = computeds.contains(&field_name);
268                            // 检查字段是否在目标列表中
269                            if fields.contains(&field_name) || from_widget.is_some() || is_computed
270                            {
271                                let prefix = if let Some((w, _)) = from_widget {
272                                    w.to_string()
273                                } else {
274                                    "self".to_string()
275                                };
276
277                                let is_setter = method_name.starts_with("set_");
278                                // 获取完整的方法调用范围
279                                let full_range = method_call.syntax().text_range();
280
281                                // 构建新的调用表达式
282                                let new_expr = if is_setter {
283                                    let mut redraw_cref = None;
284                                    // computed不需要重绘
285                                    // if !is_computed {
286                                    //     redraw = true;
287                                    // }
288                                    redraw = true && !is_special;
289
290                                    let mut new_call = String::new();
291                                    // 如果from_widget则需要反向绑定到父组件中完成双向绑定
292                                    if let Some((widget_ident, widget_id)) = from_widget {
293                                        // 获取function中的参数
294                                        let param = method_call.arg_list().map_or_else(
295                                            || Err(Error::from("set prop need a param!")),
296                                            |arg_list| Ok(arg_list.syntax().text().to_string()),
297                                        )?;
298                                        // 通过field_name获取父组件中绑定的字段名
299                                        // 没有找到的话可能是因为并没有采取双向绑定的方式,而是c_ref的直接内部访问,这里就不需要处理
300                                        if let Some(prop_binds) = prop_binds {
301                                            let _ = prop_binds
302                                                .iter()
303                                                .find(|(_, v)| {
304                                                    v.iter().any(|widget| {
305                                                        &widget.id == widget_id
306                                                            && widget.prop.as_str() == field_name
307                                                    })
308                                                })
309                                                .map(|(bind_field, _)| {
310                                                    new_call.push_str(
311                                                        format!(
312                                                            "self.{} = {}.clone();",
313                                                            bind_field,
314                                                            remove_holder(&param)
315                                                        )
316                                                        .as_str(),
317                                                    );
318                                                });
319                                        }
320                                        // c_ref!调用后使用了set_方法则需要对这个组件进行重绘
321                                        redraw_cref.replace(widget_ident);
322                                    }
323
324                                    // 对于setter,需要添加cx参数
325                                    new_call.push_str(&format!("{}.", prefix));
326                                    new_call.push_str(&method_name);
327
328                                    // 检查是否已经有cx参数
329                                    if let Some(arg_list) = method_call.arg_list() {
330                                        let args = arg_list.syntax().text().to_string();
331                                        if !args.contains("cx") {
332                                            // 在参数列表开始位置插入cx
333                                            let mut args = args.to_string();
334                                            if args == "()" {
335                                                args = "(cx)".to_string();
336                                            } else {
337                                                args.insert_str(1, "cx, ");
338                                            }
339                                            new_call.push_str(&args);
340                                        } else {
341                                            new_call.push_str(&args);
342                                        }
343                                    }
344
345                                    if let Some(widget) = redraw_cref {
346                                        new_call.push_str(&format!("; {}.redraw(cx);", widget));
347                                    }
348
349                                    new_call
350                                } else {
351                                    // 对于getter,直接替换接收者
352                                    format!("{}.{}()", prefix, method_name)
353                                };
354
355                                replacer.add_replacement(full_range, new_expr);
356                            }
357                        } else {
358                            // 检查是否在signal_fns中
359                            if signal_fns.contains(&method_name) {
360                                // 这里只需要为方法调用的参数中最后一个参数添加cx即可
361                                if let Some(arg_list) = method_call.arg_list() {
362                                    let args = arg_list.syntax().text().to_string();
363                                    if !args.contains("cx") {
364                                        // 在参数列表最后添加cx
365                                        let mut args = args.to_string();
366                                        if args == "()" {
367                                            args = "(cx)".to_string();
368                                        } else {
369                                            args.insert(args.len() - 1, ',');
370                                            args.push_str("cx");
371                                        }
372                                        let full_range = method_call.syntax().text_range();
373                                        let new_expr =
374                                            format!("{}.{}{}", receiver_text, method_name, args);
375                                        replacer.add_replacement(full_range, new_expr);
376                                    }
377                                }
378                            } else {
379                                if let Some(router_widget) = router_widget.as_ref() {
380                                    if router_widget.ident == receiver_text {
381                                        // 这里说明当前方法调用了路由组件的nav_to或nav_back方法,需要添加cx参数
382                                        match RouterCalled::from(method_name.as_str()) {
383                                            RouterCalled::NavTo => {
384                                                let args =
385                                                    handle_router_args(&method_call, 1, "nav_to")?;
386                                                let full_range = method_call.syntax().text_range();
387                                                let new_expr = format!(
388                                                    "{}.{}({});",
389                                                    receiver_text,
390                                                    method_name,
391                                                    args.join(",")
392                                                );
393                                                replacer.add_replacement(full_range, new_expr);
394                                            }
395                                            RouterCalled::NavBack => {
396                                                let args = handle_router_args(
397                                                    &method_call,
398                                                    0,
399                                                    "nav_back",
400                                                )?;
401                                                let full_range = method_call.syntax().text_range();
402                                                let new_expr = format!(
403                                                    "{}.{}({});",
404                                                    receiver_text,
405                                                    method_name,
406                                                    args.join(",")
407                                                );
408                                                replacer.add_replacement(full_range, new_expr);
409                                            }
410                                            RouterCalled::Unknown => {}
411                                        }
412
413                                        // let new_expr =format!("{}.{}")
414                                    }
415                                }
416                            }
417                        }
418                    }
419                }
420            }
421        }
422    }
423
424    // 应用所有替换
425    let modified_code = replacer.apply_replacements(&input_str);
426
427    // 解析回ImplItemFn
428    match parse_str::<ImplItemFn>(&modified_code) {
429        Ok(mut new_fn) => {
430            // 如果有需要重绘的情况,需要在最后添加self.redraw(cx);
431            if redraw {
432                new_fn
433                    .block
434                    .stmts
435                    .push(parse_str("self.redraw(cx);").unwrap());
436            }
437
438            *input = new_fn;
439            Ok(())
440        }
441        Err(e) => Err(Error::from(format!("Failed to parse modified code: {}", e))),
442    }
443    // [visit c_ref!]
444}
445
446fn inner_tt(tt: ast::TokenTree) -> String {
447    let param = tt.syntax().text().to_string();
448    remove_holder(&param).to_string()
449}
450
451/// 去除花括号和括号, 只去除一层
452fn remove_holder(input: &str) -> &str {
453    if (input.starts_with('(') && input.ends_with(')'))
454        || (input.starts_with('{') && input.ends_with('}'))
455        || (input.starts_with('[') && input.ends_with(']'))
456    {
457        &input[1..input.len() - 1]
458    } else {
459        input
460    }
461}
462
463fn handle_router_args(
464    method_call: &MethodCallExpr,
465    arg_num: usize,
466    method_name: &str,
467) -> Result<Vec<String>, Error> {
468    method_call.arg_list().map_or_else(
469        || {
470            Err(Error::from(format!(
471                "can not find args in {}()",
472                method_name
473            )))
474        },
475        |arg_list| {
476            // 这里还需要继续判断,args的数量
477            if arg_list.args().count() == arg_num {
478                // 在参数列表末尾添加cx
479                let mut args = arg_list.args().fold(Vec::new(), |mut acc, arg| {
480                    acc.push(arg.syntax().text().to_string());
481                    acc
482                });
483                args.push("cx".to_string());
484                Ok(args)
485            } else {
486                Err(Error::from(format!(
487                    "{}(), should has only {} arg",
488                    method_name, arg_num
489                )))
490            }
491        },
492    )
493}
494
495/// 用于存储路由被调用的结构体
496/// 用于在访问中替换nav_to和nav_back方法
497#[allow(unused)]
498#[derive(Debug, Clone)]
499struct UsedRouter {
500    pub id: String,
501    pub name: String,
502    pub ident: String,
503}
504
505enum RouterCalled {
506    NavTo,
507    NavBack,
508    Unknown,
509}
510
511impl From<&str> for RouterCalled {
512    fn from(value: &str) -> Self {
513        match value {
514            "nav_to" => RouterCalled::NavTo,
515            "nav_back" => RouterCalled::NavBack,
516            _ => RouterCalled::Unknown,
517        }
518    }
519}