use crate::flow::FlowContext;
use rma_common::{Confidence, Finding, FindingCategory, FindingSource, Language};
use rma_parser::ParsedFile;
use tree_sitter::Node;
pub trait Rule: Send + Sync {
fn id(&self) -> &str;
fn description(&self) -> &str;
fn applies_to(&self, lang: Language) -> bool;
fn check(&self, parsed: &ParsedFile) -> Vec<Finding>;
fn check_with_flow(&self, parsed: &ParsedFile, _flow: &FlowContext) -> Vec<Finding> {
self.check(parsed)
}
fn uses_flow(&self) -> bool {
false
}
}
#[derive(Debug, Clone, Copy)]
pub enum SnippetContext {
SingleLine,
Expression,
Statement,
Block,
Function,
TaintFlow {
source_line: usize,
sink_line: usize,
},
MultiLine { start_line: usize, end_line: usize },
}
impl SnippetContext {
pub fn from_node(node: &Node, rule_id: &str) -> Self {
let start_line = node.start_position().row;
let end_line = node.end_position().row;
let line_span = end_line - start_line;
let node_kind = node.kind();
let is_flow_rule = rule_id.contains("injection")
|| rule_id.contains("taint")
|| rule_id.contains("xss")
|| rule_id.contains("traversal");
if line_span == 0 {
return Self::SingleLine;
}
if is_flow_rule && line_span > 1 {
return Self::MultiLine {
start_line,
end_line,
};
}
match node_kind {
"function_declaration"
| "function_definition"
| "method_definition"
| "fn_item"
| "function_item"
| "arrow_function" => Self::Function,
"if_statement" | "if_expression" | "for_statement" | "while_statement"
| "try_statement" | "match_expression" | "switch_statement" => Self::Block,
"expression_statement"
| "return_statement"
| "variable_declaration"
| "let_declaration"
| "assignment_expression" => Self::Statement,
"call_expression" | "member_expression" | "binary_expression" => Self::Expression,
_ => {
if line_span <= 3 {
Self::MultiLine {
start_line,
end_line,
}
} else if line_span <= 10 {
Self::Block
} else {
Self::Function
}
}
}
}
pub fn char_limit(&self) -> usize {
match self {
Self::SingleLine => 200,
Self::Expression => 300,
Self::Statement => 400,
Self::Block => 800,
Self::Function => 1200,
Self::TaintFlow {
source_line,
sink_line,
} => {
let distance = sink_line.saturating_sub(*source_line);
(300 + distance * 50).min(1500)
}
Self::MultiLine {
start_line,
end_line,
} => {
let lines = end_line.saturating_sub(*start_line) + 1;
(lines * 100).clamp(200, 1000)
}
}
}
pub fn context_lines(&self) -> usize {
match self {
Self::SingleLine => 0,
Self::Expression => 1,
Self::Statement => 1,
Self::Block => 2,
Self::Function => 3,
Self::TaintFlow { .. } => 2,
Self::MultiLine { .. } => 1,
}
}
}
pub fn extract_smart_snippet(node: &Node, content: &str, rule_id: &str) -> Option<String> {
let ctx = SnippetContext::from_node(node, rule_id);
let limit = ctx.char_limit();
let context_lines = ctx.context_lines();
let text = node.utf8_text(content.as_bytes()).ok()?;
let char_count = text.chars().count();
if char_count <= limit {
if context_lines > 0 {
return Some(extract_with_context(node, content, context_lines));
}
return Some(text.to_string());
}
match ctx {
SnippetContext::SingleLine | SnippetContext::Expression => {
let truncated: String = text.chars().take(limit).collect();
Some(format!("{}...", truncated.trim_end()))
}
SnippetContext::Function => {
extract_function_snippet(text, limit)
}
SnippetContext::Block => {
extract_block_snippet(text, limit)
}
SnippetContext::TaintFlow {
source_line,
sink_line,
} => {
extract_flow_snippet(content, source_line, sink_line, limit)
}
_ => {
extract_head_tail_snippet(text, limit)
}
}
}
fn extract_with_context(node: &Node, content: &str, context_lines: usize) -> String {
let lines: Vec<&str> = content.lines().collect();
let start_line = node.start_position().row;
let end_line = node.end_position().row;
let ctx_start = start_line.saturating_sub(context_lines);
let ctx_end = (end_line + context_lines).min(lines.len().saturating_sub(1));
lines[ctx_start..=ctx_end].join("\n")
}
fn extract_function_snippet(text: &str, limit: usize) -> Option<String> {
let lines: Vec<&str> = text.lines().collect();
if lines.len() <= 5 {
return Some(text.to_string());
}
let head: String = lines[..3].join("\n");
let tail: String = lines[lines.len() - 2..].join("\n");
let result = format!(
"{}\n // ... ({} lines omitted)\n{}",
head,
lines.len() - 5,
tail
);
if result.chars().count() <= limit {
Some(result)
} else {
let truncated: String = text.chars().take(limit).collect();
Some(format!("{}...", truncated))
}
}
fn extract_block_snippet(text: &str, limit: usize) -> Option<String> {
let lines: Vec<&str> = text.lines().collect();
if lines.len() <= 6 {
return Some(text.to_string());
}
let head: String = lines[..2].join("\n");
let tail: String = lines[lines.len() - 2..].join("\n");
let result = format!("{}\n // ... ({} lines)\n{}", head, lines.len() - 4, tail);
if result.chars().count() <= limit {
Some(result)
} else {
let truncated: String = text.chars().take(limit).collect();
Some(format!("{}...", truncated))
}
}
fn extract_flow_snippet(
content: &str,
source_line: usize,
sink_line: usize,
limit: usize,
) -> Option<String> {
let lines: Vec<&str> = content.lines().collect();
if source_line >= lines.len() || sink_line >= lines.len() {
return None;
}
let source = lines.get(source_line).unwrap_or(&"");
let sink = lines.get(sink_line).unwrap_or(&"");
let distance = sink_line.saturating_sub(source_line);
let result = if distance <= 3 {
lines[source_line..=sink_line].join("\n")
} else {
format!(
"{}\n // ... taint flows through {} lines ...\n{}",
source.trim(),
distance - 1,
sink.trim()
)
};
if result.chars().count() <= limit {
Some(result)
} else {
Some(format!(
"{}...",
result.chars().take(limit).collect::<String>()
))
}
}
fn extract_head_tail_snippet(text: &str, limit: usize) -> Option<String> {
let chars: Vec<char> = text.chars().collect();
let total = chars.len();
if total <= limit {
return Some(text.to_string());
}
let head_len = (limit * 6) / 10;
let tail_len = limit - head_len - 20;
let head: String = chars[..head_len].iter().collect();
let tail: String = chars[total - tail_len..].iter().collect();
Some(format!("{}...{}", head.trim_end(), tail.trim_start()))
}
pub fn create_finding_at_line(
rule_id: &str,
path: &std::path::Path,
line: usize,
snippet: &str,
severity: rma_common::Severity,
message: &str,
language: Language,
) -> Finding {
let mut finding = Finding {
id: format!("{}-{}-1", rule_id, line),
rule_id: rule_id.to_string(),
message: message.to_string(),
severity,
location: rma_common::SourceLocation::new(path.to_path_buf(), line, 1, line, snippet.len()),
language,
snippet: Some(snippet.to_string()),
suggestion: None,
fix: None,
confidence: Confidence::Medium,
category: infer_category(rule_id),
subcategory: None,
technology: None,
impact: None,
likelihood: None,
source: infer_source(rule_id),
fingerprint: None,
properties: None,
occurrence_count: None,
additional_locations: None,
};
finding.compute_fingerprint();
finding
}
pub fn create_finding(
rule_id: &str,
node: &Node,
path: &std::path::Path,
content: &str,
severity: rma_common::Severity,
message: &str,
language: Language,
) -> Finding {
let start = node.start_position();
let end = node.end_position();
let snippet = extract_smart_snippet(node, content, rule_id);
let mut finding = Finding {
id: format!("{}-{}-{}", rule_id, start.row, start.column),
rule_id: rule_id.to_string(),
message: message.to_string(),
severity,
location: rma_common::SourceLocation::new(
path.to_path_buf(),
start.row + 1,
start.column + 1,
end.row + 1,
end.column + 1,
),
language,
snippet,
suggestion: None,
fix: None,
confidence: Confidence::Medium,
category: infer_category(rule_id),
subcategory: None,
technology: None,
impact: None,
likelihood: None,
source: infer_source(rule_id),
fingerprint: None,
properties: None,
occurrence_count: None,
additional_locations: None,
};
finding.compute_fingerprint();
finding
}
#[allow(clippy::too_many_arguments)]
pub fn create_finding_with_confidence(
rule_id: &str,
node: &Node,
path: &std::path::Path,
content: &str,
severity: rma_common::Severity,
message: &str,
language: Language,
confidence: Confidence,
) -> Finding {
let mut finding = create_finding(rule_id, node, path, content, severity, message, language);
finding.confidence = confidence;
finding
}
fn infer_source(rule_id: &str) -> FindingSource {
if let Some(suffix) = rule_id.split('/').nth(1) {
if suffix.starts_with("gen-pysa-") {
return FindingSource::Pysa;
}
if suffix.starts_with("gen-") {
return FindingSource::Codeql;
}
}
FindingSource::Builtin
}
fn infer_category(rule_id: &str) -> FindingCategory {
if rule_id.contains("injection")
|| rule_id.contains("xss")
|| rule_id.contains("unsafe")
|| rule_id.contains("secret")
|| rule_id.contains("crypto")
|| rule_id.contains("traversal")
|| rule_id.contains("eval")
|| rule_id.contains("exec")
|| rule_id.contains("transmute")
|| rule_id.contains("deserialization")
{
return FindingCategory::Security;
}
if rule_id.contains("complexity")
|| rule_id.contains("long-function")
|| rule_id.contains("unwrap")
|| rule_id.contains("expect")
|| rule_id.contains("panic")
{
return FindingCategory::Quality;
}
if rule_id.contains("todo")
|| rule_id.contains("fixme")
|| rule_id.contains("console")
|| rule_id.contains("style")
{
return FindingCategory::Style;
}
if rule_id.starts_with("security/") || rule_id.contains("/sql") || rule_id.contains("/command")
{
FindingCategory::Security
} else if rule_id.starts_with("quality/") || rule_id.contains("/quality") {
FindingCategory::Quality
} else if rule_id.starts_with("style/") || rule_id.starts_with("generic/todo") {
FindingCategory::Style
} else {
FindingCategory::Quality }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_infer_source_builtin() {
assert_eq!(infer_source("js/sql-injection"), FindingSource::Builtin);
assert_eq!(infer_source("go/pg-orm-sqli"), FindingSource::Builtin);
assert_eq!(infer_source("java/ecb-cipher"), FindingSource::Builtin);
assert_eq!(
infer_source("python/eval-injection"),
FindingSource::Builtin
);
assert_eq!(infer_source("generic/todo-fixme"), FindingSource::Builtin);
}
#[test]
fn test_infer_source_codeql() {
assert_eq!(infer_source("go/gen-manual"), FindingSource::Codeql);
assert_eq!(infer_source("java/gen-manual"), FindingSource::Codeql);
assert_eq!(infer_source("java/gen-ai-manual"), FindingSource::Codeql);
assert_eq!(
infer_source("javascript/gen-sql-injection"),
FindingSource::Codeql
);
assert_eq!(
infer_source("javascript/gen-credentials-key"),
FindingSource::Codeql
);
assert_eq!(
infer_source("javascript/gen-path-injection"),
FindingSource::Codeql
);
assert_eq!(infer_source("cpp/gen-manual"), FindingSource::Codeql);
}
#[test]
fn test_infer_source_pysa() {
assert_eq!(infer_source("python/gen-pysa-xss"), FindingSource::Pysa);
assert_eq!(
infer_source("python/gen-pysa-authentication"),
FindingSource::Pysa
);
assert_eq!(
infer_source("python/gen-pysa-execargsink"),
FindingSource::Pysa
);
assert_eq!(
infer_source("python/gen-pysa-remotecodeexecution"),
FindingSource::Pysa
);
assert_eq!(
infer_source("python/gen-pysa-filesystem_readwrite"),
FindingSource::Pysa
);
}
#[test]
fn test_infer_source_no_slash() {
assert_eq!(infer_source("some-rule"), FindingSource::Builtin);
assert_eq!(infer_source("gen-manual"), FindingSource::Builtin);
}
}