use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use svod_schedule::{Scheduler, apply_post_optimization_with_renderer, beam_search_cached, prepare_scheduler};
use tracing::{debug, trace};
use crate::{
PrepareConfig, Result, Tensor,
error::{
BatchOutputMismatchSnafu, CompileKernelSnafu, CreateProgramSnafu, DeviceSnafu, EmptyScheduleSnafu,
ExecutionSnafu, IrConstructionSnafu, OptimizeSnafu, RangeifySnafu, RenderKernelSnafu, ShapeUnknownSnafu,
UOpSnafu,
},
schedule::ScheduleItem,
};
use snafu::{OptionExt, ResultExt};
use std::sync::Arc;
use std::time::Duration;
use svod_device::{Buffer, device::Device};
use svod_ir::pattern::is_any_const;
use svod_ir::{DeviceSpec, Op, UOp, UOpKey};
use svod_runtime::{
ExecutionPlan, ExecutionPlanBuilder, PreparedBufferView, PreparedCopy, PreparedCustomFunction, PreparedKernel,
PreparedOp,
};
fn collect_pending_indices(tensors: &[&mut Tensor]) -> Vec<usize> {
tensors
.iter()
.enumerate()
.filter(|(_, t)| !t.uop().has_buffer_identity() && !is_any_const(&t.uop()) && !t.has_zero_elements())
.map(|(i, _)| i)
.collect()
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct BufferStorageKey {
id: u64,
offset: usize,
size: usize,
dtype: svod_dtype::DType,
}
impl Tensor {
pub fn realize(&mut self) -> Result<()> {
if self.uop().has_buffer_identity() {
self.ensure_buffer();
return Ok(());
}
if is_any_const(&self.uop()) {
let contiguous_uop = self.uop().contiguous();
self.set_uop(contiguous_uop);
}
if self.has_zero_elements() {
return Ok(());
}
let old_uop = self.uop();
let input_buffer_ids: HashSet<u64> = collect_input_buffers(&old_uop).keys().copied().collect();
let t_prep = std::time::Instant::now();
let plan = self.prepare()?;
let prep_ms = t_prep.elapsed().as_millis();
let t_exec = std::time::Instant::now();
plan.execute().context(ExecutionSnafu)?;
let exec_ms = t_exec.elapsed().as_millis();
debug!(prep_ms, exec_ms, "realize complete");
self.finalize_realize(&plan, &old_uop)?;
let realized_uop = self.uop();
if !Arc::ptr_eq(&old_uop, &realized_uop) {
#[allow(clippy::mutable_key_type)]
let becomes_map = HashMap::from([(UOpKey(old_uop), realized_uop)]);
crate::tensor_registry::apply_map_to_tensors(&becomes_map);
}
plan.release_intermediate_buffers(|uop_id| {
if !input_buffer_ids.contains(&uop_id) {
crate::tensor_registry::remove_buffer(uop_id);
}
});
Ok(())
}
pub fn realize_with(&mut self, config: &PrepareConfig) -> Result<()> {
if self.uop().has_buffer_identity() {
self.ensure_buffer();
return Ok(());
}
if is_any_const(&self.uop()) {
let contiguous_uop = self.uop().contiguous();
self.set_uop(contiguous_uop);
}
if self.has_zero_elements() {
return Ok(());
}
let old_uop = self.uop();
let input_buffer_ids: HashSet<u64> = collect_input_buffers(&old_uop).keys().copied().collect();
let t_prep = std::time::Instant::now();
let plan = self.prepare_with(config)?;
let prep_ms = t_prep.elapsed().as_millis();
let t_exec = std::time::Instant::now();
plan.execute().context(ExecutionSnafu)?;
let exec_ms = t_exec.elapsed().as_millis();
debug!(prep_ms, exec_ms, "realize_with complete");
self.finalize_realize(&plan, &old_uop)?;
let realized_uop = self.uop();
if !Arc::ptr_eq(&old_uop, &realized_uop) {
#[allow(clippy::mutable_key_type)]
let becomes_map = HashMap::from([(UOpKey(old_uop), realized_uop)]);
crate::tensor_registry::apply_map_to_tensors(&becomes_map);
}
plan.release_intermediate_buffers(|uop_id| {
if !input_buffer_ids.contains(&uop_id) {
crate::tensor_registry::remove_buffer(uop_id);
}
});
Ok(())
}
fn finalize_realize(&mut self, plan: &ExecutionPlan, uop: &Arc<UOp>) -> Result<()> {
let output_buf = plan.output_buffer().expect("realized plan must have an output buffer").clone();
trace!(
buffer.id = ?output_buf.id(),
buffer.size = output_buf.size(),
"Realized output buffer"
);
let output_dtype = uop.dtype();
let output_device = output_buf.allocator().device_spec();
let num_elements = output_buf.size() / output_dtype.bytes();
let buffer_uop = UOp::new_buffer(output_device, num_elements, output_dtype.clone());
let output_buf_arc = Arc::new(output_buf);
crate::tensor_registry::register_buffer(buffer_uop.id, self.entry.id, output_buf_arc.clone());
let shape = uop.shape().context(UOpSnafu)?.context(ShapeUnknownSnafu)?;
let realized_uop = buffer_uop.try_reshape(shape).context(UOpSnafu)?;
debug!(
buffer_uop.id = buffer_uop.id,
num_elements,
shape = ?shape,
realized_uop.id = realized_uop.id,
realized_uop.base_id = realized_uop.base().id,
"Tensor realized"
);
self.set_uop(realized_uop);
self.entry.set_buffer(Arc::clone(&output_buf_arc));
self.buffer = Some(output_buf_arc);
Ok(())
}
pub fn prepare(&mut self) -> Result<ExecutionPlan> {
self.prepare_with(&PrepareConfig::from_env())
}
pub fn prepare_with(&mut self, config: &PrepareConfig) -> Result<ExecutionPlan> {
let t_total = std::time::Instant::now();
let uop = self.uop();
let sink = UOp::sink(vec![uop.contiguous()]);
let schedule_result = schedule_result_from_sink_with_cache(sink, extract_var_vals(&uop)?, config)?;
let plan = prepare_execution_plan(&schedule_result, config)?;
self.wire_output_tensor(&plan, &uop)?;
debug!(total_ms = t_total.elapsed().as_millis() as u64, "prepare: total");
Ok(plan)
}
fn wire_output_tensor(&mut self, plan: &ExecutionPlan, uop: &Arc<UOp>) -> Result<()> {
if plan.num_outputs() > 0 {
let buf = Arc::new(plan.output_buffer().expect("plan with num_outputs > 0 must expose output").clone());
let dtype = uop.dtype();
let device = buf.allocator().device_spec();
let buffer_uop = UOp::new_buffer(device, buf.size() / dtype.bytes(), dtype);
crate::tensor_registry::register_buffer(buffer_uop.id, self.entry.id, buf.clone());
let shape = uop.shape().context(UOpSnafu)?.context(ShapeUnknownSnafu)?;
self.set_uop(buffer_uop.try_reshape(shape).context(UOpSnafu)?);
self.entry.set_buffer(buf.clone());
self.buffer = Some(buf);
}
Ok(())
}
pub fn realize_batch<'a>(tensors: impl IntoIterator<Item = &'a mut Tensor>) -> Result<()> {
Self::realize_batch_with(tensors, &PrepareConfig::from_env())
}
pub fn realize_batch_with<'a>(
tensors: impl IntoIterator<Item = &'a mut Tensor>,
config: &PrepareConfig,
) -> Result<()> {
let mut tensors: Vec<&mut Tensor> = tensors.into_iter().collect();
if tensors.is_empty() {
return Ok(());
}
for t in &mut tensors {
if t.uop().has_buffer_identity() {
t.ensure_buffer();
}
}
for t in &mut tensors {
if !t.uop().has_buffer_identity() && is_any_const(&t.uop()) {
let contiguous_uop = t.uop().contiguous();
t.set_uop(contiguous_uop);
}
}
let pending_indices = collect_pending_indices(&tensors);
if pending_indices.is_empty() {
return Ok(());
}
let old_uops: Vec<Arc<UOp>> = pending_indices.iter().map(|&i| tensors[i].uop()).collect();
let mut all_input_buffers = crate::schedule::InputBuffers::new();
for uop in &old_uops {
all_input_buffers.extend(collect_input_buffers(uop));
}
let input_ids: HashSet<u64> = all_input_buffers.keys().copied().collect();
let contiguouses: Vec<Arc<UOp>> = old_uops.iter().map(|u| u.contiguous()).collect();
let sink = UOp::sink(contiguouses);
let mut var_vals = HashMap::new();
for uop in &old_uops {
let extracted = extract_var_vals(uop)?;
merge_var_vals_checked(&mut var_vals, &extracted, "realize_batch input collection")?;
}
let schedule_result = schedule_result_from_sink_with_cache(sink, var_vals, config)?;
let t_prep = std::time::Instant::now();
let plan = prepare_execution_plan(&schedule_result, config)?;
let prep_ms = t_prep.elapsed().as_millis();
let t_exec = std::time::Instant::now();
plan.execute().context(ExecutionSnafu)?;
let exec_ms = t_exec.elapsed().as_millis();
debug!(prep_ms, exec_ms, num_outputs = pending_indices.len(), "realize_batch complete");
snafu::ensure!(
plan.num_outputs() >= pending_indices.len(),
BatchOutputMismatchSnafu { expected: pending_indices.len(), actual: plan.num_outputs() }
);
#[allow(clippy::mutable_key_type)]
let mut becomes_map = HashMap::new();
for (buf_idx, &orig_idx) in pending_indices.iter().enumerate() {
let output_buf = plan.output_buffer_at(buf_idx).expect("buf_idx in range").clone();
let old_uop = &old_uops[buf_idx];
let output_dtype = old_uop.dtype();
let output_device = output_buf.allocator().device_spec();
let num_elements = output_buf.size() / output_dtype.bytes();
let buffer_uop = UOp::new_buffer(output_device, num_elements, output_dtype);
let buf_arc = Arc::new(output_buf);
let t = &mut tensors[orig_idx];
crate::tensor_registry::register_buffer(buffer_uop.id, t.entry.id, buf_arc.clone());
let shape = old_uop.shape().context(UOpSnafu)?.context(ShapeUnknownSnafu)?;
let realized_uop = buffer_uop.try_reshape(shape).context(UOpSnafu)?;
t.set_uop(realized_uop.clone());
t.entry.set_buffer(Arc::clone(&buf_arc));
t.buffer = Some(buf_arc);
becomes_map.insert(UOpKey(old_uop.clone()), realized_uop);
}
crate::tensor_registry::apply_map_to_tensors(&becomes_map);
plan.release_intermediate_buffers(|id| {
if !input_ids.contains(&id) {
crate::tensor_registry::remove_buffer(id);
}
});
Ok(())
}
pub fn prepare_batch<'a>(tensors: impl IntoIterator<Item = &'a mut Tensor>) -> Result<ExecutionPlan> {
Self::prepare_batch_with(tensors, &PrepareConfig::from_env())
}
pub fn prepare_batch_with<'a>(
tensors: impl IntoIterator<Item = &'a mut Tensor>,
config: &PrepareConfig,
) -> Result<ExecutionPlan> {
let mut tensors: Vec<&mut Tensor> = tensors.into_iter().collect();
if tensors.is_empty() {
return EmptyScheduleSnafu.fail();
}
for t in &mut tensors {
if t.uop().has_buffer_identity() {
t.ensure_buffer();
}
}
for t in &mut tensors {
if !t.uop().has_buffer_identity() && is_any_const(&t.uop()) {
let contiguous_uop = t.uop().contiguous();
t.set_uop(contiguous_uop);
}
}
let pending_indices = collect_pending_indices(&tensors);
if pending_indices.is_empty() {
return EmptyScheduleSnafu.fail();
}
let uops: Vec<Arc<UOp>> = pending_indices.iter().map(|&i| tensors[i].uop()).collect();
let mut var_vals = HashMap::new();
for uop in &uops {
let extracted = extract_var_vals(uop)?;
merge_var_vals_checked(&mut var_vals, &extracted, "prepare_batch input collection")?;
}
let contiguouses: Vec<Arc<UOp>> = uops.iter().map(|u| u.contiguous()).collect();
let sink = UOp::sink(contiguouses);
let schedule_result = schedule_result_from_sink_with_cache(sink, var_vals, config)?;
let plan = prepare_execution_plan(&schedule_result, config)?;
for (buf_idx, &orig_idx) in pending_indices.iter().enumerate() {
if buf_idx >= plan.num_outputs() {
break;
}
let output_buf = plan.output_buffer_at(buf_idx).expect("buf_idx in range").clone();
let buf_arc = Arc::new(output_buf);
let old_uop = &uops[buf_idx];
let output_dtype = old_uop.dtype();
let output_device = buf_arc.allocator().device_spec();
let num_elements = buf_arc.size() / output_dtype.bytes();
let buffer_uop = UOp::new_buffer(output_device, num_elements, output_dtype);
let t = &mut tensors[orig_idx];
crate::tensor_registry::register_buffer(buffer_uop.id, t.entry.id, buf_arc.clone());
let shape = old_uop.shape().context(UOpSnafu)?.context(ShapeUnknownSnafu)?;
let realized_uop = buffer_uop.try_reshape(shape).context(UOpSnafu)?;
t.set_uop(realized_uop);
t.entry.set_buffer(Arc::clone(&buf_arc));
t.buffer = Some(buf_arc);
}
Ok(plan)
}
}
fn try_bind_var_val(var_vals: &mut HashMap<String, i64>, name: &str, val: i64) -> std::result::Result<(), (i64, i64)> {
if let Some(&prev) = var_vals.get(name) {
if prev != val {
return Err((prev, val));
}
return Ok(());
}
var_vals.insert(name.to_string(), val);
Ok(())
}
fn insert_var_val_checked(var_vals: &mut HashMap<String, i64>, name: &str, val: i64, context: &str) -> Result<()> {
match try_bind_var_val(var_vals, name, val) {
Ok(()) => Ok(()),
Err((prev, val)) => {
IrConstructionSnafu { details: format!("bind mismatch on {name}, {prev} != {val} ({context})") }.fail()
}
}
}
fn merge_var_vals_checked(dst: &mut HashMap<String, i64>, src: &HashMap<String, i64>, context: &str) -> Result<()> {
for (name, val) in src {
insert_var_val_checked(dst, name, *val, context)?;
}
Ok(())
}
fn extract_var_vals(root: &Arc<UOp>) -> Result<HashMap<String, i64>> {
let mut var_vals = HashMap::new();
for node in root.toposort() {
if let Op::Bind { var, value } = node.op()
&& let Op::DefineVar { name, .. } = var.op()
&& let Op::Const(cv) = value.op()
&& let Some(val) = cv.0.try_int()
{
insert_var_val_checked(&mut var_vals, name, val, "bind extraction")?;
}
}
Ok(var_vals)
}
fn schedule_cache_disabled_by_env() -> bool {
static DISABLED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*DISABLED.get_or_init(|| std::env::var("SVOD_DISABLE_SCHEDULE_CACHE").as_deref() == Ok("1"))
}
fn schedule_result_from_sink_with_cache(
sink: Arc<UOp>,
mut var_vals: HashMap<String, i64>,
config: &PrepareConfig,
) -> Result<crate::schedule::ScheduleResult> {
if config.disable_schedule_cache || schedule_cache_disabled_by_env() {
return schedule_result_from_sink_uncached(sink, var_vals, config);
}
let normalization = normalize_for_schedule_cache(&sink)?;
merge_var_vals_checked(&mut var_vals, &normalization.var_vals, "schedule cache normalization")?;
let codegen = resolve_codegen(&normalization.param_buffers, config)?;
let sched_key = (crate::schedule_cache::content_hash(&normalization.normalized), codegen);
let cache = crate::schedule_cache::schedule_cache();
let entry = {
let guard = cache.guard();
cache.get(&sched_key, &guard).cloned()
};
let entry = match entry {
Some(hit) => {
debug!("schedule cache hit");
hit
}
None => {
let schedule_root = restore_bind_placeholders_for_schedule(&normalization.normalized, &normalization);
let rangeify_result = svod_schedule::rangeify_with_map(schedule_root, None).context(RangeifySnafu)?;
let (kernel_graph, _) = svod_schedule::try_get_kernel_graph(rangeify_result.sink).context(RangeifySnafu)?;
let pre_schedule = crate::schedule::create_pre_schedule(kernel_graph)?;
let new_entry = Arc::new(crate::schedule_cache::CachedSchedule { pre_schedule: Arc::new(pre_schedule) });
let guard = cache.guard();
cache.insert(sched_key, Arc::clone(&new_entry), &guard);
new_entry
}
};
let restored_pre_schedule = restore_post_schedule_pre_schedule(&entry.pre_schedule, &normalization);
let schedule_input_buffers = build_schedule_input_buffers(&restored_pre_schedule, &normalization);
let result = crate::schedule::instantiate_schedule(&restored_pre_schedule, &schedule_input_buffers, &var_vals)?;
Ok(result)
}
fn schedule_result_from_sink_uncached(
sink: Arc<UOp>,
var_vals: HashMap<String, i64>,
_config: &PrepareConfig,
) -> Result<crate::schedule::ScheduleResult> {
let rangeify_result = svod_schedule::rangeify_with_map(sink, None).context(RangeifySnafu)?;
let (kernel_graph, _) = svod_schedule::try_get_kernel_graph(rangeify_result.sink).context(RangeifySnafu)?;
let pre_schedule = crate::schedule::create_pre_schedule(kernel_graph.clone())?;
let input_buffers = collect_input_buffers(&kernel_graph);
let result = crate::schedule::instantiate_schedule(&pre_schedule, &input_buffers, &var_vals)?;
Ok(result)
}
pub(crate) struct ScheduleCacheNormalization {
pub normalized: Arc<UOp>,
pub param_values: Vec<Arc<UOp>>,
pub param_buffers: Vec<(u64, Arc<UOp>)>,
pub unique_values: Vec<Arc<UOp>>,
pub var_vals: HashMap<String, i64>,
}
pub(crate) struct NormalizeScheduleCacheCtx {
pub param_map: HashMap<u64, usize>,
pub param_values: Vec<Arc<UOp>>,
pub param_buffers: Vec<(u64, Arc<UOp>)>,
pub var_vals: HashMap<String, i64>,
pub bind_mismatch: Option<String>,
}
pub(crate) fn normalize_for_schedule_cache(sink: &Arc<UOp>) -> Result<ScheduleCacheNormalization> {
let mut ctx = NormalizeScheduleCacheCtx {
param_map: HashMap::new(),
param_values: Vec::new(),
param_buffers: Vec::new(),
var_vals: HashMap::new(),
bind_mismatch: None,
};
use svod_ir::op::pattern_derived::OpKey;
use svod_ir::pattern::{RewriteResult, SimplifiedPatternMatcher};
use svod_ir::rewrite::graph_rewrite;
let mut matcher = SimplifiedPatternMatcher::<NormalizeScheduleCacheCtx>::new();
fn to_param(
node: &Arc<UOp>,
ctx: &mut NormalizeScheduleCacheCtx,
size: usize,
device: Option<Arc<UOp>>,
) -> Arc<UOp> {
let slot = *ctx.param_map.entry(node.id).or_insert_with(|| {
let s = ctx.param_values.len();
ctx.param_values.push(node.clone());
s
});
UOp::param(slot, size, node.dtype(), device)
}
matcher.add(&[OpKey::Buffer], |node, ctx| {
let Op::Buffer { size, device, .. } = node.op() else {
return RewriteResult::NoMatch;
};
let slot = *ctx.param_map.entry(node.id).or_insert_with(|| {
let s = ctx.param_values.len();
ctx.param_values.push(node.clone());
s
});
ctx.param_buffers.push((node.id, node.clone()));
RewriteResult::Rewritten(UOp::param(slot, *size, node.dtype(), Some(device.clone())))
});
matcher.add(&[OpKey::BufferView], |node, ctx| {
let Op::BufferView { size, .. } = node.op() else {
return RewriteResult::NoMatch;
};
RewriteResult::Rewritten(to_param(node, ctx, *size, Some(UOp::device(DeviceSpec::Cpu))))
});
matcher.add(&[OpKey::Bind], |node, ctx| {
let Op::Bind { var, value } = node.op() else {
return RewriteResult::NoMatch;
};
let Op::DefineVar { name, .. } = var.op() else {
return RewriteResult::NoMatch;
};
let Op::Const(cv) = value.op() else {
return RewriteResult::NoMatch;
};
let Some(val) = cv.0.try_int() else {
return RewriteResult::NoMatch;
};
if let Err((prev, val)) = try_bind_var_val(&mut ctx.var_vals, name, val) {
if ctx.bind_mismatch.is_none() {
ctx.bind_mismatch = Some(format!("bind mismatch on variable {name}: {prev} vs {val}"));
}
return RewriteResult::NoMatch;
}
RewriteResult::Rewritten(to_param(node, ctx, 0, Some(UOp::device(DeviceSpec::Cpu))))
});
let normalized = graph_rewrite(&matcher, sink.clone(), &mut ctx);
if let Some(details) = ctx.bind_mismatch.take() {
return IrConstructionSnafu { details }.fail();
}
struct UniqueNormalizationCtx {
unique_map: HashMap<u64, usize>,
unique_values: Vec<Arc<UOp>>,
}
let mut unique_ctx = UniqueNormalizationCtx { unique_map: HashMap::new(), unique_values: Vec::new() };
let mut unique_matcher = SimplifiedPatternMatcher::<UniqueNormalizationCtx>::new();
unique_matcher.add(&[OpKey::Unique], |node, ctx| {
let Op::Unique(_) = node.op() else {
return RewriteResult::NoMatch;
};
let slot = *ctx.unique_map.entry(node.id).or_insert_with(|| {
let s = ctx.unique_values.len();
ctx.unique_values.push(node.clone());
s
});
RewriteResult::Rewritten(UOp::lunique(Some(slot)))
});
let normalized = graph_rewrite(&unique_matcher, normalized, &mut unique_ctx);
ctx.param_buffers.sort_unstable_by_key(|(id, _)| *id);
ctx.param_buffers.dedup_by_key(|(id, _)| *id);
Ok(ScheduleCacheNormalization {
normalized,
param_values: ctx.param_values,
param_buffers: ctx.param_buffers,
unique_values: unique_ctx.unique_values,
var_vals: ctx.var_vals,
})
}
#[allow(clippy::mutable_key_type)]
pub(crate) fn restore_post_schedule_cache(root: &Arc<UOp>, normalization: &ScheduleCacheNormalization) -> Arc<UOp> {
let mut subs: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
let mut lunique_buffers: HashMap<usize, Arc<UOp>> = HashMap::new();
for node in root.toposort() {
match node.op() {
Op::Param { slot, device: Some(_), .. } => {
if let Some(original) = normalization.param_values.get(*slot) {
let restored_original = restore_post_schedule_cache(original, normalization);
subs.insert(UOpKey(node.clone()), restored_original);
}
}
Op::Buffer { unique, device, size } => {
let Op::LUnique(slot) = unique.op() else {
continue;
};
let restored = if let Some(existing) = lunique_buffers.get(slot) {
existing.clone()
} else {
let runtime_unique = UOp::buffer_id(None);
let fresh = UOp::new(
Op::Buffer { unique: runtime_unique, device: device.clone(), size: *size },
node.dtype(),
);
lunique_buffers.insert(*slot, fresh.clone());
fresh
};
subs.insert(UOpKey(node.clone()), restored);
}
Op::LUnique(slot) => {
let restored = normalization.unique_values.get(*slot).cloned().unwrap_or_else(|| UOp::buffer_id(None));
subs.insert(UOpKey(node.clone()), restored);
}
_ => {}
}
}
root.substitute(&subs)
}
#[allow(clippy::mutable_key_type)]
fn restore_bind_placeholders_for_schedule(root: &Arc<UOp>, normalization: &ScheduleCacheNormalization) -> Arc<UOp> {
let mut subs: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
for node in root.toposort() {
let Op::Param { slot, device: Some(_), .. } = node.op() else {
continue;
};
let Some(original) = normalization.param_values.get(*slot) else {
continue;
};
if matches!(original.op(), Op::Bind { .. }) {
subs.insert(UOpKey(node.clone()), original.clone());
}
}
if subs.is_empty() { root.clone() } else { root.substitute(&subs) }
}
fn restore_post_schedule_pre_schedule(
pre_schedule: &crate::schedule::PreSchedule,
normalization: &ScheduleCacheNormalization,
) -> crate::schedule::PreSchedule {
let mut flat_buf_uops = Vec::new();
let mut source_counts = Vec::with_capacity(pre_schedule.items.len());
for item in &pre_schedule.items {
source_counts.push(item.sources.len());
flat_buf_uops.extend(item.sources.iter().cloned());
}
let outputs_offset = flat_buf_uops.len();
flat_buf_uops.extend(pre_schedule.output_buffer_uops.iter().cloned());
if flat_buf_uops.is_empty() {
return pre_schedule.clone();
}
let restored_flat = match restore_post_schedule_cache(&UOp::sink(flat_buf_uops), normalization).op() {
Op::Sink { sources, .. } => sources.iter().cloned().collect::<Vec<_>>(),
_ => unreachable!("sink substitution must preserve SINK root"),
};
let mut cursor = 0usize;
let mut restored_items = Vec::with_capacity(pre_schedule.items.len());
for (item, source_count) in pre_schedule.items.iter().zip(source_counts) {
let end = cursor + source_count;
let sources = restored_flat[cursor..end].to_vec();
cursor = end;
let ast = restore_post_schedule_cache(&item.ast, normalization);
restored_items.push(crate::schedule::PreScheduleItem {
kernel: item.kernel.clone(),
ast,
sources,
dependencies: item.dependencies.clone(),
bound_ranges: item.bound_ranges.clone(),
});
}
let output_buffer_uops = restored_flat[outputs_offset..].to_vec();
crate::schedule::PreSchedule {
items: restored_items,
invocations: pre_schedule.invocations.clone(),
output_buffer_uops,
}
}
fn build_schedule_input_buffers(
pre_schedule: &crate::schedule::PreSchedule,
_normalization: &ScheduleCacheNormalization,
) -> crate::schedule::InputBuffers {
let mut inputs = crate::schedule::InputBuffers::new();
for item in &pre_schedule.items {
for src in &item.sources {
let buf = src.buf_uop();
if let Op::Buffer { .. } = buf.op()
&& let Some(buffer) = crate::tensor_registry::get_buffer(buf.id)
{
inputs.insert(buf.id, buffer);
}
}
}
inputs
}
fn collect_input_buffers(root: &Arc<UOp>) -> crate::schedule::InputBuffers {
let mut inputs = HashMap::new();
for node in root.toposort() {
if let Op::Buffer { .. } = node.op() {
if let Some(buf) = crate::tensor_registry::get_buffer(node.id) {
inputs.insert(node.id, buf);
}
}
}
inputs
}
fn output_indices_from_program_metadata(globals: &[usize], outs: &[usize], num_buffers: usize) -> Result<Vec<usize>> {
if num_buffers == 0 {
return IrConstructionSnafu { details: "cannot map outputs for kernel with zero buffers".to_string() }.fail();
}
if globals.is_empty() {
return IrConstructionSnafu { details: "ProgramSpec.globals is empty".to_string() }.fail();
}
if outs.is_empty() {
return IrConstructionSnafu { details: "ProgramSpec.outs is empty".to_string() }.fail();
}
let slot_to_position: HashMap<usize, usize> =
globals.iter().copied().enumerate().map(|(position, slot)| (slot, position)).collect();
let mut output_indices = Vec::with_capacity(outs.len());
for &slot in outs {
let Some(position) = slot_to_position.get(&slot).copied() else {
return IrConstructionSnafu {
details: format!("ProgramSpec.outs slot {slot} not found in ProgramSpec.globals={globals:?}"),
}
.fail();
};
if position >= num_buffers {
return IrConstructionSnafu {
details: format!(
"ProgramSpec output index {position} (slot {slot}) out of range for {num_buffers} buffers"
),
}
.fail();
}
output_indices.push(position);
}
output_indices.sort_unstable();
output_indices.dedup();
if output_indices.is_empty() {
return IrConstructionSnafu { details: "ProgramSpec output mapping resolved to empty set".to_string() }.fail();
}
Ok(output_indices)
}
fn resolve_item_buffer_indices(item: &ScheduleItem, uop_id_to_idx: &HashMap<u64, usize>) -> Result<Vec<usize>> {
let mut indices = Vec::with_capacity(item.buffer_uop_ids.len());
for &uop_id in &item.buffer_uop_ids {
let Some(idx) = uop_id_to_idx.get(&uop_id).copied() else {
return Err(crate::error::Error::BufferNotFound { uop_id });
};
indices.push(idx);
}
Ok(indices)
}
fn resolve_compiled_kernel_buffer_indices(
item: &ScheduleItem,
uop_id_to_idx: &HashMap<u64, usize>,
globals: &[usize],
) -> Result<Vec<usize>> {
let buffer_indices = resolve_item_buffer_indices(item, uop_id_to_idx)?;
let mut ordered = Vec::with_capacity(globals.len());
for &position in globals {
let Some(idx) = buffer_indices.get(position).copied() else {
return IrConstructionSnafu {
details: format!(
"ProgramSpec.globals position {position} out of range for CALL {} buffer list len {} (buffer_uop_ids={:?})",
item.kernel.id,
buffer_indices.len(),
item.buffer_uop_ids
),
}
.fail();
};
ordered.push(idx);
}
Ok(ordered)
}
type OptKey = (u64, DeviceSpec, &'static str, u64);
struct OptCacheState {
map: papaya::HashMap<OptKey, Arc<svod_runtime::kernel_cache::CachedKernel>>,
fifo: parking_lot::Mutex<std::collections::VecDeque<OptKey>>,
cap: usize,
}
impl OptCacheState {
const DEFAULT_CAP: usize = 4096;
fn new() -> Self {
let cap = std::env::var("SVOD_OPT_CACHE_MAX")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.filter(|&n| n > 0)
.unwrap_or(Self::DEFAULT_CAP);
Self { map: papaya::HashMap::new(), fifo: parking_lot::Mutex::new(std::collections::VecDeque::new()), cap }
}
fn insert(&self, key: OptKey, val: Arc<svod_runtime::kernel_cache::CachedKernel>) {
let guard = self.map.guard();
let was_new = self.map.insert(key.clone(), val, &guard).is_none();
if !was_new {
return;
}
let mut fifo = self.fifo.lock();
fifo.push_back(key);
while fifo.len() > self.cap {
if let Some(evict) = fifo.pop_front() {
self.map.remove(&evict, &guard);
}
}
}
}
pub(crate) fn runtime_effect_ast(ast: &Arc<UOp>) -> &Arc<UOp> {
match ast.op() {
Op::End { computation, .. }
if matches!(computation.op(), Op::Copy { .. } | Op::BufferView { .. } | Op::CustomFunction { .. }) =>
{
computation
}
_ => ast,
}
}
fn optimizer_config_fingerprint(config: &PrepareConfig) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
config.optimizer.hash(&mut hasher);
hasher.finish()
}
fn prepare_execution_plan(
schedule_result: &crate::schedule::ScheduleResult,
config: &PrepareConfig,
) -> Result<ExecutionPlan> {
let mut schedule_items = schedule_result.items.clone();
let planner_mode = crate::memory_planner::mode_from_env();
let output_buffer_ids = collect_output_buffer_ids(&schedule_items, &schedule_result.output_uop_ids);
let planner_result = crate::memory_planner::memory_planner(&schedule_items, &output_buffer_ids, planner_mode);
if !planner_result.buffer_replace.is_empty() {
trace!(
replacements = planner_result.buffer_replace.len(),
buffers_reused = planner_result.buffers_reused,
memory_saved_bytes = planner_result.memory_saved,
"applying memory planner buffer replacements"
);
crate::memory_planner::apply_reuse_dependencies(&mut schedule_items, &planner_result.reuse_dependencies);
crate::memory_planner::apply_buffer_replacements(&mut schedule_items, &planner_result.buffer_replace);
}
debug!(num_items = schedule_items.len(), "schedule items ready for execution plan");
let alloc_registry = svod_device::registry::registry();
let plan_device = if !schedule_items.is_empty() {
let device_spec = schedule_items
.iter()
.flat_map(|item| item.buffers.iter().map(|b| b.allocator().device_spec()))
.find(|spec| !spec.is_disk())
.unwrap_or(DeviceSpec::Cpu);
config.resolve_device(&device_spec, alloc_registry)?
} else {
return EmptyScheduleSnafu.fail();
};
let optimizer_fingerprint = optimizer_config_fingerprint(config);
let mut builder = ExecutionPlanBuilder::new(plan_device.device.clone());
let mut uop_id_to_idx: HashMap<u64, usize> = HashMap::new();
let mut storage_to_idx: HashMap<BufferStorageKey, usize> = HashMap::new();
let buffer_view_output_uop_ids: HashSet<u64> = schedule_items
.iter()
.filter_map(|item| {
if matches!(runtime_effect_ast(&item.ast).op(), Op::BufferView { .. }) {
item.buffer_uop_ids.first().copied()
} else {
None
}
})
.collect();
for item in &schedule_items {
for (buffer, &uop_id) in item.buffers.iter().zip(item.buffer_uop_ids.iter()) {
buffer.ensure_allocated().context(DeviceSnafu)?;
if uop_id_to_idx.contains_key(&uop_id) {
continue;
}
let storage_key = BufferStorageKey {
id: buffer.id().0,
offset: buffer.offset(),
size: buffer.size(),
dtype: buffer.dtype(),
};
let idx = if !buffer_view_output_uop_ids.contains(&uop_id) {
if let Some(&existing_idx) = storage_to_idx.get(&storage_key) {
builder.map_buffer(uop_id, existing_idx);
existing_idx
} else {
let new_idx = builder.add_buffer(uop_id, buffer.clone());
storage_to_idx.insert(storage_key, new_idx);
new_idx
}
} else {
builder.add_buffer(uop_id, buffer.clone())
};
uop_id_to_idx.insert(uop_id, idx);
}
builder.add_alias_ids(item.alias_registered_ids.iter().copied());
}
static OPT_CACHE: std::sync::OnceLock<OptCacheState> = std::sync::OnceLock::new();
let opt_state = OPT_CACHE.get_or_init(OptCacheState::new);
let opt_cache = &opt_state.map;
let opt_guard = opt_cache.guard();
for item in &schedule_items {
let runtime_ast = runtime_effect_ast(&item.ast);
if matches!(runtime_ast.op(), Op::Copy { .. }) {
let buffer_indices = resolve_item_buffer_indices(item, &uop_id_to_idx)?;
builder.add_op_with_instance_dependencies(
PreparedOp::BufferCopy(PreparedCopy {
id: item.kernel.id,
buffer_indices,
dependencies: item.dependencies.clone(),
}),
item.instance_dependencies.clone(),
);
continue;
}
if let Op::BufferView { size, offset, .. } = runtime_ast.op() {
let buffer_indices = resolve_item_buffer_indices(item, &uop_id_to_idx)?;
if item.buffers.len() >= 2 && item.buffer_uop_ids.len() >= 2 && buffer_indices.len() >= 2 {
let base = &item.buffers[1];
let byte_offset = offset * base.dtype().bytes();
let byte_size = size * runtime_ast.dtype().bytes();
let view = base.view(byte_offset, byte_size).map_err(|e| crate::error::Error::IrConstruction {
details: format!(
"BUFFER_VIEW failed for kernel {}: base_buffer_id={}, byte_offset={}, byte_size={}: {e}",
item.kernel.id,
base.id().0,
byte_offset,
byte_size
),
})?;
let output_uop_id = item.buffer_uop_ids[0];
if let Some(&idx) = uop_id_to_idx.get(&output_uop_id) {
builder.replace_buffer(idx, view);
}
builder.add_op_with_instance_dependencies(
PreparedOp::BufferView(PreparedBufferView {
id: item.kernel.id,
buffer_indices,
byte_offset,
byte_size,
dependencies: item.dependencies.clone(),
}),
item.instance_dependencies.clone(),
);
}
continue;
}
if let Op::CustomFunction { kind, attrs } = runtime_ast.op() {
let buffer_indices = resolve_item_buffer_indices(item, &uop_id_to_idx)?;
let runtime_vars = attrs.iter().flat_map(svod_runtime::execution_plan::collect_runtime_vars).collect();
builder.add_op_with_instance_dependencies(
PreparedOp::CustomFunction(PreparedCustomFunction {
id: item.kernel.id,
kind: kind.clone(),
attrs: attrs.clone(),
buffer_indices,
fixedvars: item.fixedvars.clone(),
dependencies: item.dependencies.clone(),
runtime_vars,
}),
item.instance_dependencies.clone(),
);
continue;
}
let item_device_spec = item
.buffers
.iter()
.map(|b| b.allocator().device_spec())
.find(|spec| !spec.is_disk())
.unwrap_or(DeviceSpec::Cpu);
let item_device = config.resolve_device(&item_device_spec, alloc_registry)?;
let item_codegen: &'static str = item_device.compiler.cache_key();
let opt_key = (
crate::schedule_cache::content_hash(&item.ast),
item_device.device.clone(),
item_codegen,
optimizer_fingerprint,
);
let cached = if let Some(cached) = opt_cache.get(&opt_key, &opt_guard) {
Arc::clone(cached)
} else {
let optimizer_renderer = get_optimizer_renderer(&item_device);
let optimized_ast = if let svod_schedule::OptStrategy::Beam { .. } = config.optimizer.strategy {
beam_search_optimize(
item.ast.clone(),
&optimizer_renderer,
&item_device,
&item.buffers,
&config.optimizer,
)?
} else {
svod_schedule::optimize_kernel_with_config(item.ast.clone(), &optimizer_renderer, &config.optimizer)
};
let kernel_name =
optimized_ast.metadata::<svod_schedule::optimizer::KernelInfo>().map(|info| info.function_name());
let ast_decomposed = match item_device.renderer.decompositor() {
Some(matcher) => svod_ir::decompositions::decompose_with(&optimized_ast, &matcher),
None => optimized_ast,
};
let program = svod_codegen::program_pipeline::program_from_sink(ast_decomposed, item_device.device.clone());
let result = svod_runtime::kernel_cache::get_or_compile_kernel(
crate::schedule_cache::content_hash(&program),
item_codegen,
|| {
let (spec, compiled) = compile_with_program_pipeline_components(
program.clone(),
item_device.renderer.as_ref(),
item_device.compiler.as_ref(),
kernel_name.as_deref(),
)?;
let program = (item_device.runtime)(&compiled).context(CreateProgramSnafu)?;
Ok(svod_runtime::kernel_cache::CachedKernel {
program,
device: item_codegen.to_string(),
code: spec.src.clone(),
entry_point: spec.name.clone(),
var_names: spec.var_names.clone(),
globals: spec.globals.clone(),
outs: spec.outs.clone(),
ins: spec.ins.clone(),
host_parallel_safe: matches!(item_device.device, DeviceSpec::Cpu),
global_size: spec.global_size.clone(),
local_size: spec.local_size.clone(),
})
},
)?;
opt_state.insert(opt_key, Arc::clone(&result));
result
};
let buffer_indices = resolve_compiled_kernel_buffer_indices(item, &uop_id_to_idx, &cached.globals)?;
trace!(kernel.ast_id = item.ast.id, num_buffers = item.buffers.len(), "kernel buffer mapping");
let vals: Vec<i64> =
cached.var_names.iter().map(|name| item.fixedvars.get(name).copied().unwrap_or(0)).collect();
let non_overridable_fixedvars = collect_non_overridable_fixedvars(item);
let output_indices = output_indices_from_program_metadata(&cached.globals, &cached.outs, buffer_indices.len())
.map_err(|e| crate::error::Error::IrConstruction {
details: format!(
"invalid ProgramSpec output metadata for kernel id {} (globals={:?}, outs={:?}, num_buffers={}): {e}",
item.kernel.id,
cached.globals,
cached.outs,
buffer_indices.len()
),
})?;
let runtime_vars = svod_runtime::execution_plan::collect_runtime_vars(&item.ast);
let prepared = PreparedKernel {
id: item.kernel.id,
ast: item.ast.clone(),
kernel: cached,
device: item_device.device.clone(),
buffer_indices,
output_indices,
vals,
fixedvars: non_overridable_fixedvars,
dependencies: item.dependencies.clone(),
buffer_ptrs: Vec::new(), buffer_ids: Vec::new(), runtime_vars,
};
builder.add_op_with_instance_dependencies(
PreparedOp::CompiledProgram(prepared),
item.instance_dependencies.clone(),
);
}
let mut output_buffer_indices = Vec::with_capacity(schedule_result.output_uop_ids.len());
for &uop_id in &schedule_result.output_uop_ids {
let Some(idx) = uop_id_to_idx.get(&uop_id).copied() else {
return Err(crate::error::Error::BufferNotFound { uop_id });
};
output_buffer_indices.push(idx);
}
if output_buffer_indices.is_empty() {
return IrConstructionSnafu { details: "prepare_execution_plan produced no output buffer indices".to_string() }
.fail();
}
builder.set_output_buffers(output_buffer_indices);
builder.build().context(ExecutionSnafu)
}
fn collect_output_buffer_ids(schedule: &crate::schedule::Schedule, output_uop_ids: &[u64]) -> HashSet<u64> {
let output_uop_set: HashSet<u64> = output_uop_ids.iter().copied().collect();
let mut output_buffer_ids = HashSet::new();
for item in schedule {
for (buffer, &uop_id) in item.buffers.iter().zip(item.buffer_uop_ids.iter()) {
if output_uop_set.contains(&uop_id) {
output_buffer_ids.insert(buffer.id().0);
}
}
}
output_buffer_ids
}
fn collect_non_overridable_fixedvars(item: &ScheduleItem) -> HashMap<String, i64> {
let mut locked = HashMap::with_capacity(item.loop_var_names.len());
for name in &item.loop_var_names {
if let Some(v) = item.fixedvars.get(name) {
locked.insert(name.clone(), *v);
}
}
locked
}
fn compile_with_program_pipeline_components(
kernel_ast: Arc<UOp>,
renderer: &dyn svod_device::device::Renderer,
compiler: &dyn svod_device::device::Compiler,
kernel_name: Option<&str>,
) -> Result<(svod_device::device::ProgramSpec, svod_device::device::CompiledSpec)> {
let mut program = match kernel_ast.op() {
Op::Program { .. } => kernel_ast,
other => {
return IrConstructionSnafu {
details: format!("compile_with_program_pipeline_components expects PROGRAM input, got {other:?}"),
}
.fail();
}
};
program = svod_codegen::program_pipeline::get_program(
&program,
renderer,
compiler,
kernel_name,
svod_codegen::program_pipeline::ProgramTarget::Source,
)
.context(RenderKernelSnafu)?;
let rendered_entry = svod_device::device::ProgramSpec::from_uop(&program).map(|spec| spec.name).map_err(|e| {
crate::error::Error::IrConstruction { details: format!("PROGRAM pipeline produced invalid SOURCE stage: {e}") }
})?;
let (program, compiled) =
svod_codegen::program_pipeline::do_compile(&program, compiler).context(CompileKernelSnafu)?;
let spec =
svod_device::device::ProgramSpec::from_uop(&program).map_err(|e| crate::error::Error::IrConstruction {
details: format!(
"PROGRAM pipeline produced invalid ProgramSpec after compile (entry='{}'): {e}",
rendered_entry
),
})?;
Ok((spec, compiled))
}
pub(crate) fn resolve_codegen(param_buffers: &[(u64, Arc<UOp>)], config: &PrepareConfig) -> Result<&'static str> {
let alloc_registry = svod_device::registry::registry();
let spec = param_buffers
.iter()
.find_map(|(id, _)| {
let spec = crate::tensor_registry::get_buffer(*id)?.allocator().device_spec();
(!spec.is_disk()).then_some(spec)
})
.or_else(|| {
param_buffers.iter().find_map(|(_, u)| {
let Op::Buffer { device, .. } = u.op() else {
return None;
};
let Op::Device(spec) = device.op() else {
return None;
};
(!spec.is_disk()).then_some(spec.clone())
})
})
.unwrap_or(DeviceSpec::Cpu);
let device = config.resolve_device(&spec, alloc_registry)?;
Ok(device.compiler.cache_key())
}
fn get_optimizer_renderer(device: &Device) -> svod_schedule::OptimizerRenderer {
match device.device {
DeviceSpec::Cpu => {
if std::env::var("SVOD_AMX").as_deref() == Ok("1") {
svod_schedule::OptimizerRenderer::apple_amx()
} else {
svod_schedule::OptimizerRenderer::cpu()
}
}
DeviceSpec::Cuda { .. } => svod_schedule::OptimizerRenderer::cuda(),
DeviceSpec::Metal { .. } => svod_schedule::OptimizerRenderer::metal(),
_ => svod_schedule::OptimizerRenderer::cpu(),
}
}
pub(crate) fn count_top_ops(ops: &[Arc<UOp>], top_k: usize) -> Vec<(String, usize)> {
let mut counts: HashMap<String, usize> = HashMap::new();
for u in ops {
*counts.entry(u.op().as_ref().to_string()).or_insert(0) += 1;
}
let mut v: Vec<(String, usize)> = counts.into_iter().collect();
v.sort_by_key(|(_, n)| std::cmp::Reverse(*n));
v.truncate(top_k);
v
}
pub(crate) fn fmt_op_counts(counts: &[(String, usize)]) -> String {
counts.iter().map(|(o, n)| format!("{o}={n}")).collect::<Vec<_>>().join(", ")
}
fn beam_search_optimize(
ast: Arc<UOp>,
renderer: &svod_schedule::OptimizerRenderer,
device: &Device,
buffers: &[Buffer],
optimizer_config: &svod_schedule::OptimizerConfig,
) -> Result<Arc<UOp>> {
let beam_config = &optimizer_config.beam;
let scheduler = prepare_scheduler(ast, renderer);
for buf in buffers {
buf.ensure_allocated().context(DeviceSnafu)?;
}
let buffers: Vec<Buffer> = buffers.to_vec();
let bench_config = svod_runtime::BenchmarkConfig::default();
let dev_renderer = device.renderer.clone();
let dev_compiler = device.compiler.clone();
let dev_runtime = device.runtime.clone();
let dev_device = device.device.clone();
let max_uops = beam_config.max_uops;
svod_runtime::warmup_thread_pool();
let compile_timeout =
Duration::from_secs(std::env::var("BEAM_TIMEOUT_SEC").ok().and_then(|s| s.parse().ok()).unwrap_or(10));
let log_surpass = std::env::var("BEAM_LOG_SURPASS_MAX").is_ok();
let post_opt_cache: Arc<papaya::HashMap<u64, Arc<UOp>>> = Arc::new(papaya::HashMap::new());
let compile_and_time = |s: &Scheduler, early_stop: Option<Duration>| -> Option<svod_schedule::CandidateMetrics> {
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::sync::mpsc;
let s_owned = s.clone();
let renderer_c = renderer.clone();
let dev_renderer_c = dev_renderer.clone();
let dev_compiler_c = dev_compiler.clone();
let dev_runtime_c = dev_runtime.clone();
let dev_device_c = dev_device.clone();
let buffers_c = buffers.clone();
let bench_config_c = bench_config.clone();
let max_uops_c = max_uops;
let post_opt_cache_c = Arc::clone(&post_opt_cache);
let log_surpass_c = log_surpass;
let opts_snapshot: Vec<svod_schedule::optimizer::Opt> = s_owned.applied_opts.clone();
enum WorkerMsg {
CompileDone,
Final(Option<svod_schedule::CandidateMetrics>),
}
let (tx, rx) = mpsc::sync_channel::<WorkerMsg>(2);
let tx_compile = tx.clone();
let _ = std::thread::spawn(move || {
let result = catch_unwind(AssertUnwindSafe(|| {
let raw_ast = s_owned.get_optimized_ast(None);
let cache_key = raw_ast.content_hash;
let cache_pin = post_opt_cache_c.pin();
let optimized = if let Some(cached) = cache_pin.get(&cache_key) {
cached.clone()
} else {
let opt = apply_post_optimization_with_renderer(raw_ast, Some(&renderer_c));
cache_pin.insert(cache_key, opt.clone());
opt
};
let kernel_name =
optimized.metadata::<svod_schedule::optimizer::KernelInfo>().map(|info| info.function_name());
let ir_hash = svod_schedule::hash_post_codegen_ir(&optimized);
let compute_ops = svod_schedule::compute_ops_estimate(&optimized);
let decomposed = match dev_renderer_c.decompositor() {
Some(m) => svod_ir::decompositions::decompose_with(&optimized, &m),
None => optimized,
};
let mut program = svod_codegen::program_pipeline::program_from_sink(decomposed, dev_device_c.clone());
program = match svod_codegen::program_pipeline::do_linearize(&program) {
Ok(p) => p,
Err(e) => {
if log_surpass_c {
eprintln!("[BEAM drop] linearize_err: {e:?} opts={opts_snapshot:?}");
}
return None;
}
};
let (linear_uops_count, top_op_counts) = if let svod_ir::Op::Program { linear: Some(linear), .. } =
program.op()
&& let svod_ir::Op::Linear { ops } = linear.op()
{
(ops.len(), if log_surpass_c { count_top_ops(ops, 8) } else { Vec::new() })
} else {
(0, Vec::new())
};
if linear_uops_count > max_uops_c {
if log_surpass_c {
eprintln!(
"[BEAM drop] too_many_uops: linear={linear_uops_count} max={max_uops_c} opts={opts_snapshot:?} top_ops=[{}]",
fmt_op_counts(&top_op_counts)
);
}
return None;
}
let (spec, compiled) = match compile_with_program_pipeline_components(
program,
dev_renderer_c.as_ref(),
dev_compiler_c.as_ref(),
kernel_name.as_deref(),
) {
Ok(v) => v,
Err(e) => {
if log_surpass_c {
eprintln!("[BEAM drop] compile_err: {e:?} opts={opts_snapshot:?}");
}
return None;
}
};
let program = match (dev_runtime_c)(&compiled) {
Ok(p) => p,
Err(e) => {
if log_surpass_c {
eprintln!("[BEAM drop] runtime_err: {e:?} opts={opts_snapshot:?}");
}
return None;
}
};
let _ = tx_compile.send(WorkerMsg::CompileDone);
let buffer_ptrs: Vec<*mut u8> = buffers_c.iter().map(|b| unsafe { b.as_raw_ptr() }).collect();
let mut user_var_vals: HashMap<&str, i64> = HashMap::new();
for v in &spec.vars {
if v.name != "core_id" {
user_var_vals.insert(v.name.as_str(), (v.min + v.max) / 2);
}
}
let launch_dims = spec.launch_dims(&user_var_vals).ok()?;
let vals: Vec<i64> =
spec.var_names.iter().map(|n| user_var_vals.get(n.as_str()).copied().unwrap_or(0)).collect();
const MAX_TEST_GLOBAL_SIZE: usize = 65536;
let mut test_global_size = launch_dims.global_size;
let original_size: usize = test_global_size.iter().product();
while test_global_size.iter().product::<usize>() > MAX_TEST_GLOBAL_SIZE {
let mut halved = false;
for j in (0..test_global_size.len()).rev() {
if test_global_size[j] > 16 {
test_global_size[j] /= 2;
halved = true;
break;
}
}
if !halved {
break;
}
}
let shrunk_size: usize = test_global_size.iter().product();
let factor: f64 = if shrunk_size > 0 { original_size as f64 / shrunk_size as f64 } else { 1.0 };
let mut bench_config = bench_config_c.clone();
bench_config.early_stop = early_stop.map(|t| {
let nanos = t.as_nanos() as f64 / factor;
Duration::from_nanos(nanos.min(u64::MAX as f64) as u64)
});
bench_config.clear_l2 = renderer_c.device.has_hardware_cache_invalidate();
let result = unsafe {
svod_runtime::benchmark_kernel(
program.as_ref(),
&buffer_ptrs,
&vals,
Some(test_global_size),
launch_dims.local_size,
&bench_config,
)
.ok()?
};
let scaled_nanos = (result.min.as_nanos() as f64 * factor).min(u64::MAX as f64);
let timing = Duration::from_nanos(scaled_nanos as u64);
Some(svod_schedule::CandidateMetrics { timing, ir_hash, compute_ops })
}));
let final_result = match result {
Ok(opt) => opt,
Err(_) => {
if log_surpass_c {
eprintln!("[BEAM drop] panic_in_worker opts={opts_snapshot:?}");
}
None
}
};
let _ = tx.send(WorkerMsg::Final(final_result));
});
match rx.recv_timeout(compile_timeout) {
Ok(WorkerMsg::CompileDone) => {
match rx.recv() {
Ok(WorkerMsg::Final(metrics)) => metrics,
_ => None,
}
}
Ok(WorkerMsg::Final(metrics)) => metrics,
Err(_) => {
if log_surpass {
eprintln!("[BEAM drop] compile_timeout opts={:?}", s.applied_opts);
}
None
}
}
};
let prev_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(|_| {}));
let result = beam_search_cached(scheduler, beam_config, compile_and_time);
std::panic::set_hook(prev_hook);
let result = result.context(OptimizeSnafu)?;
tracing::debug!(
opts = ?result.scheduler.applied_opts,
timing = ?result.timing,
iterations = result.iterations,
"beam_search_optimize: completed"
);
let raw_ast = result.scheduler.get_optimized_ast(None);
Ok(apply_post_optimization_with_renderer(raw_ast, Some(renderer)))
}
#[cfg(test)]
#[path = "test/unit/realize_internal.rs"]
mod tests;