Skip to main content

code_baseline/rules/ast/
no_derived_state_effect.rs

1use crate::config::{RuleConfig, Severity};
2use crate::rules::ast::parse_file;
3use crate::rules::{Rule, RuleBuildError, ScanContext, Violation};
4
5/// Flags `useEffect` callbacks where the body contains ONLY `set*()` calls.
6///
7/// When every statement in a useEffect callback is a setState call, the effect
8/// is computing derived state and should be replaced with `useMemo` or inline
9/// computation during render.
10pub struct NoDerivedStateEffectRule {
11    id: String,
12    severity: Severity,
13    message: String,
14    suggest: Option<String>,
15    glob: Option<String>,
16}
17
18impl NoDerivedStateEffectRule {
19    pub fn new(config: &RuleConfig) -> Result<Self, RuleBuildError> {
20        Ok(Self {
21            id: config.id.clone(),
22            severity: config.severity,
23            message: config.message.clone(),
24            suggest: config.suggest.clone(),
25            glob: config.glob.clone(),
26        })
27    }
28}
29
30impl Rule for NoDerivedStateEffectRule {
31    fn id(&self) -> &str {
32        &self.id
33    }
34    fn severity(&self) -> Severity {
35        self.severity
36    }
37    fn file_glob(&self) -> Option<&str> {
38        self.glob.as_deref()
39    }
40    fn check_file(&self, ctx: &ScanContext) -> Vec<Violation> {
41        let mut violations = Vec::new();
42        let tree = match parse_file(ctx.file_path, ctx.content) {
43            Some(t) => t,
44            None => return violations,
45        };
46        let source = ctx.content.as_bytes();
47        self.visit(tree.root_node(), source, ctx, &mut violations);
48        violations
49    }
50}
51
52impl NoDerivedStateEffectRule {
53    fn visit(
54        &self,
55        node: tree_sitter::Node,
56        source: &[u8],
57        ctx: &ScanContext,
58        violations: &mut Vec<Violation>,
59    ) {
60        if node.kind() == "call_expression" {
61            if let Some(func) = node.child_by_field_name("function") {
62                if func.kind() == "identifier" {
63                    if let Ok(name) = func.utf8_text(source) {
64                        if name == "useEffect" {
65                            if let Some(args) = node.child_by_field_name("arguments") {
66                                if let Some(callback) = args.named_child(0) {
67                                    if self.is_only_set_state(&callback, source) {
68                                        let line = node.start_position().row;
69                                        violations.push(Violation {
70                                            rule_id: self.id.clone(),
71                                            severity: self.severity,
72                                            file: ctx.file_path.to_path_buf(),
73                                            line: Some(line + 1),
74                                            column: Some(node.start_position().column + 1),
75                                            message: self.message.clone(),
76                                            suggest: self.suggest.clone(),
77                                            source_line: ctx
78                                                .content
79                                                .lines()
80                                                .nth(line)
81                                                .map(String::from),
82                                            fix: None,
83                                        });
84                                    }
85                                }
86                            }
87                        }
88                    }
89                }
90            }
91        }
92
93        for i in 0..node.child_count() {
94            if let Some(child) = node.child(i) {
95                self.visit(child, source, ctx, violations);
96            }
97        }
98    }
99
100    /// Check if a callback body contains ONLY set* calls (at least one).
101    fn is_only_set_state(&self, callback: &tree_sitter::Node, source: &[u8]) -> bool {
102        // Find the statement_block in the callback
103        let body = self.find_body(callback);
104        let body = match body {
105            Some(b) => b,
106            None => {
107                // Arrow function with expression body (no block): e.g. () => setFoo(x)
108                // Check if the expression itself is a set* call
109                if callback.kind() == "arrow_function" {
110                    if let Some(body_node) = callback.child_by_field_name("body") {
111                        if body_node.kind() == "call_expression" {
112                            return is_set_state_call(&body_node, source);
113                        }
114                    }
115                }
116                return false;
117            }
118        };
119
120        let mut count = 0;
121        for i in 0..body.named_child_count() {
122            if let Some(stmt) = body.named_child(i) {
123                if stmt.kind() == "expression_statement" {
124                    if let Some(expr) = stmt.named_child(0) {
125                        if expr.kind() == "call_expression" && is_set_state_call(&expr, source) {
126                            count += 1;
127                            continue;
128                        }
129                    }
130                }
131                // Non-setState statement found
132                return false;
133            }
134        }
135        count > 0
136    }
137
138    fn find_body<'a>(&self, node: &'a tree_sitter::Node<'a>) -> Option<tree_sitter::Node<'a>> {
139        match node.kind() {
140            "arrow_function" | "function_expression" | "function" => {
141                node.child_by_field_name("body")
142                    .filter(|b| b.kind() == "statement_block")
143            }
144            _ => None,
145        }
146    }
147}
148
149fn is_set_state_call(node: &tree_sitter::Node, source: &[u8]) -> bool {
150    if let Some(func) = node.child_by_field_name("function") {
151        if func.kind() == "identifier" {
152            if let Ok(name) = func.utf8_text(source) {
153                if let Some(rest) = name.strip_prefix("set") {
154                    return rest.starts_with(|c: char| c.is_ascii_uppercase());
155                }
156            }
157        }
158    }
159    false
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use std::path::Path;
166
167    fn make_rule() -> NoDerivedStateEffectRule {
168        NoDerivedStateEffectRule::new(&RuleConfig {
169            id: "no-derived-state-effect".into(),
170            severity: Severity::Warning,
171            message: "useEffect that only calls setState is derived state".into(),
172            suggest: Some("Compute during render with useMemo instead".into()),
173            glob: Some("**/*.{tsx,jsx}".into()),
174            ..Default::default()
175        })
176        .unwrap()
177    }
178
179    fn check(content: &str) -> Vec<Violation> {
180        let rule = make_rule();
181        let ctx = ScanContext {
182            file_path: Path::new("test.tsx"),
183            content,
184        };
185        rule.check_file(&ctx)
186    }
187
188    #[test]
189    fn only_set_state_flags() {
190        let content = "\
191function MyComponent({ data }) {
192  const [derived, setDerived] = useState('');
193  useEffect(() => {
194    setDerived(compute(data));
195  }, [data]);
196  return <div>{derived}</div>;
197}";
198        assert_eq!(check(content).len(), 1);
199    }
200
201    #[test]
202    fn multiple_set_state_only_flags() {
203        let content = "\
204function MyComponent({ a, b }) {
205  const [x, setX] = useState(0);
206  const [y, setY] = useState(0);
207  useEffect(() => {
208    setX(a * 2);
209    setY(b * 3);
210  }, [a, b]);
211  return <div />;
212}";
213        assert_eq!(check(content).len(), 1);
214    }
215
216    #[test]
217    fn mixed_statements_no_violation() {
218        let content = "\
219function MyComponent({ id }) {
220  const [data, setData] = useState(null);
221  useEffect(() => {
222    fetch('/api/' + id).then(r => r.json()).then(setData);
223  }, [id]);
224  return <div />;
225}";
226        assert!(check(content).is_empty());
227    }
228
229    #[test]
230    fn set_state_plus_other_no_violation() {
231        let content = "\
232function MyComponent({ value }) {
233  const [x, setX] = useState(0);
234  useEffect(() => {
235    console.log('updating');
236    setX(value * 2);
237  }, [value]);
238  return <div />;
239}";
240        assert!(check(content).is_empty());
241    }
242
243    #[test]
244    fn empty_effect_no_violation() {
245        let content = "\
246function MyComponent() {
247  useEffect(() => {
248  }, []);
249  return <div />;
250}";
251        assert!(check(content).is_empty());
252    }
253
254    #[test]
255    fn arrow_expression_body_set_state_flags() {
256        let content = "\
257function MyComponent({ data }) {
258  const [x, setX] = useState(0);
259  useEffect(() => setX(data.length), [data]);
260  return <div />;
261}";
262        assert_eq!(check(content).len(), 1);
263    }
264
265    #[test]
266    fn non_tsx_skipped() {
267        let rule = make_rule();
268        let ctx = ScanContext {
269            file_path: Path::new("test.rs"),
270            content: "fn main() {}",
271        };
272        assert!(rule.check_file(&ctx).is_empty());
273    }
274}