Skip to main content

cargo_capsec/
parser.rs

1//! Rust source file parser built on [`syn`].
2//!
3//! Parses `.rs` files into a structured representation that captures the information
4//! the [`Detector`](crate::detector::Detector) needs: function boundaries, call sites,
5//! `use` imports, and `extern` blocks. Handles free functions, `impl` block methods,
6//! and trait default methods.
7//!
8//! The parser uses [`syn::visit::Visit`] to walk the AST. It does **not** perform type
9//! resolution — all matching is done on syntactic path segments. Import aliases are
10//! tracked so the [`Detector`](crate::detector::Detector) can expand them.
11
12use std::path::Path;
13use syn::visit::Visit;
14
15/// The parsed representation of a single `.rs` source file.
16///
17/// Contains every function body, `use` import, and `extern` block found in the file.
18/// This is the input to [`Detector::analyse`](crate::detector::Detector::analyse).
19#[derive(Debug, Clone)]
20pub struct ParsedFile {
21    /// File path (for reporting).
22    pub path: String,
23    /// All functions found: free functions, `impl` methods, and trait default methods.
24    pub functions: Vec<ParsedFunction>,
25    /// All `use` imports, with aliases tracked.
26    pub use_imports: Vec<ImportPath>,
27    /// All `extern` blocks (FFI declarations).
28    pub extern_blocks: Vec<ExternBlock>,
29}
30
31/// A single function (free, `impl` method, or trait default method) and its call sites.
32#[derive(Debug, Clone)]
33pub struct ParsedFunction {
34    /// The function name (e.g., `"load_config"`).
35    pub name: String,
36    /// Line number where the function is defined.
37    pub line: usize,
38    /// Every call expression found inside the function body.
39    pub calls: Vec<CallSite>,
40    /// True if this is the `main()` function inside a `build.rs` file.
41    pub is_build_script: bool,
42    /// Categories denied by `#[capsec::deny(...)]` on this function.
43    /// Parsed from `#[doc = "capsec::deny(...)"]` attributes.
44    pub deny_categories: Vec<String>,
45}
46
47/// A single call expression inside a function body.
48///
49/// Call sites are either qualified function calls (`fs::read(...)`) or method calls
50/// (`stream.connect(...)`). The [`segments`](CallSite::segments) field holds the
51/// raw path segments before import expansion.
52#[derive(Debug, Clone)]
53pub struct CallSite {
54    /// Path segments of the call (e.g., `["fs", "read"]` or `["TcpStream", "connect"]`).
55    pub segments: Vec<String>,
56    /// Source line number.
57    pub line: usize,
58    /// Source column number.
59    pub col: usize,
60    /// Whether this is a function call or a method call.
61    pub kind: CallKind,
62}
63
64/// Distinguishes qualified function calls from method calls.
65#[derive(Debug, Clone)]
66pub enum CallKind {
67    /// A qualified path call like `fs::read(...)` or `Command::new(...)`.
68    FunctionCall,
69    /// A method call like `stream.connect(...)` or `cmd.output()`.
70    MethodCall {
71        /// The method name (e.g., `"connect"`, `"output"`).
72        method: String,
73    },
74}
75
76/// A `use` import statement, with optional alias.
77///
78/// For `use std::fs::read as load`, the segments are `["std", "fs", "read"]` and
79/// the alias is `Some("load")`. The [`Detector`](crate::detector::Detector) uses
80/// this to expand bare calls: when it sees `load(...)`, it looks up the alias and
81/// expands it to `std::fs::read`.
82#[derive(Debug, Clone)]
83pub struct ImportPath {
84    /// The full path segments (e.g., `["std", "fs", "read"]`).
85    pub segments: Vec<String>,
86    /// The `as` alias, if any (e.g., `Some("load")` for `use std::fs::read as load`).
87    pub alias: Option<String>,
88}
89
90/// An `extern` block declaring foreign functions.
91///
92/// Any `extern` block is flagged as [`Category::Ffi`](crate::authorities::Category::Ffi)
93/// by the detector, since FFI calls bypass Rust's safety model entirely.
94#[derive(Debug, Clone)]
95pub struct ExternBlock {
96    /// The ABI string (e.g., `Some("C")` for `extern "C"`).
97    pub abi: Option<String>,
98    /// Names of functions declared in the block.
99    pub functions: Vec<String>,
100    /// Source line number.
101    pub line: usize,
102}
103
104/// Parses a `.rs` file from disk into a [`ParsedFile`].
105///
106/// Requires an [`FsRead`](capsec_core::permission::FsRead) capability token,
107/// proving the caller has permission to read files. This is the dogfood example —
108/// `cargo capsec audit` flagged this function's `std::fs::read_to_string` call,
109/// and now it's gated by the capsec type system.
110///
111/// # Example
112///
113/// ```rust,ignore
114/// use capsec_core::root::test_root;
115/// use capsec_core::permission::FsRead;
116///
117/// let root = test_root();
118/// let cap = root.grant::<FsRead>();
119/// let parsed = parse_file(Path::new("src/main.rs"), &cap).unwrap();
120/// ```
121pub fn parse_file(
122    path: &Path,
123    cap: &impl capsec_core::has::Has<capsec_core::permission::FsRead>,
124) -> Result<ParsedFile, String> {
125    let source = capsec_std::fs::read_to_string(path, cap)
126        .map_err(|e| format!("Failed to read {}: {e}", path.display()))?;
127    parse_source(&source, &path.display().to_string())
128}
129
130/// Parses Rust source code from a string into a [`ParsedFile`].
131///
132/// This is the primary entry point for programmatic usage and testing.
133/// The `path` parameter is used only for error messages and the
134/// [`ParsedFile::path`] field — it doesn't need to be a real file.
135///
136/// # Errors
137///
138/// Returns an error string if [`syn::parse_file`] fails (e.g., invalid Rust syntax).
139pub fn parse_source(source: &str, path: &str) -> Result<ParsedFile, String> {
140    let syntax = syn::parse_file(source).map_err(|e| format!("Failed to parse {path}: {e}"))?;
141
142    let mut visitor = FileVisitor::new(path.to_string());
143    visitor.visit_file(&syntax);
144
145    Ok(ParsedFile {
146        path: path.to_string(),
147        functions: visitor.functions,
148        use_imports: visitor.imports,
149        extern_blocks: visitor.extern_blocks,
150    })
151}
152
153struct FileVisitor {
154    file_path: String,
155    functions: Vec<ParsedFunction>,
156    imports: Vec<ImportPath>,
157    extern_blocks: Vec<ExternBlock>,
158    current_function: Option<ParsedFunction>,
159}
160
161impl FileVisitor {
162    fn new(file_path: String) -> Self {
163        Self {
164            file_path,
165            functions: Vec::new(),
166            imports: Vec::new(),
167            extern_blocks: Vec::new(),
168            current_function: None,
169        }
170    }
171}
172
173impl<'ast> Visit<'ast> for FileVisitor {
174    fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
175        let func = ParsedFunction {
176            name: node.sig.ident.to_string(),
177            line: node.sig.ident.span().start().line,
178            calls: Vec::new(),
179            is_build_script: self.file_path.ends_with("build.rs") && node.sig.ident == "main",
180            deny_categories: extract_deny_categories(&node.attrs),
181        };
182
183        let prev = self.current_function.take();
184        self.current_function = Some(func);
185
186        syn::visit::visit_item_fn(self, node);
187
188        if let Some(func) = self.current_function.take() {
189            self.functions.push(func);
190        }
191        self.current_function = prev;
192    }
193
194    fn visit_impl_item_fn(&mut self, node: &'ast syn::ImplItemFn) {
195        let func = ParsedFunction {
196            name: node.sig.ident.to_string(),
197            line: node.sig.ident.span().start().line,
198            calls: Vec::new(),
199            is_build_script: false,
200            deny_categories: extract_deny_categories(&node.attrs),
201        };
202
203        let prev = self.current_function.take();
204        self.current_function = Some(func);
205
206        syn::visit::visit_impl_item_fn(self, node);
207
208        if let Some(func) = self.current_function.take() {
209            self.functions.push(func);
210        }
211        self.current_function = prev;
212    }
213
214    fn visit_trait_item_fn(&mut self, node: &'ast syn::TraitItemFn) {
215        // Only visit if there's a default body
216        if node.default.is_some() {
217            let func = ParsedFunction {
218                name: node.sig.ident.to_string(),
219                line: node.sig.ident.span().start().line,
220                calls: Vec::new(),
221                is_build_script: false,
222                deny_categories: extract_deny_categories(&node.attrs),
223            };
224
225            let prev = self.current_function.take();
226            self.current_function = Some(func);
227
228            syn::visit::visit_trait_item_fn(self, node);
229
230            if let Some(func) = self.current_function.take() {
231                self.functions.push(func);
232            }
233            self.current_function = prev;
234        } else {
235            syn::visit::visit_trait_item_fn(self, node);
236        }
237    }
238
239    fn visit_expr_call(&mut self, node: &'ast syn::ExprCall) {
240        if let Some(ref mut func) = self.current_function
241            && let syn::Expr::Path(ref path) = *node.func
242        {
243            let segments: Vec<String> = path
244                .path
245                .segments
246                .iter()
247                .map(|s| s.ident.to_string())
248                .collect();
249
250            if !segments.is_empty() {
251                func.calls.push(CallSite {
252                    segments,
253                    line: path
254                        .path
255                        .segments
256                        .first()
257                        .map(|s| s.ident.span().start().line)
258                        .unwrap_or(0),
259                    col: path
260                        .path
261                        .segments
262                        .first()
263                        .map(|s| s.ident.span().start().column)
264                        .unwrap_or(0),
265                    kind: CallKind::FunctionCall,
266                });
267            }
268        }
269
270        syn::visit::visit_expr_call(self, node);
271    }
272
273    fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
274        if let Some(ref mut func) = self.current_function {
275            func.calls.push(CallSite {
276                segments: vec![node.method.to_string()],
277                line: node.method.span().start().line,
278                col: node.method.span().start().column,
279                kind: CallKind::MethodCall {
280                    method: node.method.to_string(),
281                },
282            });
283        }
284
285        syn::visit::visit_expr_method_call(self, node);
286    }
287
288    fn visit_item_use(&mut self, node: &'ast syn::ItemUse) {
289        let mut paths = Vec::new();
290        collect_use_paths(&node.tree, &mut Vec::new(), &mut paths);
291        self.imports.extend(paths);
292
293        syn::visit::visit_item_use(self, node);
294    }
295
296    fn visit_item_foreign_mod(&mut self, node: &'ast syn::ItemForeignMod) {
297        let functions: Vec<String> = node
298            .items
299            .iter()
300            .filter_map(|item| {
301                if let syn::ForeignItem::Fn(f) = item {
302                    Some(f.sig.ident.to_string())
303                } else {
304                    None
305                }
306            })
307            .collect();
308
309        self.extern_blocks.push(ExternBlock {
310            abi: node.abi.name.as_ref().map(|n| n.value()),
311            functions,
312            line: node.abi.extern_token.span.start().line,
313        });
314
315        syn::visit::visit_item_foreign_mod(self, node);
316    }
317}
318
319/// Extracts denied categories from `#[doc = "capsec::deny(...)"]` attributes.
320///
321/// The `#[capsec::deny(...)]` macro emits a doc attribute like
322/// `#[doc = "capsec::deny(all, fs)"]`. This function parses that string
323/// and returns the category names (e.g., `["all", "fs"]`).
324fn extract_deny_categories(attrs: &[syn::Attribute]) -> Vec<String> {
325    let mut categories = Vec::new();
326    for attr in attrs {
327        if !attr.path().is_ident("doc") {
328            continue;
329        }
330        if let syn::Meta::NameValue(nv) = &attr.meta
331            && let syn::Expr::Lit(syn::ExprLit {
332                lit: syn::Lit::Str(lit_str),
333                ..
334            }) = &nv.value
335        {
336            let value = lit_str.value();
337            if let Some(inner) = value
338                .strip_prefix("capsec::deny(")
339                .and_then(|s| s.strip_suffix(')'))
340            {
341                for cat in inner.split(',') {
342                    let trimmed = cat.trim();
343                    if !trimmed.is_empty() {
344                        categories.push(trimmed.to_string());
345                    }
346                }
347            }
348        }
349    }
350    categories
351}
352
353fn collect_use_paths(tree: &syn::UseTree, prefix: &mut Vec<String>, out: &mut Vec<ImportPath>) {
354    match tree {
355        syn::UseTree::Path(p) => {
356            prefix.push(p.ident.to_string());
357            collect_use_paths(&p.tree, prefix, out);
358            prefix.pop();
359        }
360        syn::UseTree::Name(n) => {
361            let mut segments = prefix.clone();
362            segments.push(n.ident.to_string());
363            out.push(ImportPath {
364                segments,
365                alias: None,
366            });
367        }
368        syn::UseTree::Rename(r) => {
369            let mut segments = prefix.clone();
370            segments.push(r.ident.to_string());
371            out.push(ImportPath {
372                segments,
373                alias: Some(r.rename.to_string()),
374            });
375        }
376        syn::UseTree::Group(g) => {
377            for item in &g.items {
378                collect_use_paths(item, prefix, out);
379            }
380        }
381        syn::UseTree::Glob(_) => {
382            let mut segments = prefix.clone();
383            segments.push("*".to_string());
384            out.push(ImportPath {
385                segments,
386                alias: None,
387            });
388        }
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn parse_function_calls() {
398        let source = r#"
399            use std::fs;
400            fn do_stuff() {
401                let _ = fs::read("test");
402            }
403        "#;
404        let parsed = parse_source(source, "test.rs").unwrap();
405        assert_eq!(parsed.functions.len(), 1);
406        assert_eq!(parsed.functions[0].name, "do_stuff");
407        assert!(!parsed.functions[0].calls.is_empty());
408    }
409
410    #[test]
411    fn parse_use_statements() {
412        let source = r#"
413            use std::fs::read;
414            use std::net::{TcpStream, TcpListener};
415            use std::env::var as get_env;
416        "#;
417        let parsed = parse_source(source, "test.rs").unwrap();
418        assert_eq!(parsed.use_imports.len(), 4);
419
420        let read_import = &parsed.use_imports[0];
421        assert_eq!(read_import.segments, vec!["std", "fs", "read"]);
422        assert!(read_import.alias.is_none());
423
424        let alias_import = parsed
425            .use_imports
426            .iter()
427            .find(|i| i.alias.is_some())
428            .unwrap();
429        assert_eq!(alias_import.segments, vec!["std", "env", "var"]);
430        assert_eq!(alias_import.alias.as_deref(), Some("get_env"));
431    }
432
433    #[test]
434    fn parse_method_calls() {
435        let source = r#"
436            fn network() {
437                let stream = something();
438                stream.connect("127.0.0.1:8080");
439                stream.send_to(b"data", "addr");
440            }
441        "#;
442        let parsed = parse_source(source, "test.rs").unwrap();
443        let func = &parsed.functions[0];
444        let method_calls: Vec<&CallSite> = func
445            .calls
446            .iter()
447            .filter(|c| matches!(c.kind, CallKind::MethodCall { .. }))
448            .collect();
449        assert_eq!(method_calls.len(), 2);
450    }
451
452    #[test]
453    fn parse_extern_blocks() {
454        let source = r#"
455            extern "C" {
456                fn open(path: *const u8, flags: i32) -> i32;
457                fn close(fd: i32) -> i32;
458            }
459        "#;
460        let parsed = parse_source(source, "test.rs").unwrap();
461        assert_eq!(parsed.extern_blocks.len(), 1);
462        assert_eq!(parsed.extern_blocks[0].abi.as_deref(), Some("C"));
463        assert_eq!(parsed.extern_blocks[0].functions, vec!["open", "close"]);
464    }
465
466    #[test]
467    fn parse_error_returns_err() {
468        let source = "this is not valid rust {{{";
469        assert!(parse_source(source, "bad.rs").is_err());
470    }
471
472    #[test]
473    fn parse_impl_block_methods() {
474        let source = r#"
475            use std::fs;
476            struct Loader;
477            impl Loader {
478                fn load(&self) -> Vec<u8> {
479                    fs::read("data.bin").unwrap()
480                }
481                fn name(&self) -> &str {
482                    "loader"
483                }
484            }
485        "#;
486        let parsed = parse_source(source, "test.rs").unwrap();
487        assert_eq!(parsed.functions.len(), 2);
488        let load = parsed.functions.iter().find(|f| f.name == "load").unwrap();
489        assert!(!load.calls.is_empty());
490    }
491
492    #[test]
493    fn enum_variants_not_captured_as_calls() {
494        let source = r#"
495            enum Category { Fs, Net }
496            fn classify() -> Category {
497                let cat = Category::Fs;
498                let none: Option<i32> = Option::None;
499                cat
500            }
501        "#;
502        let parsed = parse_source(source, "test.rs").unwrap();
503        let func = parsed
504            .functions
505            .iter()
506            .find(|f| f.name == "classify")
507            .unwrap();
508        let fn_calls: Vec<&CallSite> = func
509            .calls
510            .iter()
511            .filter(|c| matches!(c.kind, CallKind::FunctionCall))
512            .collect();
513        assert!(
514            fn_calls.is_empty(),
515            "Enum variants should not be captured as function calls, got: {:?}",
516            fn_calls
517                .iter()
518                .map(|c| c.segments.join("::"))
519                .collect::<Vec<_>>()
520        );
521    }
522
523    #[test]
524    fn parse_deny_annotation() {
525        let source = r#"
526            #[doc = "capsec::deny(all)"]
527            fn pure_function() {
528                let x = 1 + 2;
529            }
530        "#;
531        let parsed = parse_source(source, "test.rs").unwrap();
532        assert_eq!(parsed.functions.len(), 1);
533        assert_eq!(parsed.functions[0].deny_categories, vec!["all"]);
534    }
535
536    #[test]
537    fn parse_deny_specific_categories() {
538        let source = r#"
539            #[doc = "capsec::deny(fs, net)"]
540            fn no_io() {}
541        "#;
542        let parsed = parse_source(source, "test.rs").unwrap();
543        assert_eq!(parsed.functions[0].deny_categories, vec!["fs", "net"]);
544    }
545
546    #[test]
547    fn parse_no_deny_annotation() {
548        let source = r#"
549            fn normal() {}
550        "#;
551        let parsed = parse_source(source, "test.rs").unwrap();
552        assert!(parsed.functions[0].deny_categories.is_empty());
553    }
554
555    #[test]
556    fn parse_trait_default_methods() {
557        let source = r#"
558            use std::fs;
559            trait Readable {
560                fn read_data(&self) -> Vec<u8> {
561                    fs::read("default.dat").unwrap()
562                }
563                fn name(&self) -> &str;
564            }
565        "#;
566        let parsed = parse_source(source, "test.rs").unwrap();
567        // Only the default method with a body should be captured
568        assert_eq!(parsed.functions.len(), 1);
569        assert_eq!(parsed.functions[0].name, "read_data");
570    }
571}