pyrograph 0.1.0

GPU-accelerated taint analysis for supply chain malware detection
Documentation
use std::path::Path;
use serde::Deserialize;
use crate::error::Result;

/// Definition of a taint source loaded from TOML.
#[derive(Debug, Clone, Deserialize)]
pub struct SourceDef {
    pub id: String,
    pub pattern: String,
    pub category: String,
}

/// Definition of a taint sink loaded from TOML.
#[derive(Debug, Clone, Deserialize)]
pub struct SinkDef {
    pub id: String,
    pub pattern: String,
    pub category: String,
}

/// Definition of a taint sanitizer loaded from TOML.
#[derive(Debug, Clone, Deserialize)]
pub struct SanitizerDef {
    pub id: String,
    pub pattern: String,
}

/// A collection of sources and sinks loaded from one or more TOML rule files.
#[derive(Debug, Clone)]
pub struct LabelSet {
    pub sources: Vec<SourceDef>,
    pub sinks: Vec<SinkDef>,
    pub sanitizers: Vec<SanitizerDef>,
}

/// Taint classification for a node, using indices into a [`LabelSet`].
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TaintLabel {
    Source(usize),
    Sink(usize),
    Both(usize, usize),
    Sanitizer(usize),
}

impl TaintLabel {
    pub fn is_source(&self) -> bool {
        matches!(self, TaintLabel::Source(_) | TaintLabel::Both(_, _))
    }

    pub fn is_sink(&self) -> bool {
        matches!(self, TaintLabel::Sink(_) | TaintLabel::Both(_, _))
    }

    pub fn is_sanitizer(&self) -> bool {
        matches!(self, TaintLabel::Sanitizer(_))
    }
}

#[derive(Debug, Deserialize)]
struct RuleFile {
    source: Option<Vec<SourceDef>>,
    sink: Option<Vec<SinkDef>>,
    sanitizer: Option<Vec<SanitizerDef>>,
}

/// Load all `.toml` files from `dir` and merge their sources and sinks.
///
/// If `dir` does not exist, an empty [`LabelSet`] is returned.
pub fn load_labels(dir: &Path) -> Result<LabelSet> {
    let mut sources = Vec::new();
    let mut sinks = Vec::new();
    let mut sanitizers = Vec::new();

    if dir.exists() {
        let mut entries: Vec<_> = std::fs::read_dir(dir)?
            .filter_map(|e| e.ok())
            .filter(|e| {
                e.path().extension().and_then(|s| s.to_str()) == Some("toml")
            })
            .collect();
        entries.sort_by_key(|a| a.file_name());
        for entry in entries {
            let path = entry.path();
            let contents = std::fs::read_to_string(&path)?;
            let rule: RuleFile = toml::from_str(&contents)
                .map_err(|e| crate::error::Error::Analysis(format!("TOML parse error in {:?}: {}", path, e)))?;
            if let Some(srcs) = rule.source {
                sources.extend(srcs);
            }
            if let Some(snks) = rule.sink {
                sinks.extend(snks);
            }
            if let Some(sans) = rule.sanitizer {
                sanitizers.extend(sans);
            }
        }
    }

    Ok(LabelSet {
        sources,
        sinks,
        sanitizers,
    })
}

/// Normalize a node name so that `require(child_process).exec` becomes
/// `child_process.exec`, allowing TOML patterns to match across `require()`
/// wrappers.
fn normalize_for_label(name: &str) -> String {
    if let Some(inner) = name.strip_prefix("require(") {
        if let Some(idx) = inner.find(").") {
            let module = &inner[..idx];
            let suffix = &inner[idx + 2..];
            return format!("{}.{}", module, suffix);
        }
    }
    name.to_string()
}

/// Check whether a pattern matches at a word boundary in the name.
///
/// A word boundary is: start/end of string, or a non-alphanumeric character
/// (`.`, `(`, `)`, `[`, `]`, space, etc). This prevents "eval" from matching
/// "medieval" or "fetch" from matching "prefetchPolicy".
fn is_word_char(b: u8) -> bool {
    b.is_ascii_alphanumeric() || b == b'_'
}

fn matches_at_word_boundary(name: &str, pattern: &str) -> bool {
    if pattern.is_empty() {
        return false;
    }
    // Patterns ending with a non-word char (like "https://", "/bin/sh", ".ssh/id_rsa")
    // use prefix-only boundary: the char BEFORE must be a boundary, but the char
    // AFTER need not be. These are path/URL patterns where the content after the
    // pattern is part of the match (e.g., "https://evil.com").
    let last_byte = pattern.as_bytes()[pattern.len() - 1];
    let suffix_boundary_required = is_word_char(last_byte);

    let name_bytes = name.as_bytes();
    let pat_len = pattern.len();
    for (idx, _) in name.match_indices(pattern) {
        let before_ok = idx == 0 || !is_word_char(name_bytes[idx - 1]);
        let after_ok = if suffix_boundary_required {
            let after = idx + pat_len;
            after >= name_bytes.len() || !is_word_char(name_bytes[after])
        } else {
            true
        };
        if before_ok && after_ok {
            return true;
        }
    }
    false
}

/// Check whether `node_name` matches any source or sink pattern in `label_set`.
///
/// Uses word-boundary matching to prevent false positives from substring
/// collisions (e.g., "eval" must not match "medieval").
/// Patterns are also checked against a normalized form that strips
/// `require(...)` wrappers.
pub fn label_node(label_set: &LabelSet, node_name: &str) -> Option<TaintLabel> {
    let normalized = normalize_for_label(node_name);
    let names = [node_name, &normalized];

    let source_idx = names.iter().find_map(|name| {
        label_set.sources.iter().position(|s| matches_at_word_boundary(name, &s.pattern))
    });
    let sink_idx = names.iter().find_map(|name| {
        label_set.sinks.iter().position(|s| matches_at_word_boundary(name, &s.pattern))
    });
    let sanitizer_idx = names.iter().find_map(|name| {
        label_set.sanitizers.iter().position(|s| matches_at_word_boundary(name, &s.pattern))
    });

    match (source_idx, sink_idx, sanitizer_idx) {
        (Some(s), Some(k), _) => Some(TaintLabel::Both(s, k)),
        (Some(s), None, _) => Some(TaintLabel::Source(s)),
        (None, Some(k), _) => Some(TaintLabel::Sink(k)),
        (None, None, Some(z)) => Some(TaintLabel::Sanitizer(z)),
        (None, None, None) => None,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::io::Write;

    #[test]
    fn test_load_labels_from_toml() {
        let dir = tempfile::tempdir().unwrap();
        let path = dir.path().join("test.toml");
        let mut file = std::fs::File::create(&path).unwrap();
        write!(
            file,
            r#"[[source]]
id = "process-env"
pattern = "process.env"
category = "credential"

[[sink]]
id = "fetch"
pattern = "fetch"
category = "network"
"#
        )
        .unwrap();

        let label_set = load_labels(dir.path()).unwrap();
        assert_eq!(label_set.sources.len(), 1);
        assert_eq!(label_set.sinks.len(), 1);
        assert_eq!(label_set.sanitizers.len(), 0);
        assert_eq!(label_set.sources[0].id, "process-env");
        assert_eq!(label_set.sinks[0].id, "fetch");
    }

    #[test]
    fn test_label_node_matches_patterns() {
        let label_set = LabelSet {
            sources: vec![
                SourceDef {
                    id: "src1".into(),
                    pattern: "process.env".into(),
                    category: "cred".into(),
                },
                SourceDef {
                    id: "src2".into(),
                    pattern: ".npmrc".into(),
                    category: "cred".into(),
                },
            ],
            sinks: vec![
                SinkDef {
                    id: "snk1".into(),
                    pattern: "fetch".into(),
                    category: "net".into(),
                },
                SinkDef {
                    id: "snk2".into(),
                    pattern: "eval".into(),
                    category: "exec".into(),
                },
            ],
            sanitizers: vec![SanitizerDef {
                id: "san1".into(),
                pattern: "JSON.parse".into(),
            }],
        };

        assert_eq!(
            label_node(&label_set, "process.env"),
            Some(TaintLabel::Source(0))
        );
        assert_eq!(
            label_node(&label_set, "foo.process.env.bar"),
            Some(TaintLabel::Source(0))
        );
        assert_eq!(
            label_node(&label_set, ".npmrc"),
            Some(TaintLabel::Source(1))
        );
        assert_eq!(label_node(&label_set, "fetch"), Some(TaintLabel::Sink(0)));
        assert_eq!(label_node(&label_set, "eval"), Some(TaintLabel::Sink(1)));
        assert_eq!(
            label_node(&label_set, "JSON.parse"),
            Some(TaintLabel::Sanitizer(0))
        );
        assert_eq!(label_node(&label_set, "unknown"), None);
    }

    #[test]
    fn test_label_node_both() {
        let label_set = LabelSet {
            sources: vec![SourceDef {
                id: "src1".into(),
                pattern: "X".into(),
                category: "c".into(),
            }],
            sinks: vec![SinkDef {
                id: "snk1".into(),
                pattern: "X".into(),
                category: "c".into(),
            }],
            sanitizers: vec![],
        };
        assert_eq!(label_node(&label_set, "X"), Some(TaintLabel::Both(0, 0)));
    }

    #[test]
    fn test_label_node_does_not_match_substrings() {
        let label_set = LabelSet {
            sources: vec![],
            sinks: vec![
                SinkDef {
                    id: "fetch".into(),
                    pattern: "fetch".into(),
                    category: "network".into(),
                },
                SinkDef {
                    id: "eval".into(),
                    pattern: "eval".into(),
                    category: "exec".into(),
                },
            ],
            sanitizers: vec![],
        };

        assert_eq!(label_node(&label_set, "prefetch"), None);
        assert_eq!(label_node(&label_set, "medieval"), None);
        // Verify "eval" does not match legitimate words containing it as a substring.
        assert_eq!(label_node(&label_set, "evaluate"), None);
        assert_eq!(label_node(&label_set, "evaluation"), None);
        // But the exact word must still match.
        assert_eq!(label_node(&label_set, "eval"), Some(TaintLabel::Sink(1)));
    }

    #[test]
    fn test_url_and_path_patterns_match_inside_strings() {
        let label_set = LabelSet {
            sources: vec![
                SourceDef { id: "https".into(), pattern: "https://".into(), category: "net".into() },
                SourceDef { id: "http".into(), pattern: "http://".into(), category: "net".into() },
                SourceDef { id: "ssh".into(), pattern: ".ssh/id_rsa".into(), category: "file".into() },
                SourceDef { id: "sh".into(), pattern: "/bin/sh".into(), category: "shell".into() },
                SourceDef { id: "curl".into(), pattern: "curl".into(), category: "shell".into() },
            ],
            sinks: vec![],
            sanitizers: vec![],
        };

        // URL patterns must match inside quoted strings (tree-sitter includes quotes)
        assert!(label_node(&label_set, "\"https://evil.com\"").is_some());
        assert!(label_node(&label_set, "\"http://evil.com/exfil\"").is_some());
        assert!(label_node(&label_set, "https://evil.com").is_some());

        // Path patterns must match inside strings
        assert!(label_node(&label_set, "\"/home/user/.ssh/id_rsa\"").is_some());
        assert!(label_node(&label_set, "/bin/sh").is_some());

        // Word-boundary patterns must NOT match substrings
        assert!(label_node(&label_set, "curling").is_none());
        assert!(label_node(&label_set, "procurl").is_none());
        assert!(label_node(&label_set, "run_curl").is_none()); // _ is a word char
        // But must match at word boundaries
        assert!(label_node(&label_set, "curl").is_some());
        assert!(label_node(&label_set, "run.curl").is_some()); // . is not a word char
    }

    #[test]
    fn test_empty_dir_returns_empty_label_set() {
        let dir = tempfile::tempdir().unwrap();
        let label_set = load_labels(dir.path()).unwrap();
        assert!(label_set.sources.is_empty());
        assert!(label_set.sinks.is_empty());
        assert!(label_set.sanitizers.is_empty());
    }

    #[test]
    fn test_missing_dir_returns_empty_label_set() {
        let label_set = load_labels(Path::new("/does/not/exist")).unwrap();
        assert!(label_set.sources.is_empty());
        assert!(label_set.sinks.is_empty());
        assert!(label_set.sanitizers.is_empty());
    }
}