pub mod shader;
use crate::cpu::{is_dangerous_combination, is_dataflow_edge};
use crate::ir::{TaintGraph, NodeId};
use crate::labels::TaintLabel;
use crate::lib_types::{TaintFinding, Severity};
use crate::error::{Error, Result};
use crate::gpu::shader::SHADER_SRC;
use bytemuck::{Pod, Zeroable};
use std::sync::Arc;
const DEFAULT_MAX_FINDINGS: u32 = 1024;
const DEFAULT_MAX_DEPTH: u32 = 64;
#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
struct GpuFinding {
start_node: u32,
sink_node: u32,
depth: u32,
source_idx: u32,
}
#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
struct GpuParams {
num_sources: u32,
num_nodes: u32,
max_findings: u32,
max_depth: u32,
words_per_source: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
pub struct GpuContext {
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
pipeline: wgpu::ComputePipeline,
}
impl GpuContext {
pub async fn new() -> Result<Self> {
let instance = wgpu::Instance::default();
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions::default())
.await
.ok_or_else(|| Error::Gpu("Failed to find wgpu adapter".into()))?;
let (device, queue) = adapter
.request_device(&wgpu::DeviceDescriptor::default(), None)
.await?;
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Taint Shader"),
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Taint Pipeline"),
layout: None,
module: &shader,
entry_point: "main",
});
Ok(Self {
device: Arc::new(device),
queue: Arc::new(queue),
pipeline,
})
}
pub async fn analyze_gpu(&self, graph: &TaintGraph) -> Result<Vec<TaintFinding>> {
let (node_buf_data, edge_offset_data, edge_target_data) = graph.to_gpu_buffers();
let source_indices: Vec<u32> = graph.sources().map(|n| n.id).collect();
if source_indices.is_empty() {
return Ok(Vec::new());
}
let max_findings = DEFAULT_MAX_FINDINGS;
let num_nodes = graph.node_count() as u32;
let words_per_source = num_nodes.div_ceil(32);
let params = GpuParams {
num_sources: source_indices.len() as u32,
num_nodes,
max_findings,
max_depth: DEFAULT_MAX_DEPTH,
words_per_source,
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
let node_buffer = self.create_buffer(&node_buf_data, wgpu::BufferUsages::STORAGE);
let offset_buffer = self.create_buffer(&edge_offset_data, wgpu::BufferUsages::STORAGE);
let target_buffer = self.create_buffer(&edge_target_data, wgpu::BufferUsages::STORAGE);
let source_buffer = self.create_buffer(&source_indices, wgpu::BufferUsages::STORAGE);
let findings_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Findings Buffer"),
size: (max_findings as usize * std::mem::size_of::<GpuFinding>()) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let findings_count_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Findings Count Buffer"),
size: 4,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let params_buffer = self.create_buffer(&[params], wgpu::BufferUsages::UNIFORM);
let visited_size = ((source_indices.len() as u32 * words_per_source * 4) as u64).max(16);
let visited_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Visited Buffer"),
size: visited_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut init_encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("Init Visited") });
init_encoder.clear_buffer(&visited_buffer, 0, Some(visited_size));
self.queue.submit(Some(init_encoder.finish()));
let bind_group_layout = self.pipeline.get_bind_group_layout(0);
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: node_buffer.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: offset_buffer.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: target_buffer.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: source_buffer.as_entire_binding() },
wgpu::BindGroupEntry { binding: 4, resource: findings_buffer.as_entire_binding() },
wgpu::BindGroupEntry { binding: 5, resource: findings_count_buffer.as_entire_binding() },
wgpu::BindGroupEntry { binding: 6, resource: params_buffer.as_entire_binding() },
wgpu::BindGroupEntry { binding: 7, resource: visited_buffer.as_entire_binding() },
],
});
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None, timestamp_writes: None });
cpass.set_pipeline(&self.pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
let workgroup_count = (source_indices.len() as u32).div_ceil(64);
cpass.dispatch_workgroups(workgroup_count, 1, 1);
}
let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging Buffer"),
size: findings_buffer.size(),
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let staging_count_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging Count Buffer"),
size: 4,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
encoder.copy_buffer_to_buffer(&findings_buffer, 0, &staging_buffer, 0, findings_buffer.size());
encoder.copy_buffer_to_buffer(&findings_count_buffer, 0, &staging_count_buffer, 0, 4);
self.queue.submit(Some(encoder.finish()));
let (tx, rx) = std::sync::mpsc::channel();
let staging_slice = staging_buffer.slice(..);
staging_slice.map_async(wgpu::MapMode::Read, move |v| { let _ = tx.send(v); });
let (tx_c, rx_c) = std::sync::mpsc::channel();
let staging_count_slice = staging_count_buffer.slice(..);
staging_count_slice.map_async(wgpu::MapMode::Read, move |v| { let _ = tx_c.send(v); });
self.device.poll(wgpu::Maintain::Wait);
rx.recv().map_err(|_| Error::Gpu("GPU buffer map channel closed".into()))?
.map_err(|e| Error::Gpu(format!("Map failed: {e:?}")))?;
rx_c.recv().map_err(|_| Error::Gpu("GPU count buffer map channel closed".into()))?
.map_err(|e| Error::Gpu(format!("Count map failed: {e:?}")))?;
let count_data = staging_count_slice.get_mapped_range();
let count = u32::from_ne_bytes(count_data[..4].try_into().map_err(|_| Error::Gpu("count bytes conversion failed".into()))?);
let count = count.min(max_findings);
let data = staging_slice.get_mapped_range();
let gpu_findings: &[GpuFinding] = bytemuck::cast_slice(&data);
let mut results = Vec::new();
for &f in gpu_findings.iter().take(count as usize) {
let source_node = graph.node(f.start_node).ok_or_else(|| Error::Analysis("Invalid source ID".into()))?;
let sink_node = graph.node(f.sink_node).ok_or_else(|| Error::Analysis("Invalid sink ID".into()))?;
let source_enum = match source_node.label {
Some(TaintLabel::Source(s)) => s,
Some(TaintLabel::Both(s, _)) => s,
Some(TaintLabel::Sanitizer(_)) => {
return Err(Error::Analysis("Source node unexpectedly labeled sanitizer".into()))
}
_ => return Err(Error::Analysis("Source node has no source label".into())),
};
let sink_enum = match sink_node.label {
Some(TaintLabel::Sink(s)) => s,
Some(TaintLabel::Both(_, s)) => s,
Some(TaintLabel::Sanitizer(_)) => {
return Err(Error::Analysis("Sink node unexpectedly labeled sanitizer".into()))
}
_ => return Err(Error::Analysis("Sink node has no sink label".into())),
};
let source_category = graph
.label_set()
.and_then(|labels| labels.sources.get(source_enum))
.map(|s| s.category.as_str())
.unwrap_or("unknown");
let sink_category = graph
.label_set()
.and_then(|labels| labels.sinks.get(sink_enum))
.map(|s| s.category.as_str())
.unwrap_or("unknown");
let has_label_set = graph.label_set().is_some();
let dangerous = !has_label_set
|| is_dangerous_combination(source_category, sink_category)
|| matches!(sink_category, "exec" | "sql" | "network" | "xss" | "file");
if !dangerous {
continue;
}
let severity = if !has_label_set || is_dangerous_combination(source_category, sink_category) {
severity_for_sink(graph, sink_enum)
} else {
Severity::Medium
};
let path = self.reconstruct_path_cpu(graph, f.start_node, f.sink_node);
results.push(TaintFinding {
source: source_enum,
sink: sink_enum,
path,
severity,
});
}
Ok(results)
}
fn create_buffer<T: bytemuck::Pod>(&self, data: &[T], usage: wgpu::BufferUsages) -> wgpu::Buffer {
use wgpu::util::DeviceExt;
let contents: &[u8] = bytemuck::cast_slice(data);
self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: if contents.is_empty() { &[0u8; 16] } else { contents },
usage,
})
}
fn reconstruct_path_cpu(&self, graph: &TaintGraph, start: NodeId, end: NodeId) -> Vec<NodeId> {
use std::collections::{VecDeque, HashMap};
let mut queue = VecDeque::new();
queue.push_back(start);
let mut parent = HashMap::new();
while let Some(current) = queue.pop_front() {
if current == end {
let mut path = vec![end];
let mut node = end;
while node != start {
if let Some(&p) = parent.get(&node) {
node = p;
path.push(node);
} else {
break;
}
}
path.reverse();
return path;
}
if graph
.node(current)
.and_then(|n| n.label)
.is_some_and(|l| l.is_sanitizer())
{
continue;
}
for (neighbor, edge_kind) in graph.edges_from(current) {
if !is_dataflow_edge(&edge_kind) {
continue;
}
if neighbor != start && !parent.contains_key(&neighbor) {
parent.insert(neighbor, current);
queue.push_back(neighbor);
}
}
}
vec![start, end]
}
}
fn severity_for_sink(graph: &TaintGraph, sink_idx: usize) -> Severity {
let category = graph
.label_set()
.and_then(|labels| labels.sinks.get(sink_idx))
.map(|sink| sink.category.as_str())
.unwrap_or("other");
match category {
"exec" => Severity::Critical,
"sql" => Severity::Critical,
"network" => Severity::High,
"xss" => Severity::High,
"file" => Severity::Medium,
_ => Severity::Low,
}
}