nyx-scanner 0.5.0

A multi-language static analysis tool for detecting security vulnerabilities
Documentation
use super::AuthExtractor;
use super::common::{
    auth_check_from_call_site, build_function_unit, call_name, call_site_from_node, function_name,
    named_children, span, text,
};
use crate::auth_analysis::config::{AuthAnalysisRules, matches_name, strip_quotes};
use crate::auth_analysis::model::{
    AnalysisUnitKind, AuthorizationModel, CallSite, Framework, HttpMethod, RouteRegistration,
};
use crate::labels::bare_method_name;
use crate::utils::project::{DetectedFramework, FrameworkContext};
use std::path::Path;
use tree_sitter::{Node, Tree};

pub struct RailsExtractor;

impl AuthExtractor for RailsExtractor {
    fn supports(&self, lang: &str, framework_ctx: Option<&FrameworkContext>) -> bool {
        lang == "ruby"
            && framework_ctx
                .is_none_or(|ctx| ctx.frameworks.is_empty() || ctx.has(DetectedFramework::Rails))
    }

    fn extract(
        &self,
        tree: &Tree,
        bytes: &[u8],
        path: &Path,
        rules: &AuthAnalysisRules,
    ) -> AuthorizationModel {
        let root = tree.root_node();
        let mut model = AuthorizationModel::default();
        collect_nodes(root, &[], bytes, path, rules, &mut model);
        model
    }
}

#[derive(Clone)]
struct FilterDirective {
    call: CallSite,
    only: Vec<String>,
    except: Vec<String>,
    skip: bool,
}

fn collect_nodes(
    node: Node<'_>,
    namespace: &[String],
    bytes: &[u8],
    path: &Path,
    rules: &AuthAnalysisRules,
    model: &mut AuthorizationModel,
) {
    match node.kind() {
        "module" => {
            let mut next_namespace = namespace.to_vec();
            if let Some(name) = ruby_constant_segments(node.child_by_field_name("name"), bytes) {
                next_namespace.extend(name);
            }
            if let Some(body) = node.child_by_field_name("body") {
                collect_nodes(body, &next_namespace, bytes, path, rules, model);
            }
        }
        "class" => {
            maybe_collect_controller(node, namespace, bytes, path, rules, model);
        }
        _ => {
            for child in named_children(node) {
                collect_nodes(child, namespace, bytes, path, rules, model);
            }
        }
    }
}

fn maybe_collect_controller(
    class_node: Node<'_>,
    namespace: &[String],
    bytes: &[u8],
    path: &Path,
    rules: &AuthAnalysisRules,
    model: &mut AuthorizationModel,
) {
    let Some(name_segments) = ruby_constant_segments(class_node.child_by_field_name("name"), bytes)
    else {
        return;
    };
    let Some(class_name) = name_segments.last() else {
        return;
    };
    if !class_name.ends_with("Controller") {
        return;
    }
    let Some(body) = class_node.child_by_field_name("body") else {
        return;
    };

    let mut controller_namespace = namespace.to_vec();
    controller_namespace.extend(
        name_segments[..name_segments.len().saturating_sub(1)]
            .iter()
            .cloned(),
    );
    let controller_segment = underscore_segment(class_name.trim_end_matches("Controller"));
    let filter_directives = class_filter_directives(body, bytes);
    let controller_name = format!(
        "{}{}",
        if controller_namespace.is_empty() {
            String::new()
        } else {
            format!("{}::", controller_namespace.join("::"))
        },
        class_name
    );

    for child in named_children(body) {
        if child.kind() != "method" {
            continue;
        }
        let Some(action_name) = function_name(child, bytes) else {
            continue;
        };
        if action_name.is_empty() || action_name.ends_with('=') {
            continue;
        }

        let unit_idx = model.units.len();
        let route_name = format!("{controller_name}#{action_name}");
        let mut unit = build_function_unit(
            child,
            AnalysisUnitKind::RouteHandler,
            Some(route_name.clone()),
            bytes,
            rules,
        );
        let handler_span = span(child);
        let handler_params = unit.params.clone();
        let line = child.start_position().row + 1;
        let middleware_calls = applicable_filters(&filter_directives, &action_name);
        for call in &middleware_calls {
            if let Some(check) = auth_check_from_call_site(call, line, rules) {
                unit.auth_checks.push(check);
            }
        }
        model.units.push(unit);

        let mut route_segments = controller_namespace
            .iter()
            .map(|segment| underscore_segment(segment))
            .collect::<Vec<_>>();
        route_segments.push(controller_segment.clone());
        route_segments.push(underscore_segment(&action_name));
        let route_path = format!("/{}", route_segments.join("/"));

        model.routes.push(RouteRegistration {
            framework: Framework::Rails,
            method: infer_action_method(&action_name),
            path: route_path,
            middleware: middleware_calls
                .iter()
                .map(|call| call.name.clone())
                .collect(),
            handler_span,
            handler_params,
            file: path.to_path_buf(),
            line,
            unit_idx,
            middleware_calls,
        });
    }
}

fn class_filter_directives(body: Node<'_>, bytes: &[u8]) -> Vec<FilterDirective> {
    let mut filters = Vec::new();
    for child in named_children(body) {
        if child.kind() != "call" {
            continue;
        }
        let callee = call_name(child, bytes);
        let directive_name = bare_method_name(&callee);
        if !matches_name(directive_name, "before_action")
            && !matches_name(directive_name, "prepend_before_action")
            && !matches_name(directive_name, "skip_before_action")
        {
            continue;
        }
        filters.extend(parse_filter_directive(
            child,
            bytes,
            matches_name(directive_name, "skip_before_action"),
        ));
    }
    filters
}

fn parse_filter_directive(node: Node<'_>, bytes: &[u8], skip: bool) -> Vec<FilterDirective> {
    let Some(arguments) = node.child_by_field_name("arguments") else {
        return Vec::new();
    };
    let args = named_children(arguments);
    if args.is_empty() {
        return Vec::new();
    }

    let mut filters = Vec::new();
    let mut only = Vec::new();
    let mut except = Vec::new();
    for arg in &args {
        if arg.kind() == "pair" {
            let key = text(arg.child_by_field_name("key").unwrap_or(*arg), bytes);
            let normalized = strip_quotes(&key).trim_start_matches(':').to_string();
            let value = arg.child_by_field_name("value").unwrap_or(*arg);
            if normalized == "only" {
                only = symbol_list(value, bytes);
            } else if normalized == "except" {
                except = symbol_list(value, bytes);
            }
            continue;
        }
        filters.extend(filter_calls_from_arg(*arg, bytes));
    }

    filters
        .into_iter()
        .map(|call| FilterDirective {
            call,
            only: only.clone(),
            except: except.clone(),
            skip,
        })
        .collect()
}

fn filter_calls_from_arg(node: Node<'_>, bytes: &[u8]) -> Vec<CallSite> {
    match node.kind() {
        "simple_symbol" | "hash_key_symbol" | "identifier" => vec![CallSite {
            name: strip_quotes(&text(node, bytes))
                .trim_start_matches(':')
                .to_string(),
            args: Vec::new(),
            span: span(node),
            args_value_refs: Vec::new(),
        }],
        "array" => named_children(node)
            .into_iter()
            .flat_map(|child| filter_calls_from_arg(child, bytes))
            .collect(),
        _ => {
            let call = call_site_from_node(node, bytes);
            if call.name.is_empty() {
                Vec::new()
            } else {
                vec![call]
            }
        }
    }
}

fn applicable_filters(filters: &[FilterDirective], action: &str) -> Vec<CallSite> {
    let mut middleware = Vec::new();
    for filter in filters {
        if !filter_applies(filter, action) {
            continue;
        }
        if filter.skip {
            middleware.retain(|existing: &CallSite| existing.name != filter.call.name);
        } else if !middleware
            .iter()
            .any(|existing: &CallSite| existing.name == filter.call.name)
        {
            middleware.push(filter.call.clone());
        }
    }
    middleware
}

fn filter_applies(filter: &FilterDirective, action: &str) -> bool {
    (filter.only.is_empty() || filter.only.iter().any(|name| name == action))
        && !filter.except.iter().any(|name| name == action)
}

fn symbol_list(node: Node<'_>, bytes: &[u8]) -> Vec<String> {
    match node.kind() {
        "simple_symbol" | "hash_key_symbol" | "identifier" | "string" => vec![
            strip_quotes(&text(node, bytes))
                .trim_start_matches(':')
                .to_string(),
        ],
        "array" => named_children(node)
            .into_iter()
            .flat_map(|child| symbol_list(child, bytes))
            .collect(),
        _ => Vec::new(),
    }
}

fn ruby_constant_segments(node: Option<Node<'_>>, bytes: &[u8]) -> Option<Vec<String>> {
    let node = node?;
    let value = text(node, bytes);
    if value.is_empty() {
        return None;
    }
    Some(
        value
            .split("::")
            .map(|segment| segment.trim().to_string())
            .filter(|segment| !segment.is_empty())
            .collect(),
    )
}

fn infer_action_method(action: &str) -> HttpMethod {
    match action {
        "index" | "show" | "new" | "edit" => HttpMethod::Get,
        "create" => HttpMethod::Post,
        "update" => HttpMethod::Patch,
        "destroy" => HttpMethod::Delete,
        _ => HttpMethod::All,
    }
}

fn underscore_segment(value: &str) -> String {
    let mut out = String::new();
    for (idx, ch) in value.chars().enumerate() {
        if ch.is_ascii_uppercase() {
            if idx > 0 && !out.ends_with('_') {
                out.push('_');
            }
            out.push(ch.to_ascii_lowercase());
        } else if ch.is_ascii_alphanumeric() {
            out.push(ch.to_ascii_lowercase());
        } else if !out.ends_with('_') {
            out.push('_');
        }
    }
    out.trim_matches('_').to_string()
}