use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use svod_device::Buffer;
use svod_device::device::Device;
use svod_device::registry;
use svod_dtype::{DType, DeviceSpec};
use svod_ir::{Op, UOp};
use tracing::{debug, trace};
use crate::error::*;
use crate::{Error, Result};
use snafu::ResultExt;
fn canonicalize_callable_source(src: &Arc<UOp>) -> Arc<UOp> {
let mut cur = src.clone();
loop {
match cur.op() {
Op::After { .. }
| Op::Buffer { .. }
| Op::Param { .. }
| Op::MSelect { .. }
| Op::MStack { .. }
| Op::Bind { .. } => return cur,
_ => {
let sources = cur.op().sources();
let Some(next) = sources.first() else {
return cur;
};
if Arc::ptr_eq(&cur, next) {
return cur;
}
cur = (*next).clone();
}
}
}
}
fn source_primary_buffer_id(src: &Arc<UOp>) -> Option<u64> {
let src = canonicalize_callable_source(src);
match src.op() {
Op::Buffer { .. } | Op::Param { .. } | Op::After { .. } => Some(src.buf_uop().id),
Op::Bind { .. } => None,
Op::MSelect { buffer, device_index } => {
if let Op::MStack { buffers } = buffer.op() {
buffers.get(*device_index).map(|b| b.buf_uop().id).or_else(|| Some(src.buf_uop().id))
} else {
Some(src.buf_uop().id)
}
}
Op::MStack { buffers } => buffers.first().map(|b| b.buf_uop().id),
_ => None,
}
}
fn collect_callable_dep_ids(dep: &Arc<UOp>, out: &mut HashSet<u64>) -> Result<()> {
match dep.op() {
Op::Call { .. } => {
out.insert(dep.id);
Ok(())
}
Op::End { computation, .. } => {
if matches!(computation.op(), Op::Call { .. }) {
out.insert(computation.id);
Ok(())
} else {
IrConstructionSnafu {
details: format!("AFTER dependency END must wrap CALL, got {:?}", computation.op()),
}
.fail()
}
}
Op::Store { .. } => Ok(()),
Op::After { deps, .. } => {
for nested in deps {
collect_callable_dep_ids(nested, out)?;
}
Ok(())
}
other => IrConstructionSnafu {
details: format!("AFTER dependency must be CALL/END(CALL)/STORE/AFTER, got {other:?}"),
}
.fail(),
}
}
type AfterDependencySplit = (Vec<Arc<UOp>>, Vec<Arc<UOp>>);
fn split_after_dependencies(after: &Arc<UOp>) -> Result<AfterDependencySplit> {
let Op::After { deps, .. } = after.op() else {
return IrConstructionSnafu {
details: format!("expected AFTER when splitting dependencies, got {:?}", after.op()),
}
.fail();
};
let mut kernels = Vec::new();
let mut after_deps = Vec::new();
for dep in deps {
match dep.op() {
Op::Call { .. } => kernels.push(dep.clone()),
Op::End { computation, .. } if matches!(computation.op(), Op::Call { .. }) => kernels.push(dep.clone()),
Op::After { .. } => after_deps.push(dep.clone()),
Op::Store { .. } => {}
other => {
return IrConstructionSnafu {
details: format!("AFTER dependency must be CALL/END(CALL)/STORE/AFTER, got {other:?}"),
}
.fail();
}
}
}
Ok((kernels, after_deps))
}
fn collect_source_dependency_callable_ids(src: &Arc<UOp>, out: &mut HashSet<u64>) -> Result<()> {
let src = canonicalize_callable_source(src);
match src.op() {
Op::After { .. } => {
let (kernels, after_deps) = split_after_dependencies(&src)?;
for kernel in kernels {
collect_callable_dep_ids(&kernel, out)?;
}
for dep in after_deps {
collect_source_dependency_callable_ids(&dep, out)?;
}
Ok(())
}
Op::MStack { buffers } => {
for buffer in buffers {
collect_source_dependency_callable_ids(buffer, out)?;
}
Ok(())
}
Op::MSelect { buffer, .. } => collect_source_dependency_callable_ids(buffer, out),
Op::Buffer { .. } | Op::Param { .. } | Op::Bind { .. } => Ok(()),
other => IrConstructionSnafu {
details: format!("input to callable must resolve to AFTER/BUFFER/PARAM/MSELECT/MSTACK/BIND, got {other:?}"),
}
.fail(),
}
}
fn callable_sources(callable: &Arc<UOp>) -> Option<Vec<Arc<UOp>>> {
match callable.op() {
Op::Call { args, .. } => Some(args.iter().cloned().collect()),
_ => None,
}
}
fn collect_scheduled_range_ids(root: &Arc<UOp>, callable_ids: &HashSet<u64>) -> HashSet<u64> {
let mut ids = HashSet::new();
for node in root.toposort_call_aware(false) {
let Op::End { computation, ranges } = node.op() else { continue };
if !matches!(computation.op(), Op::Call { .. }) || !callable_ids.contains(&computation.id) {
continue;
}
for r in ranges {
if matches!(r.op(), Op::Range { .. }) {
ids.insert(r.id);
}
}
}
ids
}
fn collect_call_bound_ranges(callable: &Arc<UOp>, scheduled_range_ids: &HashSet<u64>) -> Result<Vec<BoundRangeRef>> {
let Op::Call { args, .. } = callable.op() else {
return ExpectedCallableOpSnafu.fail();
};
let mut bound_ranges = Vec::new();
for arg in args {
let Op::Bind { var, value } = arg.op() else {
continue;
};
let Op::DefineVar { name, .. } = var.op() else {
return IrConstructionSnafu {
details: format!("CALL BIND source must wrap DEFINE_VAR, got {:?}", var.op()),
}
.fail();
};
let Op::Range { .. } = value.op() else {
continue;
};
if !scheduled_range_ids.contains(&value.id) {
continue;
}
bound_ranges.push(BoundRangeRef { var_name: name.clone(), range_uop: value.clone() });
}
Ok(bound_ranges)
}
fn collect_linear_sched_ops_internal(
root: &Arc<UOp>,
callable_ids: &HashSet<u64>,
scheduled_range_ids: &HashSet<u64>,
) -> Result<Vec<LinearSchedOp>> {
let mut linear_ops = Vec::new();
for node in root.toposort_call_aware(false) {
match node.op() {
Op::Range { .. } if scheduled_range_ids.contains(&node.id) => {
linear_ops.push(LinearSchedOp::Range { range: node.clone() });
}
Op::Call { .. } if callable_ids.contains(&node.id) => {
linear_ops.push(LinearSchedOp::Call { kernel_id: node.id });
}
Op::End { computation, ranges } if matches!(computation.op(), Op::Call { .. }) => {
if !callable_ids.contains(&computation.id) {
continue;
}
let wrapper_ranges: Vec<Arc<UOp>> =
ranges.iter().filter(|r| matches!(r.op(), Op::Range { .. })).cloned().collect();
match wrapper_ranges.as_slice() {
[] => {}
[outer] => linear_ops.push(LinearSchedOp::End { range: outer.clone(), kernel_id: computation.id }),
_ => {
return IrConstructionSnafu {
details: format!(
"END(CALL) must close at most one wrapper range in strict scheduler, got {}",
wrapper_ranges.len()
),
}
.fail();
}
}
}
_ => {}
}
}
if linear_ops.is_empty() {
return IrConstructionSnafu { details: "strict scheduler produced empty linear control stream".to_string() }
.fail();
}
Ok(linear_ops)
}
fn collect_kernel_invocations(
root: &Arc<UOp>,
items: &[PreScheduleItem],
scheduled_range_ids: &HashSet<u64>,
) -> Result<Vec<KernelInvocation>> {
let callable_ids: HashSet<u64> = items.iter().map(|it| it.kernel.id).collect();
let linear_ops = collect_linear_sched_ops_internal(root, &callable_ids, scheduled_range_ids)?;
let bound_ranges_by_kernel: HashMap<u64, &[BoundRangeRef]> =
items.iter().map(|it| (it.kernel.id, it.bound_ranges.as_slice())).collect();
let mut declared_ranges: HashSet<u64> = HashSet::new();
let mut ended_ranges: HashSet<u64> = HashSet::new();
for op in &linear_ops {
match op {
LinearSchedOp::Range { range } => {
declared_ranges.insert(range.id);
}
LinearSchedOp::End { range, .. } => {
ended_ranges.insert(range.id);
}
LinearSchedOp::Call { .. } => {}
}
}
for &rid in &declared_ranges {
if !ended_ranges.contains(&rid) {
return IrConstructionSnafu { details: format!("schedule range {rid} is missing END in strict scheduler") }
.fail();
}
}
for item in items {
for br in &item.bound_ranges {
if !declared_ranges.contains(&br.range_uop.id) {
return IrConstructionSnafu {
details: format!(
"CALL {} bound variable '{}' references schedule range {} missing from linear schedule",
item.kernel.id, br.var_name, br.range_uop.id
),
}
.fail();
}
}
}
let mut invocations = Vec::new();
let mut in_ranges: HashMap<u64, i64> = HashMap::new();
let mut range_ptrs: HashMap<u64, usize> = HashMap::new();
let mut range_bounds: HashMap<u64, (i64, i64)> = HashMap::new();
let mut sched_ptr = 0usize;
while sched_ptr < linear_ops.len() {
match &linear_ops[sched_ptr] {
LinearSchedOp::Range { range } => {
let bounds = if let Some(bounds) = range_bounds.get(&range.id).copied() {
bounds
} else {
let bounds = schedule_range_bounds(range)?;
range_bounds.insert(range.id, bounds);
bounds
};
in_ranges.insert(range.id, bounds.0);
range_ptrs.insert(range.id, sched_ptr + 1);
}
LinearSchedOp::End { range, kernel_id } => {
if !bound_ranges_by_kernel.contains_key(kernel_id) {
return IrConstructionSnafu {
details: format!("linear END references unknown CALL id {kernel_id}"),
}
.fail();
}
let (_, vmax) = if let Some(bounds) = range_bounds.get(&range.id).copied() {
bounds
} else {
let bounds = schedule_range_bounds(range)?;
range_bounds.insert(range.id, bounds);
bounds
};
let Some(cur) = in_ranges.get_mut(&range.id) else {
return IrConstructionSnafu {
details: format!("END references schedule range {} that is not active", range.id),
}
.fail();
};
if *cur < vmax {
*cur += 1;
let Some(jump_ptr) = range_ptrs.get(&range.id).copied() else {
return IrConstructionSnafu {
details: format!("missing loop jump pointer for schedule range {}", range.id),
}
.fail();
};
sched_ptr = jump_ptr;
continue;
}
}
LinearSchedOp::Call { kernel_id } => {
let Some(bound_ranges) = bound_ranges_by_kernel.get(kernel_id) else {
return IrConstructionSnafu {
details: format!("linear CALL references unknown kernel id {kernel_id}"),
}
.fail();
};
let mut fixedvars = HashMap::new();
for br in *bound_ranges {
let Some(value) = in_ranges.get(&br.range_uop.id).copied() else {
return IrConstructionSnafu {
details: format!(
"CALL {} bound variable '{}' references inactive schedule range {}",
kernel_id, br.var_name, br.range_uop.id
),
}
.fail();
};
fixedvars.insert(br.var_name.clone(), value);
}
invocations.push(KernelInvocation { kernel_id: *kernel_id, fixedvars });
}
}
sched_ptr += 1;
}
Ok(invocations)
}
fn analyze_callable_dependencies(callables: &[Arc<UOp>], root: &Arc<UOp>) -> Result<Vec<HashSet<usize>>> {
let callable_idx: HashMap<u64, usize> = callables.iter().enumerate().map(|(i, c)| (c.id, i)).collect();
let mut dependencies: Vec<HashSet<usize>> = vec![HashSet::new(); callables.len()];
for (consumer_idx, callable) in callables.iter().enumerate() {
let mut dep_ids = HashSet::new();
if let Some(sources) = callable_sources(callable) {
for src in sources {
collect_source_dependency_callable_ids(&src, &mut dep_ids)?;
}
}
for dep_id in dep_ids {
let Some(&producer_idx) = callable_idx.get(&dep_id) else {
return IrConstructionSnafu {
details: format!("callable dependency references unknown callable id {dep_id}"),
}
.fail();
};
if producer_idx != consumer_idx {
dependencies[consumer_idx].insert(producer_idx);
}
}
}
for node in root.toposort() {
let Op::After { .. } = node.op() else {
continue;
};
let (kernels, after_deps) = split_after_dependencies(&node)?;
for kernel in kernels {
let callable = match kernel.op() {
Op::Call { .. } => kernel.clone(),
Op::End { computation, .. } => computation.clone(),
_ => unreachable!("split_after_dependencies only returns CALL/END(CALL) kernels"),
};
let Some(&consumer_idx) = callable_idx.get(&callable.id) else {
return IrConstructionSnafu {
details: format!("AFTER dependency references unknown callable id {}", callable.id),
}
.fail();
};
let mut dep_ids = HashSet::new();
for dep in &after_deps {
collect_source_dependency_callable_ids(dep, &mut dep_ids)?;
}
for dep_id in dep_ids {
let Some(&producer_idx) = callable_idx.get(&dep_id) else {
return IrConstructionSnafu {
details: format!("callable dependency references unknown callable id {dep_id}"),
}
.fail();
};
if producer_idx != consumer_idx {
dependencies[consumer_idx].insert(producer_idx);
}
}
}
}
Ok(dependencies)
}
pub type InputBuffers = HashMap<u64, Buffer>;
#[derive(Clone, Debug)]
pub struct BoundRangeRef {
pub var_name: String,
pub range_uop: Arc<UOp>,
}
#[derive(Clone, Debug)]
enum LinearSchedOp {
Range { range: Arc<UOp> },
Call { kernel_id: u64 },
End { range: Arc<UOp>, kernel_id: u64 },
}
#[derive(Clone, Debug)]
pub struct KernelInvocation {
pub kernel_id: u64,
pub fixedvars: HashMap<String, i64>,
}
#[derive(Clone)]
pub struct ScheduleItem {
pub kernel: Arc<UOp>,
pub ast: Arc<UOp>,
pub buffers: Vec<Buffer>,
pub buffer_uop_ids: Vec<u64>,
pub fixedvars: HashMap<String, i64>,
pub loop_var_names: HashSet<String>,
pub dependencies: Vec<u64>,
pub instance_dependencies: Vec<usize>,
pub alias_registered_ids: Vec<u64>,
}
pub type Schedule = Vec<ScheduleItem>;
#[derive(Clone)]
pub struct PreScheduleItem {
pub kernel: Arc<UOp>,
pub ast: Arc<UOp>,
pub sources: Vec<Arc<UOp>>,
pub dependencies: Vec<u64>,
pub bound_ranges: Vec<BoundRangeRef>,
}
#[derive(Clone)]
pub struct PreSchedule {
pub items: Vec<PreScheduleItem>,
pub invocations: Vec<KernelInvocation>,
pub output_buffer_uops: Vec<Arc<UOp>>,
}
type SortedCallables = (Vec<Arc<UOp>>, HashMap<u64, Vec<u64>>);
pub struct ScheduleResult {
pub items: Schedule,
pub output_uop_ids: Vec<u64>,
}
struct CallableBuffers {
buffers: Vec<Buffer>,
uop_ids: Vec<u64>,
alias_ids: Vec<u64>,
}
fn sort_callables_by_dependencies(callables: &[Arc<UOp>], root: &Arc<UOp>) -> Result<SortedCallables> {
debug!(num_callables = callables.len(), "sorting callables by dependencies");
let dependencies = analyze_callable_dependencies(callables, root)?;
let mut in_degree: Vec<usize> = dependencies.iter().map(|deps| deps.len()).collect();
let mut dependents: Vec<Vec<usize>> = vec![vec![]; callables.len()];
for (consumer, deps) in dependencies.iter().enumerate() {
for &producer in deps {
dependents[producer].push(consumer);
}
}
let mut queue: VecDeque<usize> =
in_degree.iter().enumerate().filter(|&(_, °)| deg == 0).map(|(idx, _)| idx).collect();
let mut sorted_indices = Vec::new();
while let Some(idx) = queue.pop_front() {
sorted_indices.push(idx);
for &dependent in &dependents[idx] {
in_degree[dependent] -= 1;
if in_degree[dependent] == 0 {
queue.push_back(dependent);
}
}
}
if sorted_indices.len() < callables.len() {
return DependencyCyclesSnafu.fail();
}
let sorted: Vec<Arc<UOp>> = sorted_indices.iter().map(|&idx| callables[idx].clone()).collect();
let dependency_ids_by_callable: HashMap<u64, Vec<u64>> = callables
.iter()
.enumerate()
.map(|(idx, callable)| {
let mut deps: Vec<u64> = dependencies[idx].iter().map(|&dep_idx| callables[dep_idx].id).collect();
deps.sort_unstable();
(callable.id, deps)
})
.collect();
debug!(num_sorted = sorted.len(), "callables sorted");
Ok((sorted, dependency_ids_by_callable))
}
pub fn create_pre_schedule(transformed: Arc<UOp>) -> Result<PreSchedule> {
let mut callables = Vec::new();
for node in transformed.toposort_call_aware(false) {
if matches!(node.op(), Op::Call { .. }) {
callables.push(node);
}
}
if callables.is_empty() {
return NoKernelsFoundSnafu.fail();
}
let (callables, dependency_ids_by_callable) = sort_callables_by_dependencies(&callables, &transformed)?;
let callable_ids: HashSet<u64> = callables.iter().map(|c| c.id).collect();
let scheduled_range_ids = collect_scheduled_range_ids(&transformed, &callable_ids);
let mut items = Vec::with_capacity(callables.len());
for callable_uop in callables {
let Op::Call { body, args, .. } = callable_uop.op() else {
unreachable!("filtered to only call wrappers above")
};
let dependencies = dependency_ids_by_callable.get(&callable_uop.id).cloned().unwrap_or_default();
let bound_ranges = collect_call_bound_ranges(&callable_uop, &scheduled_range_ids)?;
items.push(PreScheduleItem {
kernel: callable_uop.clone(),
ast: body.clone(),
sources: args.iter().cloned().collect(),
dependencies,
bound_ranges,
});
}
let invocations = collect_kernel_invocations(&transformed, &items, &scheduled_range_ids)?;
let output_buffer_uops: Vec<Arc<UOp>> = match transformed.op() {
Op::Sink { sources, .. } => sources.iter().map(|src| src.buf_uop()).collect(),
_ => vec![transformed.buf_uop()],
};
Ok(PreSchedule { items, invocations, output_buffer_uops })
}
pub fn instantiate_schedule(
pre_schedule: &PreSchedule,
input_buffers: &InputBuffers,
var_vals: &HashMap<String, i64>,
) -> Result<ScheduleResult> {
let mut allocated_buffers: HashMap<u64, Buffer> = HashMap::new();
let mut templates: HashMap<u64, ScheduleItemTemplate> = HashMap::with_capacity(pre_schedule.items.len());
for item in &pre_schedule.items {
let nodes = item.ast.toposort();
let kb = collect_callable_buffers(&item.sources, &item.ast, input_buffers, &mut allocated_buffers)?;
debug!(callable.id = item.kernel.id, num_sources = item.sources.len(), "Schedule item created");
let fixedvars: HashMap<String, i64> = if var_vals.is_empty() {
HashMap::new()
} else {
let ast_var_names: HashSet<&str> = nodes
.iter()
.filter_map(|n| match n.op() {
Op::DefineVar { name, .. } => Some(name.as_str()),
_ => None,
})
.collect();
var_vals
.iter()
.filter(|(name, _)| ast_var_names.contains(name.as_str()))
.map(|(k, v)| (k.clone(), *v))
.collect()
};
templates.insert(
item.kernel.id,
ScheduleItemTemplate {
kernel: item.kernel.clone(),
ast: item.ast.clone(),
buffers: kb.buffers,
buffer_uop_ids: kb.uop_ids,
dependencies: item.dependencies.clone(),
alias_registered_ids: kb.alias_ids,
base_fixedvars: fixedvars,
},
);
}
let mut schedule = Vec::with_capacity(pre_schedule.invocations.len());
for invocation in &pre_schedule.invocations {
let Some(template) = templates.get(&invocation.kernel_id) else {
return IrConstructionSnafu {
details: format!("invocation references unknown kernel id {}", invocation.kernel_id),
}
.fail();
};
let mut fixedvars = template.base_fixedvars.clone();
fixedvars.extend(invocation.fixedvars.iter().map(|(k, v)| (k.clone(), *v)));
let loop_var_names: HashSet<String> = invocation.fixedvars.keys().cloned().collect();
schedule.push(ScheduleItem {
kernel: template.kernel.clone(),
ast: template.ast.clone(),
buffers: template.buffers.clone(),
buffer_uop_ids: template.buffer_uop_ids.clone(),
fixedvars,
loop_var_names,
dependencies: template.dependencies.clone(),
instance_dependencies: Vec::new(),
alias_registered_ids: template.alias_registered_ids.clone(),
});
}
if schedule.is_empty() {
return EmptyScheduleSnafu.fail();
}
let output_uop_ids: Vec<u64> = pre_schedule.output_buffer_uops.iter().map(|u| u.buf_uop().id).collect();
Ok(ScheduleResult { items: schedule, output_uop_ids })
}
pub fn create_schedule(
transformed: Arc<UOp>,
input_buffers: &InputBuffers,
var_vals: &HashMap<String, i64>,
) -> Result<ScheduleResult> {
let pre = create_pre_schedule(transformed)?;
instantiate_schedule(&pre, input_buffers, var_vals)
}
fn find_first_input_buffer_device(
sources: &[Arc<UOp>],
input_buffers: &InputBuffers,
allocated_buffers: &HashMap<u64, Buffer>,
) -> Result<Arc<Device>> {
let alloc_registry = registry::registry();
for src in sources {
if let Some(buf_id) = source_primary_buffer_id(src) {
let buffer = allocated_buffers.get(&buf_id).cloned().or_else(|| input_buffers.get(&buf_id).cloned());
if let Some(buffer) = buffer {
let device_spec = buffer.allocator().device_spec();
if device_spec.is_disk() {
continue;
}
return svod_runtime::DEVICE_FACTORIES.device(&device_spec, alloc_registry).context(DeviceFactorySnafu);
}
}
}
svod_runtime::DEVICE_FACTORIES.device(&DeviceSpec::Cpu, alloc_registry).context(DeviceFactorySnafu)
}
fn collect_callable_buffers(
sources: &[Arc<UOp>],
ast: &Arc<UOp>,
input_buffers: &InputBuffers,
allocated_buffers: &mut HashMap<u64, Buffer>,
) -> Result<CallableBuffers> {
let target_device = find_first_input_buffer_device(sources, input_buffers, allocated_buffers)?;
let mut buffers = Vec::new();
let mut uop_ids = Vec::new();
let mut alias_ids = Vec::new();
for src in sources {
let canonical_src = canonicalize_callable_source(src);
if canonical_src.id != src.id {
alias_ids.push(src.id);
}
match canonical_src.op() {
Op::After { passthrough, .. } => {
let buf_id = passthrough.buf_uop().id;
if buf_id != canonical_src.id {
alias_ids.push(canonical_src.id);
}
let existing = allocated_buffers.get(&buf_id).cloned().or_else(|| input_buffers.get(&buf_id).cloned());
if let Some(buffer) = existing {
trace!(
buf_id,
buffer.id = ?buffer.id(),
"Found shared buffer from AFTER"
);
allocated_buffers.entry(buf_id).or_insert_with(|| buffer.clone());
buffers.push(buffer);
uop_ids.push(buf_id);
} else {
trace!(buf_id, "after buffer not found in allocated_buffers or input_buffers");
return Err(Error::BufferNotFound { uop_id: buf_id });
}
}
Op::MSelect { .. } | Op::MStack { .. } => {
let Some(canonical_id) = source_primary_buffer_id(&canonical_src) else {
return IrConstructionSnafu {
details: format!(
"multi-device callable source must resolve a primary buffer id: source_id={}, op={:?}",
canonical_src.id,
canonical_src.op()
),
}
.fail();
};
if canonical_id != canonical_src.id {
alias_ids.push(canonical_src.id);
}
let existing =
allocated_buffers.get(&canonical_id).cloned().or_else(|| input_buffers.get(&canonical_id).cloned());
if let Some(buffer) = existing {
trace!(canonical_id, buffer.id = ?buffer.id(), "Found shared buffer from MSELECT/MSTACK source");
allocated_buffers.entry(canonical_id).or_insert_with(|| buffer.clone());
buffers.push(buffer);
uop_ids.push(canonical_id);
} else {
trace!(canonical_id, "multi-device source buffer not found in allocated_buffers or input_buffers");
return Err(Error::BufferNotFound { uop_id: canonical_id });
}
}
Op::DefineLocal(_id) => {
let ptr_dtype = canonical_src.dtype();
let size = compute_buffer_size(ast, &canonical_src)?;
let scalar_dtype = match ptr_dtype {
svod_dtype::DType::Ptr { base, .. } => *base,
other => {
return ExpectedPtrDtypeSnafu { context: "DEFINE_LOCAL", actual: other.clone() }.fail();
}
};
let buffer =
Buffer::new(target_device.allocator.clone(), scalar_dtype.clone(), vec![size], Default::default());
allocated_buffers.insert(canonical_src.id, buffer.clone());
buffers.push(buffer);
uop_ids.push(canonical_src.id);
}
Op::Buffer { size, .. } | Op::Param { size, .. } => {
let canonical_id = canonical_src.buf_uop().id;
if canonical_id != canonical_src.id {
alias_ids.push(canonical_src.id);
}
if let Some(buffer) =
input_buffers.get(&canonical_id).cloned().or_else(|| input_buffers.get(&canonical_src.id).cloned())
{
buffers.push(buffer);
uop_ids.push(canonical_id);
} else if let Some(buffer) = allocated_buffers
.get(&canonical_id)
.cloned()
.or_else(|| allocated_buffers.get(&canonical_src.id).cloned())
{
buffers.push(buffer);
uop_ids.push(canonical_id);
} else {
trace!(src.id = canonical_src.id, canonical_id, size, "Allocating output BUFFER/PARAM");
let scalar_dtype = canonical_src.dtype();
let buffer = Buffer::new(
target_device.allocator.clone(),
scalar_dtype.clone(),
vec![*size],
Default::default(),
);
allocated_buffers.insert(canonical_id, buffer.clone());
buffers.push(buffer);
uop_ids.push(canonical_id);
}
}
Op::Bind { .. } => {
continue;
}
other => {
return IrConstructionSnafu {
details: format!("unsupported callable source op for buffer collection: {other:?}"),
}
.fail();
}
}
}
alias_ids.sort_unstable();
alias_ids.dedup();
Ok(CallableBuffers { buffers, uop_ids, alias_ids })
}
#[derive(Clone)]
struct ScheduleItemTemplate {
kernel: Arc<UOp>,
ast: Arc<UOp>,
buffers: Vec<Buffer>,
buffer_uop_ids: Vec<u64>,
dependencies: Vec<u64>,
alias_registered_ids: Vec<u64>,
base_fixedvars: HashMap<String, i64>,
}
fn schedule_range_bounds(range: &Arc<UOp>) -> Result<(i64, i64)> {
let Op::Range { .. } = range.op() else {
return IrConstructionSnafu {
details: format!("expected RANGE for schedule loop control, got {:?}", range.op()),
}
.fail();
};
let Some(vmin) = range.vmin().try_int() else {
return IrConstructionSnafu {
details: format!("schedule range vmin must be concrete integer, got {:?}", range.vmin()),
}
.fail();
};
let Some(vmax) = range.vmax().try_int() else {
return IrConstructionSnafu {
details: format!("schedule range vmax must be concrete integer, got {:?}", range.vmax()),
}
.fail();
};
if vmax < vmin {
return IrConstructionSnafu { details: format!("invalid schedule range bounds: vmin={vmin}, vmax={vmax}") }
.fail();
}
Ok((vmin, vmax))
}
fn compute_buffer_size(_ast: &Arc<UOp>, buffer_def: &Arc<UOp>) -> Result<usize> {
match buffer_def.dtype() {
DType::Ptr { size: Some(s), .. } => Ok(s),
DType::Ptr { size: None, .. } => BufferPtrNoSizeSnafu.fail(),
other => ExpectedPtrDtypeSnafu { context: "buffer_size", actual: other.clone() }.fail(),
}
}