nyx-scanner 0.5.0

A multi-language static analysis tool for detecting security vulnerabilities
Documentation
use super::AuthExtractor;
use super::axum::{
    GuardFramework, apply_aliases, dedup_call_sites, guard_calls_for_handler, inject_guard_checks,
    rust_param_aliases,
};
use super::common::{
    attach_route_handler, collect_top_level_units, function_definition_node, function_name,
    named_children, text,
};
use crate::auth_analysis::config::AuthAnalysisRules;
use crate::auth_analysis::model::{AuthorizationModel, Framework, HttpMethod, RouteRegistration};
use crate::utils::project::{DetectedFramework, FrameworkContext};
use std::path::Path;
use tree_sitter::{Node, Tree};

pub struct RocketExtractor;

impl AuthExtractor for RocketExtractor {
    fn supports(&self, lang: &str, framework_ctx: Option<&FrameworkContext>) -> bool {
        lang == "rust"
            && framework_ctx
                .is_none_or(|ctx| ctx.frameworks.is_empty() || ctx.has(DetectedFramework::Rocket))
    }

    fn extract(
        &self,
        tree: &Tree,
        bytes: &[u8],
        path: &Path,
        rules: &AuthAnalysisRules,
    ) -> AuthorizationModel {
        let root = tree.root_node();
        let mut model = AuthorizationModel::default();

        collect_top_level_units(root, bytes, rules, &mut model);
        collect_handlers(root, root, bytes, path, rules, &mut model);

        model
    }
}

fn collect_handlers(
    root: Node<'_>,
    node: Node<'_>,
    bytes: &[u8],
    path: &Path,
    rules: &AuthAnalysisRules,
    model: &mut AuthorizationModel,
) {
    if node.kind() == "function_item" {
        maybe_collect_route(root, node, bytes, path, rules, model);
    }

    for child in named_children(node) {
        collect_handlers(root, child, bytes, path, rules, model);
    }
}

fn maybe_collect_route(
    root: Node<'_>,
    node: Node<'_>,
    bytes: &[u8],
    path: &Path,
    rules: &AuthAnalysisRules,
    model: &mut AuthorizationModel,
) {
    let route_attrs = route_attributes(node, bytes);
    if route_attrs.is_empty() {
        return;
    }

    for (method, route_path) in route_attrs {
        let Some(handler) = attach_route_handler(
            root,
            node,
            format!(
                "{:?} {}",
                method,
                function_name(function_definition_node(node), bytes)
                    .unwrap_or_else(|| "rocket_handler".to_string())
            ),
            bytes,
            rules,
            model,
        ) else {
            continue;
        };

        let mut middleware_calls =
            guard_calls_for_handler(node, &route_path, bytes, GuardFramework::Rocket);
        dedup_call_sites(&mut middleware_calls);

        if let Some(unit) = model.units.get_mut(handler.unit_idx) {
            let aliases = rust_param_aliases(node, &route_path, bytes, GuardFramework::Rocket);
            apply_aliases(unit, &aliases);
            inject_guard_checks(unit, &middleware_calls, rules);
        }

        model.routes.push(RouteRegistration {
            framework: Framework::Rocket,
            method,
            path: route_path,
            middleware: middleware_calls
                .iter()
                .map(|call| call.name.clone())
                .collect(),
            handler_span: handler.span,
            handler_params: handler.params,
            file: path.to_path_buf(),
            line: handler.line,
            unit_idx: handler.unit_idx,
            middleware_calls,
        });
    }
}

fn route_attributes(node: Node<'_>, bytes: &[u8]) -> Vec<(HttpMethod, String)> {
    text(node, bytes)
        .lines()
        .map(str::trim)
        .take_while(|line| line.starts_with("#["))
        .filter_map(parse_route_attribute)
        .collect()
}

fn parse_route_attribute(line: &str) -> Option<(HttpMethod, String)> {
    let method = if line.starts_with("#[get") {
        HttpMethod::Get
    } else if line.starts_with("#[post") {
        HttpMethod::Post
    } else if line.starts_with("#[put") {
        HttpMethod::Put
    } else if line.starts_with("#[delete") {
        HttpMethod::Delete
    } else if line.starts_with("#[patch") {
        HttpMethod::Patch
    } else {
        return None;
    };

    let start = line.find('"')?;
    let rest = &line[start + 1..];
    let end = rest.find('"')?;
    Some((method, rest[..end].to_string()))
}