use super::config::AuthAnalysisRules;
use super::model::AuthorizationModel;
use crate::utils::project::{FrameworkContext, rust_file_imports_web_framework};
use std::path::Path;
use tree_sitter::Tree;
pub mod actix_web;
pub mod axum;
pub mod common;
pub mod django;
pub mod echo;
pub mod express;
pub mod fastify;
pub mod flask;
pub mod gin;
pub mod koa;
pub mod rails;
pub mod rocket;
pub mod sinatra;
pub mod spring;
pub trait AuthExtractor {
fn supports(&self, lang: &str, framework_ctx: Option<&FrameworkContext>) -> bool;
fn extract(
&self,
tree: &Tree,
bytes: &[u8],
path: &Path,
rules: &AuthAnalysisRules,
) -> AuthorizationModel;
}
pub fn extract_authorization_model(
lang: &str,
framework_ctx: Option<&FrameworkContext>,
tree: &Tree,
bytes: &[u8],
path: &Path,
rules: &AuthAnalysisRules,
) -> AuthorizationModel {
let extractors: [&dyn AuthExtractor; 13] = [
&express::ExpressExtractor,
&koa::KoaExtractor,
&fastify::FastifyExtractor,
&gin::GinExtractor,
&echo::EchoExtractor,
&flask::FlaskExtractor,
&django::DjangoExtractor,
&spring::SpringExtractor,
&rails::RailsExtractor,
&sinatra::SinatraExtractor,
&axum::AxumExtractor,
&actix_web::ActixWebExtractor,
&rocket::RocketExtractor,
];
let mut model = AuthorizationModel {
lang: lang.to_string(),
..Default::default()
};
for extractor in extractors {
if extractor.supports(lang, framework_ctx) {
let mut other = extractor.extract(tree, bytes, path, rules);
other.lang = model.lang.clone();
model.extend(other);
}
}
model.lang_web_framework_signal = compute_web_framework_signal(lang, framework_ctx, bytes);
deduplicate_units_by_span(&mut model);
model
}
fn compute_web_framework_signal(
lang: &str,
framework_ctx: Option<&FrameworkContext>,
bytes: &[u8],
) -> Option<bool> {
if !matches!(lang, "rust" | "rs") {
return None;
}
let project_signal = framework_ctx.and_then(|ctx| ctx.lang_has_web_framework("rust"));
if project_signal == Some(true) {
return Some(true);
}
if rust_file_imports_web_framework(bytes) {
return Some(true);
}
project_signal
}
fn deduplicate_units_by_span(model: &mut AuthorizationModel) {
use crate::auth_analysis::model::{AnalysisUnit, AnalysisUnitKind};
use std::collections::HashMap;
let mut winner_by_span: HashMap<(usize, usize), usize> = HashMap::new();
for (idx, unit) in model.units.iter().enumerate() {
let key = unit.span;
match winner_by_span.get(&key) {
None => {
winner_by_span.insert(key, idx);
}
Some(&existing) => {
let prev_kind = model.units[existing].kind;
if prev_kind != AnalysisUnitKind::RouteHandler
&& unit.kind == AnalysisUnitKind::RouteHandler
{
winner_by_span.insert(key, idx);
}
}
}
}
let mut moved_checks: Vec<Vec<crate::auth_analysis::model::AuthCheck>> =
Vec::with_capacity(model.units.len());
for old_idx in 0..model.units.len() {
let span = model.units[old_idx].span;
let winner = *winner_by_span.get(&span).unwrap_or(&old_idx);
if winner == old_idx {
moved_checks.push(Vec::new());
} else {
moved_checks.push(std::mem::take(&mut model.units[old_idx].auth_checks));
}
}
let mut new_idx_for_old: HashMap<usize, usize> = HashMap::new();
let mut surviving: Vec<AnalysisUnit> = Vec::with_capacity(winner_by_span.len());
for old_idx in 0..model.units.len() {
let span = model.units[old_idx].span;
let winner = *winner_by_span.get(&span).unwrap_or(&old_idx);
if winner == old_idx {
new_idx_for_old.insert(old_idx, surviving.len());
surviving.push(model.units[old_idx].clone());
}
}
for (old_idx, checks) in moved_checks.iter_mut().enumerate() {
let span = model.units[old_idx].span;
let winner = *winner_by_span.get(&span).unwrap_or(&old_idx);
if winner == old_idx {
continue;
}
let Some(&new_winner_idx) = new_idx_for_old.get(&winner) else {
continue;
};
for check in checks.drain(..) {
let already_present = surviving[new_winner_idx]
.auth_checks
.iter()
.any(|existing| existing.span == check.span && existing.callee == check.callee);
if !already_present {
surviving[new_winner_idx].auth_checks.push(check);
}
}
}
model.units = surviving;
for route in &mut model.routes {
if let Some(&new_idx) = new_idx_for_old.get(&route.unit_idx) {
route.unit_idx = new_idx;
}
}
}