use crate::auth_analysis::config::{AuthAnalysisRules, canonical_name, matches_name, strip_quotes};
use crate::auth_analysis::model::{
AnalysisUnit, AnalysisUnitKind, AuthCheck, AuthCheckKind, AuthorizationModel, CallSite,
Framework, HttpMethod, OperationKind, RouteRegistration, SensitiveOperation, SinkClass,
ValueRef, ValueSourceKind,
};
use crate::labels::bare_method_name;
use std::collections::{HashMap, HashSet};
use std::path::Path;
use tree_sitter::Node;
pub fn collect_top_level_units(
root: Node<'_>,
bytes: &[u8],
rules: &AuthAnalysisRules,
model: &mut AuthorizationModel,
) {
let file_meta = FileMeta::scan(root, bytes);
for idx in 0..root.named_child_count() {
let Some(child) = root.named_child(idx as u32) else {
continue;
};
collect_top_level_from_node(child, bytes, rules, model, &file_meta);
}
}
fn collect_top_level_from_node(
node: Node<'_>,
bytes: &[u8],
rules: &AuthAnalysisRules,
model: &mut AuthorizationModel,
file_meta: &FileMeta,
) {
match node.kind() {
"function_declaration"
| "function_definition"
| "method_declaration"
| "function_item"
| "method"
| "singleton_method" => {
model.units.push(build_function_unit_with_meta(
node,
AnalysisUnitKind::Function,
function_name(node, bytes),
bytes,
rules,
Some(file_meta),
));
}
"decorated_definition"
if decorated_definition_child(node)
.is_some_and(|definition| definition.kind() == "function_definition") =>
{
if python_decorated_definition_is_background_task(node, bytes) {
return;
}
model.units.push(build_function_unit_with_meta(
node,
AnalysisUnitKind::Function,
function_name(node, bytes),
bytes,
rules,
Some(file_meta),
));
}
"lexical_declaration" | "variable_declaration" => {
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if child.kind() == "variable_declarator"
&& let Some(unit) =
function_unit_from_var_declarator(child, bytes, rules, Some(file_meta))
{
model.units.push(unit);
}
}
}
"export_statement" => {
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if child.is_named() {
collect_top_level_from_node(child, bytes, rules, model, file_meta);
}
}
}
"program" | "source_file" | "module" | "class_declaration" | "class_body"
| "body_statement" => {
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
collect_top_level_from_node(child, bytes, rules, model, file_meta);
}
}
"class" => {
let body = node.child_by_field_name("body");
let visibility = body
.map(|b| ruby_method_visibility(b, bytes))
.unwrap_or_default();
let callbacks = body
.map(|b| ruby_callback_target_names(b, bytes))
.unwrap_or_default();
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if Some(child) == body {
for body_idx in 0..child.named_child_count() {
let Some(grand) = child.named_child(body_idx as u32) else {
continue;
};
if grand.kind() == "method" {
let name = function_name(grand, bytes).unwrap_or_default();
if !name.is_empty()
&& ruby_method_is_callback_or_private(
&name,
&visibility,
&callbacks,
)
{
continue;
}
}
collect_top_level_from_node(grand, bytes, rules, model, file_meta);
}
} else {
collect_top_level_from_node(child, bytes, rules, model, file_meta);
}
}
}
_ => {}
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub enum RubyVisibility {
Public,
Protected,
Private,
}
pub fn ruby_method_visibility(
body: Node<'_>,
bytes: &[u8],
) -> std::collections::HashMap<String, RubyVisibility> {
use crate::auth_analysis::config::matches_name;
use std::collections::HashMap;
let mut map: HashMap<String, RubyVisibility> = HashMap::new();
let mut current = RubyVisibility::Public;
for child in named_children(body) {
match child.kind() {
"identifier" => {
if let Some(vis) = ruby_visibility_for_directive(text(child, bytes).trim()) {
current = vis;
}
}
"call" => {
let callee_full = call_name(child, bytes);
let callee = bare_method_name(&callee_full);
let Some(target_vis) = ruby_visibility_for_directive(callee) else {
continue;
};
let arguments = child.child_by_field_name("arguments");
let args: Vec<Node<'_>> = arguments
.map(|node| named_children(node))
.unwrap_or_default();
if args.is_empty() {
current = target_vis;
continue;
}
let mut targeted_any = false;
for arg in args {
for name in ruby_symbol_names(arg, bytes) {
if name.is_empty() {
continue;
}
map.insert(name, target_vis);
targeted_any = true;
}
if arg.kind() == "method"
&& let Some(name_node) = arg.child_by_field_name("name")
{
let name = text(name_node, bytes);
if !name.is_empty() {
map.insert(name, target_vis);
targeted_any = true;
}
}
}
if !targeted_any {
current = target_vis;
}
let _ = matches_name;
}
"method" => {
if let Some(name_node) = child.child_by_field_name("name") {
let name = text(name_node, bytes);
if !name.is_empty() {
map.insert(name, current);
}
}
}
_ => {}
}
}
map
}
fn ruby_visibility_for_directive(name: &str) -> Option<RubyVisibility> {
match name {
"private" => Some(RubyVisibility::Private),
"protected" => Some(RubyVisibility::Protected),
"public" => Some(RubyVisibility::Public),
_ => None,
}
}
pub fn ruby_callback_target_names(
body: Node<'_>,
bytes: &[u8],
) -> std::collections::HashSet<String> {
use std::collections::HashSet;
let mut targets: HashSet<String> = HashSet::new();
for child in named_children(body) {
if child.kind() != "call" {
continue;
}
let callee_full = call_name(child, bytes);
let callee = bare_method_name(&callee_full);
if !ruby_is_filter_callback_directive(callee) {
continue;
}
let Some(arguments) = child.child_by_field_name("arguments") else {
continue;
};
for arg in named_children(arguments) {
if arg.kind() == "pair" {
continue;
}
for name in ruby_symbol_names(arg, bytes) {
if name.is_empty() {
continue;
}
targets.insert(name);
}
}
}
targets
}
fn ruby_is_filter_callback_directive(name: &str) -> bool {
matches!(
name,
"before_action"
| "after_action"
| "around_action"
| "prepend_before_action"
| "prepend_after_action"
| "prepend_around_action"
| "append_before_action"
| "append_after_action"
| "append_around_action"
| "skip_before_action"
| "skip_after_action"
| "skip_around_action"
| "before_filter"
| "after_filter"
| "around_filter"
| "prepend_before_filter"
| "prepend_after_filter"
| "prepend_around_filter"
| "append_before_filter"
| "append_after_filter"
| "append_around_filter"
| "skip_before_filter"
| "skip_after_filter"
| "skip_around_filter"
)
}
fn ruby_symbol_names(node: Node<'_>, bytes: &[u8]) -> Vec<String> {
match node.kind() {
"simple_symbol" | "hash_key_symbol" | "identifier" | "string" => {
vec![
strip_quotes(&text(node, bytes))
.trim_start_matches(':')
.to_string(),
]
}
"array" => named_children(node)
.into_iter()
.flat_map(|child| ruby_symbol_names(child, bytes))
.collect(),
_ => Vec::new(),
}
}
pub fn ruby_method_is_callback_or_private(
name: &str,
visibility: &std::collections::HashMap<String, RubyVisibility>,
callbacks: &std::collections::HashSet<String>,
) -> bool {
let vis = visibility
.get(name)
.copied()
.unwrap_or(RubyVisibility::Public);
if vis != RubyVisibility::Public {
return true;
}
callbacks.contains(name)
}
fn function_unit_from_var_declarator(
node: Node<'_>,
bytes: &[u8],
rules: &AuthAnalysisRules,
file_meta: Option<&FileMeta>,
) -> Option<AnalysisUnit> {
let value = node.child_by_field_name("value")?;
if !is_function_like(value) {
return None;
}
let name = node
.child_by_field_name("name")
.map(|n| text(n, bytes))
.filter(|s| !s.is_empty());
Some(build_function_unit_with_meta(
value,
AnalysisUnitKind::Function,
name,
bytes,
rules,
file_meta,
))
}
pub struct ResolvedHandler {
pub unit_idx: usize,
pub span: (usize, usize),
pub params: Vec<String>,
pub line: usize,
}
pub fn visit_named_nodes(node: Node<'_>, visit: &mut impl FnMut(Node<'_>)) {
visit(node);
for child in named_children(node) {
visit_named_nodes(child, visit);
}
}
pub fn attach_route_handler(
root: Node<'_>,
handler_expr: Node<'_>,
route_name: String,
bytes: &[u8],
rules: &AuthAnalysisRules,
model: &mut AuthorizationModel,
) -> Option<ResolvedHandler> {
let handler_node = resolve_handler_node(root, handler_expr, bytes)?;
let file_meta = FileMeta::scan(root, bytes);
let line = handler_node.start_position().row + 1;
let handler_span = span(handler_node);
let definition = function_definition_node(handler_node);
let route_handler_params = function_params_route_handler(definition, bytes);
if let Some((idx, existing)) = model
.units
.iter_mut()
.enumerate()
.find(|(_, u)| u.kind == AnalysisUnitKind::Function && u.span == handler_span)
{
existing.kind = AnalysisUnitKind::RouteHandler;
existing.name = Some(route_name);
existing.params = route_handler_params.clone();
return Some(ResolvedHandler {
unit_idx: idx,
span: handler_span,
params: route_handler_params,
line,
});
}
let unit_idx = model.units.len();
let mut unit = build_function_unit_with_meta(
handler_node,
AnalysisUnitKind::RouteHandler,
Some(route_name),
bytes,
rules,
Some(&file_meta),
);
unit.params = route_handler_params.clone();
model.units.push(unit);
Some(ResolvedHandler {
unit_idx,
span: handler_span,
params: route_handler_params,
line,
})
}
#[derive(Default, Debug, Clone)]
pub struct FileMeta {
pub trpc_alias_names: HashSet<String>,
}
impl FileMeta {
pub fn scan(root: Node<'_>, bytes: &[u8]) -> Self {
let mut trpc_alias_names = HashSet::new();
scan_trpc_aliases_visit(root, bytes, &mut trpc_alias_names);
Self { trpc_alias_names }
}
}
pub fn push_route_registration(
model: &mut AuthorizationModel,
framework: Framework,
method: HttpMethod,
path: String,
file: &Path,
handler: ResolvedHandler,
middleware_calls: Vec<CallSite>,
) {
model.routes.push(RouteRegistration {
framework,
method,
path,
middleware: middleware_names(&middleware_calls),
handler_span: handler.span,
handler_params: handler.params,
file: file.to_path_buf(),
line: handler.line,
unit_idx: handler.unit_idx,
middleware_calls,
});
}
pub fn middleware_names(middleware_calls: &[CallSite]) -> Vec<String> {
middleware_calls
.iter()
.map(|call| call.name.clone())
.collect()
}
pub fn resolve_handler_node<'tree>(
root: Node<'tree>,
handler_expr: Node<'tree>,
bytes: &[u8],
) -> Option<Node<'tree>> {
if is_function_like(handler_expr) {
return Some(handler_expr);
}
if !is_handler_reference(handler_expr) {
return None;
}
let candidate = callee_name(handler_expr, bytes);
let name = candidate.rsplit('.').next().unwrap_or(&candidate);
if name.is_empty() {
return None;
}
find_top_level_function_node(root, name, bytes)
}
fn find_top_level_function_node<'tree>(
root: Node<'tree>,
name: &str,
bytes: &[u8],
) -> Option<Node<'tree>> {
for idx in 0..root.named_child_count() {
let Some(child) = root.named_child(idx as u32) else {
continue;
};
if let Some(found) = find_top_level_function_node_in_child(child, name, bytes) {
return Some(found);
}
}
None
}
fn find_top_level_function_node_in_child<'tree>(
node: Node<'tree>,
name: &str,
bytes: &[u8],
) -> Option<Node<'tree>> {
match node.kind() {
"function_declaration" | "function_definition" | "method_declaration" => {
if function_name(node, bytes).as_deref() == Some(name) {
Some(node)
} else {
None
}
}
"function_item" => {
if function_name(node, bytes).as_deref() == Some(name) {
Some(node)
} else {
None
}
}
"decorated_definition" => {
let definition = decorated_definition_child(node)?;
if definition.kind() == "function_definition"
&& function_name(node, bytes).as_deref() == Some(name)
{
Some(node)
} else {
None
}
}
"lexical_declaration" | "variable_declaration" => {
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if child.kind() != "variable_declarator" {
continue;
}
let Some(var_name) = child.child_by_field_name("name") else {
continue;
};
if text(var_name, bytes) != name {
continue;
}
let Some(value) = child.child_by_field_name("value") else {
continue;
};
if is_function_like(value) {
return Some(value);
}
}
None
}
"export_statement" => {
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if child.is_named()
&& let Some(found) = find_top_level_function_node_in_child(child, name, bytes)
{
return Some(found);
}
}
None
}
"program" | "source_file" | "class_declaration" | "class_body" => {
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if let Some(found) = find_top_level_function_node_in_child(child, name, bytes) {
return Some(found);
}
}
None
}
_ => None,
}
}
pub fn build_function_unit(
node: Node<'_>,
kind: AnalysisUnitKind,
name: Option<String>,
bytes: &[u8],
rules: &AuthAnalysisRules,
) -> AnalysisUnit {
build_function_unit_with_meta(node, kind, name, bytes, rules, None)
}
pub fn build_function_unit_with_meta(
node: Node<'_>,
kind: AnalysisUnitKind,
name: Option<String>,
bytes: &[u8],
rules: &AuthAnalysisRules,
file_meta: Option<&FileMeta>,
) -> AnalysisUnit {
let definition = function_definition_node(node);
let params = function_params(definition, bytes);
let preseeded_bounded = python_int_bounded_typed_params(definition, bytes);
let line = node.start_position().row + 1;
let mut state = UnitState::default();
if let Some(receiver_name) = method_receiver_name(definition, bytes) {
state.non_sink_vars.insert(receiver_name);
}
if let Some(meta) = file_meta {
state.trpc_alias_names = meta.trpc_alias_names.clone();
}
collect_unit_state(node, bytes, rules, &mut state);
dedup_value_refs(&mut state.value_refs);
let context_inputs: Vec<ValueRef> = state
.value_refs
.iter()
.filter(|value| {
matches!(
value.source_kind,
ValueSourceKind::RequestParam
| ValueSourceKind::RequestBody
| ValueSourceKind::RequestQuery
| ValueSourceKind::Session
)
})
.cloned()
.collect();
AnalysisUnit {
kind,
name,
span: span(node),
params,
context_inputs,
call_sites: state.call_sites,
auth_checks: state.auth_checks,
operations: state.operations,
value_refs: state.value_refs,
condition_texts: state.condition_texts,
line,
row_field_vars: state.row_field_vars,
var_alias_chain: state.var_alias_chain,
row_population_data: state.row_population_data,
self_actor_vars: state.self_actor_vars,
self_actor_id_vars: state.self_actor_id_vars,
authorized_sql_vars: state.authorized_sql_vars,
const_bound_vars: state.const_bound_vars,
typed_bounded_vars: preseeded_bounded,
typed_bounded_dto_fields: std::collections::HashMap::new(),
self_scoped_session_bases: state.self_scoped_session_bases,
}
}
#[derive(Default)]
struct UnitState {
call_sites: Vec<CallSite>,
auth_checks: Vec<AuthCheck>,
operations: Vec<SensitiveOperation>,
value_refs: Vec<ValueRef>,
condition_texts: Vec<String>,
non_sink_vars: HashSet<String>,
row_field_vars: HashMap<String, String>,
var_alias_chain: HashMap<String, String>,
row_population_data: HashMap<String, (usize, Vec<ValueRef>)>,
self_actor_vars: HashSet<String>,
self_actor_id_vars: HashSet<String>,
authorized_sql_vars: HashSet<String>,
const_bound_vars: HashSet<String>,
self_scoped_session_bases: HashSet<String>,
trpc_alias_names: HashSet<String>,
}
fn collect_unit_state(
node: Node<'_>,
bytes: &[u8],
rules: &AuthAnalysisRules,
state: &mut UnitState,
) {
match node.kind() {
"call_expression" | "call" | "method_invocation" | "method_call_expression" => {
collect_call(node, bytes, rules, state)
}
"if_statement" | "elif_clause" | "while_statement" | "do_statement" | "if" | "unless"
| "if_modifier" | "unless_modifier" | "while_modifier" | "until_modifier"
| "while_expression" => {
if let Some(condition) = node.child_by_field_name("condition") {
collect_condition(condition, bytes, rules, state);
}
}
"if_expression" => {
if let Some(condition) = node.child_by_field_name("condition") {
collect_condition(condition, bytes, rules, state);
}
detect_ownership_equality_check(node, bytes, state);
}
"conditional_expression" => collect_condition(node, bytes, rules, state),
"let_declaration" => {
collect_non_sink_binding(node, bytes, rules, state);
collect_row_field_binding(node, bytes, state);
collect_member_alias_binding(node, bytes, state);
collect_row_population(node, bytes, state);
collect_self_actor_binding(node, bytes, rules, state);
collect_self_actor_id_binding(node, bytes, state);
collect_sql_authorized_binding(node, bytes, rules, state);
propagate_sql_authorized_through_field_read(node, bytes, state);
collect_const_string_binding(node, bytes, state);
}
"variable_declarator" => {
collect_self_actor_binding(node, bytes, rules, state);
collect_self_actor_id_binding(node, bytes, state);
collect_const_string_binding(node, bytes, state);
}
"short_var_declaration"
| "const_declaration"
| "var_declaration"
| "var_spec"
| "lexical_declaration"
| "local_variable_declaration"
| "assignment"
| "assignment_expression"
| "augmented_assignment"
| "expression_statement" => {
collect_const_string_binding(node, bytes, state);
if matches!(node.kind(), "assignment" | "assignment_expression") {
collect_row_population(node, bytes, state);
}
}
"for_expression" => {
collect_for_row_binding(node, bytes, state);
}
"parameter" => {
collect_typed_extractor_self_actor(node, bytes, state);
}
"required_parameter" | "optional_parameter" => {
collect_trpc_ctx_param(node, bytes, state);
}
_ => {}
}
for value in extract_value_refs(node, bytes) {
state.value_refs.push(value);
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
collect_unit_state(child, bytes, rules, state);
}
}
fn collect_call(node: Node<'_>, bytes: &[u8], rules: &AuthAnalysisRules, state: &mut UnitState) {
let callee = call_name(node, bytes);
if callee.is_empty() {
return;
}
let args = node
.child_by_field_name("arguments")
.map(named_children)
.unwrap_or_default();
let mut subjects: Vec<ValueRef> = call_receiver_subjects(node, bytes);
subjects.extend(
args.iter()
.flat_map(|arg| extract_value_refs(*arg, bytes))
.collect::<Vec<_>>(),
);
let line = node.start_position().row + 1;
let string_args: Vec<String> = args.iter().map(|arg| text(*arg, bytes)).collect();
let args_value_refs: Vec<Vec<ValueRef>> = args
.iter()
.map(|arg| extract_value_refs(*arg, bytes))
.collect();
let node_text = text(node, bytes);
state.call_sites.push(CallSite {
name: callee.clone(),
args: string_args.clone(),
span: span(node),
args_value_refs,
});
if rules.is_authorization_check(&callee) {
state.auth_checks.push(AuthCheck {
kind: classify_auth_check(&callee, rules),
callee: callee.clone(),
subjects: subjects.clone(),
span: span(node),
line,
args: string_args,
condition_text: None,
is_route_level: false,
});
}
let (op_kind, sink_class) = if rules.is_token_lookup_call(&callee, &node_text) {
(Some(OperationKind::TokenLookup), None)
} else if let Some(class) = rules.classify_sink_class(&callee, &state.non_sink_vars) {
let op = match class {
SinkClass::DbCrossTenantRead => OperationKind::Read,
SinkClass::InMemoryLocal => {
if rules.is_mutation(&callee) {
OperationKind::Mutation
} else {
OperationKind::Read
}
}
_ => {
if rules.is_read(&callee) && !rules.is_mutation(&callee) {
OperationKind::Read
} else {
OperationKind::Mutation
}
}
};
(Some(op), Some(class))
} else {
(None, None)
};
if let Some(kind) = op_kind {
state.operations.push(SensitiveOperation {
kind,
sink_class,
callee,
subjects,
span: span(node),
line,
text: node_text,
});
}
}
fn collect_condition(
node: Node<'_>,
bytes: &[u8],
rules: &AuthAnalysisRules,
state: &mut UnitState,
) {
let condition_text = text(node, bytes);
if condition_text.is_empty() {
return;
}
state.condition_texts.push(condition_text.clone());
let subjects = extract_value_refs(node, bytes);
let line = node.start_position().row + 1;
if rules.has_expiry_field(&condition_text) {
state.auth_checks.push(AuthCheck {
kind: AuthCheckKind::TokenExpiry,
callee: "(condition)".into(),
subjects: subjects.clone(),
span: span(node),
line,
args: Vec::new(),
condition_text: Some(condition_text.clone()),
is_route_level: false,
});
}
if rules.has_recipient_field(&condition_text) {
state.auth_checks.push(AuthCheck {
kind: AuthCheckKind::TokenRecipient,
callee: "(condition)".into(),
subjects,
span: span(node),
line,
args: Vec::new(),
condition_text: Some(condition_text),
is_route_level: false,
});
}
}
fn collect_non_sink_binding(
node: Node<'_>,
bytes: &[u8],
rules: &AuthAnalysisRules,
state: &mut UnitState,
) {
let Some(pattern) = node.child_by_field_name("pattern") else {
return;
};
let Some(var_name) = first_identifier_name(pattern, bytes) else {
return;
};
if var_name.is_empty() {
return;
}
if let Some(ty_node) = node.child_by_field_name("type") {
let ty_text = text(ty_node, bytes);
if rules.is_non_sink_receiver_type(&ty_text) {
state.non_sink_vars.insert(var_name);
return;
}
}
if let Some(value) = node.child_by_field_name("value")
&& value_is_non_sink_constructor(value, bytes, rules)
{
state.non_sink_vars.insert(var_name);
}
}
fn first_identifier_name(node: Node<'_>, bytes: &[u8]) -> Option<String> {
if matches!(
node.kind(),
"identifier"
| "shorthand_property_identifier_pattern"
| "instance_variable"
| "class_variable"
| "global_variable"
) {
let value = text(node, bytes);
if !value.is_empty() {
return Some(value);
}
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if let Some(found) = first_identifier_name(child, bytes) {
return Some(found);
}
}
None
}
fn value_is_non_sink_constructor(node: Node<'_>, bytes: &[u8], rules: &AuthAnalysisRules) -> bool {
match node.kind() {
"call_expression" | "call" | "method_invocation" | "method_call_expression" => {
let callee = call_name(node, bytes);
rules.is_non_sink_constructor_callee(&callee)
}
"macro_invocation" => {
let name = node
.child_by_field_name("macro")
.map(|m| text(m, bytes))
.unwrap_or_default();
let last = name.rsplit("::").next().unwrap_or(&name);
matches!(last, "vec" | "smallvec")
}
"try_expression" | "await_expression" | "reference_expression" => {
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if value_is_non_sink_constructor(child, bytes, rules) {
return true;
}
}
false
}
_ => false,
}
}
fn collect_row_field_binding(node: Node<'_>, bytes: &[u8], state: &mut UnitState) {
let Some(pattern) = node.child_by_field_name("pattern") else {
return;
};
let Some(var_name) = first_identifier_name(pattern, bytes) else {
return;
};
if var_name.is_empty() {
return;
}
let Some(value) = node.child_by_field_name("value") else {
return;
};
let Some(row_name) = extract_row_receiver_name(value, bytes) else {
return;
};
state.row_field_vars.insert(var_name, row_name);
}
fn collect_member_alias_binding(node: Node<'_>, bytes: &[u8], state: &mut UnitState) {
let Some(pattern) = node.child_by_field_name("pattern") else {
return;
};
let Some(var_name) = first_identifier_name(pattern, bytes) else {
return;
};
if var_name.is_empty() {
return;
}
let Some(value) = node.child_by_field_name("value") else {
return;
};
let target = unwrap_try_like(value);
if !matches!(
target.kind(),
"member_expression"
| "attribute"
| "selector_expression"
| "field_expression"
| "field_access"
) {
return;
}
let chain = member_chain(target, bytes);
if chain.len() < 2 {
return;
}
let chain_text = chain.join(".");
state.var_alias_chain.entry(var_name).or_insert(chain_text);
}
fn collect_row_population(node: Node<'_>, bytes: &[u8], state: &mut UnitState) {
let Some(pattern) = node
.child_by_field_name("pattern")
.or_else(|| node.child_by_field_name("left"))
else {
return;
};
let Some(var_name) = first_identifier_name(pattern, bytes) else {
return;
};
if var_name.is_empty() {
return;
}
let Some(value) = node
.child_by_field_name("value")
.or_else(|| node.child_by_field_name("right"))
else {
return;
};
let call_node = unwrap_try_like(value);
if !matches!(
call_node.kind(),
"call_expression" | "call" | "method_invocation" | "method_call_expression"
) {
return;
}
let args = call_node
.child_by_field_name("arguments")
.map(named_children)
.unwrap_or_default();
let mut arg_refs: Vec<ValueRef> = Vec::new();
for arg in args {
arg_refs.extend(extract_value_refs(arg, bytes));
}
let call_line = call_node.start_position().row + 1;
state
.row_population_data
.insert(var_name, (call_line, arg_refs));
}
fn collect_self_actor_binding(
node: Node<'_>,
bytes: &[u8],
rules: &AuthAnalysisRules,
state: &mut UnitState,
) {
let Some(pattern) = node
.child_by_field_name("pattern")
.or_else(|| node.child_by_field_name("name"))
else {
return;
};
let Some(value) = node.child_by_field_name("value") else {
return;
};
if pattern.kind() == "object_pattern" {
collect_destructured_self_actor_binding(pattern, value, bytes, rules, state);
return;
}
let Some(var_name) = first_identifier_name(pattern, bytes) else {
return;
};
if var_name.is_empty() {
return;
}
if value_is_self_actor_call(value, bytes, rules) {
state.self_actor_vars.insert(var_name);
}
}
fn collect_destructured_self_actor_binding(
pattern: Node<'_>,
value: Node<'_>,
bytes: &[u8],
rules: &AuthAnalysisRules,
state: &mut UnitState,
) {
let kind = classify_destructure_rhs(value, bytes, rules);
let trpc_ctx_path = lookup_trpc_ctx_destructure_match(value, bytes, state);
if kind == DestructureRhsKind::None && trpc_ctx_path.is_none() {
return;
}
for idx in 0..pattern.named_child_count() {
let Some(child) = pattern.named_child(idx as u32) else {
continue;
};
let (key, local) = match child.kind() {
"shorthand_property_identifier_pattern" => {
let name = text(child, bytes);
(name.clone(), name)
}
"object_assignment_pattern" => {
let Some(left) = child.child_by_field_name("left") else {
continue;
};
let name = if matches!(
left.kind(),
"identifier" | "shorthand_property_identifier_pattern"
) {
text(left, bytes)
} else {
first_identifier_name(left, bytes).unwrap_or_default()
};
(name.clone(), name)
}
"pair_pattern" => {
let key_node = child.child_by_field_name("key");
let local_node = child.child_by_field_name("value");
let (Some(k), Some(v)) = (key_node, local_node) else {
continue;
};
let key = text(k, bytes);
let local = first_identifier_name(v, bytes).unwrap_or_default();
(key, local)
}
_ => continue,
};
if kind != DestructureRhsKind::None {
process_destructure_entry(&key, &local, kind, state);
}
if let Some(rhs_path) = trpc_ctx_path.as_deref()
&& key.eq_ignore_ascii_case("user")
&& !local.is_empty()
{
let _ = rhs_path; state.self_actor_vars.insert(local);
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DestructureRhsKind {
SessionContainer,
SelfActorBase,
None,
}
fn lookup_trpc_ctx_destructure_match(
node: Node<'_>,
bytes: &[u8],
state: &UnitState,
) -> Option<String> {
if state.self_scoped_session_bases.is_empty() {
return None;
}
let chain_text = chain_text_from_value(node, bytes)?;
if chain_text.is_empty() {
return None;
}
let candidate = format!("{chain_text}.user");
if state.self_scoped_session_bases.contains(&candidate) {
Some(chain_text)
} else {
None
}
}
fn chain_text_from_value(node: Node<'_>, bytes: &[u8]) -> Option<String> {
match node.kind() {
"identifier" => {
let t = text(node, bytes);
if t.is_empty() { None } else { Some(t) }
}
"field_expression" | "member_expression" | "field_access" | "scoped_identifier" => {
let chain = member_chain(node, bytes);
if chain.is_empty() {
None
} else {
Some(chain.join("."))
}
}
"type_cast_expression"
| "as_expression"
| "cast_expression"
| "parenthesized_expression"
| "non_null_expression"
| "await_expression"
| "try_expression" => {
let inner = node
.child_by_field_name("value")
.or_else(|| node.child_by_field_name("expression"));
if let Some(v) = inner
&& let Some(t) = chain_text_from_value(v, bytes)
{
return Some(t);
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if let Some(t) = chain_text_from_value(child, bytes) {
return Some(t);
}
}
None
}
_ => None,
}
}
fn classify_destructure_rhs(
node: Node<'_>,
bytes: &[u8],
rules: &AuthAnalysisRules,
) -> DestructureRhsKind {
if value_is_self_actor_call(node, bytes, rules) {
return DestructureRhsKind::SessionContainer;
}
if value_is_session_provider_chain(node, bytes) {
return DestructureRhsKind::SessionContainer;
}
if value_is_self_actor_base_chain(node, bytes) {
return DestructureRhsKind::SelfActorBase;
}
DestructureRhsKind::None
}
fn process_destructure_entry(
key: &str,
local: &str,
kind: DestructureRhsKind,
state: &mut UnitState,
) {
if key.is_empty() || local.is_empty() {
return;
}
let key_lower = key.to_ascii_lowercase();
match kind {
DestructureRhsKind::SessionContainer => {
if matches!(key_lower.as_str(), "user" | "currentuser" | "current_user") {
state.self_actor_vars.insert(local.to_string());
}
}
DestructureRhsKind::SelfActorBase => {
if matches!(key_lower.as_str(), "id" | "userid" | "user_id" | "uid") {
state.self_actor_id_vars.insert(local.to_string());
}
}
DestructureRhsKind::None => {}
}
}
fn value_is_session_provider_chain(node: Node<'_>, bytes: &[u8]) -> bool {
match node.kind() {
"field_expression" | "member_expression" | "field_access" | "scoped_identifier" => {
let chain = member_chain(node, bytes);
if chain.is_empty() {
return false;
}
let joined = chain.join(".");
matches!(
joined.as_str(),
"ctx.session" | "ctx.state" | "req.session" | "request.session" | "session"
)
}
"identifier" => {
let name = text(node, bytes);
matches!(name.as_str(), "session")
}
"call_expression" | "call" => {
let callee = call_name(node, bytes);
let last = bare_method_name(&callee);
matches!(
last,
"getServerSession"
| "getSession"
| "getServerSideSession"
| "unstable_getServerSession"
)
}
"type_cast_expression"
| "as_expression"
| "cast_expression"
| "parenthesized_expression"
| "non_null_expression"
| "await_expression"
| "try_expression" => {
let inner = node
.child_by_field_name("value")
.or_else(|| node.child_by_field_name("expression"));
if let Some(v) = inner
&& value_is_session_provider_chain(v, bytes)
{
return true;
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if value_is_session_provider_chain(child, bytes) {
return true;
}
}
false
}
_ => false,
}
}
fn value_is_self_actor_base_chain(node: Node<'_>, bytes: &[u8]) -> bool {
match node.kind() {
"field_expression" | "member_expression" | "field_access" | "scoped_identifier" => {
let chain = member_chain(node, bytes);
if chain.is_empty() {
return false;
}
let joined = chain.join(".");
is_self_scoped_session_base_text(&joined)
}
"type_cast_expression"
| "as_expression"
| "cast_expression"
| "parenthesized_expression"
| "non_null_expression"
| "await_expression"
| "try_expression" => {
let inner = node
.child_by_field_name("value")
.or_else(|| node.child_by_field_name("expression"));
if let Some(v) = inner
&& value_is_self_actor_base_chain(v, bytes)
{
return true;
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if value_is_self_actor_base_chain(child, bytes) {
return true;
}
}
false
}
_ => false,
}
}
fn collect_const_string_binding(node: Node<'_>, bytes: &[u8], state: &mut UnitState) {
if matches!(
node.kind(),
"assignment" | "assignment_expression" | "augmented_assignment"
) {
let lhs = node
.child_by_field_name("left")
.or_else(|| node.child_by_field_name("name"))
.or_else(|| node.child_by_field_name("target"));
let rhs = node
.child_by_field_name("right")
.or_else(|| node.child_by_field_name("value"));
if let (Some(lhs), Some(rhs)) = (lhs, rhs)
&& rhs_is_pure_literal(rhs)
{
for var in collect_lhs_idents(lhs, bytes) {
state.const_bound_vars.insert(var);
}
}
return;
}
if matches!(
node.kind(),
"short_var_declaration" | "var_spec" | "const_spec"
) {
let left = node.child_by_field_name("left").or_else(|| {
node.child_by_field_name("name")
});
let right = node.child_by_field_name("right").or_else(|| {
node.child_by_field_name("value")
.or_else(|| node.child_by_field_name("default"))
});
if let (Some(left), Some(right)) = (left, right) {
let lhs_idents = collect_lhs_idents(left, bytes);
let rhs_exprs: Vec<Node<'_>> = if right.kind() == "expression_list" {
let mut cursor = right.walk();
right
.children(&mut cursor)
.filter(|c| !matches!(c.kind(), "," | "(" | ")"))
.collect()
} else {
vec![right]
};
for (idx, var) in lhs_idents.into_iter().enumerate() {
if let Some(expr) = rhs_exprs.get(idx)
&& rhs_is_pure_literal(*expr)
{
state.const_bound_vars.insert(var);
}
}
}
return;
}
if matches!(node.kind(), "var_declaration" | "const_declaration") {
for idx in 0..node.named_child_count() {
if let Some(child) = node.named_child(idx as u32) {
collect_const_string_binding(child, bytes, state);
}
}
return;
}
let pattern = node
.child_by_field_name("pattern")
.or_else(|| node.child_by_field_name("name"));
let value = node.child_by_field_name("value");
if let (Some(pattern), Some(value)) = (pattern, value)
&& rhs_is_pure_literal(value)
{
for var in collect_lhs_idents(pattern, bytes) {
state.const_bound_vars.insert(var);
}
return;
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if matches!(
child.kind(),
"variable_declarator"
| "init_declarator"
| "var_spec"
| "const_spec"
| "assignment"
| "assignment_expression"
) {
collect_const_string_binding(child, bytes, state);
}
}
}
fn rhs_is_pure_literal(node: Node<'_>) -> bool {
let inner = match node.kind() {
"parenthesized_expression"
| "type_cast_expression"
| "as_expression"
| "cast_expression"
| "reference_expression" => {
let value = node
.child_by_field_name("value")
.or_else(|| node.child_by_field_name("expression"));
value.unwrap_or(node)
}
_ => node,
};
matches!(
inner.kind(),
"string_literal"
| "raw_string_literal"
| "string"
| "interpreted_string_literal"
| "rune_literal"
| "integer_literal"
| "int_literal"
| "float_literal"
| "true"
| "false"
| "boolean_literal"
| "nil"
| "null"
| "null_literal"
| "none"
| "character_literal"
) || (inner.kind() == "template_string" && !template_has_interpolation(inner))
|| (inner.kind() == "template_literal" && !template_has_interpolation(inner))
}
fn template_has_interpolation(node: Node<'_>) -> bool {
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if matches!(
child.kind(),
"template_substitution" | "interpolation" | "string_interpolation"
) {
return true;
}
}
false
}
fn collect_lhs_idents(node: Node<'_>, bytes: &[u8]) -> Vec<String> {
let mut out = Vec::new();
if node.kind() == "identifier" {
out.push(text(node, bytes));
return out;
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
match child.kind() {
"identifier" => out.push(text(child, bytes)),
"tuple_pattern"
| "expression_list"
| "pattern_list"
| "list_pattern"
| "field_identifier"
| "shorthand_field_identifier" => {
out.extend(collect_lhs_idents(child, bytes));
}
_ => {}
}
}
out
}
fn collect_self_actor_id_binding(node: Node<'_>, bytes: &[u8], state: &mut UnitState) {
let Some(pattern) = node
.child_by_field_name("pattern")
.or_else(|| node.child_by_field_name("name"))
else {
return;
};
let Some(var_name) = first_identifier_name(pattern, bytes) else {
return;
};
if var_name.is_empty() {
return;
}
let Some(value) = node.child_by_field_name("value") else {
return;
};
if value_is_self_actor_id_field(value, bytes, &state.self_actor_vars)
|| value_is_self_scoped_session_id_chain(value, bytes)
{
state.self_actor_id_vars.insert(var_name);
}
}
fn value_is_self_actor_id_field(
node: Node<'_>,
bytes: &[u8],
actor_vars: &HashSet<String>,
) -> bool {
match node.kind() {
"field_expression" | "member_expression" | "field_access" | "scoped_identifier" => {
let receiver = node
.child_by_field_name("value")
.or_else(|| node.child_by_field_name("object"));
let field = node
.child_by_field_name("field")
.or_else(|| node.child_by_field_name("property"))
.or_else(|| node.child_by_field_name("name"));
let (Some(receiver), Some(field)) = (receiver, field) else {
return false;
};
let receiver_name = text(receiver, bytes);
let field_name = text(field, bytes);
actor_vars.contains(&receiver_name) && is_self_actor_id_field_name(&field_name)
}
"type_cast_expression"
| "as_expression"
| "cast_expression"
| "parenthesized_expression"
| "try_expression"
| "await_expression"
| "reference_expression" => {
let value = node
.child_by_field_name("value")
.or_else(|| node.child_by_field_name("expression"));
if let Some(v) = value
&& value_is_self_actor_id_field(v, bytes, actor_vars)
{
return true;
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if value_is_self_actor_id_field(child, bytes, actor_vars) {
return true;
}
}
false
}
"call_expression" | "call" | "method_invocation" | "method_call_expression" => {
let receiver = node
.child_by_field_name("function")
.or_else(|| node.child_by_field_name("object"));
if let Some(r) = receiver {
if value_is_self_actor_id_field(r, bytes, actor_vars) {
return true;
}
if let Some(inner) = r
.child_by_field_name("value")
.or_else(|| r.child_by_field_name("object"))
&& value_is_self_actor_id_field(inner, bytes, actor_vars)
{
return true;
}
}
false
}
_ => false,
}
}
fn is_self_actor_id_field_name(field: &str) -> bool {
let lower = field.to_ascii_lowercase();
matches!(
lower.as_str(),
"id" | "user_id" | "userid" | "uid" | "email" | "username" | "handle"
)
}
fn value_is_self_scoped_session_id_chain(node: Node<'_>, bytes: &[u8]) -> bool {
match node.kind() {
"field_expression" | "member_expression" | "field_access" | "scoped_identifier" => {
let chain = member_chain(node, bytes);
if chain.len() < 2 {
return false;
}
let field = chain.last().expect("len >= 2");
if !is_self_actor_id_field_name(field) {
return false;
}
let base_chain = &chain[..chain.len() - 1];
let base = base_chain.join(".");
classify_member_chain(base_chain) == ValueSourceKind::Session
&& is_self_scoped_session_base_text(&base)
}
"type_cast_expression"
| "as_expression"
| "cast_expression"
| "parenthesized_expression"
| "try_expression"
| "await_expression"
| "reference_expression"
| "non_null_expression" => {
let value = node
.child_by_field_name("value")
.or_else(|| node.child_by_field_name("expression"));
if let Some(v) = value
&& value_is_self_scoped_session_id_chain(v, bytes)
{
return true;
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if value_is_self_scoped_session_id_chain(child, bytes) {
return true;
}
}
false
}
"call_expression" | "call" | "method_invocation" | "method_call_expression" => {
let receiver = node
.child_by_field_name("function")
.or_else(|| node.child_by_field_name("object"));
if let Some(r) = receiver {
if value_is_self_scoped_session_id_chain(r, bytes) {
return true;
}
if let Some(inner) = r
.child_by_field_name("value")
.or_else(|| r.child_by_field_name("object"))
&& value_is_self_scoped_session_id_chain(inner, bytes)
{
return true;
}
}
false
}
_ => false,
}
}
fn is_self_scoped_session_base_text(base: &str) -> bool {
matches!(
base,
"req.session.user"
| "request.session.user"
| "session.user"
| "req.session.currentUser"
| "request.session.currentUser"
| "session.currentUser"
| "req.user"
| "request.user"
| "req.currentUser"
| "request.currentUser"
| "ctx.session.user"
| "ctx.session.currentUser"
| "ctx.state.user"
| "ctx.state.currentUser"
)
}
fn value_is_self_actor_call(node: Node<'_>, bytes: &[u8], rules: &AuthAnalysisRules) -> bool {
match node.kind() {
"call_expression" | "call" | "method_invocation" | "method_call_expression" => {
let callee = call_name(node, bytes);
!callee.is_empty()
&& (rules.is_login_guard(&callee) || rules.is_authorization_check(&callee))
}
"try_expression"
| "await_expression"
| "reference_expression"
| "parenthesized_expression"
| "match_expression" => {
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if value_is_self_actor_call(child, bytes, rules) {
return true;
}
}
false
}
_ => false,
}
}
fn collect_typed_extractor_self_actor(node: Node<'_>, bytes: &[u8], state: &mut UnitState) {
let Some(pattern) = node.child_by_field_name("pattern") else {
return;
};
let Some(var_name) = first_identifier_name(pattern, bytes) else {
return;
};
if var_name.is_empty() {
return;
}
let Some(ty_node) = node.child_by_field_name("type") else {
return;
};
let ty_text = text(ty_node, bytes);
if is_self_actor_type_text(&ty_text) {
state.self_actor_vars.insert(var_name);
}
}
fn collect_sql_authorized_binding(
node: Node<'_>,
bytes: &[u8],
rules: &AuthAnalysisRules,
state: &mut UnitState,
) {
if rules.acl_tables.is_empty() && !sql_direct_user_id_enabled() {
return;
}
let Some(pattern) = node.child_by_field_name("pattern") else {
return;
};
let Some(var_name) = first_identifier_name(pattern, bytes) else {
return;
};
if var_name.is_empty() {
return;
}
let Some(value) = node.child_by_field_name("value") else {
return;
};
let Some((sql_call, bind_arg_refs)) = find_authorized_sql_call_in_chain(value, bytes, rules)
else {
return;
};
state.authorized_sql_vars.insert(var_name.clone());
let mut subjects = bind_arg_refs;
subjects.push(ValueRef {
source_kind: ValueSourceKind::Identifier,
name: var_name,
base: None,
field: None,
index: None,
span: span(node),
});
let line = node.start_position().row + 1;
state.auth_checks.push(AuthCheck {
kind: AuthCheckKind::Membership,
callee: "(sql ACL)".into(),
subjects,
span: span(sql_call),
line,
args: Vec::new(),
condition_text: None,
is_route_level: false,
});
}
fn sql_direct_user_id_enabled() -> bool {
true
}
fn find_authorized_sql_call_in_chain<'tree>(
value: Node<'tree>,
bytes: &[u8],
rules: &AuthAnalysisRules,
) -> Option<(Node<'tree>, Vec<ValueRef>)> {
let mut bind_arg_refs: Vec<ValueRef> = Vec::new();
let mut cur = unwrap_try_like(value);
let mut steps = 0;
while steps < 16 {
steps += 1;
if !matches!(
cur.kind(),
"call_expression" | "call" | "method_invocation" | "method_call_expression"
) {
return None;
}
if let Some(args_node) = cur.child_by_field_name("arguments") {
for arg in named_children(args_node) {
if matches!(
arg.kind(),
"string_literal" | "raw_string_literal" | "string"
) {
continue;
}
bind_arg_refs.extend(extract_value_refs(arg, bytes));
}
}
let callee = call_name(cur, bytes);
let last_segment = bare_method_name(&callee);
if is_sql_prepare_method(last_segment) {
let args = cur
.child_by_field_name("arguments")
.map(named_children)
.unwrap_or_default();
if let Some(first_arg) = args.first().copied()
&& let Some(literal) = collect_string_literal_text(first_arg, bytes)
&& crate::auth_analysis::sql_semantics::classify_sql_query(
&literal,
&rules.acl_tables,
)
.is_some()
{
return Some((cur, bind_arg_refs));
}
return None;
}
let next = cur
.child_by_field_name("receiver")
.or_else(|| {
cur.child_by_field_name("function").and_then(|fun| {
fun.child_by_field_name("object")
.or_else(|| fun.child_by_field_name("operand"))
.or_else(|| fun.child_by_field_name("argument"))
.or_else(|| fun.child_by_field_name("value"))
})
})
.or_else(|| cur.child_by_field_name("object"));
let next = next?;
cur = unwrap_try_like(next);
}
None
}
fn is_sql_prepare_method(method: &str) -> bool {
matches!(
method,
"prepare"
| "query"
| "query_one"
| "query_all"
| "query_as"
| "query_map"
| "query_row"
| "query_scalar"
| "fetch"
| "fetch_one"
| "fetch_all"
| "fetch_optional"
| "fetch_scalar"
| "execute"
| "exec"
)
}
fn collect_string_literal_text(node: Node<'_>, bytes: &[u8]) -> Option<String> {
match node.kind() {
"string_literal" | "raw_string_literal" => {
let mut buf = String::new();
let mut found = false;
for child in named_children(node) {
if child.kind() == "string_content" {
buf.push_str(&text(child, bytes));
found = true;
}
}
if found {
Some(buf)
} else {
Some(strip_quotes(&text(node, bytes)))
}
}
"string" | "template_string" | "interpreted_string_literal" => {
Some(strip_quotes(&text(node, bytes)))
}
_ => None,
}
}
fn collect_for_row_binding(node: Node<'_>, bytes: &[u8], state: &mut UnitState) {
let Some(pattern) = node.child_by_field_name("pattern") else {
return;
};
let Some(var_name) = first_identifier_name(pattern, bytes) else {
return;
};
if var_name.is_empty() {
return;
}
let Some(value) = node.child_by_field_name("value") else {
return;
};
let Some(source_var) = single_iter_source_name(value, bytes) else {
return;
};
state
.row_field_vars
.insert(var_name.clone(), source_var.clone());
if state.authorized_sql_vars.contains(&source_var) {
state.authorized_sql_vars.insert(var_name);
}
}
fn single_iter_source_name(node: Node<'_>, bytes: &[u8]) -> Option<String> {
match node.kind() {
"identifier" => {
let value = text(node, bytes);
if value.is_empty() { None } else { Some(value) }
}
"reference_expression" | "parenthesized_expression" => {
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if let Some(name) = single_iter_source_name(child, bytes) {
return Some(name);
}
}
None
}
"call_expression" | "call" | "method_invocation" | "method_call_expression" => {
let callee = call_name(node, bytes);
let last = bare_method_name(&callee);
if !matches!(
last,
"iter" | "iter_mut" | "into_iter" | "values" | "keys" | "drain"
) {
return None;
}
let receiver = node
.child_by_field_name("receiver")
.or_else(|| {
node.child_by_field_name("function").and_then(|fun| {
fun.child_by_field_name("object")
.or_else(|| fun.child_by_field_name("operand"))
.or_else(|| fun.child_by_field_name("argument"))
.or_else(|| fun.child_by_field_name("value"))
})
})
.or_else(|| node.child_by_field_name("object"))?;
single_iter_source_name(receiver, bytes)
}
_ => None,
}
}
fn propagate_sql_authorized_through_field_read(
node: Node<'_>,
bytes: &[u8],
state: &mut UnitState,
) {
let Some(pattern) = node.child_by_field_name("pattern") else {
return;
};
let Some(var_name) = first_identifier_name(pattern, bytes) else {
return;
};
if var_name.is_empty() {
return;
}
let Some(value) = node.child_by_field_name("value") else {
return;
};
let Some(source) = extract_row_receiver_name(value, bytes) else {
return;
};
if state.authorized_sql_vars.contains(&source) {
state.authorized_sql_vars.insert(var_name);
}
}
fn is_self_actor_type_text(ty: &str) -> bool {
let trimmed = ty
.trim()
.trim_start_matches('&')
.trim_start_matches("mut ")
.trim();
let after_colons = trimmed.rsplit("::").next().unwrap_or(trimmed);
let base = after_colons
.split('<')
.next()
.unwrap_or(after_colons)
.trim();
if matches!(base, "Authenticated" | "Identity" | "Principal") {
return true;
}
matches_self_actor_user_form(base)
}
fn matches_self_actor_user_form(base: &str) -> bool {
const PREFIXES: &[&str] = &[
"Local",
"Current",
"Session",
"Authenticated",
"Auth",
"LoggedIn",
"Admin",
];
const SUFFIXES: &[&str] = &["View", "Info", "Context", "Session", "Token"];
for prefix in PREFIXES {
let Some(rest) = base.strip_prefix(prefix) else {
continue;
};
let Some(after_user) = rest.strip_prefix("User") else {
continue;
};
if after_user.is_empty() {
return true;
}
if SUFFIXES.contains(&after_user) {
return true;
}
}
false
}
fn extract_row_receiver_name(node: Node<'_>, bytes: &[u8]) -> Option<String> {
let node = unwrap_try_like(node);
match node.kind() {
"call_expression" | "call" | "method_invocation" | "method_call_expression" => {
let function = node
.child_by_field_name("function")
.or_else(|| node.child_by_field_name("method"));
let function = function?;
single_ident_receiver(function, bytes)
.or_else(|| single_ident_from_call_receiver(node, bytes))
}
"field_expression"
| "member_expression"
| "attribute"
| "selector_expression"
| "field_access" => single_ident_receiver(node, bytes),
_ => None,
}
}
fn single_ident_receiver(node: Node<'_>, bytes: &[u8]) -> Option<String> {
let object = node
.child_by_field_name("value")
.or_else(|| node.child_by_field_name("object"))
.or_else(|| node.child_by_field_name("operand"))
.or_else(|| node.child_by_field_name("receiver"))?;
single_ident_text(object, bytes)
}
fn single_ident_from_call_receiver(node: Node<'_>, bytes: &[u8]) -> Option<String> {
let receiver = node
.child_by_field_name("receiver")
.or_else(|| node.child_by_field_name("object"))?;
single_ident_text(receiver, bytes)
}
fn single_ident_text(node: Node<'_>, bytes: &[u8]) -> Option<String> {
if matches!(
node.kind(),
"identifier" | "shorthand_property_identifier" | "field_identifier"
) {
let value = text(node, bytes);
if value.is_empty() { None } else { Some(value) }
} else {
None
}
}
fn unwrap_try_like(node: Node<'_>) -> Node<'_> {
let mut cur = node;
loop {
match cur.kind() {
"try_expression"
| "await_expression"
| "reference_expression"
| "parenthesized_expression" => {
let Some(inner) = cur
.child_by_field_name("expression")
.or_else(|| cur.named_child(0))
else {
return cur;
};
cur = inner;
}
_ => return cur,
}
}
}
fn detect_ownership_equality_check(if_node: Node<'_>, bytes: &[u8], state: &mut UnitState) {
let Some(condition_raw) = if_node.child_by_field_name("condition") else {
return;
};
let Some(consequence) = if_node.child_by_field_name("consequence") else {
return;
};
let alternative = if_node.child_by_field_name("alternative");
let condition = unwrap_parens_local(condition_raw);
if condition.kind() != "binary_expression" {
return;
}
let Some(operator) = binary_operator_text(condition, bytes) else {
return;
};
let is_ne = matches!(operator.as_str(), "!=" | "ne");
let is_eq = matches!(operator.as_str(), "==" | "eq");
if !is_ne && !is_eq {
return;
}
let Some((left, right)) = binary_operands(condition) else {
return;
};
let fail_branch = if is_ne {
consequence
} else if let Some(alt) = alternative {
resolve_else_block(alt)
} else {
return;
};
if !branch_has_early_exit(fail_branch) {
return;
}
let left_refs = extract_value_refs(left, bytes);
let right_refs = extract_value_refs(right, bytes);
let (owner_ref, _self_ref) = match (
pick_owner_field_ref(&left_refs),
pick_self_actor_ref(&right_refs),
) {
(Some(o), Some(s)) => (o, s),
_ => match (
pick_owner_field_ref(&right_refs),
pick_self_actor_ref(&left_refs),
) {
(Some(o), Some(s)) => (o, s),
_ => return,
},
};
let row_binding = state.row_field_vars.get(&owner_ref.name).cloned();
let if_line = if_node.start_position().row + 1;
let if_span = span(if_node);
let condition_text = text(condition, bytes);
let (check_line, mut subjects) = match row_binding
.as_ref()
.and_then(|row| state.row_population_data.get(row).map(|v| (row, v)))
{
Some((row, (row_line, arg_refs))) => {
let mut subjects = arg_refs.clone();
subjects.push(ValueRef {
source_kind: ValueSourceKind::Identifier,
name: row.clone(),
base: None,
field: None,
index: None,
span: if_span,
});
(*row_line, subjects)
}
None => match row_binding.as_ref() {
Some(row) => (
if_line,
vec![ValueRef {
source_kind: ValueSourceKind::Identifier,
name: row.clone(),
base: None,
field: None,
index: None,
span: if_span,
}],
),
None => (if_line, Vec::new()),
},
};
subjects.push(owner_ref);
state.auth_checks.push(AuthCheck {
kind: AuthCheckKind::Ownership,
callee: "(row ownership equality)".into(),
subjects,
span: if_span,
line: check_line,
args: Vec::new(),
condition_text: Some(condition_text),
is_route_level: false,
});
}
fn unwrap_parens_local(node: Node<'_>) -> Node<'_> {
if node.kind() == "parenthesized_expression"
&& let Some(inner) = node.named_child(0)
{
return unwrap_parens_local(inner);
}
node
}
fn binary_operator_text(node: Node<'_>, bytes: &[u8]) -> Option<String> {
if let Some(op) = node.child_by_field_name("operator") {
let value = text(op, bytes);
if !value.is_empty() {
return Some(value);
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if !child.is_named() {
let value = text(child, bytes);
if !value.is_empty() {
return Some(value);
}
}
}
None
}
fn binary_operands<'tree>(node: Node<'tree>) -> Option<(Node<'tree>, Node<'tree>)> {
if let (Some(left), Some(right)) = (
node.child_by_field_name("left"),
node.child_by_field_name("right"),
) {
return Some((left, right));
}
let children = named_children(node);
match children.as_slice() {
[left, right] => Some((*left, *right)),
_ => None,
}
}
fn resolve_else_block(alt: Node<'_>) -> Node<'_> {
if alt.kind() == "else_clause"
&& let Some(block) = named_children(alt).into_iter().next()
{
return block;
}
alt
}
fn branch_has_early_exit(branch: Node<'_>) -> bool {
named_children(branch).into_iter().any(node_is_early_exit)
}
fn node_is_early_exit(node: Node<'_>) -> bool {
match node.kind() {
"return_expression" | "return_statement" => true,
"expression_statement" => named_children(node).into_iter().any(node_is_early_exit),
_ => false,
}
}
pub(super) fn is_owner_field_subject(subject: &ValueRef) -> bool {
let raw = match subject.source_kind {
ValueSourceKind::ArrayIndex => subject.base.as_deref().unwrap_or(&subject.name),
_ => subject
.field
.as_deref()
.or(subject.base.as_deref())
.unwrap_or(&subject.name),
};
let key = canonical_name(raw);
matches!(
key.as_str(),
"userid"
| "ownerid"
| "authorid"
| "createdby"
| "uploaderid"
| "updatedby"
| "submittedby"
| "assignedto"
| "creatorid"
| "postedby"
)
}
pub(super) fn is_self_actor_subject(subject: &ValueRef) -> bool {
if subject.source_kind == ValueSourceKind::Session
&& subject
.base
.as_deref()
.is_some_and(is_self_session_base_local)
{
return true;
}
let Some(field) = subject.field.as_deref() else {
return false;
};
if !field.eq_ignore_ascii_case("id") {
return false;
}
let Some(base) = subject.base.as_deref() else {
return false;
};
let last = base.rsplit('.').next().unwrap_or(base);
matches!(
last,
"user" | "current_user" | "currentUser" | "actor" | "current_actor"
)
}
fn is_self_session_base_local(base: &str) -> bool {
matches!(
base,
"req.session.user"
| "request.session.user"
| "session.user"
| "req.session.currentUser"
| "request.session.currentUser"
| "session.currentUser"
| "req.user"
| "request.user"
| "req.currentUser"
| "request.currentUser"
| "ctx.session.user"
| "ctx.session.currentUser"
| "ctx.state.user"
| "ctx.state.currentUser"
)
}
fn pick_owner_field_ref(refs: &[ValueRef]) -> Option<ValueRef> {
refs.iter().find(|v| is_owner_field_subject(v)).cloned()
}
fn pick_self_actor_ref(refs: &[ValueRef]) -> Option<ValueRef> {
refs.iter().find(|v| is_self_actor_subject(v)).cloned()
}
fn classify_auth_check(callee: &str, rules: &AuthAnalysisRules) -> AuthCheckKind {
if rules.is_admin_guard(callee, &[]) || matches_name(callee, "isAdmin") {
AuthCheckKind::AdminGuard
} else if rules.is_login_guard(callee) {
AuthCheckKind::LoginGuard
} else if matches_name(callee, "checkMembership")
|| matches_name(callee, "hasWorkspaceMembership")
|| matches_name(callee, "isMember")
|| matches_name(callee, "requireMembership")
|| matches_name(callee, "check_membership")
|| matches_name(callee, "has_membership")
|| matches_name(callee, "has_membership?")
|| matches_name(callee, "require_membership")
|| matches_name(callee, "ensure_membership")
|| matches_name(callee, "member_of?")
|| matches_name(callee, "member?")
{
AuthCheckKind::Membership
} else if matches_name(callee, "checkOwnership")
|| matches_name(callee, "isOwner")
|| matches_name(callee, "requireOwnership")
|| matches_name(callee, "check_ownership")
|| matches_name(callee, "has_ownership")
|| matches_name(callee, "require_ownership")
|| matches_name(callee, "ensure_ownership")
|| matches_name(callee, "is_owner")
|| matches_name(callee, "owner?")
|| matches_name(callee, "owns?")
{
AuthCheckKind::Ownership
} else {
AuthCheckKind::Other
}
}
pub fn function_name(node: Node<'_>, bytes: &[u8]) -> Option<String> {
function_definition_node(node)
.child_by_field_name("name")
.map(|name| text(name, bytes))
.filter(|name| !name.is_empty())
}
fn python_decorated_definition_is_background_task(node: Node<'_>, bytes: &[u8]) -> bool {
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if child.kind() != "decorator" {
continue;
}
let Some(inner) = child.named_child(0) else {
continue;
};
let callee_text = match inner.kind() {
"call" => {
let Some(function) = inner.child_by_field_name("function") else {
continue;
};
text(function, bytes)
}
"identifier" | "attribute" | "scoped_identifier" => text(inner, bytes),
_ => continue,
};
let last = callee_text.rsplit('.').next().unwrap_or(&callee_text);
if matches!(
last,
"task" | "shared_task" | "periodic_task" | "instrumented_task" | "receiver"
) {
return true;
}
}
false
}
fn function_params(node: Node<'_>, bytes: &[u8]) -> Vec<String> {
let Some(params_node) = node.child_by_field_name("parameters") else {
return Vec::new();
};
let mut params = Vec::new();
collect_param_names(params_node, bytes, false, &mut params);
params
}
pub fn function_params_route_handler(node: Node<'_>, bytes: &[u8]) -> Vec<String> {
let Some(params_node) = node.child_by_field_name("parameters") else {
return Vec::new();
};
let mut params = Vec::new();
collect_param_names(params_node, bytes, true, &mut params);
params
}
fn python_int_bounded_typed_params(node: Node<'_>, bytes: &[u8]) -> HashSet<String> {
let mut out: HashSet<String> = HashSet::new();
let Some(params_node) = node.child_by_field_name("parameters") else {
return out;
};
for idx in 0..params_node.named_child_count() {
let Some(child) = params_node.named_child(idx as u32) else {
continue;
};
if !matches!(child.kind(), "typed_parameter" | "typed_default_parameter") {
continue;
}
let mut name: Option<String> = None;
let mut type_text: Option<String> = None;
for inner_idx in 0..child.named_child_count() {
let Some(inner) = child.named_child(inner_idx as u32) else {
continue;
};
if inner.kind() == "identifier" && name.is_none() {
let n = text(inner, bytes);
if !n.is_empty() {
name = Some(n);
}
} else if inner.kind() == "type" {
type_text = Some(text(inner, bytes));
}
}
if let (Some(n), Some(t)) = (name, type_text)
&& python_type_text_is_integer_bounded(&t)
{
out.insert(n);
}
}
out
}
fn python_type_text_is_integer_bounded(text: &str) -> bool {
let trimmed = text.trim();
if trimmed.contains('|') {
return trimmed
.split('|')
.map(str::trim)
.all(|alt| alt == "None" || python_type_text_is_integer_bounded(alt));
}
if matches!(trimmed, "int" | "bool" | "float") {
return true;
}
let Some((head, rest)) = trimmed.split_once('[') else {
return false;
};
if !rest.ends_with(']') {
return false;
}
let inner = &rest[..rest.len() - 1];
let head_trim = head.trim();
if matches!(head_trim, "Annotated" | "typing.Annotated") {
return false;
}
let inner_first = inner.split(',').next().unwrap_or(inner).trim();
matches!(
head_trim,
"Optional"
| "typing.Optional"
| "Union"
| "typing.Union"
| "list"
| "List"
| "typing.List"
| "tuple"
| "Tuple"
| "typing.Tuple"
| "set"
| "Set"
| "typing.Set"
| "frozenset"
| "Frozenset"
| "Sequence"
| "typing.Sequence"
| "Iterable"
| "typing.Iterable"
| "Iterator"
| "typing.Iterator"
| "Collection"
| "typing.Collection"
| "dict"
| "Dict"
| "typing.Dict"
| "Mapping"
| "typing.Mapping"
) && python_type_text_is_integer_bounded(inner_first)
}
fn scan_trpc_aliases_visit(node: Node<'_>, bytes: &[u8], out: &mut HashSet<String>) {
match node.kind() {
"type_alias_declaration" | "interface_declaration" => {
let body = node
.child_by_field_name("value")
.or_else(|| node.child_by_field_name("body"));
if let Some(body) = body {
let body_text = text(body, bytes);
if body_text_references_trpc_marker(&body_text)
&& let Some(name_node) = node.child_by_field_name("name")
{
let name = text(name_node, bytes);
if !name.is_empty() {
out.insert(name);
}
}
}
return;
}
"program"
| "source_file"
| "module"
| "export_statement"
| "namespace_declaration"
| "module_declaration"
| "internal_module"
| "ambient_declaration"
| "lexical_declaration"
| "variable_declaration"
| "statement_block" => {}
_ => return,
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
scan_trpc_aliases_visit(child, bytes, out);
}
}
fn body_text_references_trpc_marker(body_text: &str) -> bool {
body_text.contains("TrpcSessionUser")
|| body_text.contains("TRPCContext")
|| body_text.contains("ProtectedTRPCContext")
|| body_text.contains("TrpcContext")
}
fn collect_trpc_ctx_param(node: Node<'_>, bytes: &[u8], state: &mut UnitState) {
let Some(pattern) = node.child_by_field_name("pattern") else {
return;
};
let Some(ty_node) = node.child_by_field_name("type") else {
return;
};
let ty_text = text(ty_node, bytes);
if !type_text_is_trpc_options(&ty_text, &state.trpc_alias_names) {
return;
}
if pattern.kind() == "object_pattern" {
for idx in 0..pattern.named_child_count() {
let Some(child) = pattern.named_child(idx as u32) else {
continue;
};
match child.kind() {
"shorthand_property_identifier_pattern" => {
let name = text(child, bytes);
if name.eq_ignore_ascii_case("ctx") {
state
.self_scoped_session_bases
.insert(format!("{name}.user"));
}
}
"object_assignment_pattern" => {
if let Some(left) = child.child_by_field_name("left") {
let name = if matches!(
left.kind(),
"identifier" | "shorthand_property_identifier_pattern"
) {
text(left, bytes)
} else {
first_identifier_name(left, bytes).unwrap_or_default()
};
if name.eq_ignore_ascii_case("ctx") {
state
.self_scoped_session_bases
.insert(format!("{name}.user"));
}
}
}
"pair_pattern" => {
let key_node = child.child_by_field_name("key");
let local_node = child.child_by_field_name("value");
if let (Some(k), Some(v)) = (key_node, local_node) {
let key = text(k, bytes);
let local = first_identifier_name(v, bytes).unwrap_or_default();
if !local.is_empty() && key.eq_ignore_ascii_case("ctx") {
state
.self_scoped_session_bases
.insert(format!("{local}.user"));
}
}
}
_ => {}
}
}
return;
}
if let Some(name) = first_identifier_name(pattern, bytes)
&& !name.is_empty()
{
state
.self_scoped_session_bases
.insert(format!("{name}.ctx.user"));
}
}
fn type_text_is_trpc_options(ty_text: &str, trpc_alias_names: &HashSet<String>) -> bool {
if body_text_references_trpc_marker(ty_text) {
return true;
}
let trimmed = ty_text.trim_start_matches(':').trim();
if trimmed.is_empty() {
return false;
}
let head = trimmed.split('<').next().unwrap_or(trimmed).trim();
if trpc_alias_names.contains(head) {
return true;
}
for alias in trpc_alias_names {
if alias.is_empty() {
continue;
}
if let Some(idx) = ty_text.find(alias.as_str()) {
let before_ok = idx == 0
|| !ty_text.as_bytes()[idx - 1].is_ascii_alphanumeric()
&& ty_text.as_bytes()[idx - 1] != b'_';
let end = idx + alias.len();
let after_ok = end >= ty_text.len()
|| !ty_text.as_bytes()[end].is_ascii_alphanumeric()
&& ty_text.as_bytes()[end] != b'_';
if before_ok && after_ok {
return true;
}
}
}
false
}
pub fn method_receiver_name(node: Node<'_>, bytes: &[u8]) -> Option<String> {
let receiver = node.child_by_field_name("receiver")?;
extract_receiver_param_name(receiver, bytes)
}
fn extract_receiver_param_name(node: Node<'_>, bytes: &[u8]) -> Option<String> {
if let Some(name_node) = node.child_by_field_name("name") {
let name = text(name_node, bytes);
if !name.is_empty() {
return Some(name);
}
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if let Some(found) = extract_receiver_param_name(child, bytes) {
return Some(found);
}
}
None
}
fn collect_param_names(
node: Node<'_>,
bytes: &[u8],
include_id_like_typed: bool,
out: &mut Vec<String>,
) {
match node.kind() {
"identifier" | "property_identifier" | "shorthand_property_identifier_pattern" => {
let name = text(node, bytes);
if !name.is_empty() && !out.contains(&name) {
out.push(name);
}
}
"parameter_declaration" | "variadic_parameter_declaration"
if node.child_by_field_name("name").is_some() =>
{
if let Some(type_node) = node.child_by_field_name("type")
&& is_go_non_user_input_type(type_node, bytes)
{
return;
}
let mut cursor = node.walk();
for child in node.children_by_field_name("name", &mut cursor) {
if child.kind() == "identifier" {
let name = text(child, bytes);
if !name.is_empty() && !out.contains(&name) {
out.push(name);
}
}
}
}
"parameter" => {
if let Some(pattern) = node.child_by_field_name("pattern") {
collect_param_names(pattern, bytes, include_id_like_typed, out);
return;
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
collect_param_names(child, bytes, include_id_like_typed, out);
}
}
"default_parameter" | "typed_parameter" | "typed_default_parameter" => {
if let Some(name) = node.child_by_field_name("name") {
collect_param_names(name, bytes, include_id_like_typed, out);
return;
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if child.kind() == "identifier" {
let name_text = text(child, bytes);
let is_id_like = is_python_id_like_typed_param(&name_text);
if !name_text.is_empty()
&& !out.contains(&name_text)
&& (include_id_like_typed || !is_id_like)
{
out.push(name_text);
}
return;
}
}
}
_ => {
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
collect_param_names(child, bytes, include_id_like_typed, out);
}
}
}
}
fn is_go_non_user_input_type(type_node: Node<'_>, bytes: &[u8]) -> bool {
let mut node = type_node;
if node.kind() == "pointer_type" {
if let Some(inner) = node.child_by_field_name("type") {
node = inner;
} else if let Some(inner) = node.named_child(0) {
node = inner;
}
}
if node.kind() != "qualified_type" {
return false;
}
let pkg = node
.child_by_field_name("package")
.map(|n| text(n, bytes))
.unwrap_or_default();
let name = node
.child_by_field_name("name")
.map(|n| text(n, bytes))
.unwrap_or_default();
matches!(
(pkg.as_str(), name.as_str()),
("context", "Context") | ("context", "CancelFunc")
)
}
fn is_python_id_like_typed_param(name: &str) -> bool {
let lower = name.to_ascii_lowercase();
lower == "id" || lower.ends_with("id") || lower.ends_with("_id") || lower.ends_with("ids")
}
pub fn is_function_like(node: Node<'_>) -> bool {
matches!(
node.kind(),
"function_declaration"
| "function_expression"
| "arrow_function"
| "function_definition"
| "method_declaration"
| "function_item"
| "closure_expression"
| "func_literal"
| "decorated_definition"
| "method"
| "singleton_method"
| "block"
| "do_block"
)
}
pub fn is_handler_reference(node: Node<'_>) -> bool {
is_function_like(node)
|| matches!(
node.kind(),
"identifier"
| "member_expression"
| "attribute"
| "selector_expression"
| "field_expression"
| "scoped_identifier"
| "field_access"
| "constant"
| "scope_resolution"
)
}
pub fn call_site_from_node(node: Node<'_>, bytes: &[u8]) -> CallSite {
if matches!(
node.kind(),
"call_expression" | "call" | "method_invocation" | "method_call_expression"
) {
let name = call_name(node, bytes);
let arg_nodes = node
.child_by_field_name("arguments")
.map(named_children)
.unwrap_or_default();
let args = arg_nodes.iter().map(|arg| text(*arg, bytes)).collect();
let args_value_refs = arg_nodes
.iter()
.map(|arg| extract_value_refs(*arg, bytes))
.collect();
CallSite {
name,
args,
span: span(node),
args_value_refs,
}
} else {
CallSite {
name: text(node, bytes),
args: Vec::new(),
span: span(node),
args_value_refs: Vec::new(),
}
}
}
pub fn call_sites_from_value(node: Node<'_>, bytes: &[u8]) -> Vec<CallSite> {
if matches!(node.kind(), "array" | "list" | "tuple") {
named_children(node)
.into_iter()
.map(|child| call_site_from_node(child, bytes))
.filter(|call| !call.name.is_empty())
.collect()
} else {
let call = call_site_from_node(node, bytes);
if call.name.is_empty() {
Vec::new()
} else {
vec![call]
}
}
}
pub fn auth_check_from_call_site(
call: &CallSite,
line: usize,
rules: &AuthAnalysisRules,
) -> Option<AuthCheck> {
let kind = if rules.is_admin_guard(&call.name, &call.args) {
AuthCheckKind::AdminGuard
} else if rules.is_login_guard(&call.name) {
AuthCheckKind::LoginGuard
} else if rules.is_authorization_check(&call.name) {
classify_auth_check(&call.name, rules)
} else {
return None;
};
Some(AuthCheck {
kind,
callee: call.name.clone(),
subjects: Vec::new(),
span: call.span,
line,
args: call.args.clone(),
condition_text: None,
is_route_level: false,
})
}
pub fn extract_value_refs(node: Node<'_>, bytes: &[u8]) -> Vec<ValueRef> {
match node.kind() {
"member_expression"
| "attribute"
| "selector_expression"
| "field_expression"
| "field_access" => member_value_ref(node, bytes).into_iter().collect(),
"subscript_expression" | "subscript" | "element_reference" | "index_expression" => {
subscript_value_ref(node, bytes).into_iter().collect()
}
"call_expression" | "call" | "method_invocation" | "method_call_expression" => {
call_value_ref(node, bytes)
.map(|value| vec![value])
.unwrap_or_else(|| {
let mut refs = Vec::new();
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
refs.extend(extract_value_refs(child, bytes));
}
refs
})
}
"identifier"
| "instance_variable"
| "class_variable"
| "global_variable" => vec![ValueRef {
source_kind: ValueSourceKind::Identifier,
name: text(node, bytes),
base: None,
field: None,
index: None,
span: span(node),
}],
"keyword_argument"
| "keyword_arg"
| "named_argument"
| "named_arg" => {
if let Some(value) = node
.child_by_field_name("value")
.or_else(|| node.child_by_field_name("argument"))
{
extract_value_refs(value, bytes)
} else {
Vec::new()
}
}
_ => {
let mut refs = Vec::new();
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
refs.extend(extract_value_refs(child, bytes));
}
refs
}
}
}
fn call_value_ref(node: Node<'_>, bytes: &[u8]) -> Option<ValueRef> {
let callee = call_name(node, bytes);
let args = node
.child_by_field_name("arguments")
.map(named_children)
.unwrap_or_default();
let chain = member_chain(node, bytes);
if let Some(value) = accessor_call_value_ref(node, &callee, &chain, &args, bytes) {
return Some(value);
}
if !args.is_empty() {
return None;
}
if chain.is_empty() {
return None;
}
let name = chain.join(".");
let field = chain.last().cloned();
let base = if chain.len() > 1 {
Some(chain[..chain.len() - 1].join("."))
} else {
None
};
Some(ValueRef {
source_kind: classify_member_chain(&chain),
name,
base,
field,
index: None,
span: span(node),
})
}
fn member_value_ref(node: Node<'_>, bytes: &[u8]) -> Option<ValueRef> {
let chain = member_chain(node, bytes);
if chain.is_empty() {
return None;
}
let name = chain.join(".");
let field = chain.last().cloned();
let base = if chain.len() > 1 {
Some(chain[..chain.len() - 1].join("."))
} else {
None
};
let source_kind = classify_member_chain(&chain);
Some(ValueRef {
source_kind,
name,
base,
field,
index: None,
span: span(node),
})
}
fn classify_member_chain(chain: &[String]) -> ValueSourceKind {
if matches_request_param(chain) {
ValueSourceKind::RequestParam
} else if matches_request_body(chain) {
ValueSourceKind::RequestBody
} else if matches_request_query(chain) {
ValueSourceKind::RequestQuery
} else if matches_session_context(chain) {
ValueSourceKind::Session
} else if chain.first().is_some_and(|segment| {
matches!(
segment.to_ascii_lowercase().as_str(),
"invitation" | "token" | "invite"
)
}) {
ValueSourceKind::TokenField
} else {
ValueSourceKind::MemberField
}
}
fn matches_request_param(chain: &[String]) -> bool {
let lower = lower_segments(chain);
(lower.first().is_some_and(|segment| segment == "params"))
|| (lower.len() >= 2 && lower[0] == "self" && lower[1] == "params")
|| (lower.len() >= 3
&& matches!(lower[0].as_str(), "req" | "request")
&& lower[1] == "params")
|| (lower.len() >= 3 && lower[0] == "ctx" && lower[1] == "params")
}
fn matches_request_body(chain: &[String]) -> bool {
let lower = lower_segments(chain);
(lower.len() >= 3 && matches!(lower[0].as_str(), "req" | "request") && lower[1] == "body")
|| (lower.len() >= 3
&& matches!(lower[0].as_str(), "req" | "request")
&& matches!(
lower[1].as_str(),
"form" | "json" | "values" | "post" | "data"
))
|| (lower.len() >= 4 && lower[0] == "ctx" && lower[1] == "request" && lower[2] == "body")
|| (lower.len() >= 3 && lower[0] == "ctx" && lower[1] == "body")
}
fn matches_request_query(chain: &[String]) -> bool {
let lower = lower_segments(chain);
(lower.len() >= 3 && matches!(lower[0].as_str(), "req" | "request") && lower[1] == "query")
|| (lower.len() >= 3
&& matches!(lower[0].as_str(), "req" | "request")
&& matches!(lower[1].as_str(), "args" | "get"))
|| (lower.len() >= 3 && lower[0] == "ctx" && lower[1] == "query")
|| (lower.len() >= 4 && lower[0] == "ctx" && lower[1] == "request" && lower[2] == "query")
}
fn matches_session_context(chain: &[String]) -> bool {
let lower = lower_segments(chain);
let bare_session_chain_is_auth = lower.first().is_some_and(|segment| segment == "session")
&& (lower.len() == 1 || lower.len() >= 3 || !is_orm_session_verb(&lower[1]));
let unambiguous_chain_root = lower.first().is_some_and(|segment| {
matches!(
segment.as_str(),
"current_user"
| "current_account"
| "current_member"
| "securitycontext"
| "principal"
| "authentication"
)
});
bare_session_chain_is_auth
|| unambiguous_chain_root
|| (lower.len() >= 2
&& matches!(lower[0].as_str(), "req" | "request")
&& matches!(lower[1].as_str(), "session" | "user" | "currentuser"))
|| (lower.len() >= 3
&& lower[0] == "self"
&& matches!(lower[1].as_str(), "request" | "session" | "current_user")
&& matches!(lower[2].as_str(), "session" | "user" | "currentuser"))
|| (lower.len() >= 3
&& lower[0] == "ctx"
&& matches!(lower[1].as_str(), "session" | "state"))
}
fn is_orm_session_verb(segment: &str) -> bool {
matches!(
segment,
"commit"
| "rollback"
| "flush"
| "refresh"
| "merge"
| "expunge"
| "expunge_all"
| "close"
| "begin"
| "begin_nested"
| "query"
| "scalar"
| "scalars"
| "execute"
| "exec"
| "exec_driver_sql"
| "add"
| "add_all"
| "delete"
| "bulk_save_objects"
| "bulk_insert_mappings"
| "bulk_update_mappings"
| "configure"
| "info"
)
}
fn subscript_value_ref(node: Node<'_>, bytes: &[u8]) -> Option<ValueRef> {
let object = node
.child_by_field_name("object")
.or_else(|| node.child_by_field_name("value"))
.or_else(|| node.child_by_field_name("operand"));
let index = node
.child_by_field_name("index")
.or_else(|| node.child_by_field_name("subscript"));
let (object, index) = if let (Some(object), Some(index)) = (object, index) {
(object, index)
} else {
let children = named_children(node);
match children.as_slice() {
[object, index, ..] => (*object, *index),
_ => return None,
}
};
let base_chain = member_chain(object, bytes);
let base = if base_chain.is_empty() {
text(object, bytes)
} else {
base_chain.join(".")
};
let index_text = text(index, bytes);
let field = Some(strip_quotes(&index_text));
let source_kind = if base_chain.is_empty() {
ValueSourceKind::ArrayIndex
} else {
match classify_member_chain(&base_chain) {
ValueSourceKind::MemberField => ValueSourceKind::ArrayIndex,
other => other,
}
};
Some(ValueRef {
source_kind,
name: if source_kind == ValueSourceKind::ArrayIndex {
format!("{base}[{index_text}]")
} else {
format!("{base}.{}", strip_quotes(&index_text))
},
base: Some(base),
field,
index: Some(index_text),
span: span(node),
})
}
pub fn member_chain(node: Node<'_>, bytes: &[u8]) -> Vec<String> {
if node.kind() == "call" {
let mut chain = if let Some(receiver) = node.child_by_field_name("receiver") {
member_chain(receiver, bytes)
} else {
Vec::new()
};
let method = node
.child_by_field_name("method")
.or_else(|| node.child_by_field_name("name"))
.map(|method| text(method, bytes))
.unwrap_or_default();
if !method.is_empty() {
chain.push(method);
}
return chain;
}
if node.kind() == "method_invocation" || node.kind() == "method_call_expression" {
let mut chain = node
.child_by_field_name("object")
.or_else(|| node.child_by_field_name("receiver"))
.map(|object| member_chain(object, bytes))
.unwrap_or_default();
let method = node
.child_by_field_name("name")
.or_else(|| node.child_by_field_name("method"))
.map(|method| text(method, bytes))
.unwrap_or_default();
if !method.is_empty() {
chain.push(method);
}
return chain;
}
if node.kind() == "scope_resolution" {
let mut chain = Vec::new();
if let Some(scope) = node.child_by_field_name("scope") {
chain.extend(member_chain(scope, bytes));
}
if let Some(name) = node.child_by_field_name("name") {
let value = text(name, bytes);
if !value.is_empty() {
chain.push(value);
}
}
return chain;
}
if node.kind() == "scoped_identifier" {
let mut chain = Vec::new();
if let Some(path) = node.child_by_field_name("path") {
chain.extend(member_chain(path, bytes));
}
if let Some(name) = node.child_by_field_name("name") {
let value = text(name, bytes);
if !value.is_empty() {
chain.push(value);
}
}
return chain;
}
if !matches!(
node.kind(),
"member_expression"
| "attribute"
| "selector_expression"
| "field_expression"
| "field_access"
) {
let value = text(node, bytes);
return if value.is_empty() {
Vec::new()
} else {
vec![value]
};
}
let mut chain = Vec::new();
if let Some(object) = node
.child_by_field_name("object")
.or_else(|| node.child_by_field_name("value"))
.or_else(|| node.child_by_field_name("operand"))
.or_else(|| node.child_by_field_name("argument"))
{
chain.extend(member_chain(object, bytes));
}
if let Some(property) = node
.child_by_field_name("property")
.or_else(|| node.child_by_field_name("attribute"))
.or_else(|| node.child_by_field_name("field"))
.or_else(|| node.child_by_field_name("name"))
{
let property_text = text(property, bytes);
if !property_text.is_empty() {
chain.push(property_text);
}
}
chain
}
pub fn callee_name(node: Node<'_>, bytes: &[u8]) -> String {
match node.kind() {
"identifier" | "property_identifier" | "constant" | "field_identifier" => text(node, bytes),
"member_expression"
| "attribute"
| "selector_expression"
| "field_expression"
| "scoped_identifier"
| "field_access"
| "scope_resolution"
| "call"
| "method_invocation"
| "method_call_expression" => member_chain(node, bytes).join("."),
_ => text(node, bytes),
}
}
pub fn call_name(node: Node<'_>, bytes: &[u8]) -> String {
if !matches!(
node.kind(),
"call_expression" | "call" | "method_invocation" | "method_call_expression"
) {
return callee_name(node, bytes);
}
if let Some(function) = node.child_by_field_name("function") {
return callee_name(function, bytes);
}
let method = node
.child_by_field_name("method")
.or_else(|| node.child_by_field_name("name"))
.map(|child| text(child, bytes))
.unwrap_or_default();
let receiver = node
.child_by_field_name("receiver")
.or_else(|| node.child_by_field_name("object"))
.or_else(|| node.child_by_field_name("scope"))
.or_else(|| node.child_by_field_name("argument"))
.map(|child| member_chain(child, bytes).join("."))
.filter(|value| !value.is_empty());
match (receiver, method.is_empty()) {
(Some(receiver), false) => format!("{receiver}.{method}"),
(_, false) => method,
_ => text(node, bytes),
}
}
pub fn member_target(node: Node<'_>, bytes: &[u8]) -> Option<(String, String)> {
let object = node
.child_by_field_name("object")
.or_else(|| node.child_by_field_name("operand"))
.or_else(|| node.child_by_field_name("value"))
.or_else(|| node.child_by_field_name("receiver"))
.or_else(|| node.child_by_field_name("argument"))?;
let property = node
.child_by_field_name("property")
.or_else(|| node.child_by_field_name("field"))
.or_else(|| node.child_by_field_name("attribute"))
.or_else(|| node.child_by_field_name("name"))?;
Some((text(object, bytes), text(property, bytes)))
}
pub fn http_method_from_name(name: &str) -> Option<HttpMethod> {
match name.to_ascii_lowercase().as_str() {
"get" => Some(HttpMethod::Get),
"post" => Some(HttpMethod::Post),
"put" => Some(HttpMethod::Put),
"delete" => Some(HttpMethod::Delete),
"patch" => Some(HttpMethod::Patch),
"all" | "any" => Some(HttpMethod::All),
"use" => Some(HttpMethod::Use),
_ => None,
}
}
pub fn join_route_paths(prefix: &str, route: &str) -> String {
match (prefix.trim_end_matches('/'), route.trim_start_matches('/')) {
("", "") => "/".to_string(),
("", route) => format!("/{route}"),
(prefix, "") => prefix.to_string(),
(prefix, route) => format!("{prefix}/{route}"),
}
}
fn call_receiver_subjects(node: Node<'_>, bytes: &[u8]) -> Vec<ValueRef> {
let mut subjects = Vec::new();
if let Some(receiver) = node
.child_by_field_name("receiver")
.or_else(|| node.child_by_field_name("object"))
.or_else(|| node.child_by_field_name("argument"))
.or_else(|| {
node.child_by_field_name("function").and_then(|function| {
function
.child_by_field_name("object")
.or_else(|| function.child_by_field_name("operand"))
.or_else(|| function.child_by_field_name("argument"))
})
})
{
subjects.extend(extract_value_refs(receiver, bytes));
}
subjects
}
pub fn string_literal_value(node: Node<'_>, bytes: &[u8]) -> Option<String> {
match node.kind() {
"string"
| "template_string"
| "string_literal"
| "interpreted_string_literal"
| "raw_string_literal" => Some(strip_quotes(&text(node, bytes))),
_ => None,
}
}
pub fn object_property_value<'tree>(
node: Node<'tree>,
bytes: &[u8],
names: &[&str],
) -> Option<Node<'tree>> {
if node.kind() != "object" {
return None;
}
for child in named_children(node) {
match child.kind() {
"pair" => {
let Some(key) = child.child_by_field_name("key") else {
continue;
};
let key_name = strip_quotes(&text(key, bytes));
if names.iter().any(|name| *name == key_name) {
return child.child_by_field_name("value");
}
}
"shorthand_property_identifier" | "identifier" => {
let key_name = text(child, bytes);
if names.iter().any(|name| *name == key_name) {
return Some(child);
}
}
_ => {}
}
}
None
}
pub fn decorated_definition_child(node: Node<'_>) -> Option<Node<'_>> {
node.child_by_field_name("definition")
}
pub fn function_definition_node(node: Node<'_>) -> Node<'_> {
decorated_definition_child(node).unwrap_or(node)
}
pub fn named_children(node: Node<'_>) -> Vec<Node<'_>> {
let mut children = Vec::new();
for idx in 0..node.named_child_count() {
if let Some(child) = node.named_child(idx as u32) {
children.push(child);
}
}
children
}
pub fn text(node: Node<'_>, bytes: &[u8]) -> String {
node.utf8_text(bytes).unwrap_or("").to_string()
}
pub fn span(node: Node<'_>) -> (usize, usize) {
(node.start_byte(), node.end_byte())
}
fn dedup_value_refs(values: &mut Vec<ValueRef>) {
let mut deduped = Vec::new();
for value in values.drain(..) {
if !deduped
.iter()
.any(|existing: &ValueRef| existing.name == value.name && existing.span == value.span)
{
deduped.push(value);
}
}
*values = deduped;
}
fn lower_segments(chain: &[String]) -> Vec<String> {
chain
.iter()
.map(|segment| segment.to_ascii_lowercase())
.collect()
}
fn accessor_call_value_ref(
node: Node<'_>,
callee: &str,
chain: &[String],
args: &[Node<'_>],
bytes: &[u8],
) -> Option<ValueRef> {
let method = bare_method_name(callee);
let field = args
.first()
.and_then(|arg| string_literal_value(*arg, bytes));
let source_kind = match method {
"Param" | "PathParam" => Some(ValueSourceKind::RequestParam),
"Query" | "QueryParam" | "DefaultQuery" | "getParameter" | "getQueryString" => {
Some(ValueSourceKind::RequestQuery)
}
"PostForm" | "FormValue" | "DefaultPostForm" => Some(ValueSourceKind::RequestBody),
"Get" | "GetString" | "MustGet" | "getAttribute" => Some(ValueSourceKind::Session),
_ if chain.first().is_some_and(|segment| {
matches!(
segment.to_ascii_lowercase().as_str(),
"invitation" | "token" | "invite"
)
}) && method.starts_with("get")
&& method.len() > 3 =>
{
Some(ValueSourceKind::TokenField)
}
_ => None,
}?;
let normalized_field = field
.or_else(|| {
if source_kind == ValueSourceKind::TokenField && method.starts_with("get") {
Some(method[3..].to_string())
} else {
None
}
})
.map(|field| {
let mut chars = field.chars();
let Some(first) = chars.next() else {
return field;
};
format!("{}{}", first.to_ascii_lowercase(), chars.as_str())
})
.filter(|field| !field.is_empty());
let base = match source_kind {
ValueSourceKind::Session => Some("session".to_string()),
_ if chain.len() > 1 => Some(chain[..chain.len() - 1].join(".")),
_ => chain.first().cloned(),
};
let name = if let Some(field) = normalized_field.as_deref() {
match base.as_deref() {
Some(base) if !base.is_empty() => format!("{base}.{field}"),
_ => field.to_string(),
}
} else {
callee.to_string()
};
Some(ValueRef {
source_kind,
name,
base,
field: normalized_field,
index: None,
span: span(node),
})
}
#[cfg(test)]
mod tests {
use super::{is_owner_field_subject, is_self_actor_subject, is_self_actor_type_text};
use crate::auth_analysis::model::{ValueRef, ValueSourceKind};
#[test]
fn is_self_actor_type_text_matches_known_wrappers() {
assert!(is_self_actor_type_text("Authenticated"));
assert!(is_self_actor_type_text("Identity"));
assert!(is_self_actor_type_text("Principal"));
assert!(is_self_actor_type_text("CurrentUser"));
assert!(is_self_actor_type_text("SessionUser"));
assert!(is_self_actor_type_text("AuthUser"));
assert!(is_self_actor_type_text("AdminUser"));
assert!(is_self_actor_type_text("AuthenticatedUser"));
assert!(is_self_actor_type_text("LocalUserView"));
assert!(is_self_actor_type_text("LocalUser"));
assert!(is_self_actor_type_text("LoggedInUser"));
assert!(is_self_actor_type_text("CurrentUserContext"));
assert!(is_self_actor_type_text("AuthenticatedUserSession"));
assert!(is_self_actor_type_text("SessionUserToken"));
assert!(is_self_actor_type_text("AdminUserInfo"));
assert!(is_self_actor_type_text("crate::auth::CurrentUser"));
assert!(is_self_actor_type_text("crate::user::LocalUserView"));
assert!(is_self_actor_type_text("&CurrentUser"));
assert!(is_self_actor_type_text("&mut AuthUser"));
assert!(is_self_actor_type_text("CurrentUser<Admin>"));
assert!(is_self_actor_type_text("LocalUserView<Admin>"));
assert!(!is_self_actor_type_text("User"));
assert!(!is_self_actor_type_text("UserPreferences"));
assert!(!is_self_actor_type_text("UserView"));
assert!(!is_self_actor_type_text("PaymentUser"));
assert!(!is_self_actor_type_text("CurrentUserPreferences"));
assert!(!is_self_actor_type_text("Db"));
assert!(!is_self_actor_type_text("Path<(i64,)>"));
assert!(!is_self_actor_type_text("Json<Body>"));
assert!(!is_self_actor_type_text("RequireAuth"));
assert!(!is_self_actor_type_text("RequireLogin"));
}
fn ident(name: &str) -> ValueRef {
ValueRef {
source_kind: ValueSourceKind::Identifier,
name: name.to_string(),
base: None,
field: None,
index: None,
span: (0, 0),
}
}
fn member(base: &str, field: &str) -> ValueRef {
ValueRef {
source_kind: ValueSourceKind::MemberField,
name: format!("{base}.{field}"),
base: Some(base.to_string()),
field: Some(field.to_string()),
index: None,
span: (0, 0),
}
}
fn session(base: &str, field: &str) -> ValueRef {
ValueRef {
source_kind: ValueSourceKind::Session,
name: format!("{base}.{field}"),
base: Some(base.to_string()),
field: Some(field.to_string()),
index: None,
span: (0, 0),
}
}
#[test]
fn is_owner_field_subject_matches_known_column_names() {
assert!(is_owner_field_subject(&ident("owner_id")));
assert!(is_owner_field_subject(&ident("user_id")));
assert!(is_owner_field_subject(&ident("author_id")));
assert!(is_owner_field_subject(&ident("created_by")));
assert!(is_owner_field_subject(&member("row", "owner_id")));
assert!(!is_owner_field_subject(&ident("group_id")));
assert!(!is_owner_field_subject(&ident("doc_id")));
assert!(!is_owner_field_subject(&ident("user")));
}
#[test]
fn is_self_actor_subject_matches_known_self_shapes() {
assert!(is_self_actor_subject(&member("user", "id")));
assert!(is_self_actor_subject(&member("current_user", "id")));
assert!(is_self_actor_subject(&session("req.user", "id")));
assert!(is_self_actor_subject(&session("ctx.session.user", "id")));
assert!(!is_self_actor_subject(&member("user", "workspace_id")));
assert!(!is_self_actor_subject(&member("target", "id")));
assert!(!is_self_actor_subject(&ident("user_id")));
}
#[test]
fn type_text_is_trpc_options_matches_alias_and_inline_marker() {
use super::type_text_is_trpc_options;
use std::collections::HashSet;
let mut aliases = HashSet::new();
aliases.insert("GetOptions".to_string());
aliases.insert("UpdateOptions".to_string());
assert!(type_text_is_trpc_options(
": { ctx: { user: NonNullable<TrpcSessionUser> } }",
&aliases
));
assert!(type_text_is_trpc_options(
": { user: TrpcSessionUser }",
&HashSet::new()
));
assert!(type_text_is_trpc_options(": GetOptions", &aliases));
assert!(type_text_is_trpc_options("GetOptions", &aliases));
assert!(type_text_is_trpc_options(": Promise<GetOptions>", &aliases));
assert!(type_text_is_trpc_options(
": NonNullable<UpdateOptions>",
&aliases
));
assert!(!type_text_is_trpc_options(": OtherOptions", &aliases));
assert!(!type_text_is_trpc_options(": Promise<Foo>", &aliases));
assert!(!type_text_is_trpc_options(": SomeRandomType", &aliases));
assert!(!type_text_is_trpc_options(": MyGetOptionsX", &aliases));
}
#[test]
fn body_text_references_trpc_marker_recognises_known_markers() {
use super::body_text_references_trpc_marker as bm;
assert!(bm("type X = { user: NonNullable<TrpcSessionUser> }"));
assert!(bm("interface Ctx extends TRPCContext { ... }"));
assert!(bm("type Ctx = ProtectedTRPCContext"));
assert!(bm("export type Y = { ctx: TrpcContext }"));
assert!(!bm("type X = { user: User }"));
assert!(!bm("type X = SessionContext"));
assert!(!bm("type X = { foo: SomeContext }"));
}
#[test]
fn is_self_scoped_session_base_text_matches_known_session_bases() {
use super::is_self_scoped_session_base_text as bt;
assert!(bt("req.user"));
assert!(bt("request.user"));
assert!(bt("req.session.user"));
assert!(bt("req.session.currentUser"));
assert!(bt("session.user"));
assert!(bt("session.currentUser"));
assert!(bt("ctx.session.user"));
assert!(bt("ctx.state.user"));
assert!(!bt("req.body"));
assert!(!bt("req.params"));
assert!(!bt("ctx.user"));
assert!(!bt("data.user"));
assert!(!bt("user"));
}
#[test]
fn matches_session_context_denylists_orm_session_verbs() {
use super::matches_session_context as msc;
let v = |chain: &[&str]| chain.iter().map(|s| s.to_string()).collect::<Vec<_>>();
assert!(msc(&v(&["session", "user"])));
assert!(msc(&v(&["session", "user_id"])));
assert!(msc(&v(&["session", "id"])));
assert!(msc(&v(&["session", "uid"])));
assert!(msc(&v(&["session", "email"])));
assert!(msc(&v(&["session", "currentUser"])));
assert!(msc(&v(&["session", "workspace_id"])));
assert!(msc(&v(&["session", "project_id"])));
assert!(msc(&v(&["session", "role"])));
assert!(msc(&v(&["session", "currentWorkspaceID"])));
assert!(!msc(&v(&["session", "commit"])));
assert!(!msc(&v(&["session", "rollback"])));
assert!(!msc(&v(&["session", "scalar"])));
assert!(!msc(&v(&["session", "scalars"])));
assert!(!msc(&v(&["session", "add"])));
assert!(!msc(&v(&["session", "delete"])));
assert!(!msc(&v(&["session", "execute"])));
assert!(!msc(&v(&["session", "flush"])));
assert!(!msc(&v(&["session", "query"])));
assert!(!msc(&v(&["session", "merge"])));
assert!(!msc(&v(&["session", "refresh"])));
assert!(!msc(&v(&["session", "close"])));
assert!(msc(&v(&["session"])));
assert!(msc(&v(&["req", "session", "user"])));
assert!(msc(&v(&["request", "session"])));
assert!(msc(&v(&["current_user", "id"])));
assert!(msc(&v(&["current_user", "preferences"])));
}
#[test]
fn collect_param_names_rust_skips_type_segment_idents() {
use super::function_params;
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter::Language::from(tree_sitter_rust::LANGUAGE))
.unwrap();
let src = b"unsafe fn remove_tasks(tasks: &[Task], dst: &std::path::Path, sz: usize) {}";
let tree = parser.parse(src.as_slice(), None).unwrap();
let func = tree
.root_node()
.child(0)
.expect("source_file should have a function");
let params = function_params(func, src);
assert_eq!(
params,
vec!["tasks".to_string(), "dst".to_string(), "sz".to_string()],
"type-segment idents (`std`, `path`, `Path`) must NOT pollute the param-name set"
);
}
#[test]
fn collect_param_names_rust_handles_request_typed_params() {
use super::function_params;
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter::Language::from(tree_sitter_rust::LANGUAGE))
.unwrap();
let src = b"fn handle(req: &Request<Body>, state: AppState) -> Response { todo!() }";
let tree = parser.parse(src.as_slice(), None).unwrap();
let func = tree.root_node().child(0).expect("function");
let params = function_params(func, src);
assert_eq!(
params,
vec!["req".to_string(), "state".to_string()],
"type idents `Request`/`Body`/`Response`/`AppState` must not leak as params"
);
}
#[test]
fn collect_param_names_rust_destructured_pattern_picks_up_bindings() {
use super::function_params;
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter::Language::from(tree_sitter_rust::LANGUAGE))
.unwrap();
let src = b"fn split((a, b): (u32, u32)) {}";
let tree = parser.parse(src.as_slice(), None).unwrap();
let func = tree.root_node().child(0).expect("function");
let params = function_params(func, src);
assert!(params.contains(&"a".to_string()), "got {:?}", params);
assert!(params.contains(&"b".to_string()), "got {:?}", params);
assert!(!params.contains(&"u32".to_string()), "got {:?}", params);
}
#[test]
fn collect_param_names_go_drops_context_context_param() {
use super::function_params;
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter::Language::from(tree_sitter_go::LANGUAGE))
.unwrap();
let src = b"package x\nfunc GetPackage(ctx context.Context, info *PackageInfo) {}\n";
let tree = parser.parse(src.as_slice(), None).unwrap();
let func = (0..tree.root_node().named_child_count())
.filter_map(|i| tree.root_node().named_child(i as u32))
.find(|n| n.kind() == "function_declaration")
.expect("file should have a function_declaration");
let params = function_params(func, src);
assert!(
!params.contains(&"ctx".to_string()),
"ctx context.Context must be dropped: got {:?}",
params
);
assert!(
!params.contains(&"context".to_string()) && !params.contains(&"Context".to_string()),
"type-segment idents must not leak: got {:?}",
params
);
assert!(
params.contains(&"info".to_string()),
"non-context typed params keep their name: got {:?}",
params
);
assert!(
!params.contains(&"PackageInfo".to_string()),
"type-segment idents must not leak from non-context params either: got {:?}",
params
);
}
#[test]
fn collect_param_names_go_keeps_framework_context_param() {
use super::function_params;
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter::Language::from(tree_sitter_go::LANGUAGE))
.unwrap();
let src = b"package x\nfunc Handle(ctx *context.APIContext) {}\n";
let tree = parser.parse(src.as_slice(), None).unwrap();
let func = (0..tree.root_node().named_child_count())
.filter_map(|i| tree.root_node().named_child(i as u32))
.find(|n| n.kind() == "function_declaration")
.expect("file should have a function_declaration");
let params = function_params(func, src);
assert!(
params.contains(&"ctx".to_string()),
"framework-bearing ctx must survive: got {:?}",
params
);
}
#[test]
fn collect_param_names_go_multi_name_param_decl() {
use super::function_params;
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter::Language::from(tree_sitter_go::LANGUAGE))
.unwrap();
let src = b"package x\nfunc Add(a, b int, ctx context.Context) {}\n";
let tree = parser.parse(src.as_slice(), None).unwrap();
let func = (0..tree.root_node().named_child_count())
.filter_map(|i| tree.root_node().named_child(i as u32))
.find(|n| n.kind() == "function_declaration")
.expect("file should have a function_declaration");
let params = function_params(func, src);
assert!(params.contains(&"a".to_string()), "got {:?}", params);
assert!(params.contains(&"b".to_string()), "got {:?}", params);
assert!(!params.contains(&"ctx".to_string()), "got {:?}", params);
assert!(!params.contains(&"int".to_string()), "got {:?}", params);
}
mod ruby_visibility_and_callbacks {
use super::super::{
RubyVisibility, ruby_callback_target_names, ruby_method_is_callback_or_private,
ruby_method_visibility,
};
use tree_sitter::{Node, Parser, Tree};
fn parse(src: &str) -> (Tree, Vec<u8>) {
let mut parser = Parser::new();
parser
.set_language(&tree_sitter::Language::from(tree_sitter_ruby::LANGUAGE))
.unwrap();
let bytes = src.as_bytes().to_vec();
let tree = parser.parse(bytes.as_slice(), None).expect("parse");
(tree, bytes)
}
fn find_class_body<'a>(node: Node<'a>) -> Option<Node<'a>> {
if node.kind() == "class" {
return node.child_by_field_name("body");
}
for idx in 0..node.named_child_count() {
let Some(child) = node.named_child(idx as u32) else {
continue;
};
if let Some(body) = find_class_body(child) {
return Some(body);
}
}
None
}
#[test]
fn bare_private_directive_marks_subsequent_methods_private() {
let src = "class C\n def public_a; end\n private\n def helper_b; end\n def helper_c; end\nend\n";
let (tree, bytes) = parse(src);
let body = find_class_body(tree.root_node()).expect("body");
let vis = ruby_method_visibility(body, &bytes);
assert_eq!(vis.get("public_a").copied(), Some(RubyVisibility::Public));
assert_eq!(vis.get("helper_b").copied(), Some(RubyVisibility::Private));
assert_eq!(vis.get("helper_c").copied(), Some(RubyVisibility::Private));
}
#[test]
fn targeted_private_marks_only_named_methods() {
let src = "class C\n def a; end\n def b; end\n def c; end\n private :a, :c\nend\n";
let (tree, bytes) = parse(src);
let body = find_class_body(tree.root_node()).expect("body");
let vis = ruby_method_visibility(body, &bytes);
assert_eq!(vis.get("a").copied(), Some(RubyVisibility::Private));
assert_eq!(vis.get("b").copied(), Some(RubyVisibility::Public));
assert_eq!(vis.get("c").copied(), Some(RubyVisibility::Private));
}
#[test]
fn public_directive_re_opens_visibility() {
let src = "class C\n private\n def a; end\n public\n def b; end\nend\n";
let (tree, bytes) = parse(src);
let body = find_class_body(tree.root_node()).expect("body");
let vis = ruby_method_visibility(body, &bytes);
assert_eq!(vis.get("a").copied(), Some(RubyVisibility::Private));
assert_eq!(vis.get("b").copied(), Some(RubyVisibility::Public));
}
#[test]
fn protected_directive_recognised() {
let src = "class C\n protected\n def helper; end\nend\n";
let (tree, bytes) = parse(src);
let body = find_class_body(tree.root_node()).expect("body");
let vis = ruby_method_visibility(body, &bytes);
assert_eq!(vis.get("helper").copied(), Some(RubyVisibility::Protected));
}
#[test]
fn before_action_collects_callback_target_names() {
let src = "class C\n before_action :set_account\n before_action :set_user, only: [:show, :update]\n def show; end\n def set_account; end\n def set_user; end\nend\n";
let (tree, bytes) = parse(src);
let body = find_class_body(tree.root_node()).expect("body");
let callbacks = ruby_callback_target_names(body, &bytes);
assert!(callbacks.contains("set_account"));
assert!(callbacks.contains("set_user"));
assert!(!callbacks.contains("show"));
assert!(!callbacks.contains("update"));
assert!(!callbacks.contains("only"));
}
#[test]
fn before_action_block_form_yields_no_targets() {
let src =
"class C\n before_action do\n require_login\n end\n def show; end\nend\n";
let (tree, bytes) = parse(src);
let body = find_class_body(tree.root_node()).expect("body");
let callbacks = ruby_callback_target_names(body, &bytes);
assert!(callbacks.is_empty(), "got {:?}", callbacks);
}
#[test]
fn skip_before_action_target_collected() {
let src = "class C\n skip_before_action :authenticate_user!, only: [:index]\n def index; end\nend\n";
let (tree, bytes) = parse(src);
let body = find_class_body(tree.root_node()).expect("body");
let callbacks = ruby_callback_target_names(body, &bytes);
assert!(callbacks.contains("authenticate_user!"));
}
#[test]
fn legacy_before_filter_alias_collected() {
let src = "class C\n before_filter :legacy_helper\n def legacy_helper; end\nend\n";
let (tree, bytes) = parse(src);
let body = find_class_body(tree.root_node()).expect("body");
let callbacks = ruby_callback_target_names(body, &bytes);
assert!(callbacks.contains("legacy_helper"));
}
#[test]
fn callback_target_or_private_predicate_combines_layers() {
let src = "class C\n before_action :set_account\n def show; end\n def set_account; end\n private\n def helper; end\nend\n";
let (tree, bytes) = parse(src);
let body = find_class_body(tree.root_node()).expect("body");
let visibility = ruby_method_visibility(body, &bytes);
let callbacks = ruby_callback_target_names(body, &bytes);
assert!(!ruby_method_is_callback_or_private(
"show",
&visibility,
&callbacks
));
assert!(ruby_method_is_callback_or_private(
"set_account",
&visibility,
&callbacks
));
assert!(ruby_method_is_callback_or_private(
"helper",
&visibility,
&callbacks
));
}
}
}