#![cfg_attr(coverage_nightly, coverage(off))]
use crate::services::agent_context::AgentContextIndex;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PtxRole {
Emitter,
Loader,
Analyzer,
Consumer,
}
impl std::fmt::Display for PtxRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PtxRole::Emitter => write!(f, "emitter"),
PtxRole::Loader => write!(f, "loader"),
PtxRole::Analyzer => write!(f, "analyzer"),
PtxRole::Consumer => write!(f, "consumer"),
}
}
}
#[derive(Debug, Clone)]
pub struct PtxFlowNode {
pub project: String,
pub function_name: String,
pub file_path: String,
pub role: PtxRole,
pub func_idx: usize,
}
#[derive(Debug, Clone)]
pub struct PtxFlowEdge {
pub from_idx: usize,
pub to_idx: usize,
}
pub struct PtxFlowResult {
pub nodes: Vec<PtxFlowNode>,
pub edges: Vec<PtxFlowEdge>,
}
const EMITTER_KEYWORDS: &[&str] = &[
"asm!(",
"asm volatile",
"global_asm!",
".version ",
".target sm_",
"__global__",
"__device__",
"emit_ptx",
"ptx_builder",
];
const LOADER_KEYWORDS: &[&str] = &[
"cuModuleLoad",
"cuModuleLoadData",
"create_shader_module",
"load_ptx",
"ptx_module",
"load_module",
"compile_ptx",
];
const ANALYZER_KEYWORDS: &[&str] = &[
"barrier_count",
"register_pressure",
"shared_memory_size",
"ptx_analysis",
"detect_ptx_barrier",
"detect_shared_memory",
"ptx_diagnostic",
];
fn source_matches_any(source: &str, keywords: &[&str]) -> bool {
keywords.iter().any(|kw| source.contains(kw))
}
pub fn classify_ptx_role(source: &str, file_path: &str) -> Option<PtxRole> {
let is_ptx_file =
file_path.ends_with(".ptx") || file_path.ends_with(".cu") || file_path.ends_with(".cuh");
if is_ptx_file || source_matches_any(source, EMITTER_KEYWORDS) {
return Some(PtxRole::Emitter);
}
if source_matches_any(source, LOADER_KEYWORDS) {
return Some(PtxRole::Loader);
}
if source_matches_any(source, ANALYZER_KEYWORDS) {
return Some(PtxRole::Analyzer);
}
None
}
fn make_node(
func: &crate::services::agent_context::FunctionEntry,
idx: usize,
role: PtxRole,
) -> PtxFlowNode {
PtxFlowNode {
project: func
.file_path
.split('/')
.next()
.unwrap_or("local")
.to_string(),
function_name: func.function_name.clone(),
file_path: func.file_path.clone(),
role,
func_idx: idx,
}
}
fn classify_ptx_nodes(index: &AgentContextIndex) -> (Vec<PtxFlowNode>, Vec<usize>) {
let mut nodes = Vec::new();
let mut node_func_idx = Vec::new();
for (i, func) in index.all_functions().iter().enumerate() {
if let Some(role) = classify_ptx_role(&func.source, &func.file_path) {
nodes.push(make_node(func, i, role));
node_func_idx.push(i);
}
}
(nodes, node_func_idx)
}
fn find_consumer_nodes(
index: &AgentContextIndex,
nodes: &mut Vec<PtxFlowNode>,
node_func_idx: &[usize],
) {
let ptx_indices: std::collections::HashSet<usize> = node_func_idx.iter().copied().collect();
let mut seen = std::collections::HashSet::new();
for &ptx_idx in node_func_idx {
let Some(caller_indices) = index.called_by_indices(ptx_idx) else {
continue;
};
for &caller_idx in caller_indices {
if !ptx_indices.contains(&caller_idx) && seen.insert(caller_idx) {
nodes.push(make_node(
&index.all_functions()[caller_idx],
caller_idx,
PtxRole::Consumer,
));
}
}
}
}
fn build_flow_edges(index: &AgentContextIndex, nodes: &[PtxFlowNode]) -> Vec<PtxFlowEdge> {
let func_idx_to_node: std::collections::HashMap<usize, usize> = nodes
.iter()
.enumerate()
.map(|(node_idx, node)| (node.func_idx, node_idx))
.collect();
let mut edges = Vec::new();
for (from_node_idx, from_node) in nodes.iter().enumerate() {
let Some(callee_indices) = index.calls_indices(from_node.func_idx) else {
continue;
};
for &callee_idx in callee_indices {
if let Some(&to_node_idx) = func_idx_to_node.get(&callee_idx) {
if from_node_idx != to_node_idx {
edges.push(PtxFlowEdge {
from_idx: from_node_idx,
to_idx: to_node_idx,
});
}
}
}
}
edges
}
pub fn trace_ptx_dataflow(index: &AgentContextIndex) -> PtxFlowResult {
let (mut nodes, node_func_idx) = classify_ptx_nodes(index);
find_consumer_nodes(index, &mut nodes, &node_func_idx);
let edges = build_flow_edges(index, &nodes);
PtxFlowResult { nodes, edges }
}
pub fn format_ptx_flow_text(result: &PtxFlowResult) -> String {
let mut out = String::new();
out.push_str(&format!(
"\x1b[1;4mPTX Dataflow\x1b[0m ({} nodes, {} edges)\n\n",
result.nodes.len(),
result.edges.len()
));
if result.nodes.is_empty() {
out.push_str(" No PTX-related functions found in workspace.\n");
return out;
}
for role in &[
PtxRole::Emitter,
PtxRole::Loader,
PtxRole::Analyzer,
PtxRole::Consumer,
] {
let role_nodes: Vec<_> = result.nodes.iter().filter(|n| &n.role == role).collect();
if role_nodes.is_empty() {
continue;
}
let role_color = match role {
PtxRole::Emitter => "\x1b[1;31m",
PtxRole::Loader => "\x1b[1;33m",
PtxRole::Analyzer => "\x1b[1;36m",
PtxRole::Consumer => "\x1b[1;32m",
};
out.push_str(&format!(
" {role_color}{role}\x1b[0m ({}):\n",
role_nodes.len()
));
for node in &role_nodes {
out.push_str(&format!(
" [{project}] {name} \x1b[2m{path}\x1b[0m\n",
project = node.project,
name = node.function_name,
path = node.file_path
));
}
out.push('\n');
}
if !result.edges.is_empty() {
out.push_str(" \x1b[1mDataflow chains:\x1b[0m\n");
for edge in &result.edges {
let from = &result.nodes[edge.from_idx];
let to = &result.nodes[edge.to_idx];
out.push_str(&format!(" [{src}] {src_fn} \x1b[2m({src_role})\x1b[0m → [{dst}] {dst_fn} \x1b[2m({dst_role})\x1b[0m\n",
src = from.project, src_fn = from.function_name, src_role = from.role,
dst = to.project, dst_fn = to.function_name, dst_role = to.role));
}
out.push('\n');
}
out
}
pub fn format_ptx_flow_json(result: &PtxFlowResult) -> String {
let nodes: Vec<serde_json::Value> = result
.nodes
.iter()
.map(|n| {
serde_json::json!({
"project": n.project,
"function_name": n.function_name,
"file_path": n.file_path,
"role": n.role.to_string(),
})
})
.collect();
let edges: Vec<serde_json::Value> = result
.edges
.iter()
.map(|e| {
serde_json::json!({
"from": result.nodes[e.from_idx].function_name,
"to": result.nodes[e.to_idx].function_name,
"from_project": result.nodes[e.from_idx].project,
"to_project": result.nodes[e.to_idx].project,
})
})
.collect();
serde_json::json!({
"ptx_flow": { "nodes": nodes, "edges": edges }
})
.to_string()
}