use crate::detectors::file_cache::FileContentCache;
use crate::detectors::taint::{TaintAnalyzer, TaintCategory, TaintPath};
use crate::graph::{GraphQuery, GraphQueryExt};
use crate::parsers::lightweight::Language;
use rayon::prelude::*;
use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, Mutex};
use tracing::{debug, info};
const ALL_CATEGORIES: &[TaintCategory] = &[
TaintCategory::SqlInjection,
TaintCategory::CommandInjection,
TaintCategory::Xss,
TaintCategory::Ssrf,
TaintCategory::PathTraversal,
TaintCategory::CodeInjection,
TaintCategory::LogInjection,
TaintCategory::Xxe,
];
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CentralizedTaintResults {
pub cross_function: HashMap<TaintCategory, Vec<TaintPath>>,
pub intra_function: HashMap<TaintCategory, Vec<TaintPath>>,
}
impl CentralizedTaintResults {
pub fn paths_for(&self, category: TaintCategory) -> Vec<TaintPath> {
let mut paths = self
.cross_function
.get(&category)
.cloned()
.unwrap_or_default();
if let Some(intra) = self.intra_function.get(&category) {
paths.extend(intra.iter().cloned());
}
paths
}
pub fn intra_paths_for(&self, category: TaintCategory) -> Vec<TaintPath> {
self.intra_function
.get(&category)
.cloned()
.unwrap_or_default()
}
}
pub fn run_centralized_taint(
graph: &dyn GraphQuery,
repository_path: &Path,
file_cache: Option<&Arc<FileContentCache>>,
) -> CentralizedTaintResults {
let _i = graph.interner();
let analyzer = TaintAnalyzer::new();
let start = std::time::Instant::now();
let functions = graph.get_functions_shared();
let (cross_function, intra_function) = std::thread::scope(|s| {
let cross_handle = s.spawn(|| -> HashMap<TaintCategory, Vec<TaintPath>> {
ALL_CATEGORIES
.par_iter()
.map(|&category| {
let paths =
analyzer.trace_taint_with_functions(graph, category, Some(&functions));
(category, paths)
})
.collect()
});
let intra = run_intra_all_categories(&analyzer, graph, repository_path, file_cache);
let cross = cross_handle.join().expect("cross-function taint panicked");
(cross, intra)
});
let elapsed = start.elapsed();
let total_cross: usize = cross_function.values().map(|v| v.len()).sum();
let total_intra: usize = intra_function.values().map(|v| v.len()).sum();
info!(
"Centralized taint: {} cross + {} intra paths across {} categories in {:?}",
total_cross,
total_intra,
ALL_CATEGORIES.len(),
elapsed,
);
CentralizedTaintResults {
cross_function,
intra_function,
}
}
fn run_intra_all_categories(
analyzer: &TaintAnalyzer,
graph: &dyn GraphQuery,
repository_path: &Path,
file_cache: Option<&Arc<FileContentCache>>,
) -> HashMap<TaintCategory, Vec<TaintPath>> {
let i = graph.interner();
let functions = graph.get_functions_shared();
let shared_cache = file_cache
.cloned()
.unwrap_or_else(|| Arc::new(FileContentCache::new()));
let mut by_file: HashMap<&str, Vec<usize>> = HashMap::new();
for (idx, func) in functions.iter().enumerate() {
if !func.path(i).is_empty() {
by_file.entry(func.path(i)).or_default().push(idx);
}
}
let file_groups: Vec<(&str, Vec<usize>)> = by_file.into_iter().collect();
let results: Mutex<HashMap<TaintCategory, Vec<TaintPath>>> = {
let mut m = HashMap::new();
for &cat in ALL_CATEGORIES {
m.insert(cat, Vec::new());
}
Mutex::new(m)
};
file_groups
.par_iter()
.for_each(|(file_path, func_indices)| {
let full_path = repository_path.join(file_path);
let ext = full_path.extension().and_then(|e| e.to_str()).unwrap_or("");
let language = Language::from_extension(ext);
let content: Arc<String> = match shared_cache.get_or_read(&full_path) {
Some(c) => c,
None => return,
};
let relevant_categories: Vec<TaintCategory> = ALL_CATEGORIES
.iter()
.copied()
.filter(|cat| {
cat.relevant_extensions().contains(&ext) && cat.file_might_be_relevant(&content)
})
.collect();
if relevant_categories.is_empty() {
return;
}
let lines: Vec<&str> = content.lines().collect();
let mut file_results: Vec<(TaintCategory, Vec<TaintPath>)> = Vec::new();
for &func_idx in func_indices {
let func = &functions[func_idx];
let line_start = func.line_start as usize;
let line_end = func.get_i64("lineEnd").unwrap_or(0) as usize;
if line_start == 0 || line_end == 0 || line_end < line_start {
continue;
}
if line_end > lines.len() {
continue;
}
let func_body = lines[line_start.saturating_sub(1)..line_end].join("\n");
for &category in &relevant_categories {
let paths = analyzer.analyze_intra_function(
&func_body,
func.node_name(i),
func.path(i),
line_start,
language,
category,
);
if !paths.is_empty() {
file_results.push((category, paths));
}
}
}
if !file_results.is_empty() {
let mut results = results.lock().unwrap_or_else(|e| e.into_inner());
for (category, paths) in file_results {
if let Some(cat_results) = results.get_mut(&category) {
cat_results.extend(paths);
}
}
}
});
let mut final_results = results.into_inner().unwrap_or_else(|e| e.into_inner());
for paths in final_results.values_mut() {
paths.sort_by(|a, b| {
a.source_file
.cmp(&b.source_file)
.then_with(|| a.source_line.cmp(&b.source_line))
.then_with(|| a.source_function.cmp(&b.source_function))
.then_with(|| a.sink_file.cmp(&b.sink_file))
.then_with(|| a.sink_function.cmp(&b.sink_function))
});
}
final_results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_all_categories_covered() {
assert_eq!(ALL_CATEGORIES.len(), 8);
assert!(ALL_CATEGORIES.contains(&TaintCategory::Xxe));
}
#[test]
fn test_centralized_results_paths_for_empty() {
let results = CentralizedTaintResults {
cross_function: HashMap::new(),
intra_function: HashMap::new(),
};
assert!(results.paths_for(TaintCategory::SqlInjection).is_empty());
assert!(results
.intra_paths_for(TaintCategory::CommandInjection)
.is_empty());
}
}