Skip to main content

code_baseline/rules/ast/
mod.rs

1pub mod max_component_size;
2pub mod no_cascading_set_state;
3pub mod no_click_handler;
4pub mod no_derived_state_effect;
5pub mod no_nested_components;
6pub mod no_object_dep_array;
7pub mod no_outline_none;
8pub mod no_regexp_in_render;
9pub mod prefer_use_reducer;
10pub mod require_img_alt;
11
12pub use max_component_size::MaxComponentSizeRule;
13pub use no_cascading_set_state::NoCascadingSetStateRule;
14pub use no_click_handler::{NoDivClickHandlerRule, NoSpanClickHandlerRule};
15pub use no_derived_state_effect::NoDerivedStateEffectRule;
16pub use no_nested_components::NoNestedComponentsRule;
17pub use no_object_dep_array::NoObjectDepArrayRule;
18pub use no_outline_none::NoOutlineNoneRule;
19pub use no_regexp_in_render::NoRegexpInRenderRule;
20pub use prefer_use_reducer::PreferUseReducerRule;
21pub use require_img_alt::RequireImgAltRule;
22
23use std::path::Path;
24
25/// Supported languages for AST parsing.
26#[derive(Debug, Clone, Copy)]
27pub enum Lang {
28    Tsx,
29    Typescript,
30    Jsx,
31    Javascript,
32}
33
34/// Detect language from file extension.
35pub fn detect_language(path: &Path) -> Option<Lang> {
36    match path.extension()?.to_str()? {
37        "tsx" => Some(Lang::Tsx),
38        "ts" => Some(Lang::Typescript),
39        "jsx" => Some(Lang::Jsx),
40        "js" => Some(Lang::Javascript),
41        _ => None,
42    }
43}
44
45/// Parse a file into a tree-sitter syntax tree.
46pub fn parse_file(path: &Path, content: &str) -> Option<tree_sitter::Tree> {
47    let lang = detect_language(path)?;
48    let mut parser = tree_sitter::Parser::new();
49    let ts_lang: tree_sitter::Language = match lang {
50        Lang::Tsx => tree_sitter_typescript::LANGUAGE_TSX.into(),
51        Lang::Typescript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
52        Lang::Jsx | Lang::Javascript => tree_sitter_javascript::LANGUAGE.into(),
53    };
54    parser.set_language(&ts_lang).ok()?;
55    parser.parse(content, None)
56}
57
58/// Check if a tree-sitter node represents a React component declaration.
59///
60/// Recognizes PascalCase function declarations, arrow functions assigned to
61/// PascalCase variables, and PascalCase class declarations.
62pub fn is_component_node(node: &tree_sitter::Node, source: &[u8]) -> bool {
63    match node.kind() {
64        "function_declaration" => node
65            .child_by_field_name("name")
66            .and_then(|n| n.utf8_text(source).ok())
67            .map_or(false, starts_with_uppercase),
68        "arrow_function" => node
69            .parent()
70            .filter(|p| p.kind() == "variable_declarator")
71            .and_then(|p| p.child_by_field_name("name"))
72            .and_then(|n| n.utf8_text(source).ok())
73            .map_or(false, starts_with_uppercase),
74        "class_declaration" => node
75            .child_by_field_name("name")
76            .and_then(|n| n.utf8_text(source).ok())
77            .map_or(false, starts_with_uppercase),
78        _ => false,
79    }
80}
81
82fn starts_with_uppercase(name: &str) -> bool {
83    name.chars()
84        .next()
85        .map_or(false, |c| c.is_ascii_uppercase())
86}
87
88/// Count calls to a specific function within a node's subtree,
89/// skipping nested component definitions.
90pub fn count_calls_in_scope(
91    node: tree_sitter::Node,
92    source: &[u8],
93    target_name: &str,
94) -> usize {
95    let mut count = 0;
96    for i in 0..node.child_count() {
97        if let Some(child) = node.child(i) {
98            if is_component_node(&child, source) {
99                continue;
100            }
101            if child.kind() == "call_expression" && is_call_to(&child, source, target_name) {
102                count += 1;
103            }
104            count += count_calls_in_scope(child, source, target_name);
105        }
106    }
107    count
108}
109
110/// Check if a call_expression node calls a function with the given name.
111fn is_call_to(node: &tree_sitter::Node, source: &[u8], name: &str) -> bool {
112    node.child_by_field_name("function")
113        .filter(|f| f.kind() == "identifier")
114        .and_then(|f| f.utf8_text(source).ok())
115        .map_or(false, |n| n == name)
116}
117
118/// A fragment of a class string extracted from a JSX className/class attribute.
119#[derive(Debug)]
120pub struct ClassFragment {
121    pub value: String,
122    pub line: usize,
123    pub col: usize,
124}
125
126/// Utility function names that accept class strings as arguments.
127const CLASSNAME_UTILS: &[&str] = &["cn", "clsx", "classNames", "cva", "twMerge"];
128
129/// Extract string fragments from a className attribute value node.
130///
131/// Recursively handles strings, jsx_expression, call_expression (cn/clsx/etc),
132/// binary_expression, ternary_expression, template_string, arrays, and
133/// parenthesized_expression.
134pub fn extract_classname_strings(node: tree_sitter::Node, source: &[u8]) -> Vec<ClassFragment> {
135    let mut fragments = Vec::new();
136    match node.kind() {
137        "string" => {
138            for i in 0..node.child_count() {
139                if let Some(child) = node.child(i) {
140                    if child.kind() == "string_fragment" {
141                        if let Ok(text) = child.utf8_text(source) {
142                            if !text.is_empty() {
143                                fragments.push(ClassFragment {
144                                    value: text.to_string(),
145                                    line: child.start_position().row,
146                                    col: child.start_position().column,
147                                });
148                            }
149                        }
150                    }
151                }
152            }
153        }
154        "jsx_expression" => {
155            for i in 0..node.named_child_count() {
156                if let Some(child) = node.named_child(i) {
157                    fragments.extend(extract_classname_strings(child, source));
158                }
159            }
160        }
161        "call_expression" => {
162            let is_util = node
163                .child_by_field_name("function")
164                .filter(|f| f.kind() == "identifier")
165                .and_then(|f| f.utf8_text(source).ok())
166                .map_or(false, |name| CLASSNAME_UTILS.contains(&name));
167            if is_util {
168                if let Some(args) = node.child_by_field_name("arguments") {
169                    fragments.extend(extract_classname_strings(args, source));
170                }
171            }
172        }
173        "arguments" | "array" | "parenthesized_expression" => {
174            for i in 0..node.named_child_count() {
175                if let Some(child) = node.named_child(i) {
176                    fragments.extend(extract_classname_strings(child, source));
177                }
178            }
179        }
180        "binary_expression" => {
181            if let Some(left) = node.child_by_field_name("left") {
182                fragments.extend(extract_classname_strings(left, source));
183            }
184            if let Some(right) = node.child_by_field_name("right") {
185                fragments.extend(extract_classname_strings(right, source));
186            }
187        }
188        "ternary_expression" => {
189            if let Some(cons) = node.child_by_field_name("consequence") {
190                fragments.extend(extract_classname_strings(cons, source));
191            }
192            if let Some(alt) = node.child_by_field_name("alternative") {
193                fragments.extend(extract_classname_strings(alt, source));
194            }
195        }
196        "template_string" => {
197            for i in 0..node.child_count() {
198                if let Some(child) = node.child(i) {
199                    match child.kind() {
200                        "string_fragment" => {
201                            if let Ok(text) = child.utf8_text(source) {
202                                if !text.is_empty() {
203                                    fragments.push(ClassFragment {
204                                        value: text.to_string(),
205                                        line: child.start_position().row,
206                                        col: child.start_position().column,
207                                    });
208                                }
209                            }
210                        }
211                        "template_substitution" => {
212                            for j in 0..child.named_child_count() {
213                                if let Some(sub) = child.named_child(j) {
214                                    fragments.extend(extract_classname_strings(sub, source));
215                                }
216                            }
217                        }
218                        _ => {}
219                    }
220                }
221            }
222        }
223        _ => {
224            for i in 0..node.child_count() {
225                if let Some(child) = node.child(i) {
226                    fragments.extend(extract_classname_strings(child, source));
227                }
228            }
229        }
230    }
231    fragments
232}
233
234/// Walk the syntax tree and collect class strings from all className/class JSX attributes.
235///
236/// Returns `Vec<Vec<ClassFragment>>` — outer vec is per-attribute, inner vec is
237/// all class string fragments from that attribute.
238pub fn collect_class_attributes(tree: &tree_sitter::Tree, source: &[u8]) -> Vec<Vec<ClassFragment>> {
239    let mut result = Vec::new();
240    collect_class_attrs_walk(tree.root_node(), source, &mut result);
241    result
242}
243
244fn collect_class_attrs_walk(
245    node: tree_sitter::Node,
246    source: &[u8],
247    result: &mut Vec<Vec<ClassFragment>>,
248) {
249    if node.kind() == "jsx_attribute" {
250        let is_class_attr = node
251            .named_child(0)
252            .and_then(|n| n.utf8_text(source).ok())
253            .map_or(false, |name| name == "className" || name == "class");
254        if is_class_attr {
255            if let Some(value) = node.named_child(1) {
256                let fragments = extract_classname_strings(value, source);
257                if !fragments.is_empty() {
258                    result.push(fragments);
259                }
260            }
261            return;
262        }
263    }
264    for i in 0..node.child_count() {
265        if let Some(child) = node.child(i) {
266            collect_class_attrs_walk(child, source, result);
267        }
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use std::path::Path;
275
276    #[test]
277    fn detect_tsx() {
278        assert!(matches!(
279            detect_language(Path::new("foo.tsx")),
280            Some(Lang::Tsx)
281        ));
282    }
283
284    #[test]
285    fn detect_ts() {
286        assert!(matches!(
287            detect_language(Path::new("bar.ts")),
288            Some(Lang::Typescript)
289        ));
290    }
291
292    #[test]
293    fn detect_jsx() {
294        assert!(matches!(
295            detect_language(Path::new("baz.jsx")),
296            Some(Lang::Jsx)
297        ));
298    }
299
300    #[test]
301    fn detect_js() {
302        assert!(matches!(
303            detect_language(Path::new("qux.js")),
304            Some(Lang::Javascript)
305        ));
306    }
307
308    #[test]
309    fn detect_unknown() {
310        assert!(detect_language(Path::new("file.rs")).is_none());
311    }
312
313    #[test]
314    fn parse_tsx_file() {
315        let content = "function App() { return <div />; }";
316        let tree = parse_file(Path::new("app.tsx"), content);
317        assert!(tree.is_some());
318    }
319
320    #[test]
321    fn parse_unknown_ext_returns_none() {
322        let tree = parse_file(Path::new("app.rs"), "fn main() {}");
323        assert!(tree.is_none());
324    }
325
326    #[test]
327    fn component_function_declaration() {
328        let content = "function MyComponent() { return <div />; }";
329        let tree = parse_file(Path::new("a.tsx"), content).unwrap();
330        let root = tree.root_node();
331        let func = root.child(0).unwrap();
332        assert!(is_component_node(&func, content.as_bytes()));
333    }
334
335    #[test]
336    fn non_component_lowercase() {
337        let content = "function helper() { return 1; }";
338        let tree = parse_file(Path::new("a.tsx"), content).unwrap();
339        let root = tree.root_node();
340        let func = root.child(0).unwrap();
341        assert!(!is_component_node(&func, content.as_bytes()));
342    }
343
344    #[test]
345    fn component_arrow_function() {
346        let content = "const MyComponent = () => { return <div />; };";
347        let tree = parse_file(Path::new("a.tsx"), content).unwrap();
348        let source = content.as_bytes();
349        let root = tree.root_node();
350        // Walk to find the arrow_function
351        let mut found = false;
352        visit_all(root, &mut |node| {
353            if node.kind() == "arrow_function" && is_component_node(&node, source) {
354                found = true;
355            }
356        });
357        assert!(found);
358    }
359
360    fn visit_all<F: FnMut(tree_sitter::Node)>(node: tree_sitter::Node, f: &mut F) {
361        f(node);
362        for i in 0..node.child_count() {
363            if let Some(child) = node.child(i) {
364                visit_all(child, f);
365            }
366        }
367    }
368
369    #[test]
370    fn extract_simple_classname_string() {
371        let content = r#"<div className="bg-white text-black" />"#;
372        let tree = parse_file(Path::new("a.tsx"), content).unwrap();
373        let attrs = collect_class_attributes(&tree, content.as_bytes());
374        assert_eq!(attrs.len(), 1);
375        assert_eq!(attrs[0].len(), 1);
376        assert_eq!(attrs[0][0].value, "bg-white text-black");
377    }
378
379    #[test]
380    fn extract_cn_call_strings() {
381        let content = r#"<div className={cn("bg-white", "text-black")} />"#;
382        let tree = parse_file(Path::new("a.tsx"), content).unwrap();
383        let attrs = collect_class_attributes(&tree, content.as_bytes());
384        assert_eq!(attrs.len(), 1);
385        assert_eq!(attrs[0].len(), 2);
386        assert_eq!(attrs[0][0].value, "bg-white");
387        assert_eq!(attrs[0][1].value, "text-black");
388    }
389
390    #[test]
391    fn extract_multiline_cn_call() {
392        let content = "<div className={cn(\n  \"bg-white\",\n  active && \"text-black\",\n  \"p-4\"\n)} />";
393        let tree = parse_file(Path::new("a.tsx"), content).unwrap();
394        let attrs = collect_class_attributes(&tree, content.as_bytes());
395        assert_eq!(attrs.len(), 1);
396        let values: Vec<&str> = attrs[0].iter().map(|f| f.value.as_str()).collect();
397        assert_eq!(values, vec!["bg-white", "text-black", "p-4"]);
398    }
399
400    #[test]
401    fn extract_ternary_expression() {
402        let content = r#"<div className={active ? "bg-white" : "bg-gray-100"} />"#;
403        let tree = parse_file(Path::new("a.tsx"), content).unwrap();
404        let attrs = collect_class_attributes(&tree, content.as_bytes());
405        assert_eq!(attrs.len(), 1);
406        let values: Vec<&str> = attrs[0].iter().map(|f| f.value.as_str()).collect();
407        assert_eq!(values, vec!["bg-white", "bg-gray-100"]);
408    }
409
410    #[test]
411    fn no_class_attrs_in_data_object() {
412        let content = r#"const obj = { className: "bg-white" };"#;
413        let tree = parse_file(Path::new("a.tsx"), content).unwrap();
414        let attrs = collect_class_attributes(&tree, content.as_bytes());
415        assert!(attrs.is_empty());
416    }
417
418    #[test]
419    fn non_util_call_not_extracted() {
420        let content = r#"<div className={getClass("special")} />"#;
421        let tree = parse_file(Path::new("a.tsx"), content).unwrap();
422        let attrs = collect_class_attributes(&tree, content.as_bytes());
423        assert!(attrs.is_empty(), "non-utility calls should produce no fragments");
424    }
425}