use super::types::*;
use rma_common::Language;
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
#[derive(Debug)]
pub struct MergedKnowledge {
pub language: Language,
pub active_frameworks: Vec<&'static str>,
source_functions: HashSet<&'static str>,
source_members: HashSet<&'static str>,
source_type_extractors: HashSet<&'static str>,
source_method_on_type: HashMap<&'static str, Vec<&'static str>>,
parameters_are_sources: bool,
all_sources: Vec<&'static SourceDef>,
sink_functions: HashSet<&'static str>,
sink_methods: HashSet<&'static str>,
sink_macros: HashSet<&'static str>,
sink_response_bodies: HashSet<&'static str>,
all_sinks: Vec<&'static SinkDef>,
sanitizer_functions: HashSet<&'static str>,
sanitizer_methods: HashSet<&'static str>,
sanitizer_macros: HashSet<&'static str>,
sanitizer_targets: HashMap<&'static str, &'static str>,
all_sanitizers: Vec<&'static SanitizerDef>,
safe_patterns: HashSet<&'static str>,
all_safe_patterns: Vec<&'static SafePattern>,
all_dangerous_patterns: Vec<&'static DangerousPattern>,
all_resource_types: Vec<&'static ResourceType>,
}
impl MergedKnowledge {
pub fn from_profiles(language: Language, profiles: Vec<&'static FrameworkProfile>) -> Self {
let mut merged = Self {
language,
active_frameworks: Vec::with_capacity(profiles.len()),
source_functions: HashSet::new(),
source_members: HashSet::new(),
source_type_extractors: HashSet::new(),
source_method_on_type: HashMap::new(),
parameters_are_sources: false,
all_sources: Vec::new(),
sink_functions: HashSet::new(),
sink_methods: HashSet::new(),
sink_macros: HashSet::new(),
sink_response_bodies: HashSet::new(),
all_sinks: Vec::new(),
sanitizer_functions: HashSet::new(),
sanitizer_methods: HashSet::new(),
sanitizer_macros: HashSet::new(),
sanitizer_targets: HashMap::new(),
all_sanitizers: Vec::new(),
safe_patterns: HashSet::new(),
all_safe_patterns: Vec::new(),
all_dangerous_patterns: Vec::new(),
all_resource_types: Vec::new(),
};
for profile in profiles {
merged.merge_profile(profile);
}
merged
}
pub fn empty(language: Language) -> Self {
Self {
language,
active_frameworks: Vec::new(),
source_functions: HashSet::new(),
source_members: HashSet::new(),
source_type_extractors: HashSet::new(),
source_method_on_type: HashMap::new(),
parameters_are_sources: false,
all_sources: Vec::new(),
sink_functions: HashSet::new(),
sink_methods: HashSet::new(),
sink_macros: HashSet::new(),
sink_response_bodies: HashSet::new(),
all_sinks: Vec::new(),
sanitizer_functions: HashSet::new(),
sanitizer_methods: HashSet::new(),
sanitizer_macros: HashSet::new(),
sanitizer_targets: HashMap::new(),
all_sanitizers: Vec::new(),
safe_patterns: HashSet::new(),
all_safe_patterns: Vec::new(),
all_dangerous_patterns: Vec::new(),
all_resource_types: Vec::new(),
}
}
fn merge_profile(&mut self, profile: &'static FrameworkProfile) {
self.active_frameworks.push(profile.name);
for source in profile.sources {
self.all_sources.push(source);
match &source.pattern {
SourceKind::FunctionCall(name) => {
self.source_functions.insert(name);
}
SourceKind::MemberAccess(path) => {
self.source_members.insert(path);
}
SourceKind::TypeExtractor(name) => {
self.source_type_extractors.insert(name);
}
SourceKind::MethodOnType {
type_pattern,
method,
} => {
self.source_method_on_type
.entry(type_pattern)
.or_default()
.push(method);
}
SourceKind::Parameter => {
self.parameters_are_sources = true;
}
}
}
for sink in profile.sinks {
self.all_sinks.push(sink);
match &sink.pattern {
SinkKind::FunctionCall(name) => {
self.sink_functions.insert(name);
}
SinkKind::MethodCall(name) => {
self.sink_methods.insert(name);
}
SinkKind::MacroInvocation(name) => {
self.sink_macros.insert(name);
}
SinkKind::ResponseBody(name) => {
self.sink_response_bodies.insert(name);
}
SinkKind::PropertyAssignment(_) | SinkKind::TemplateInsertion => {
}
}
}
for sanitizer in profile.sanitizers {
self.all_sanitizers.push(sanitizer);
let key = match &sanitizer.pattern {
SanitizerKind::Function(name) => {
self.sanitizer_functions.insert(name);
*name
}
SanitizerKind::MethodCall(name) => {
self.sanitizer_methods.insert(name);
*name
}
SanitizerKind::Macro(name) => {
self.sanitizer_macros.insert(name);
*name
}
SanitizerKind::TemplateEngine(name) => {
self.sanitizer_functions.insert(name);
*name
}
};
self.sanitizer_targets.insert(key, sanitizer.sanitizes);
}
for pattern in profile.safe_patterns {
self.safe_patterns.insert(pattern.name);
self.all_safe_patterns.push(pattern);
}
self.all_dangerous_patterns
.extend(profile.dangerous_patterns.iter());
self.all_resource_types
.extend(profile.resource_types.iter());
}
#[inline]
pub fn is_source_function(&self, func_name: &str) -> bool {
if self.source_functions.contains(func_name) {
return true;
}
for &source in &self.source_functions {
if func_name.ends_with(source) {
return true;
}
}
false
}
#[inline]
pub fn is_source_member(&self, member_path: &str) -> bool {
if self.source_members.contains(member_path) {
return true;
}
for &source in &self.source_members {
if member_path.contains(source) || member_path.ends_with(source) {
return true;
}
}
false
}
#[inline]
pub fn is_source_type_extractor(&self, type_name: &str) -> bool {
self.source_type_extractors.contains(type_name)
|| self
.source_type_extractors
.iter()
.any(|&t| type_name.contains(t))
}
#[inline]
pub fn parameters_are_sources(&self) -> bool {
self.parameters_are_sources
}
pub fn get_source(&self, pattern: &str) -> Option<&'static SourceDef> {
self.all_sources
.iter()
.find(|s| match &s.pattern {
SourceKind::FunctionCall(p) => pattern.contains(*p),
SourceKind::MemberAccess(p) => pattern.contains(*p),
SourceKind::TypeExtractor(p) => pattern.contains(*p),
SourceKind::MethodOnType { method, .. } => pattern.contains(*method),
SourceKind::Parameter => false,
})
.copied()
}
#[inline]
pub fn is_sink_function(&self, func_name: &str) -> bool {
if self.sink_functions.contains(func_name) {
return true;
}
for &sink in &self.sink_functions {
if func_name.ends_with(sink) || func_name.contains(sink) {
return true;
}
}
false
}
#[inline]
pub fn is_sink_method(&self, method_name: &str) -> bool {
if self.sink_methods.contains(method_name) {
return true;
}
for &sink in &self.sink_methods {
if method_name.ends_with(sink) {
return true;
}
}
false
}
#[inline]
pub fn is_sink_macro(&self, macro_name: &str) -> bool {
self.sink_macros.contains(macro_name)
}
pub fn is_sink_property(&self, prop_name: &str) -> bool {
self.all_sinks.iter().any(|s| match &s.pattern {
SinkKind::PropertyAssignment(p) => prop_name == *p,
_ => false,
})
}
pub fn get_sink(&self, pattern: &str) -> Option<&'static SinkDef> {
self.all_sinks
.iter()
.find(|s| match &s.pattern {
SinkKind::FunctionCall(p) => pattern.contains(*p),
SinkKind::MethodCall(p) => pattern.contains(*p),
SinkKind::MacroInvocation(p) => pattern.contains(*p),
SinkKind::PropertyAssignment(p) => pattern == *p,
SinkKind::ResponseBody(p) => pattern.contains(*p),
SinkKind::TemplateInsertion => false,
})
.copied()
}
#[inline]
pub fn is_sanitizer(&self, func_name: &str) -> bool {
if self.sanitizer_functions.contains(func_name)
|| self.sanitizer_methods.contains(func_name)
|| self.sanitizer_macros.contains(func_name)
{
return true;
}
for &sanitizer in &self.sanitizer_functions {
if func_name.contains(sanitizer) || func_name.ends_with(sanitizer) {
return true;
}
}
for &sanitizer in &self.sanitizer_methods {
if func_name.ends_with(sanitizer) {
return true;
}
}
false
}
pub fn sanitizes_type(&self, func_name: &str, taint_type: &str) -> bool {
for (&key, &target) in &self.sanitizer_targets {
if func_name.contains(key) {
return target == taint_type || target == "*";
}
}
false
}
pub fn get_sanitizer(&self, func_name: &str) -> Option<&'static SanitizerDef> {
self.all_sanitizers
.iter()
.find(|s| match &s.pattern {
SanitizerKind::Function(p) => func_name.contains(*p),
SanitizerKind::MethodCall(p) => func_name.contains(*p),
SanitizerKind::Macro(p) => func_name == *p,
SanitizerKind::TemplateEngine(p) => func_name.contains(*p),
})
.copied()
}
#[inline]
pub fn is_safe_pattern(&self, pattern_name: &str) -> bool {
self.safe_patterns.contains(pattern_name)
}
pub fn dangerous_patterns(&self) -> &[&'static DangerousPattern] {
&self.all_dangerous_patterns
}
pub fn resource_types(&self) -> &[&'static ResourceType] {
&self.all_resource_types
}
pub fn source_count(&self) -> usize {
self.all_sources.len()
}
pub fn sink_count(&self) -> usize {
self.all_sinks.len()
}
pub fn sanitizer_count(&self) -> usize {
self.all_sanitizers.len()
}
pub fn has_frameworks(&self) -> bool {
!self.active_frameworks.is_empty()
}
pub fn sanitizers_for_type(&self, sanitize_type: &str) -> Vec<&'static SanitizerDef> {
self.all_sanitizers
.iter()
.filter(|s| s.sanitizes == sanitize_type || s.sanitizes == "*")
.copied()
.collect()
}
pub fn all_sanitizer_defs(&self) -> &[&'static SanitizerDef] {
&self.all_sanitizers
}
pub fn all_source_patterns(&self) -> Vec<Cow<'static, str>> {
let mut patterns = Vec::new();
for &p in &self.source_functions {
patterns.push(Cow::Borrowed(p));
}
for &p in &self.source_members {
patterns.push(Cow::Borrowed(p));
}
for &p in &self.source_type_extractors {
patterns.push(Cow::Borrowed(p));
}
patterns
}
}
pub struct KnowledgeBuilder {
language: Language,
}
impl KnowledgeBuilder {
pub fn new(language: Language) -> Self {
Self { language }
}
pub fn from_content(&self, content: &str) -> MergedKnowledge {
let profiles = super::detect_frameworks(self.language, content);
MergedKnowledge::from_profiles(self.language, profiles)
}
pub fn from_imports(&self, imports: &[&str]) -> MergedKnowledge {
let profiles = super::detect_frameworks_from_imports(self.language, imports);
MergedKnowledge::from_profiles(self.language, profiles)
}
pub fn all_profiles(&self) -> MergedKnowledge {
let profiles = super::profiles_for_language(self.language);
MergedKnowledge::from_profiles(self.language, profiles)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_merged_knowledge_javascript() {
let builder = KnowledgeBuilder::new(Language::JavaScript);
let knowledge = builder.all_profiles();
assert!(knowledge.has_frameworks());
assert!(knowledge.source_count() > 0);
assert!(knowledge.sink_count() > 0);
assert!(knowledge.sanitizer_count() > 0);
assert!(knowledge.is_source_member("req.query"));
assert!(knowledge.is_source_member("req.body"));
assert!(knowledge.is_source_member("location.search"));
assert!(knowledge.is_sink_property("innerHTML"));
assert!(knowledge.is_sink_function("eval"));
assert!(knowledge.is_sanitizer("DOMPurify.sanitize"));
assert!(knowledge.is_sanitizer("encodeURIComponent"));
}
#[test]
fn test_merged_knowledge_rust() {
let builder = KnowledgeBuilder::new(Language::Rust);
let knowledge = builder.all_profiles();
assert!(knowledge.has_frameworks());
assert!(knowledge.is_source_function("env::var"));
assert!(knowledge.is_sink_function("Command::new"));
assert!(knowledge.is_sanitizer("ammonia::clean"));
}
#[test]
fn test_framework_detection() {
let express_code = r#"
import express from 'express';
const app = express();
app.get('/user', (req, res) => {
const query = req.query.name;
});
"#;
let builder = KnowledgeBuilder::new(Language::JavaScript);
let knowledge = builder.from_content(express_code);
assert!(knowledge.active_frameworks.contains(&"express"));
}
#[test]
fn test_empty_knowledge() {
let knowledge = MergedKnowledge::empty(Language::Unknown);
assert!(!knowledge.has_frameworks());
assert_eq!(knowledge.source_count(), 0);
assert_eq!(knowledge.sink_count(), 0);
assert!(!knowledge.is_source_function("anything"));
assert!(!knowledge.is_sink_function("anything"));
}
}