#[derive(Debug, Clone, Copy)]
pub struct Recipe {
pub id: &'static str,
pub title: &'static str,
pub body: &'static str,
pub keywords: &'static [&'static str],
pub related_verbs: &'static [&'static str],
pub trace_shape: &'static [&'static str],
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TraceShape {
pub has_kernels: bool,
pub has_memcpy: bool,
pub has_nvtx: bool,
pub has_target_info: bool,
pub multi_device: bool,
pub multi_process: bool,
pub has_graph_trace: bool,
pub has_graph_nodes: bool,
}
impl Recipe {
pub fn matches_trace_shape(&self, shape: &TraceShape) -> bool {
self.trace_shape
.iter()
.all(|group| group.split('|').any(|p| eval_predicate(p, shape)))
}
}
fn eval_predicate(name: &str, shape: &TraceShape) -> bool {
match name {
"has_kernels" => shape.has_kernels,
"has_memcpy" => shape.has_memcpy,
"has_nvtx" => shape.has_nvtx,
"has_target_info" => shape.has_target_info,
"multi_device" => shape.multi_device,
"multi_process" => shape.multi_process,
"has_graph_trace" => shape.has_graph_trace,
"has_graph_nodes" => shape.has_graph_nodes,
_ => false,
}
}
include!(concat!(env!("OUT_DIR"), "/recipes_generated.rs"));
pub fn recipes_for_verb(verb: &str) -> impl Iterator<Item = &'static Recipe> {
RECIPES
.iter()
.filter(move |r| r.related_verbs.contains(&verb))
}
pub fn recipes_for_trace_shape(shape: &TraceShape) -> impl Iterator<Item = &'static Recipe> + '_ {
RECIPES.iter().filter(move |r| r.matches_trace_shape(shape))
}
pub fn recipe_by_id(id: &str) -> Option<&'static Recipe> {
RECIPES.iter().find(|r| r.id == id)
}
pub fn all_recipes() -> &'static [Recipe] {
RECIPES
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn registry_is_populated() {
assert!(
all_recipes().len() >= 8,
"expected at least 8 recipes; got {}",
all_recipes().len(),
);
}
#[test]
fn recipe_by_id_returns_none_for_unknown_slug() {
assert!(recipe_by_id("nope-not-a-recipe").is_none());
}
#[test]
fn recipes_for_verb_filters_correctly() {
assert!(recipes_for_verb("stats").count() > 0);
assert_eq!(recipes_for_verb("does-not-exist").count(), 0);
}
#[test]
fn recipes_for_trace_shape_filters_correctly() {
let empty = TraceShape::default();
let full = TraceShape {
has_kernels: true,
has_memcpy: true,
has_nvtx: true,
has_target_info: true,
multi_device: true,
multi_process: true,
has_graph_trace: true,
has_graph_nodes: true,
};
let unrestricted = recipes_for_trace_shape(&empty).count();
let all_match = recipes_for_trace_shape(&full).count();
assert!(
all_match >= unrestricted,
"full-shape match count must dominate the unrestricted count",
);
assert_eq!(all_match, all_recipes().len());
}
#[test]
fn empty_trace_shape_matches_anything() {
let r = Recipe {
id: "x",
title: "x",
body: "x",
keywords: &["x"],
related_verbs: &["stats"],
trace_shape: &[],
};
assert!(r.matches_trace_shape(&TraceShape::default()));
assert!(r.matches_trace_shape(&TraceShape {
has_nvtx: true,
..Default::default()
}));
}
type PredicateCase = (&'static [&'static str], fn(&mut TraceShape));
#[test]
fn or_group_matches_when_any_alternant_holds() {
let r = Recipe {
id: "or-x",
title: "x",
body: "x",
keywords: &["x"],
related_verbs: &["graph-replays"],
trace_shape: &["has_graph_trace|has_graph_nodes"],
};
assert!(!r.matches_trace_shape(&TraceShape::default()));
assert!(r.matches_trace_shape(&TraceShape {
has_graph_trace: true,
..Default::default()
}));
assert!(r.matches_trace_shape(&TraceShape {
has_graph_nodes: true,
..Default::default()
}));
}
#[test]
fn and_of_or_groups_combines_correctly() {
let r = Recipe {
id: "and-or",
title: "x",
body: "x",
keywords: &["x"],
related_verbs: &["graph-replays"],
trace_shape: &["has_graph_trace|has_graph_nodes", "has_nvtx"],
};
assert!(!r.matches_trace_shape(&TraceShape::default()));
assert!(!r.matches_trace_shape(&TraceShape {
has_graph_trace: true,
..Default::default()
}));
assert!(r.matches_trace_shape(&TraceShape {
has_graph_nodes: true,
has_nvtx: true,
..Default::default()
}));
}
#[test]
fn graph_replay_recipes_match_their_capture_modes() {
let ids_for = |shape: &TraceShape| {
recipes_for_trace_shape(shape)
.map(|r| r.id)
.collect::<Vec<_>>()
};
let only_trace_mode = TraceShape {
has_graph_trace: true,
..Default::default()
};
let only_node_mode = TraceShape {
has_graph_nodes: true,
..Default::default()
};
let trace_ids = ids_for(&only_trace_mode);
assert!(
trace_ids.contains(&"graph-replay-survey"),
"graph-replay-survey should match has_graph_trace; got {trace_ids:?}",
);
assert!(
!trace_ids.contains(&"graph-replay-hotspots"),
"graph-replay-hotspots must NOT match has_graph_trace alone; got {trace_ids:?}",
);
let node_ids = ids_for(&only_node_mode);
assert!(
node_ids.contains(&"graph-replay-survey"),
"graph-replay-survey should match has_graph_nodes; got {node_ids:?}",
);
assert!(
node_ids.contains(&"graph-replay-hotspots"),
"graph-replay-hotspots should match has_graph_nodes; got {node_ids:?}",
);
let none_ids = ids_for(&TraceShape::default());
assert!(!none_ids.contains(&"graph-replay-survey"));
assert!(!none_ids.contains(&"graph-replay-hotspots"));
}
#[test]
fn all_predicates_evaluate() {
const ALL: &[PredicateCase] = &[
(&["has_kernels"], |s| s.has_kernels = true),
(&["has_memcpy"], |s| s.has_memcpy = true),
(&["has_nvtx"], |s| s.has_nvtx = true),
(&["has_target_info"], |s| s.has_target_info = true),
(&["multi_device"], |s| s.multi_device = true),
(&["multi_process"], |s| s.multi_process = true),
(&["has_graph_trace"], |s| s.has_graph_trace = true),
(&["has_graph_nodes"], |s| s.has_graph_nodes = true),
];
for (trace_shape, set) in ALL {
let mut shape = TraceShape::default();
set(&mut shape);
let r = Recipe {
id: "x",
title: "x",
body: "x",
keywords: &["x"],
related_verbs: &["stats"],
trace_shape,
};
assert!(
r.matches_trace_shape(&shape),
"predicate `{trace_shape:?}` failed against {shape:?}",
);
}
}
}