pyrograph 0.1.0

GPU-accelerated taint analysis for supply chain malware detection
Documentation
use std::collections::{HashMap, HashSet};

use crate::ir::{EdgeKind, NodeId, NodeKind};

use super::visitor::PythonParser;

pub(super) struct Scope {
    bindings: HashMap<String, NodeId>,
    globals: HashSet<String>,
    nonlocals: HashSet<String>,
}

impl Scope {
    pub(super) fn new() -> Self {
        Self {
            bindings: HashMap::new(),
            globals: HashSet::new(),
            nonlocals: HashSet::new(),
        }
    }

    pub(super) fn define(&mut self, name: String, id: NodeId) {
        self.bindings.insert(name, id);
    }

    pub(super) fn resolve(&self, name: &str) -> Option<NodeId> {
        self.bindings.get(name).copied()
    }

    pub(super) fn add_global(&mut self, name: String) {
        self.globals.insert(name);
    }

    pub(super) fn add_nonlocal(&mut self, name: String) {
        self.nonlocals.insert(name);
    }

    pub(super) fn is_global(&self, name: &str) -> bool {
        self.globals.contains(name)
    }

    pub(super) fn is_nonlocal(&self, name: &str) -> bool {
        self.nonlocals.contains(name)
    }

    pub(super) fn bindings(&self) -> &HashMap<String, NodeId> {
        &self.bindings
    }
}

impl PythonParser {
    pub(super) fn member_node(&mut self, base: NodeId, full: &str) -> NodeId {
        if let Some(id) = self.resolve(full) {
            self.flow(base, id);
            return id;
        }
        let id = self.graph.add_node(NodeKind::Variable, full.to_string(), None);
        if let Some(graph_node) = self.graph.node_mut(id) {
            graph_node.alias = Some(full.to_string());
        }
        self.cur().define(full.to_string(), id);
        self.flow(base, id);
        id
    }

    pub(super) fn var(&mut self, name: &str) -> NodeId {
        self.resolve(name).unwrap_or_else(|| {
            let id = self.graph.add_node(NodeKind::Variable, name.to_string(), None);
            self.cur().define(name.to_string(), id);
            id
        })
    }

    pub(super) fn literal(&mut self, name: &str) -> NodeId {
        self.graph.add_node(NodeKind::Literal, name.to_string(), None)
    }

    pub(super) fn cur(&mut self) -> &mut Scope {
        if self.scopes.is_empty() {
            self.scopes.push(Scope::new());
        }
        #[allow(clippy::expect_used)]
        self.scopes.last_mut().expect("scope stack is initialized")
    }

    pub(super) fn resolve(&self, name: &str) -> Option<NodeId> {
        if let Some(cur) = self.scopes.last() {
            if cur.is_global(name) {
                return self.scopes.first().and_then(|s| s.resolve(name));
            }
            if cur.is_nonlocal(name) {
                return self
                    .scopes
                    .iter()
                    .rev()
                    .skip(1)
                    .find_map(|scope| scope.resolve(name));
            }
        }
        self.scopes.iter().rev().find_map(|scope| scope.resolve(name))
    }

    pub(super) fn flow(&mut self, from: NodeId, to: NodeId) {
        self.graph.add_edge(from, to, EdgeKind::Assignment);
    }

    pub(super) fn canonical_name(&self, id: NodeId) -> String {
        self.graph
            .node(id)
            .and_then(|node| node.alias.clone())
            .unwrap_or_else(|| self.graph.node(id).map(|node| node.name.clone()).unwrap_or_default())
    }

    pub(super) fn alias_from(&self, id: NodeId) -> Option<String> {
        self.graph.node(id).and_then(|node| {
            node.alias.clone().or_else(|| match node.kind {
                NodeKind::Variable | NodeKind::Call | NodeKind::Import => Some(node.name.clone()),
                NodeKind::Literal => None,
            })
        })
    }
}