pub mod registry;
pub mod python;
pub mod typescript;
pub mod go;
pub mod rust;
pub use registry::{
get_go_sources, get_python_sources, get_rust_sources, get_sources_for_language,
get_typescript_sources, MatchStrategy, SourceRegistry, TaintSource,
};
pub use go::GoSourceDetector;
pub use python::PythonSourceDetector;
pub use rust::RustSourceDetector;
pub use typescript::TypeScriptSourceDetector;
use crate::security::taint::types::{Location, TaintLabel};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum SourceKind {
RequestParam,
RequestBody,
HttpHeader,
Cookie,
FileUpload,
UrlPath,
Stdin,
ProcessArgs,
Environment,
FileRead,
HttpResponse,
SocketRecv,
DatabaseResult,
Deserialized,
ExternalApi,
WebSocketMessage,
GenericUserInput,
}
impl SourceKind {
pub fn to_taint_labels(&self) -> HashSet<TaintLabel> {
let mut labels = HashSet::new();
match self {
SourceKind::RequestParam | SourceKind::RequestBody => {
labels.insert(TaintLabel::UserInput);
}
SourceKind::HttpHeader => {
labels.insert(TaintLabel::HttpHeader);
}
SourceKind::Cookie => {
labels.insert(TaintLabel::Cookie);
}
SourceKind::FileUpload => {
labels.insert(TaintLabel::UserInput);
labels.insert(TaintLabel::FileContent);
}
SourceKind::UrlPath => {
labels.insert(TaintLabel::UrlData);
}
SourceKind::Stdin => {
labels.insert(TaintLabel::Stdin);
}
SourceKind::ProcessArgs => {
labels.insert(TaintLabel::ProcessArgs);
}
SourceKind::Environment => {
labels.insert(TaintLabel::Environment);
}
SourceKind::FileRead => {
labels.insert(TaintLabel::FileContent);
}
SourceKind::HttpResponse | SourceKind::ExternalApi => {
labels.insert(TaintLabel::NetworkData);
labels.insert(TaintLabel::ExternalApi);
}
SourceKind::SocketRecv => {
labels.insert(TaintLabel::NetworkData);
}
SourceKind::DatabaseResult => {
labels.insert(TaintLabel::DatabaseQuery);
}
SourceKind::Deserialized => {
labels.insert(TaintLabel::DeserializedData);
}
SourceKind::WebSocketMessage => {
labels.insert(TaintLabel::NetworkData);
labels.insert(TaintLabel::UserInput);
}
SourceKind::GenericUserInput => {
labels.insert(TaintLabel::UserInput);
}
}
labels
}
pub fn severity_weight(&self) -> u8 {
match self {
SourceKind::RequestParam | SourceKind::RequestBody => 10,
SourceKind::ProcessArgs | SourceKind::Stdin => 9,
SourceKind::Cookie | SourceKind::UrlPath => 8,
SourceKind::HttpHeader | SourceKind::WebSocketMessage => 7,
SourceKind::HttpResponse | SourceKind::ExternalApi => 7,
SourceKind::Deserialized => 8,
SourceKind::FileUpload | SourceKind::FileRead => 6,
SourceKind::SocketRecv => 6,
SourceKind::DatabaseResult => 5,
SourceKind::Environment => 4,
SourceKind::GenericUserInput => 5,
}
}
pub fn description(&self) -> &'static str {
match self {
SourceKind::RequestParam => "HTTP request parameter",
SourceKind::RequestBody => "HTTP request body",
SourceKind::HttpHeader => "HTTP header",
SourceKind::Cookie => "HTTP cookie",
SourceKind::FileUpload => "uploaded file",
SourceKind::UrlPath => "URL path segment",
SourceKind::Stdin => "standard input",
SourceKind::ProcessArgs => "command line argument",
SourceKind::Environment => "environment variable",
SourceKind::FileRead => "file content",
SourceKind::HttpResponse => "HTTP response",
SourceKind::SocketRecv => "socket data",
SourceKind::DatabaseResult => "database query result",
SourceKind::Deserialized => "deserialized data",
SourceKind::ExternalApi => "external API response",
SourceKind::WebSocketMessage => "WebSocket message",
SourceKind::GenericUserInput => "user input",
}
}
}
impl std::fmt::Display for SourceKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.description())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DetectedSource {
pub kind: SourceKind,
pub location: Location,
pub expression: String,
pub assigned_to: Option<String>,
pub confidence: f64,
pub framework: Option<String>,
pub context: Option<String>,
pub in_handler: bool,
pub enclosing_function: Option<String>,
}
impl DetectedSource {
pub fn new(kind: SourceKind, location: Location, expression: impl Into<String>) -> Self {
Self {
kind,
location,
expression: expression.into(),
assigned_to: None,
confidence: 1.0,
framework: None,
context: None,
in_handler: false,
enclosing_function: None,
}
}
pub fn with_assignment(mut self, var: impl Into<String>) -> Self {
self.assigned_to = Some(var.into());
self
}
pub fn with_confidence(mut self, confidence: f64) -> Self {
self.confidence = confidence.clamp(0.0, 1.0);
self
}
pub fn with_framework(mut self, framework: impl Into<String>) -> Self {
self.framework = Some(framework.into());
self
}
pub fn with_context(mut self, context: impl Into<String>) -> Self {
self.context = Some(context.into());
self
}
pub fn in_handler_function(mut self, func_name: impl Into<String>) -> Self {
self.in_handler = true;
self.enclosing_function = Some(func_name.into());
self
}
pub fn taint_labels(&self) -> HashSet<TaintLabel> {
self.kind.to_taint_labels()
}
}
#[derive(Debug, Clone)]
pub struct SourcePattern {
pub name: &'static str,
pub kind: SourceKind,
pub object: Option<&'static str>,
pub method: &'static str,
pub is_property: bool,
pub confidence: f64,
pub framework: Option<&'static str>,
}
impl SourcePattern {
pub const fn method_call(
name: &'static str,
kind: SourceKind,
object: &'static str,
method: &'static str,
framework: Option<&'static str>,
) -> Self {
Self {
name,
kind,
object: Some(object),
method,
is_property: false,
confidence: 0.95,
framework,
}
}
pub const fn property_access(
name: &'static str,
kind: SourceKind,
object: &'static str,
property: &'static str,
framework: Option<&'static str>,
) -> Self {
Self {
name,
kind,
object: Some(object),
method: property,
is_property: true,
confidence: 0.95,
framework,
}
}
pub const fn function_call(
name: &'static str,
kind: SourceKind,
function: &'static str,
confidence: f64,
) -> Self {
Self {
name,
kind,
object: None,
method: function,
is_property: false,
confidence,
framework: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandlerInfo {
pub name: String,
pub start_line: usize,
pub end_line: usize,
pub route: Option<String>,
pub methods: Vec<String>,
pub framework: String,
pub tainted_params: Vec<TaintedParameter>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaintedParameter {
pub name: String,
pub kind: SourceKind,
pub index: usize,
pub annotation: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SourceScanResult {
pub file: String,
pub sources: Vec<DetectedSource>,
pub handlers: Vec<HandlerInfo>,
pub language: String,
pub errors: Vec<String>,
}
impl SourceScanResult {
pub fn new(file: impl Into<String>, language: impl Into<String>) -> Self {
Self {
file: file.into(),
sources: Vec::new(),
handlers: Vec::new(),
language: language.into(),
errors: Vec::new(),
}
}
pub fn add_source(&mut self, source: DetectedSource) {
self.sources.push(source);
}
pub fn add_handler(&mut self, handler: HandlerInfo) {
self.handlers.push(handler);
}
pub fn add_error(&mut self, error: impl Into<String>) {
self.errors.push(error.into());
}
pub fn has_sources(&self) -> bool {
!self.sources.is_empty()
}
pub fn sources_by_kind(&self, kind: SourceKind) -> Vec<&DetectedSource> {
self.sources.iter().filter(|s| s.kind == kind).collect()
}
pub fn handler_sources(&self) -> Vec<&DetectedSource> {
self.sources.iter().filter(|s| s.in_handler).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_source_kind_to_labels() {
let labels = SourceKind::RequestParam.to_taint_labels();
assert!(labels.contains(&TaintLabel::UserInput));
let labels = SourceKind::Cookie.to_taint_labels();
assert!(labels.contains(&TaintLabel::Cookie));
let labels = SourceKind::FileUpload.to_taint_labels();
assert!(labels.contains(&TaintLabel::UserInput));
assert!(labels.contains(&TaintLabel::FileContent));
}
#[test]
fn test_source_kind_severity() {
assert!(SourceKind::RequestParam.severity_weight() > SourceKind::Environment.severity_weight());
assert!(SourceKind::ProcessArgs.severity_weight() > SourceKind::DatabaseResult.severity_weight());
}
#[test]
fn test_detected_source_builder() {
let loc = Location::new("test.py", 10, 5);
let source = DetectedSource::new(SourceKind::RequestParam, loc, "request.args.get('id')")
.with_assignment("user_id")
.with_confidence(0.9)
.with_framework("flask")
.with_context("id parameter")
.in_handler_function("get_user");
assert_eq!(source.kind, SourceKind::RequestParam);
assert_eq!(source.assigned_to, Some("user_id".to_string()));
assert!((source.confidence - 0.9).abs() < f64::EPSILON);
assert_eq!(source.framework, Some("flask".to_string()));
assert!(source.in_handler);
assert_eq!(source.enclosing_function, Some("get_user".to_string()));
}
#[test]
fn test_source_pattern() {
let pattern = SourcePattern::method_call(
"flask_request_args",
SourceKind::RequestParam,
"request",
"args",
Some("flask"),
);
assert_eq!(pattern.object, Some("request"));
assert_eq!(pattern.method, "args");
assert!(!pattern.is_property);
let prop_pattern = SourcePattern::property_access(
"django_request_GET",
SourceKind::RequestParam,
"request",
"GET",
Some("django"),
);
assert!(prop_pattern.is_property);
}
#[test]
fn test_scan_result() {
let mut result = SourceScanResult::new("test.py", "python");
let loc = Location::new("test.py", 10, 5);
let source = DetectedSource::new(SourceKind::RequestParam, loc, "request.args");
result.add_source(source);
assert!(result.has_sources());
assert_eq!(result.sources_by_kind(SourceKind::RequestParam).len(), 1);
assert!(result.sources_by_kind(SourceKind::Cookie).is_empty());
}
#[test]
fn test_handler_info() {
let handler = HandlerInfo {
name: "get_user".to_string(),
start_line: 10,
end_line: 20,
route: Some("/users/<int:id>".to_string()),
methods: vec!["GET".to_string()],
framework: "flask".to_string(),
tainted_params: vec![TaintedParameter {
name: "id".to_string(),
kind: SourceKind::UrlPath,
index: 0,
annotation: None,
}],
};
assert_eq!(handler.name, "get_user");
assert_eq!(handler.tainted_params.len(), 1);
assert_eq!(handler.tainted_params[0].kind, SourceKind::UrlPath);
}
#[test]
fn test_registry_reexport() {
let registry = get_python_sources();
assert!(!registry.is_empty());
let matches = registry.find_matches("request.args.get('id')");
assert!(!matches.is_empty());
}
}