Skip to main content

cha_core/plugins/
async_callback_leak.rs

1use crate::{AnalysisContext, Finding, Location, Plugin, Severity, SmellCategory, TypeRef};
2
3/// Known async / channel handle type names across Rust, TypeScript, Python,
4/// Go. Matched on the innermost type identifier (via TypeRef::name), so
5/// `tokio::task::JoinHandle<T>` and `JoinHandle` both match.
6const ASYNC_HANDLE_TYPES: &[&str] = &[
7    // Rust async ecosystem
8    "JoinHandle",
9    "Future",
10    "Task",
11    "AbortHandle",
12    "oneshot",
13    "mpsc",
14    // Channel halves (cross-ecosystem)
15    "Sender",
16    "Receiver",
17    "UnboundedSender",
18    "UnboundedReceiver",
19    "WatchSender",
20    "WatchReceiver",
21    // JavaScript / TypeScript
22    "Promise",
23    "PromiseLike",
24    // Go (channel types are punctuation so they surface as Unknown — leave to
25    // a later pass; we still catch `context.CancelFunc`, `sync.WaitGroup`).
26    "CancelFunc",
27    "WaitGroup",
28    // Python
29    "Awaitable",
30    "Coroutine",
31    "Queue",
32];
33
34/// Names that legitimately return async handles — launchers/spawners whose
35/// whole point is to expose a handle. Skip them to keep the signal tight.
36const LAUNCHER_PREFIXES: &[&str] = &[
37    "spawn",
38    "spawn_",
39    "launch",
40    "launch_",
41    "start",
42    "start_",
43    "run_async",
44    "fire_",
45    "dispatch_",
46    "background_",
47];
48
49#[derive(Default)]
50pub struct AsyncCallbackLeakAnalyzer;
51
52impl Plugin for AsyncCallbackLeakAnalyzer {
53    fn name(&self) -> &str {
54        "async_callback_leak"
55    }
56
57    fn smells(&self) -> Vec<String> {
58        vec!["async_callback_leak".into()]
59    }
60
61    fn description(&self) -> &str {
62        "Function signature leaks a raw async handle (JoinHandle/Future/Channel)"
63    }
64
65    fn analyze(&self, ctx: &AnalysisContext) -> Vec<Finding> {
66        ctx.model
67            .functions
68            .iter()
69            .filter_map(|f| {
70                if is_launcher_shaped(&f.name) {
71                    return None;
72                }
73                if let Some(ret) = &f.return_type
74                    && is_async_handle(ret)
75                {
76                    return Some(build_finding(ctx, f, ret, Position::Return));
77                }
78                for (idx, param) in f.parameter_types.iter().enumerate() {
79                    if is_async_handle(param) {
80                        return Some(build_finding(ctx, f, param, Position::Param(idx + 1)));
81                    }
82                }
83                None
84            })
85            .collect()
86    }
87}
88
89fn is_launcher_shaped(name: &str) -> bool {
90    LAUNCHER_PREFIXES
91        .iter()
92        .any(|p| name == *p || name.starts_with(p))
93}
94
95fn is_async_handle(t: &TypeRef) -> bool {
96    ASYNC_HANDLE_TYPES.contains(&t.name.as_str())
97}
98
99enum Position {
100    Return,
101    Param(usize),
102}
103
104fn build_finding(
105    ctx: &AnalysisContext,
106    f: &crate::FunctionInfo,
107    t: &TypeRef,
108    pos: Position,
109) -> Finding {
110    let (where_it, suggestion) = match pos {
111        Position::Return => (
112            "return type".to_string(),
113            "Wait inside the function and return a domain value, or wrap the handle in a local Task abstraction".to_string(),
114        ),
115        Position::Param(i) => (
116            format!("parameter #{i}"),
117            "Accept a domain callback/value instead of a raw async handle".to_string(),
118        ),
119    };
120    Finding {
121        smell_name: "async_callback_leak".into(),
122        category: SmellCategory::Couplers,
123        severity: Severity::Hint,
124        location: Location {
125            path: ctx.file.path.clone(),
126            start_line: f.start_line,
127            start_col: f.name_col,
128            end_line: f.start_line,
129            end_col: f.name_end_col,
130            name: Some(f.name.clone()),
131        },
132        message: format!(
133            "Function `{}` has `{}` in its {} — concurrency primitive leaks to callers",
134            f.name, t.name, where_it
135        ),
136        suggested_refactorings: vec![
137            suggestion,
138            "Expose a higher-level interface (domain event, callback) instead of the raw handle"
139                .into(),
140        ],
141        ..Default::default()
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::{FunctionInfo, SourceFile, SourceModel, TypeOrigin};
149    use std::path::PathBuf;
150
151    fn ctx_with(functions: Vec<FunctionInfo>) -> (SourceFile, SourceModel) {
152        let file = SourceFile::new(PathBuf::from("test.rs"), String::new());
153        let model = SourceModel {
154            language: "rust".into(),
155            total_lines: 10,
156            functions,
157            classes: vec![],
158            imports: vec![],
159            comments: vec![],
160            type_aliases: vec![],
161        };
162        (file, model)
163    }
164
165    fn tref(name: &str, origin: TypeOrigin) -> TypeRef {
166        TypeRef {
167            name: name.into(),
168            raw: name.into(),
169            origin,
170        }
171    }
172
173    fn run(functions: Vec<FunctionInfo>) -> Vec<Finding> {
174        let (file, model) = ctx_with(functions);
175        let ctx = AnalysisContext {
176            file: &file,
177            model: &model,
178            tree: None,
179            ts_language: None,
180        };
181        AsyncCallbackLeakAnalyzer.analyze(&ctx)
182    }
183
184    #[test]
185    fn flags_function_returning_join_handle() {
186        let f = FunctionInfo {
187            name: "load_user".into(),
188            start_line: 1,
189            end_line: 5,
190            return_type: Some(tref("JoinHandle", TypeOrigin::External("tokio".into()))),
191            ..Default::default()
192        };
193        let findings = run(vec![f]);
194        assert_eq!(findings.len(), 1);
195        assert!(findings[0].message.contains("JoinHandle"));
196        assert!(findings[0].message.contains("return type"));
197    }
198
199    #[test]
200    fn flags_function_taking_sender() {
201        let f = FunctionInfo {
202            name: "configure".into(),
203            start_line: 1,
204            end_line: 5,
205            parameter_count: 1,
206            parameter_types: vec![tref("Sender", TypeOrigin::External("tokio".into()))],
207            ..Default::default()
208        };
209        let findings = run(vec![f]);
210        assert_eq!(findings.len(), 1);
211        assert!(findings[0].message.contains("Sender"));
212        assert!(findings[0].message.contains("parameter #1"));
213    }
214
215    #[test]
216    fn ignores_launcher_shaped_names() {
217        // spawn_worker returning JoinHandle is legitimate — that's the whole
218        // point of a spawn function.
219        let f = FunctionInfo {
220            name: "spawn_worker".into(),
221            start_line: 1,
222            end_line: 5,
223            return_type: Some(tref("JoinHandle", TypeOrigin::External("tokio".into()))),
224            ..Default::default()
225        };
226        let findings = run(vec![f]);
227        assert!(findings.is_empty());
228    }
229
230    #[test]
231    fn ignores_plain_domain_signatures() {
232        let f = FunctionInfo {
233            name: "get_user".into(),
234            start_line: 1,
235            end_line: 5,
236            return_type: Some(tref("User", TypeOrigin::Local)),
237            parameter_count: 1,
238            parameter_types: vec![tref("UserId", TypeOrigin::Local)],
239            ..Default::default()
240        };
241        let findings = run(vec![f]);
242        assert!(findings.is_empty());
243    }
244
245    #[test]
246    fn flags_promise_typescript() {
247        let f = FunctionInfo {
248            name: "fetch_users".into(),
249            start_line: 1,
250            end_line: 5,
251            return_type: Some(tref("Promise", TypeOrigin::Primitive)),
252            ..Default::default()
253        };
254        let findings = run(vec![f]);
255        assert_eq!(findings.len(), 1);
256        assert!(findings[0].message.contains("Promise"));
257    }
258}