use std::cell::RefCell;
use std::ops::Range;
use crate::ir::{BufferAccess, DataType, MemoryKind, Node, Program};
use crate::optimizer::AdapterCaps;
use crate::program_caps::{self, RequiredCapabilities};
use crate::validate::{validate_with_options, ValidationOptions};
pub mod fusion;
mod policy;
mod strategy;
pub use policy::{PolicyRoute, SchedulingPolicy};
pub use strategy::{
AccuracyStrategy, AutotuneStrategy, DispatchStrategy, FusionStrategy, LayoutStrategy,
ProvenanceStrategy, ReadbackStrategy, StrategyPlan,
};
thread_local! {
static PLAN_WIRE_SCRATCH: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
#[non_exhaustive]
pub enum InnovationTrack {
WholeProgramFusion,
PersistentExecution,
DifferentialAccuracy,
ConformanceGuidedAutotune,
GpuResidentProvenance,
DataLayoutCompiler,
ReadbackMinimization,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TrackDecision {
pub track: InnovationTrack,
pub active: bool,
pub reason: &'static str,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ExecutionPlan {
pub program_fingerprint: [u8; 32],
pub required_capabilities: RequiredCapabilities,
pub fusion: FusionPlan,
pub memory: MemoryPlan,
pub provenance: ProvenancePlan,
pub accuracy: AccuracyPlan,
pub autotune: AutotunePlan,
pub strategy: StrategyPlan,
pub tracks: Vec<TrackDecision>,
}
impl ExecutionPlan {
#[must_use]
pub fn track_active(&self, track: InnovationTrack) -> bool {
self.tracks
.iter()
.any(|decision| decision.track == track && decision.active)
}
}
#[derive(Debug, thiserror::Error)]
pub enum PlanError {
#[error("non-canonical program: {source}")]
NonCanonicalProgram {
source: crate::error::Error,
},
#[error(
"invalid output range for buffer {name}: {start}..{end} exceeds full size {full_size}. Fix: keep output byte ranges ordered and inside the declared buffer size."
)]
InvalidOutputRange {
name: String,
start: usize,
end: usize,
full_size: u64,
},
}
pub fn plan(program: &Program) -> Result<ExecutionPlan, PlanError> {
plan_for_adapter(program, &AdapterCaps::conservative())
}
pub fn plan_for_adapter(
program: &Program,
adapter_caps: &AdapterCaps,
) -> Result<ExecutionPlan, PlanError> {
plan_with_options_for_adapter(program, ValidationOptions::default(), adapter_caps)
}
pub fn plan_with_options(
program: &Program,
options: ValidationOptions<'_>,
) -> Result<ExecutionPlan, PlanError> {
plan_with_options_for_adapter(program, options, &AdapterCaps::conservative())
}
pub fn plan_with_options_for_adapter(
program: &Program,
options: ValidationOptions<'_>,
adapter_caps: &AdapterCaps,
) -> Result<ExecutionPlan, PlanError> {
validate_program_for_plan(program, options)?;
let required_capabilities = program_caps::scan(program);
let fusion = fusion_plan(program);
let memory = memory_plan(program)?;
let program_fingerprint = canonical_program_fingerprint(program)?;
let provenance = provenance_plan(program, &fusion);
let accuracy = accuracy_plan(&required_capabilities, &provenance);
let autotune = autotune_plan(program, &required_capabilities, &fusion, adapter_caps);
let strategy = StrategyPlan::from_parts(&fusion, &memory, &provenance, &accuracy, &autotune);
let tracks = track_decisions(&fusion, &memory, &provenance, &accuracy, &autotune);
Ok(ExecutionPlan {
program_fingerprint,
required_capabilities,
fusion,
memory,
provenance,
accuracy,
autotune,
strategy,
tracks,
})
}
fn validate_program_for_plan(
program: &Program,
options: ValidationOptions<'_>,
) -> Result<(), PlanError> {
if options.backend.is_none()
&& options.backend_capabilities.is_none()
&& program.is_structurally_validated()
{
return Ok(());
}
let report = validate_with_options(program, options);
if report.errors.is_empty() {
return Ok(());
}
let message_len = report
.errors
.iter()
.map(|error| error.message().len())
.sum::<usize>()
+ report.errors.len().saturating_sub(1) * 2;
let mut messages = String::with_capacity(message_len);
for (index, error) in report.errors.iter().enumerate() {
if index != 0 {
messages.push_str("; ");
}
messages.push_str(error.message());
}
Err(PlanError::NonCanonicalProgram {
source: crate::error::Error::WireFormatValidation {
message: format!(
"canonical execution plan validation failed: {messages}. Fix: repair the Program before planning."
),
},
})
}
fn fusion_plan(program: &Program) -> FusionPlan {
let stats = program.stats();
let node_count = count_nodes(program.entry());
FusionPlan {
entry_op_id: program.entry_op_id().map(ToOwned::to_owned),
top_level_regions: stats.top_level_regions as usize,
node_count,
batch_fusion_candidate: !program.is_non_composable_with_self()
&& program.is_top_level_region_wrapped(),
}
}
fn count_nodes(nodes: &[Node]) -> usize {
nodes
.iter()
.map(|node| {
1 + match node {
Node::If {
then, otherwise, ..
} => count_nodes(then) + count_nodes(otherwise),
Node::Loop { body, .. } | Node::Block(body) => count_nodes(body),
Node::Region { body, .. } => count_nodes(body),
_ => 0,
}
})
.sum()
}
fn canonical_program_fingerprint(program: &Program) -> Result<[u8; 32], PlanError> {
PLAN_WIRE_SCRATCH.with(|scratch| {
let mut wire = scratch.borrow_mut();
wire.clear();
program
.to_wire_into(&mut wire)
.map_err(|source| PlanError::NonCanonicalProgram { source })?;
Ok(*blake3::hash(&wire).as_bytes())
})
}
fn memory_plan(program: &Program) -> Result<MemoryPlan, PlanError> {
let mut static_bytes = 0u64;
let mut visible_readback_bytes = 0u64;
let mut avoided_readback_bytes = 0u64;
let mut buffers = Vec::new();
for buffer in program.buffers() {
let count = buffer.count();
let elem_size = buffer.element().size_bytes().unwrap_or(4) as u64;
let size = if count > 0 {
Some(u64::from(count) * elem_size)
} else {
None
};
if let Some(s) = size {
static_bytes += s;
}
let output_range = buffer.output_byte_range();
if buffer.is_output() {
let full_size = size.unwrap_or(0);
if full_size == 0 {
return Err(PlanError::NonCanonicalProgram {
source: crate::error::Error::WireFormatValidation {
message: format!(
"canonical execution plan requires static output buffer `{}` size. Fix: set BufferDecl::output(...).with_count(n) before planning.",
buffer.name()
),
},
});
}
let visible = if let Some(range) = output_range.clone() {
if range.start > range.end || range.end as u64 > full_size {
return Err(PlanError::InvalidOutputRange {
name: buffer.name().to_string(),
start: range.start,
end: range.end,
full_size,
});
}
(range.end - range.start) as u64
} else {
full_size
};
visible_readback_bytes += visible;
avoided_readback_bytes += full_size.saturating_sub(visible);
}
buffers.push(BufferPlan {
name: buffer.name().to_string(),
binding: buffer.binding(),
access: buffer.access(),
kind: buffer.kind(),
element: buffer.element(),
count: buffer.count(),
static_size_bytes: size,
output_range,
});
}
Ok(MemoryPlan {
buffers,
static_bytes,
dynamic_buffers: program.buffers().iter().filter(|b| b.count() == 0).count(),
visible_readback_bytes,
avoided_readback_bytes,
})
}
fn provenance_plan(program: &Program, _fusion: &FusionPlan) -> ProvenancePlan {
ProvenancePlan {
top_level_region_wrapped: program.is_top_level_region_wrapped(),
region_count: program.stats().region_count as usize,
emit_region_trace: program.is_top_level_region_wrapped(),
}
}
fn accuracy_plan(caps: &RequiredCapabilities, _provenance: &ProvenancePlan) -> AccuracyPlan {
AccuracyPlan {
shadow_reference_recommended: caps.subgroup_ops,
reason: if caps.subgroup_ops {
"subgroup semantics"
} else {
"baseline"
},
}
}
fn autotune_plan(
program: &Program,
_caps: &RequiredCapabilities,
_fusion: &FusionPlan,
adapter_caps: &AdapterCaps,
) -> AutotunePlan {
let node_count = count_nodes(program.entry());
let policy = SchedulingPolicy::standard();
let problem_size = infer_static_problem_size(program);
let recommended_workgroup_size = [
policy.select_workgroup_x(
program.parallel_region_size()[0],
problem_size,
adapter_caps,
),
1,
1,
];
let recommended_tile =
policy.select_workgroup_tile(program.parallel_region_size(), problem_size, adapter_caps);
let recommended_vector_pack_bits = policy.select_vector_pack_bits(32, adapter_caps);
let recommended_unroll_depth = policy.select_unroll_depth(None, adapter_caps);
let profile_driven = adapter_caps.ideal_unroll_depth > 0
|| adapter_caps.ideal_vector_pack_bits > 0
|| !adapter_caps.ideal_workgroup_tile.contains(&0);
AutotunePlan {
recommended: policy.recommend_autotune(node_count) || profile_driven,
parallel_region_size: program.parallel_region_size(),
recommended_workgroup_size,
recommended_tile,
recommended_vector_pack_bits,
recommended_unroll_depth,
reason: if profile_driven {
"device profile"
} else if policy.recommend_autotune(node_count) {
"large program"
} else {
"none"
},
}
}
fn infer_static_problem_size(program: &Program) -> Option<u32> {
program
.buffers()
.iter()
.filter(|buffer| buffer.count() > 0 && !matches!(buffer.kind(), MemoryKind::Shared))
.map(|buffer| buffer.count())
.min()
}
fn track_decisions(
fusion: &FusionPlan,
memory: &MemoryPlan,
_provenance: &ProvenancePlan,
accuracy: &AccuracyPlan,
autotune: &AutotunePlan,
) -> Vec<TrackDecision> {
vec![
track_decision(
InnovationTrack::WholeProgramFusion,
fusion.batch_fusion_candidate,
"fusion",
),
track_decision(
InnovationTrack::PersistentExecution,
SchedulingPolicy::standard().use_persistent_runtime(fusion.node_count),
"persistent",
),
track_decision(
InnovationTrack::DifferentialAccuracy,
accuracy.shadow_reference_recommended,
accuracy.reason,
),
track_decision(
InnovationTrack::ConformanceGuidedAutotune,
autotune.recommended,
autotune.reason,
),
track_decision(InnovationTrack::GpuResidentProvenance, false, "none"),
track_decision(
InnovationTrack::DataLayoutCompiler,
memory.static_bytes > 0,
"layout",
),
track_decision(
InnovationTrack::ReadbackMinimization,
memory.avoided_readback_bytes > 0,
"trimmed readback",
),
]
}
fn track_decision(track: InnovationTrack, active: bool, reason: &'static str) -> TrackDecision {
TrackDecision {
track,
active,
reason,
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct FusionPlan {
pub entry_op_id: Option<String>,
pub top_level_regions: usize,
pub node_count: usize,
pub batch_fusion_candidate: bool,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MemoryPlan {
pub buffers: Vec<BufferPlan>,
pub static_bytes: u64,
pub dynamic_buffers: usize,
pub visible_readback_bytes: u64,
pub avoided_readback_bytes: u64,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct BufferPlan {
pub name: String,
pub binding: u32,
pub access: BufferAccess,
pub kind: MemoryKind,
pub element: DataType,
pub count: u32,
pub static_size_bytes: Option<u64>,
pub output_range: Option<Range<usize>>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ProvenancePlan {
pub top_level_region_wrapped: bool,
pub region_count: usize,
pub emit_region_trace: bool,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct AccuracyPlan {
pub shadow_reference_recommended: bool,
pub reason: &'static str,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct AutotunePlan {
pub recommended: bool,
pub parallel_region_size: [u32; 3],
pub recommended_workgroup_size: [u32; 3],
pub recommended_tile: [u32; 3],
pub recommended_vector_pack_bits: u32,
pub recommended_unroll_depth: u32,
pub reason: &'static str,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
fn trivial_program() -> Program {
Program::wrapped(
vec![
BufferDecl::read("input", 0, DataType::U32).with_count(4),
BufferDecl::output("out", 1, DataType::U32).with_count(1),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::load("input", Expr::u32(0)),
)],
)
}
#[test]
fn plan_succeeds_on_trivial_program() {
let p = trivial_program();
let exec_plan = plan(&p).expect("Fix: plan should succeed on trivial program; restore this invariant before continuing.");
assert!(exec_plan.memory.static_bytes > 0);
assert_eq!(exec_plan.memory.dynamic_buffers, 0);
}
#[test]
fn plan_fingerprint_is_deterministic() {
let p = trivial_program();
let plan1 = plan(&p).unwrap();
let plan2 = plan(&p).unwrap();
assert_eq!(plan1.program_fingerprint, plan2.program_fingerprint);
}
#[test]
fn count_nodes_simple() {
let nodes = vec![
Node::let_bind("x", Expr::u32(1)),
Node::let_bind("y", Expr::u32(2)),
];
assert_eq!(count_nodes(&nodes), 2);
}
#[test]
fn count_nodes_nested_if() {
let nodes = vec![Node::if_then_else(
Expr::u32(1),
vec![Node::let_bind("x", Expr::u32(1))],
vec![
Node::let_bind("a", Expr::u32(2)),
Node::let_bind("b", Expr::u32(3)),
],
)];
assert_eq!(count_nodes(&nodes), 4);
}
#[test]
fn track_active_returns_false_for_inactive() {
let p = trivial_program();
let exec_plan = plan(&p).unwrap();
assert!(!exec_plan.track_active(InnovationTrack::GpuResidentProvenance));
}
#[test]
fn plan_tiny_program_uses_persistent_dispatch() {
let p = trivial_program();
let exec_plan = plan(&p).unwrap();
assert_eq!(
exec_plan.strategy.dispatch,
DispatchStrategy::PersistentRuntime
);
}
#[test]
fn device_profile_changes_autotune_recommendations() {
let p = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32).with_count(4096)],
[1, 1, 1],
vec![Node::store("out", Expr::gid_x(), Expr::u32(1))],
);
let compact = AdapterCaps {
max_workgroup_size: [256, 256, 64],
max_invocations_per_workgroup: 256,
subgroup_size: 32,
ideal_unroll_depth: 4,
ideal_vector_pack_bits: 64,
ideal_workgroup_tile: [8, 8, 1],
..AdapterCaps::conservative()
};
let wide = AdapterCaps {
ideal_unroll_depth: 8,
ideal_vector_pack_bits: 128,
ideal_workgroup_tile: [16, 16, 1],
..compact
};
let compact_plan = plan_for_adapter(&p, &compact).unwrap();
let wide_plan = plan_for_adapter(&p, &wide).unwrap();
assert_eq!(compact_plan.autotune.recommended_workgroup_size, [64, 1, 1]);
assert_eq!(wide_plan.autotune.recommended_workgroup_size, [256, 1, 1]);
assert_eq!(compact_plan.autotune.recommended_tile, [8, 8, 1]);
assert_eq!(wide_plan.autotune.recommended_tile, [16, 16, 1]);
assert_eq!(compact_plan.autotune.recommended_vector_pack_bits, 64);
assert_eq!(wide_plan.autotune.recommended_vector_pack_bits, 128);
assert_eq!(compact_plan.autotune.recommended_unroll_depth, 4);
assert_eq!(wide_plan.autotune.recommended_unroll_depth, 8);
}
}