pyrograph 0.1.0

GPU-accelerated taint analysis for supply chain malware detection
Documentation
pub mod shader;
use crate::cpu::{is_dangerous_combination, is_dataflow_edge};
use crate::ir::{TaintGraph, NodeId};
use crate::labels::TaintLabel;
use crate::lib_types::{TaintFinding, Severity};
use crate::error::{Error, Result};
use crate::gpu::shader::SHADER_SRC;
use bytemuck::{Pod, Zeroable};
use std::sync::Arc;

/// Maximum taint findings the GPU kernel can report per dispatch.
/// Sized for worst-case: a highly-connected graph where every source
/// reaches every sink. At internet scale this caps GPU memory usage
/// to ~16 KiB per dispatch (1024 × 16 bytes per GpuFinding).
const DEFAULT_MAX_FINDINGS: u32 = 1024;

/// Maximum BFS depth for taint propagation on GPU. Limits kernel
/// runtime on pathological graphs (e.g., deeply nested callback chains).
/// 64 covers realistic code; obfuscated malware rarely exceeds 20.
const DEFAULT_MAX_DEPTH: u32 = 64;

#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
struct GpuFinding {
    start_node: u32,
    sink_node: u32,
    depth: u32,
    source_idx: u32,
}

#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
struct GpuParams {
    num_sources: u32,
    num_nodes: u32,
    max_findings: u32,
    max_depth: u32,
    words_per_source: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
}

pub struct GpuContext {
    device: Arc<wgpu::Device>,
    queue: Arc<wgpu::Queue>,
    pipeline: wgpu::ComputePipeline,
}

impl GpuContext {
    pub async fn new() -> Result<Self> {
        let instance = wgpu::Instance::default();
        let adapter = instance
            .request_adapter(&wgpu::RequestAdapterOptions::default())
            .await
            .ok_or_else(|| Error::Gpu("Failed to find wgpu adapter".into()))?;

        let (device, queue) = adapter
            .request_device(&wgpu::DeviceDescriptor::default(), None)
            .await?;

        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
            label: Some("Taint Shader"),
            source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
        });

        let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
            label: Some("Taint Pipeline"),
            layout: None,
            module: &shader,
            entry_point: "main",
        });

        Ok(Self {
            device: Arc::new(device),
            queue: Arc::new(queue),
            pipeline,
        })
    }

    pub async fn analyze_gpu(&self, graph: &TaintGraph) -> Result<Vec<TaintFinding>> {
        let (node_buf_data, edge_offset_data, edge_target_data) = graph.to_gpu_buffers();
        let source_indices: Vec<u32> = graph.sources().map(|n| n.id).collect();
        
        if source_indices.is_empty() {
            return Ok(Vec::new());
        }

        let max_findings = DEFAULT_MAX_FINDINGS;
        let num_nodes = graph.node_count() as u32;
        let words_per_source = num_nodes.div_ceil(32);
        let params = GpuParams {
            num_sources: source_indices.len() as u32,
            num_nodes,
            max_findings,
            max_depth: DEFAULT_MAX_DEPTH,
            words_per_source,
            _pad0: 0,
            _pad1: 0,
            _pad2: 0,
        };

        let node_buffer = self.create_buffer(&node_buf_data, wgpu::BufferUsages::STORAGE);
        let offset_buffer = self.create_buffer(&edge_offset_data, wgpu::BufferUsages::STORAGE);
        let target_buffer = self.create_buffer(&edge_target_data, wgpu::BufferUsages::STORAGE);
        let source_buffer = self.create_buffer(&source_indices, wgpu::BufferUsages::STORAGE);
        
        let findings_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
            label: Some("Findings Buffer"),
            size: (max_findings as usize * std::mem::size_of::<GpuFinding>()) as u64,
            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
            mapped_at_creation: false,
        });

        let findings_count_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
            label: Some("Findings Count Buffer"),
            size: 4,
            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
            mapped_at_creation: false,
        });

        let params_buffer = self.create_buffer(&[params], wgpu::BufferUsages::UNIFORM);

        // words_per_source can be 0 when num_nodes is 0; wgpu storage buffers must not be 0-sized.
        let visited_size = ((source_indices.len() as u32 * words_per_source * 4) as u64).max(16);
        let visited_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
            label: Some("Visited Buffer"),
            size: visited_size,
            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
            mapped_at_creation: false,
        });

        // Zero-initialize visited buffer
        let mut init_encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("Init Visited") });
        init_encoder.clear_buffer(&visited_buffer, 0, Some(visited_size));
        self.queue.submit(Some(init_encoder.finish()));

        let bind_group_layout = self.pipeline.get_bind_group_layout(0);
        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
            label: None,
            layout: &bind_group_layout,
            entries: &[
                wgpu::BindGroupEntry { binding: 0, resource: node_buffer.as_entire_binding() },
                wgpu::BindGroupEntry { binding: 1, resource: offset_buffer.as_entire_binding() },
                wgpu::BindGroupEntry { binding: 2, resource: target_buffer.as_entire_binding() },
                wgpu::BindGroupEntry { binding: 3, resource: source_buffer.as_entire_binding() },
                wgpu::BindGroupEntry { binding: 4, resource: findings_buffer.as_entire_binding() },
                wgpu::BindGroupEntry { binding: 5, resource: findings_count_buffer.as_entire_binding() },
                wgpu::BindGroupEntry { binding: 6, resource: params_buffer.as_entire_binding() },
                wgpu::BindGroupEntry { binding: 7, resource: visited_buffer.as_entire_binding() },
            ],
        });

        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
        {
            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None, timestamp_writes: None });
            cpass.set_pipeline(&self.pipeline);
            cpass.set_bind_group(0, &bind_group, &[]);
            let workgroup_count = (source_indices.len() as u32).div_ceil(64);
            cpass.dispatch_workgroups(workgroup_count, 1, 1);
        }

        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
            label: Some("Staging Buffer"),
            size: findings_buffer.size(),
            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
            mapped_at_creation: false,
        });

        let staging_count_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
            label: Some("Staging Count Buffer"),
            size: 4,
            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
            mapped_at_creation: false,
        });

        encoder.copy_buffer_to_buffer(&findings_buffer, 0, &staging_buffer, 0, findings_buffer.size());
        encoder.copy_buffer_to_buffer(&findings_count_buffer, 0, &staging_count_buffer, 0, 4);

        self.queue.submit(Some(encoder.finish()));

        // Read back
        let (tx, rx) = std::sync::mpsc::channel();
        let staging_slice = staging_buffer.slice(..);
        staging_slice.map_async(wgpu::MapMode::Read, move |v| { let _ = tx.send(v); });

        let (tx_c, rx_c) = std::sync::mpsc::channel();
        let staging_count_slice = staging_count_buffer.slice(..);
        staging_count_slice.map_async(wgpu::MapMode::Read, move |v| { let _ = tx_c.send(v); });

        self.device.poll(wgpu::Maintain::Wait);

        rx.recv().map_err(|_| Error::Gpu("GPU buffer map channel closed".into()))?
            .map_err(|e| Error::Gpu(format!("Map failed: {e:?}")))?;
        rx_c.recv().map_err(|_| Error::Gpu("GPU count buffer map channel closed".into()))?
            .map_err(|e| Error::Gpu(format!("Count map failed: {e:?}")))?;

        let count_data = staging_count_slice.get_mapped_range();
        let count = u32::from_ne_bytes(count_data[..4].try_into().map_err(|_| Error::Gpu("count bytes conversion failed".into()))?);
        let count = count.min(max_findings);

        let data = staging_slice.get_mapped_range();
        let gpu_findings: &[GpuFinding] = bytemuck::cast_slice(&data);
        
        let mut results = Vec::new();
        for &f in gpu_findings.iter().take(count as usize) {
            
            let source_node = graph.node(f.start_node).ok_or_else(|| Error::Analysis("Invalid source ID".into()))?;
            let sink_node = graph.node(f.sink_node).ok_or_else(|| Error::Analysis("Invalid sink ID".into()))?;
            
            let source_enum = match source_node.label {
                Some(TaintLabel::Source(s)) => s,
                Some(TaintLabel::Both(s, _)) => s,
                Some(TaintLabel::Sanitizer(_)) => {
                    return Err(Error::Analysis("Source node unexpectedly labeled sanitizer".into()))
                }
                _ => return Err(Error::Analysis("Source node has no source label".into())),
            };

            let sink_enum = match sink_node.label {
                Some(TaintLabel::Sink(s)) => s,
                Some(TaintLabel::Both(_, s)) => s,
                Some(TaintLabel::Sanitizer(_)) => {
                    return Err(Error::Analysis("Sink node unexpectedly labeled sanitizer".into()))
                }
                _ => return Err(Error::Analysis("Sink node has no sink label".into())),
            };

            // Taint coloring: filter out benign source→sink combinations.
            let source_category = graph
                .label_set()
                .and_then(|labels| labels.sources.get(source_enum))
                .map(|s| s.category.as_str())
                .unwrap_or("unknown");
            let sink_category = graph
                .label_set()
                .and_then(|labels| labels.sinks.get(sink_enum))
                .map(|s| s.category.as_str())
                .unwrap_or("unknown");

            let has_label_set = graph.label_set().is_some();
            let dangerous = !has_label_set
                || is_dangerous_combination(source_category, sink_category)
                || matches!(sink_category, "exec" | "sql" | "network" | "xss" | "file");
            if !dangerous {
                continue;
            }

            let severity = if !has_label_set || is_dangerous_combination(source_category, sink_category) {
                severity_for_sink(graph, sink_enum)
            } else {
                Severity::Medium
            };

            let path = self.reconstruct_path_cpu(graph, f.start_node, f.sink_node);

            results.push(TaintFinding {
                source: source_enum,
                sink: sink_enum,
                path,
                severity,
            });
        }

        Ok(results)
    }

    /// Builds an init buffer; empty slices use 16 zero bytes so wgpu never sees 0-byte STORAGE init
    /// (e.g. CSR `edge_target` with no edges).
    fn create_buffer<T: bytemuck::Pod>(&self, data: &[T], usage: wgpu::BufferUsages) -> wgpu::Buffer {
        use wgpu::util::DeviceExt;
        let contents: &[u8] = bytemuck::cast_slice(data);
        self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: None,
            contents: if contents.is_empty() { &[0u8; 16] } else { contents },
            usage,
        })
    }

    fn reconstruct_path_cpu(&self, graph: &TaintGraph, start: NodeId, end: NodeId) -> Vec<NodeId> {
        use std::collections::{VecDeque, HashMap};
        let mut queue = VecDeque::new();
        queue.push_back(start);
        let mut parent = HashMap::new();

        while let Some(current) = queue.pop_front() {
            if current == end {
                let mut path = vec![end];
                let mut node = end;
                while node != start {
                    if let Some(&p) = parent.get(&node) {
                        node = p;
                        path.push(node);
                    } else {
                        break;
                    }
                }
                path.reverse();
                return path;
            }

            // Do not expand through sanitizers or non-dataflow edges.
            if graph
                .node(current)
                .and_then(|n| n.label)
                .is_some_and(|l| l.is_sanitizer())
            {
                continue;
            }

            for (neighbor, edge_kind) in graph.edges_from(current) {
                if !is_dataflow_edge(&edge_kind) {
                    continue;
                }
                if neighbor != start && !parent.contains_key(&neighbor) {
                    parent.insert(neighbor, current);
                    queue.push_back(neighbor);
                }
            }
        }
        vec![start, end]
    }
}

fn severity_for_sink(graph: &TaintGraph, sink_idx: usize) -> Severity {
    let category = graph
        .label_set()
        .and_then(|labels| labels.sinks.get(sink_idx))
        .map(|sink| sink.category.as_str())
        .unwrap_or("other");

    match category {
        "exec" => Severity::Critical,
        "sql" => Severity::Critical,
        "network" => Severity::High,
        "xss" => Severity::High,
        "file" => Severity::Medium,
        _ => Severity::Low,
    }
}