Skip to main content

cha_core/plugins/
async_callback_leak.rs

1use crate::{AnalysisContext, Finding, Location, Plugin, Severity, SmellCategory, TypeRef};
2
3/// Known async / channel handle type names across Rust, TypeScript, Python,
4/// Go. Matched on the innermost type identifier (via TypeRef::name), so
5/// `tokio::task::JoinHandle<T>` and `JoinHandle` both match.
6const ASYNC_HANDLE_TYPES: &[&str] = &[
7    // Rust async ecosystem
8    "JoinHandle",
9    "Future",
10    "Task",
11    "AbortHandle",
12    "oneshot",
13    "mpsc",
14    // Channel halves (cross-ecosystem)
15    "Sender",
16    "Receiver",
17    "UnboundedSender",
18    "UnboundedReceiver",
19    "WatchSender",
20    "WatchReceiver",
21    // JavaScript / TypeScript
22    "Promise",
23    "PromiseLike",
24    // Go (channel types are punctuation so they surface as Unknown — leave to
25    // a later pass; we still catch `context.CancelFunc`, `sync.WaitGroup`).
26    "CancelFunc",
27    "WaitGroup",
28    // Python
29    "Awaitable",
30    "Coroutine",
31    "Queue",
32];
33
34/// Names that legitimately return async handles — launchers/spawners whose
35/// whole point is to expose a handle. Skip them to keep the signal tight.
36const LAUNCHER_PREFIXES: &[&str] = &[
37    "spawn",
38    "spawn_",
39    "launch",
40    "launch_",
41    "start",
42    "start_",
43    "run_async",
44    "fire_",
45    "dispatch_",
46    "background_",
47];
48
49#[derive(Default)]
50pub struct AsyncCallbackLeakAnalyzer;
51
52impl Plugin for AsyncCallbackLeakAnalyzer {
53    fn name(&self) -> &str {
54        "async_callback_leak"
55    }
56
57    fn smells(&self) -> Vec<String> {
58        vec!["async_callback_leak".into()]
59    }
60
61    fn description(&self) -> &str {
62        "Function signature leaks a raw async handle (JoinHandle/Future/Channel)"
63    }
64
65    fn analyze(&self, ctx: &AnalysisContext) -> Vec<Finding> {
66        ctx.model
67            .functions
68            .iter()
69            .filter_map(|f| {
70                if is_launcher_shaped(&f.name) {
71                    return None;
72                }
73                if let Some(ret) = &f.return_type
74                    && is_async_handle(ret)
75                {
76                    return Some(build_finding(ctx, f, ret, Position::Return));
77                }
78                for (idx, param) in f.parameter_types.iter().enumerate() {
79                    if is_async_handle(param) {
80                        return Some(build_finding(ctx, f, param, Position::Param(idx + 1)));
81                    }
82                }
83                None
84            })
85            .collect()
86    }
87}
88
89fn is_launcher_shaped(name: &str) -> bool {
90    LAUNCHER_PREFIXES
91        .iter()
92        .any(|p| name == *p || name.starts_with(p))
93}
94
95fn is_async_handle(t: &TypeRef) -> bool {
96    ASYNC_HANDLE_TYPES.contains(&t.name.as_str())
97}
98
99enum Position {
100    Return,
101    Param(usize),
102}
103
104fn build_finding(
105    ctx: &AnalysisContext,
106    f: &crate::FunctionInfo,
107    t: &TypeRef,
108    pos: Position,
109) -> Finding {
110    let (where_it, suggestion) = match pos {
111        Position::Return => (
112            "return type".to_string(),
113            "Wait inside the function and return a domain value, or wrap the handle in a local Task abstraction".to_string(),
114        ),
115        Position::Param(i) => (
116            format!("parameter #{i}"),
117            "Accept a domain callback/value instead of a raw async handle".to_string(),
118        ),
119    };
120    Finding {
121        smell_name: "async_callback_leak".into(),
122        category: SmellCategory::Couplers,
123        severity: Severity::Hint,
124        location: Location {
125            path: ctx.file.path.clone(),
126            start_line: f.start_line,
127            start_col: f.name_col,
128            end_line: f.start_line,
129            end_col: f.name_end_col,
130            name: Some(f.name.clone()),
131        },
132        message: format!(
133            "Function `{}` has `{}` in its {} — concurrency primitive leaks to callers",
134            f.name, t.name, where_it
135        ),
136        suggested_refactorings: vec![
137            suggestion,
138            "Expose a higher-level interface (domain event, callback) instead of the raw handle"
139                .into(),
140        ],
141        ..Default::default()
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::{FunctionInfo, SourceFile, SourceModel, TypeOrigin};
149    use std::path::PathBuf;
150
151    fn ctx_with(functions: Vec<FunctionInfo>) -> (SourceFile, SourceModel) {
152        let file = SourceFile::new(PathBuf::from("test.rs"), String::new());
153        let model = SourceModel {
154            language: "rust".into(),
155            total_lines: 10,
156            functions,
157            classes: vec![],
158            imports: vec![],
159            comments: vec![],
160            type_aliases: vec![],
161        };
162        (file, model)
163    }
164
165    fn tref(name: &str, origin: TypeOrigin) -> TypeRef {
166        TypeRef {
167            name: name.into(),
168            raw: name.into(),
169            origin,
170        }
171    }
172
173    fn run(functions: Vec<FunctionInfo>) -> Vec<Finding> {
174        let (file, model) = ctx_with(functions);
175        let ctx = AnalysisContext {
176            file: &file,
177            model: &model,
178        };
179        AsyncCallbackLeakAnalyzer.analyze(&ctx)
180    }
181
182    #[test]
183    fn flags_function_returning_join_handle() {
184        let f = FunctionInfo {
185            name: "load_user".into(),
186            start_line: 1,
187            end_line: 5,
188            return_type: Some(tref("JoinHandle", TypeOrigin::External("tokio".into()))),
189            ..Default::default()
190        };
191        let findings = run(vec![f]);
192        assert_eq!(findings.len(), 1);
193        assert!(findings[0].message.contains("JoinHandle"));
194        assert!(findings[0].message.contains("return type"));
195    }
196
197    #[test]
198    fn flags_function_taking_sender() {
199        let f = FunctionInfo {
200            name: "configure".into(),
201            start_line: 1,
202            end_line: 5,
203            parameter_count: 1,
204            parameter_types: vec![tref("Sender", TypeOrigin::External("tokio".into()))],
205            ..Default::default()
206        };
207        let findings = run(vec![f]);
208        assert_eq!(findings.len(), 1);
209        assert!(findings[0].message.contains("Sender"));
210        assert!(findings[0].message.contains("parameter #1"));
211    }
212
213    #[test]
214    fn ignores_launcher_shaped_names() {
215        // spawn_worker returning JoinHandle is legitimate — that's the whole
216        // point of a spawn function.
217        let f = FunctionInfo {
218            name: "spawn_worker".into(),
219            start_line: 1,
220            end_line: 5,
221            return_type: Some(tref("JoinHandle", TypeOrigin::External("tokio".into()))),
222            ..Default::default()
223        };
224        let findings = run(vec![f]);
225        assert!(findings.is_empty());
226    }
227
228    #[test]
229    fn ignores_plain_domain_signatures() {
230        let f = FunctionInfo {
231            name: "get_user".into(),
232            start_line: 1,
233            end_line: 5,
234            return_type: Some(tref("User", TypeOrigin::Local)),
235            parameter_count: 1,
236            parameter_types: vec![tref("UserId", TypeOrigin::Local)],
237            ..Default::default()
238        };
239        let findings = run(vec![f]);
240        assert!(findings.is_empty());
241    }
242
243    #[test]
244    fn flags_promise_typescript() {
245        let f = FunctionInfo {
246            name: "fetch_users".into(),
247            start_line: 1,
248            end_line: 5,
249            return_type: Some(tref("Promise", TypeOrigin::Primitive)),
250            ..Default::default()
251        };
252        let findings = run(vec![f]);
253        assert_eq!(findings.len(), 1);
254        assert!(findings[0].message.contains("Promise"));
255    }
256}