Skip to main content

svod_runtime/
execution_plan.rs

1//! Pre-compiled execution plan for kernel execution.
2//!
3//! `ExecutionPlan` separates one-time preparation (kernel compilation, buffer
4//! allocation) from fast repeated execution.
5//!
6//! ```text
7//! ┌─────────────────────────────────────────────────────────┐
8//! │              PREPARATION (one-time)                      │
9//! │  Schedule → instantiate → compile_kernels → build()     │
10//! │                       ↓                                  │
11//! │                ExecutionPlan                             │
12//! └─────────────────────────────────────────────────────────┘
13//!                         ↓
14//! ┌─────────────────────────────────────────────────────────┐
15//! │              EXECUTION (fast path)                       │
16//! │  dependency-ordered PreparedOp execution                 │
17//! └─────────────────────────────────────────────────────────┘
18//! ```
19//!
20//! # Example
21//!
22//! ```ignore
23//! let plan = tensor.prepare()?;
24//! plan.execute()?;
25//! let output = plan.output_buffer();
26//! ```
27
28use std::cmp::Reverse;
29use std::collections::{BinaryHeap, HashMap, HashSet};
30use std::sync::Arc;
31use std::time::Instant;
32
33use rayon::prelude::*;
34use smallvec::SmallVec;
35use svod_device::device::ProgramSpec;
36use svod_device::{Buffer, BufferId};
37use svod_dtype::DeviceSpec;
38use svod_ir::{CustomFunctionKind, Op, UOp};
39
40use crate::error::Result;
41use crate::kernel_cache::CachedKernel;
42use crate::profiler::KernelProfile;
43
44type RuntimeLaunchSizes = (Option<[usize; 3]>, Option<[usize; 3]>);
45
46// ============================================================================
47// Core Structures
48// ============================================================================
49
50/// A pre-compiled kernel ready for execution.
51///
52/// Variable values are stored as positional `vals: Vec<i64>` rather than a named
53/// HashMap, matching Tinygrad's `vals: tuple[int, ...]` parameter style.
54#[derive(Clone)]
55pub struct PreparedKernel {
56    /// Unique identifier (from original AST).
57    pub id: u64,
58
59    pub ast: Arc<UOp>,
60
61    /// Compiled kernel program (Arc-shared from cache).
62    pub kernel: Arc<CachedKernel>,
63
64    /// Device this kernel executes on.
65    pub device: DeviceSpec,
66
67    /// Indices into `ExecutionPlan::buffers` for this kernel's buffers.
68    /// Ordered as expected by the kernel (matches codegen buffer order).
69    pub buffer_indices: Vec<usize>,
70
71    /// Indices of output buffers within `buffer_indices`.
72    pub output_indices: Vec<usize>,
73
74    /// Variable values in positional order (matches `var_names` in CachedKernel).
75    pub vals: Vec<i64>,
76
77    /// Fixed variable bindings captured at prepare time.
78    ///
79    /// These mirror Tinygrad's `fixedvars` semantics: values fixed by scheduling
80    /// (for example from bound ranges) are not overridden by `execute_with_vars`.
81    pub fixedvars: HashMap<String, i64>,
82
83    /// Kernel IDs that must complete before this one (dependencies).
84    pub dependencies: Vec<u64>,
85
86    /// Pre-computed raw buffer addresses for low-allocation execution.
87    /// Computed once during prepare(), stable for the lifetime of ExecutionPlan.
88    /// SAFETY: Pointers are valid as long as ExecutionPlan owns the buffers.
89    pub buffer_ptrs: Vec<usize>,
90
91    /// Pre-computed buffer IDs for dependency tracking.
92    pub buffer_ids: Vec<BufferId>,
93
94    /// Cached `(name, min_val, max_val)` triples for every `DefineVar` reachable
95    /// from `ast`. Populated at construction so `validate_runtime_var_bounds`
96    /// doesn't re-toposort on every execute call.
97    pub runtime_vars: Vec<RuntimeVar>,
98}
99
100/// Bound description for one `DefineVar` consumed by a kernel.
101#[derive(Clone, Debug)]
102pub struct RuntimeVar {
103    pub name: String,
104    pub min_val: i64,
105    pub max_val: i64,
106}
107
108/// Walk `root` and collect bounds for every reachable `DefineVar`.
109pub fn collect_runtime_vars(root: &Arc<UOp>) -> Vec<RuntimeVar> {
110    let mut vars = Vec::new();
111    let mut seen = std::collections::HashSet::new();
112    for node in root.toposort() {
113        if let Op::DefineVar { name, min_val, max_val } = node.op()
114            && seen.insert(name.clone())
115        {
116            vars.push(RuntimeVar { name: name.clone(), min_val: *min_val, max_val: *max_val });
117        }
118    }
119    vars
120}
121
122/// Prepared buffer-to-buffer copy operation.
123#[derive(Clone, Debug)]
124pub struct PreparedCopy {
125    /// Unique operation identifier.
126    pub id: u64,
127
128    /// Buffer indices in ExecutionPlan order: [dst, src].
129    pub buffer_indices: Vec<usize>,
130
131    /// Operation IDs that must complete before this copy.
132    pub dependencies: Vec<u64>,
133}
134
135/// Prepared zero-copy buffer view operation.
136#[derive(Clone, Debug)]
137pub struct PreparedBufferView {
138    /// Unique operation identifier.
139    pub id: u64,
140
141    /// Output and base buffer indices in ExecutionPlan order.
142    /// `buffer_indices[0]` is output view, `buffer_indices[1]` is base source.
143    pub buffer_indices: Vec<usize>,
144
145    /// Expected byte offset into base for the view.
146    pub byte_offset: usize,
147
148    /// Expected byte size of the view.
149    pub byte_size: usize,
150
151    /// Operation IDs that must complete before this view is consumed.
152    pub dependencies: Vec<u64>,
153}
154
155/// Prepared custom runtime function operation.
156#[derive(Clone, Debug)]
157pub struct PreparedCustomFunction {
158    /// Unique operation identifier.
159    pub id: u64,
160
161    /// Explicit custom function kind (for example: `EncDec`).
162    pub kind: CustomFunctionKind,
163
164    /// Runtime descriptor attributes encoded by the IR body.
165    pub attrs: SmallVec<[Arc<UOp>; 4]>,
166
167    /// Buffer indices in ExecutionPlan order.
168    pub buffer_indices: Vec<usize>,
169
170    /// Bound variable values for this operation.
171    pub fixedvars: HashMap<String, i64>,
172
173    /// Operation IDs that must complete before this custom function runs.
174    pub dependencies: Vec<u64>,
175
176    /// Cached `(name, min_val, max_val)` triples for every `DefineVar`
177    /// reachable from `attrs`. Populated at construction so
178    /// `validate_runtime_var_bounds` doesn't re-toposort on every execute call.
179    pub runtime_vars: Vec<RuntimeVar>,
180}
181
182/// Prepared execution item.
183#[derive(Clone, Debug)]
184pub enum PreparedOp {
185    /// Compiled kernel/program operation.
186    CompiledProgram(PreparedKernel),
187
188    /// Direct buffer copy operation.
189    BufferCopy(PreparedCopy),
190
191    /// Zero-copy view aliasing operation.
192    BufferView(PreparedBufferView),
193
194    /// Runtime custom function operation.
195    CustomFunction(PreparedCustomFunction),
196}
197
198fn op_identity(op: &PreparedOp) -> (u64, Vec<u64>) {
199    match op {
200        PreparedOp::CompiledProgram(kernel) => (kernel.id, kernel.dependencies.clone()),
201        PreparedOp::BufferCopy(copy) => (copy.id, copy.dependencies.clone()),
202        PreparedOp::BufferView(view) => (view.id, view.dependencies.clone()),
203        PreparedOp::CustomFunction(custom) => (custom.id, custom.dependencies.clone()),
204    }
205}
206
207fn validate_var_bound(name: &str, value: i64, min_val: i64, max_val: i64) -> Result<()> {
208    if value < min_val || value > max_val {
209        return Err(crate::error::Error::Execution {
210            reason: format!("variable {name}={value} is outside bounds [{min_val}, {max_val}]"),
211        });
212    }
213    Ok(())
214}
215
216struct DependencyGraph {
217    op_ids: Vec<u64>,
218    in_degree: Vec<usize>,
219    successors: Vec<Vec<usize>>,
220}
221
222fn build_dependency_graph(ops: &[PreparedOp], instance_deps_per_op: Option<&[Vec<usize>]>) -> Result<DependencyGraph> {
223    if let Some(instance_deps) = instance_deps_per_op
224        && instance_deps.len() != ops.len()
225    {
226        return Err(crate::error::Error::Execution {
227            reason: format!(
228                "prepared op instance dependency table length mismatch: ops={}, instance_deps={}",
229                ops.len(),
230                instance_deps.len()
231            ),
232        });
233    }
234
235    let mut op_ids = Vec::with_capacity(ops.len());
236    let mut deps_per_op = Vec::with_capacity(ops.len());
237    let mut id_counts: HashMap<u64, usize> = HashMap::with_capacity(ops.len());
238
239    for op in ops {
240        let (op_id, deps) = op_identity(op);
241        op_ids.push(op_id);
242        deps_per_op.push(deps);
243        *id_counts.entry(op_id).or_insert(0) += 1;
244    }
245
246    let has_duplicate_ids = id_counts.values().any(|&count| count > 1);
247
248    let mut in_degree = vec![0usize; ops.len()];
249    let mut successors: Vec<Vec<usize>> = vec![Vec::new(); ops.len()];
250
251    if !has_duplicate_ids {
252        let mut id_to_idx: HashMap<u64, usize> = HashMap::with_capacity(ops.len());
253        for (idx, &op_id) in op_ids.iter().enumerate() {
254            id_to_idx.insert(op_id, idx);
255        }
256
257        for (idx, deps) in deps_per_op.iter().enumerate() {
258            for dep in deps {
259                let Some(&dep_idx) = id_to_idx.get(dep) else {
260                    return Err(crate::error::Error::Execution {
261                        reason: format!("prepared op {} depends on unknown op id {}", op_ids[idx], dep),
262                    });
263                };
264                in_degree[idx] += 1;
265                successors[dep_idx].push(idx);
266            }
267        }
268    } else {
269        // Expanded schedules may contain repeated op IDs for per-iteration items.
270        // Resolve dependencies against the most recent prior op with that ID.
271        let mut last_seen: HashMap<u64, usize> = HashMap::with_capacity(ops.len());
272
273        for (idx, deps) in deps_per_op.iter().enumerate() {
274            for dep in deps {
275                let Some(&dep_idx) = last_seen.get(dep) else {
276                    return Err(crate::error::Error::Execution {
277                        reason: format!(
278                            "prepared op {} depends on unknown prior op id {} (duplicate-id schedule mode)",
279                            op_ids[idx], dep
280                        ),
281                    });
282                };
283                in_degree[idx] += 1;
284                successors[dep_idx].push(idx);
285            }
286
287            last_seen.insert(op_ids[idx], idx);
288        }
289    }
290
291    if let Some(instance_deps_per_op) = instance_deps_per_op {
292        for (idx, instance_deps) in instance_deps_per_op.iter().enumerate() {
293            for &dep_idx in instance_deps {
294                if dep_idx >= ops.len() {
295                    return Err(crate::error::Error::Execution {
296                        reason: format!("prepared op {} depends on unknown op index {}", op_ids[idx], dep_idx),
297                    });
298                }
299                if dep_idx == idx {
300                    return Err(crate::error::Error::Execution {
301                        reason: format!("prepared op {} cannot depend on itself by op index {}", op_ids[idx], dep_idx),
302                    });
303                }
304                in_degree[idx] += 1;
305                successors[dep_idx].push(idx);
306            }
307        }
308    }
309
310    Ok(DependencyGraph { op_ids, in_degree, successors })
311}
312
313#[cfg(test)]
314fn compute_mixed_op_order(ops: &[PreparedOp]) -> Result<Vec<usize>> {
315    compute_mixed_op_order_with_instance_dependencies(ops, &[])
316}
317
318fn compute_mixed_op_order_with_instance_dependencies(
319    ops: &[PreparedOp],
320    instance_deps_per_op: &[Vec<usize>],
321) -> Result<Vec<usize>> {
322    let instance_deps = (!instance_deps_per_op.is_empty()).then_some(instance_deps_per_op);
323    let DependencyGraph { op_ids, mut in_degree, successors } = build_dependency_graph(ops, instance_deps)?;
324
325    let mut ready: BinaryHeap<Reverse<usize>> = BinaryHeap::new();
326    for (idx, &deg) in in_degree.iter().enumerate() {
327        if deg == 0 {
328            ready.push(Reverse(idx));
329        }
330    }
331
332    let mut order = Vec::with_capacity(ops.len());
333    while let Some(Reverse(idx)) = ready.pop() {
334        order.push(idx);
335        for &succ in &successors[idx] {
336            in_degree[succ] -= 1;
337            if in_degree[succ] == 0 {
338                ready.push(Reverse(succ));
339            }
340        }
341    }
342
343    if order.len() != ops.len() {
344        let blocked: Vec<u64> = in_degree
345            .iter()
346            .enumerate()
347            .filter_map(|(idx, &deg)| if deg > 0 { Some(op_ids[idx]) } else { None })
348            .collect();
349        return Err(crate::error::Error::Execution {
350            reason: format!("cycle detected in prepared op dependencies: blocked_ids={blocked:?}"),
351        });
352    }
353
354    Ok(order)
355}
356
357#[cfg(test)]
358fn compute_execution_levels(ops: &[PreparedOp]) -> Result<Vec<Vec<usize>>> {
359    compute_execution_levels_with_instance_dependencies(ops, &[])
360}
361
362fn compute_execution_levels_with_instance_dependencies(
363    ops: &[PreparedOp],
364    instance_deps_per_op: &[Vec<usize>],
365) -> Result<Vec<Vec<usize>>> {
366    let instance_deps = (!instance_deps_per_op.is_empty()).then_some(instance_deps_per_op);
367    let DependencyGraph { op_ids, mut in_degree, successors } = build_dependency_graph(ops, instance_deps)?;
368
369    let mut ready: BinaryHeap<Reverse<usize>> = BinaryHeap::new();
370    for (idx, &deg) in in_degree.iter().enumerate() {
371        if deg == 0 {
372            ready.push(Reverse(idx));
373        }
374    }
375
376    let mut levels: Vec<Vec<usize>> = Vec::new();
377    let mut visited = 0usize;
378
379    while !ready.is_empty() {
380        let mut level: Vec<usize> = Vec::new();
381        while let Some(Reverse(idx)) = ready.pop() {
382            level.push(idx);
383        }
384
385        let mut next_ready: BinaryHeap<Reverse<usize>> = BinaryHeap::new();
386        for &idx in &level {
387            visited += 1;
388            for &succ in &successors[idx] {
389                in_degree[succ] -= 1;
390                if in_degree[succ] == 0 {
391                    next_ready.push(Reverse(succ));
392                }
393            }
394        }
395
396        levels.push(level);
397        ready = next_ready;
398    }
399
400    if visited != ops.len() {
401        let blocked: Vec<u64> = in_degree
402            .iter()
403            .enumerate()
404            .filter_map(|(idx, &deg)| if deg > 0 { Some(op_ids[idx]) } else { None })
405            .collect();
406        return Err(crate::error::Error::Execution {
407            reason: format!("cycle detected in prepared op dependencies: blocked_ids={blocked:?}"),
408        });
409    }
410
411    Ok(levels)
412}
413
414/// Pre-compiled execution plan for a computation graph.
415///
416/// Created once via `prepare()`, then executed multiple times.
417/// The plan owns all its buffers and compiled kernels.
418pub struct ExecutionPlan {
419    /// Prepared operations in schedule order.
420    ops: Vec<PreparedOp>,
421
422    /// Concrete op-index dependencies parallel to `ops`.
423    op_instance_dependencies: Vec<Vec<usize>>,
424
425    /// Precomputed dependency-safe operation order.
426    op_order: Vec<usize>,
427
428    /// Topological levels of dependency-independent operations.
429    op_levels: Vec<Vec<usize>>,
430
431    /// ALL buffers owned by this plan (inputs, intermediates, outputs).
432    buffers: Vec<Buffer>,
433
434    /// Mapping: AST id → buffer index (for kernel buffer binding).
435    ast_to_buffer: HashMap<u64, usize>,
436
437    /// Indices of output buffers in `buffers` (matches SINK source order).
438    output_buffer_indices: Vec<usize>,
439
440    /// Primary device for this plan.
441    device: DeviceSpec,
442
443    /// Last dynamic variable bindings supplied through `execute_with_vars`.
444    runtime_var_vals: HashMap<String, i64>,
445
446    /// Additional UOp IDs registered as aliases that need cleanup.
447    alias_ids: Vec<u64>,
448}
449
450// ============================================================================
451// ExecutionPlan Implementation
452// ============================================================================
453
454impl ExecutionPlan {
455    fn kernel_launch_sizes(kernel: &PreparedKernel) -> Result<RuntimeLaunchSizes> {
456        let mut vars: HashMap<&str, i64> =
457            HashMap::with_capacity(kernel.kernel.var_names.len() + kernel.fixedvars.len());
458        for (idx, name) in kernel.kernel.var_names.iter().enumerate() {
459            let value = kernel.vals.get(idx).copied().ok_or_else(|| crate::error::Error::Execution {
460                reason: format!(
461                    "Kernel {} has {} var names but only {} values",
462                    kernel.id,
463                    kernel.kernel.var_names.len(),
464                    kernel.vals.len()
465                ),
466            })?;
467            vars.insert(name.as_str(), value);
468        }
469        for (name, value) in &kernel.fixedvars {
470            vars.insert(name.as_str(), *value);
471        }
472
473        let dims =
474            ProgramSpec::resolve_launch_dims(&kernel.kernel.global_size, kernel.kernel.local_size.as_ref(), &vars)
475                .map_err(|e| crate::error::Error::Execution {
476                    reason: format!("Kernel {} launch dimensions failed: {e}", kernel.id),
477                })?;
478        Ok((Some(dims.global_size), dims.local_size))
479    }
480
481    fn kernel_uses_cpu_threading(kernel: &PreparedKernel) -> Result<bool> {
482        if !matches!(kernel.device, DeviceSpec::Cpu) {
483            return Ok(false);
484        }
485        let (global_size, _) = Self::kernel_launch_sizes(kernel)?;
486        Ok(global_size.map(|[x, _, _]| x > 1).unwrap_or(false))
487    }
488
489    #[inline]
490    fn execute_kernel(kernel: &PreparedKernel) -> Result<()> {
491        let buffer_ptrs: SmallVec<[*mut u8; 8]> = kernel.buffer_ptrs.iter().map(|&ptr| ptr as *mut u8).collect();
492        let (global_size, local_size) = Self::kernel_launch_sizes(kernel)?;
493        unsafe {
494            kernel
495                .kernel
496                .program
497                .execute(&buffer_ptrs, &kernel.vals, global_size, local_size)
498                .map_err(|e| crate::error::Error::Execution { reason: format!("Kernel {} failed: {}", kernel.id, e) })
499        }
500    }
501
502    fn validate_runtime_var_bounds(&self, var_vals: &[(&str, i64)]) -> Result<()> {
503        let vals_map: HashMap<&str, i64> = var_vals.iter().copied().collect();
504        for op in &self.ops {
505            match op {
506                PreparedOp::CompiledProgram(kernel) => {
507                    for var in &kernel.runtime_vars {
508                        if kernel.fixedvars.contains_key(&var.name) || var.name == "core_id" {
509                            continue;
510                        }
511                        if let Some(&value) = vals_map.get(var.name.as_str()) {
512                            validate_var_bound(&var.name, value, var.min_val, var.max_val)?;
513                        }
514                    }
515                }
516                PreparedOp::CustomFunction(custom) => {
517                    for var in &custom.runtime_vars {
518                        if custom.fixedvars.contains_key(&var.name) || var.name == "core_id" {
519                            continue;
520                        }
521                        if let Some(&value) = vals_map.get(var.name.as_str()) {
522                            validate_var_bound(&var.name, value, var.min_val, var.max_val)?;
523                        }
524                    }
525                }
526                PreparedOp::BufferCopy(_) | PreparedOp::BufferView(_) => {}
527            }
528        }
529        Ok(())
530    }
531
532    fn update_runtime_var_vals(&mut self, var_vals: &[(&str, i64)]) -> Result<()> {
533        self.validate_runtime_var_bounds(var_vals)?;
534
535        let vals_map: HashMap<&str, i64> = var_vals.iter().copied().collect();
536        for &(name, value) in var_vals {
537            if name == "core_id" {
538                continue;
539            }
540            self.runtime_var_vals.insert(name.to_string(), value);
541        }
542        for op in &mut self.ops {
543            if let PreparedOp::CompiledProgram(kernel) = op {
544                for (idx, name) in kernel.kernel.var_names.iter().enumerate() {
545                    if kernel.fixedvars.contains_key(name) || name == "core_id" {
546                        continue;
547                    }
548                    if let Some(&v) = vals_map.get(name.as_str()) {
549                        let Some(slot) = kernel.vals.get_mut(idx) else {
550                            return Err(crate::error::Error::Execution {
551                                reason: format!(
552                                    "Kernel {} has {} var names but only {} values",
553                                    kernel.id,
554                                    kernel.kernel.var_names.len(),
555                                    kernel.vals.len()
556                                ),
557                            });
558                        };
559                        *slot = v;
560                    }
561                }
562            }
563        }
564        Ok(())
565    }
566
567    #[inline]
568    fn execute_copy(&self, copy: &PreparedCopy) -> Result<()> {
569        if copy.buffer_indices.len() < 2 {
570            return Err(crate::error::Error::Execution {
571                reason: format!(
572                    "Copy op {} requires at least two buffer indices (dst, src), got {}",
573                    copy.id,
574                    copy.buffer_indices.len()
575                ),
576            });
577        }
578        let dst_idx = copy.buffer_indices[0];
579        let src_idx = copy.buffer_indices[1];
580
581        if dst_idx >= self.buffers.len() || src_idx >= self.buffers.len() {
582            return Err(crate::error::Error::Execution {
583                reason: format!(
584                    "Copy op {} buffer index out of range: dst={}, src={}, total_buffers={}",
585                    copy.id,
586                    dst_idx,
587                    src_idx,
588                    self.buffers.len()
589                ),
590            });
591        }
592
593        let mut dst = self.buffers[dst_idx].clone();
594        let src = &self.buffers[src_idx];
595        dst.copy_from(src)
596            .map_err(|e| crate::error::Error::Execution { reason: format!("Copy op {} failed: {}", copy.id, e) })
597    }
598
599    #[inline]
600    fn execute_buffer_view(&self, view: &PreparedBufferView) -> Result<()> {
601        if view.buffer_indices.len() < 2 {
602            return Err(crate::error::Error::Execution {
603                reason: format!(
604                    "BufferView op {} requires at least two buffer indices (out, base), got {}",
605                    view.id,
606                    view.buffer_indices.len()
607                ),
608            });
609        }
610        let out_idx = view.buffer_indices[0];
611        let base_idx = view.buffer_indices[1];
612
613        if out_idx >= self.buffers.len() || base_idx >= self.buffers.len() {
614            return Err(crate::error::Error::Execution {
615                reason: format!(
616                    "BufferView op {} buffer index out of range: out={}, base={}, total_buffers={}",
617                    view.id,
618                    out_idx,
619                    base_idx,
620                    self.buffers.len()
621                ),
622            });
623        }
624
625        let out = &self.buffers[out_idx];
626        let base = &self.buffers[base_idx];
627        let expected_offset = base.offset() + view.byte_offset;
628
629        if out.storage_id() != base.storage_id() || out.offset() != expected_offset || out.size() != view.byte_size {
630            return Err(crate::error::Error::Execution {
631                reason: format!(
632                    "BufferView op {} mismatch: out(storage={:?},off={},size={}) base(storage={:?},off={}) expected(off={},size={})",
633                    view.id,
634                    out.storage_id(),
635                    out.offset(),
636                    out.size(),
637                    base.storage_id(),
638                    base.offset(),
639                    expected_offset,
640                    view.byte_size,
641                ),
642            });
643        }
644        Ok(())
645    }
646
647    #[inline]
648    fn execute_custom_function(&self, custom: &PreparedCustomFunction) -> Result<()> {
649        let mut buffers = Vec::with_capacity(custom.buffer_indices.len());
650        for &idx in &custom.buffer_indices {
651            let Some(buffer) = self.buffers.get(idx) else {
652                return Err(crate::error::Error::Execution {
653                    reason: format!(
654                        "Custom function op {} ({:?}) buffer index out of range: idx={}, total_buffers={}",
655                        custom.id,
656                        custom.kind,
657                        idx,
658                        self.buffers.len()
659                    ),
660                });
661            };
662            buffers.push(buffer.clone());
663        }
664
665        let mut vars = self.runtime_var_vals.clone();
666        vars.extend(custom.fixedvars.iter().map(|(k, v)| (k.clone(), *v)));
667
668        crate::custom_function::run_custom_function(&custom.kind, &custom.attrs, &mut buffers, &vars).map_err(|e| {
669            // Pass typed `Unsupported` errors through unchanged so callers can match on `kind`.
670            // Other errors are wrapped with op context for debugging.
671            match e {
672                crate::error::Error::Unsupported { .. } => e,
673                other => crate::error::Error::Execution {
674                    reason: format!("Custom function op {} ({:?}) failed: {other}", custom.id, custom.kind),
675                },
676            }
677        })
678    }
679
680    #[inline]
681    fn execute_op(&self, op: &PreparedOp) -> Result<()> {
682        match op {
683            PreparedOp::CompiledProgram(kernel) => Self::execute_kernel(kernel),
684            PreparedOp::BufferCopy(copy) => self.execute_copy(copy),
685            PreparedOp::BufferView(view) => self.execute_buffer_view(view),
686            PreparedOp::CustomFunction(custom) => self.execute_custom_function(custom),
687        }
688    }
689
690    #[inline]
691    fn op_requires_serial(op: &PreparedOp) -> bool {
692        match op {
693            PreparedOp::CompiledProgram(kernel) => !kernel.kernel.host_parallel_safe,
694            PreparedOp::BufferCopy(_) | PreparedOp::BufferView(_) | PreparedOp::CustomFunction(_) => true,
695        }
696    }
697
698    #[inline]
699    fn compiled_kernel_at(&self, idx: usize) -> Option<&PreparedKernel> {
700        match &self.ops[idx] {
701            PreparedOp::CompiledProgram(kernel) => Some(kernel),
702            _ => None,
703        }
704    }
705
706    fn kernels_conflict(lhs: &PreparedKernel, rhs: &PreparedKernel) -> bool {
707        let lhs_outputs: HashSet<BufferId> =
708            lhs.output_indices.iter().filter_map(|&out_idx| lhs.buffer_ids.get(out_idx).copied()).collect();
709        let rhs_outputs: HashSet<BufferId> =
710            rhs.output_indices.iter().filter_map(|&out_idx| rhs.buffer_ids.get(out_idx).copied()).collect();
711
712        if !lhs_outputs.is_disjoint(&rhs_outputs) {
713            return true;
714        }
715
716        let lhs_reads: HashSet<BufferId> = lhs
717            .buffer_ids
718            .iter()
719            .enumerate()
720            .filter_map(|(idx, &buf)| (!lhs.output_indices.contains(&idx)).then_some(buf))
721            .collect();
722        let rhs_reads: HashSet<BufferId> = rhs
723            .buffer_ids
724            .iter()
725            .enumerate()
726            .filter_map(|(idx, &buf)| (!rhs.output_indices.contains(&idx)).then_some(buf))
727            .collect();
728
729        !lhs_outputs.is_disjoint(&rhs_reads) || !rhs_outputs.is_disjoint(&lhs_reads)
730    }
731
732    fn partition_parallel_safe_group(&self, indices: &[usize]) -> Result<Vec<Vec<usize>>> {
733        let mut groups: Vec<Vec<usize>> = Vec::new();
734
735        for &idx in indices {
736            let Some(kernel) = self.compiled_kernel_at(idx) else {
737                return Err(crate::error::Error::Execution {
738                    reason: format!("parallel partition expected compiled kernel at op index {idx}"),
739                });
740            };
741
742            let mut placed = false;
743            for group in &mut groups {
744                let has_conflict = group.iter().any(|&existing_idx| {
745                    self.compiled_kernel_at(existing_idx)
746                        .map(|existing| Self::kernels_conflict(existing, kernel))
747                        .unwrap_or(true)
748                });
749                if !has_conflict {
750                    group.push(idx);
751                    placed = true;
752                    break;
753                }
754            }
755
756            if !placed {
757                groups.push(vec![idx]);
758            }
759        }
760
761        Ok(groups)
762    }
763
764    fn execute_parallel_group(&self, indices: &[usize]) -> Result<()> {
765        if indices.len() <= 1 {
766            if let Some(&idx) = indices.first() {
767                self.execute_op(&self.ops[idx])?;
768            }
769            return Ok(());
770        }
771
772        let has_threaded_cpu_kernel = indices.iter().try_fold(false, |acc, &idx| {
773            let Some(kernel) = self.compiled_kernel_at(idx) else {
774                return Err(crate::error::Error::Execution {
775                    reason: format!("parallel execution expected compiled kernel at op index {idx}"),
776                });
777            };
778            Ok(acc || Self::kernel_uses_cpu_threading(kernel)?)
779        })?;
780
781        if has_threaded_cpu_kernel {
782            for &idx in indices {
783                let Some(kernel) = self.compiled_kernel_at(idx) else {
784                    return Err(crate::error::Error::Execution {
785                        reason: format!("parallel execution expected compiled kernel at op index {idx}"),
786                    });
787                };
788                Self::execute_kernel(kernel)?;
789            }
790            return Ok(());
791        }
792
793        indices
794            .par_iter()
795            .map(|&idx| {
796                let Some(kernel) = self.compiled_kernel_at(idx) else {
797                    return Err(crate::error::Error::Execution {
798                        reason: format!("parallel execution expected compiled kernel at op index {idx}"),
799                    });
800                };
801                Self::execute_kernel(kernel)
802            })
803            .collect::<Result<Vec<_>>>()?;
804
805        Ok(())
806    }
807
808    fn execute_parallel_group_profiled(&self, indices: &[usize]) -> Result<Vec<(usize, KernelProfile)>> {
809        if indices.len() <= 1 {
810            let mut profiles = Vec::new();
811            if let Some(&idx) = indices.first() {
812                let Some(kernel) = self.compiled_kernel_at(idx) else {
813                    return Err(crate::error::Error::Execution {
814                        reason: format!("profiled execution expected compiled kernel at op index {idx}"),
815                    });
816                };
817                let start = Instant::now();
818                Self::execute_kernel(kernel)?;
819                profiles.push((
820                    idx,
821                    KernelProfile {
822                        kernel: Arc::clone(&kernel.kernel),
823                        device: kernel.device.clone(),
824                        num_buffers: kernel.buffer_ptrs.len(),
825                        elapsed: start.elapsed(),
826                    },
827                ));
828            }
829            return Ok(profiles);
830        }
831
832        let has_threaded_cpu_kernel = indices.iter().try_fold(false, |acc, &idx| {
833            let Some(kernel) = self.compiled_kernel_at(idx) else {
834                return Err(crate::error::Error::Execution {
835                    reason: format!("profiled execution expected compiled kernel at op index {idx}"),
836                });
837            };
838            Ok(acc || Self::kernel_uses_cpu_threading(kernel)?)
839        })?;
840
841        if has_threaded_cpu_kernel {
842            let mut profiles = Vec::with_capacity(indices.len());
843            for &idx in indices {
844                let Some(kernel) = self.compiled_kernel_at(idx) else {
845                    return Err(crate::error::Error::Execution {
846                        reason: format!("profiled execution expected compiled kernel at op index {idx}"),
847                    });
848                };
849                let start = Instant::now();
850                Self::execute_kernel(kernel)?;
851                profiles.push((
852                    idx,
853                    KernelProfile {
854                        kernel: Arc::clone(&kernel.kernel),
855                        device: kernel.device.clone(),
856                        num_buffers: kernel.buffer_ptrs.len(),
857                        elapsed: start.elapsed(),
858                    },
859                ));
860            }
861            return Ok(profiles);
862        }
863
864        let mut profiles = indices
865            .par_iter()
866            .map(|&idx| {
867                let Some(kernel) = self.compiled_kernel_at(idx) else {
868                    return Err(crate::error::Error::Execution {
869                        reason: format!("profiled execution expected compiled kernel at op index {idx}"),
870                    });
871                };
872                let start = Instant::now();
873                Self::execute_kernel(kernel)?;
874                Ok((
875                    idx,
876                    KernelProfile {
877                        kernel: Arc::clone(&kernel.kernel),
878                        device: kernel.device.clone(),
879                        num_buffers: kernel.buffer_ptrs.len(),
880                        elapsed: start.elapsed(),
881                    },
882                ))
883            })
884            .collect::<Result<Vec<_>>>()?;
885
886        profiles.sort_by_key(|(idx, _)| *idx);
887        Ok(profiles)
888    }
889
890    /// Get the first (or only) output buffer after execution.
891    ///
892    /// Returns `None` for plans with no output buffers (for example, plans
893    /// constructed before `set_output_buffer*` is called).
894    pub fn output_buffer(&self) -> Option<&Buffer> {
895        self.output_buffer_indices.first().and_then(|&i| self.buffers.get(i))
896    }
897
898    /// Get output buffer by position (matches SINK source order for batch).
899    ///
900    /// Returns `None` if `position` is out of range.
901    pub fn output_buffer_at(&self, position: usize) -> Option<&Buffer> {
902        self.output_buffer_indices.get(position).and_then(|&i| self.buffers.get(i))
903    }
904
905    /// Get all output buffers.
906    pub fn output_buffers(&self) -> Vec<&Buffer> {
907        self.output_buffer_indices.iter().map(|&i| &self.buffers[i]).collect()
908    }
909
910    /// Number of outputs in this plan.
911    pub fn num_outputs(&self) -> usize {
912        self.output_buffer_indices.len()
913    }
914
915    /// Get a buffer by AST id (for reading intermediate results).
916    pub fn buffer(&self, ast_id: u64) -> Option<&Buffer> {
917        self.ast_to_buffer.get(&ast_id).map(|&idx| &self.buffers[idx])
918    }
919
920    /// Get a mutable buffer by AST id (for `copyin()` on input buffers).
921    pub fn buffer_mut_by_id(&mut self, ast_id: u64) -> Option<&mut Buffer> {
922        self.ast_to_buffer.get(&ast_id).copied().map(|idx| &mut self.buffers[idx])
923    }
924
925    /// Get the primary device for this plan.
926    pub fn device(&self) -> &DeviceSpec {
927        &self.device
928    }
929
930    /// Get all buffers owned by this plan.
931    pub fn buffers(&self) -> &[Buffer] {
932        &self.buffers
933    }
934
935    /// Get mutable access to all buffers owned by this plan.
936    pub fn buffers_mut(&mut self) -> &mut [Buffer] {
937        &mut self.buffers
938    }
939
940    /// Get a mutable buffer by its index in the buffers array.
941    pub fn buffer_at_mut(&mut self, index: usize) -> Option<&mut Buffer> {
942        self.buffers.get_mut(index)
943    }
944
945    /// Get all prepared kernels.
946    pub fn prepared_kernels(&self) -> Vec<&PreparedKernel> {
947        self.ops
948            .iter()
949            .filter_map(|op| match op {
950                PreparedOp::CompiledProgram(kernel) => Some(kernel),
951                _ => None,
952            })
953            .collect()
954    }
955
956    /// Get all prepared operations in schedule order.
957    pub fn prepared_ops(&self) -> &[PreparedOp] {
958        &self.ops
959    }
960
961    /// Iterate over compiled kernels (for inspecting generated source code).
962    pub fn kernels(&self) -> impl Iterator<Item = &CachedKernel> {
963        self.ops.iter().filter_map(|op| match op {
964            PreparedOp::CompiledProgram(kernel) => Some(kernel.kernel.as_ref()),
965            _ => None,
966        })
967    }
968
969    /// Execute the plan.
970    ///
971    /// Uses dependency-aware operation ordering for all prepared op types.
972    pub fn execute(&self) -> Result<()> {
973        for level in &self.op_levels {
974            let mut pending_parallel: Vec<usize> = Vec::new();
975
976            for &idx in level {
977                let op = &self.ops[idx];
978                if Self::op_requires_serial(op) {
979                    if !pending_parallel.is_empty() {
980                        let groups = self.partition_parallel_safe_group(&pending_parallel)?;
981                        for group in groups {
982                            self.execute_parallel_group(&group)?;
983                        }
984                        pending_parallel.clear();
985                    }
986                    self.execute_op(op)?;
987                } else {
988                    pending_parallel.push(idx);
989                }
990            }
991
992            if !pending_parallel.is_empty() {
993                let groups = self.partition_parallel_safe_group(&pending_parallel)?;
994                for group in groups {
995                    self.execute_parallel_group(&group)?;
996                }
997            }
998        }
999        Ok(())
1000    }
1001
1002    /// Execute the plan with per-kernel timing.
1003    ///
1004    /// Returns a [`KernelProfile`] for each kernel in execution order.
1005    ///
1006    /// # Example
1007    ///
1008    /// ```ignore
1009    /// let plan = tensor.prepare()?;
1010    /// let profiles = plan.execute_profiled()?;
1011    ///
1012    /// // Sort by time descending
1013    /// let mut sorted = profiles;
1014    /// sorted.sort_by(|a, b| b.elapsed.cmp(&a.elapsed));
1015    /// for p in &sorted[..10.min(sorted.len())] {
1016    ///     println!("{:>8.3}ms  {}", p.elapsed.as_secs_f64() * 1000.0, p.kernel.entry_point);
1017    /// }
1018    /// ```
1019    pub fn execute_profiled(&self) -> Result<Vec<KernelProfile>> {
1020        let mut profiles = Vec::new();
1021        for level in &self.op_levels {
1022            let mut pending_parallel: Vec<usize> = Vec::new();
1023
1024            for &idx in level {
1025                match &self.ops[idx] {
1026                    PreparedOp::CompiledProgram(kernel) if kernel.kernel.host_parallel_safe => {
1027                        pending_parallel.push(idx);
1028                    }
1029                    PreparedOp::CompiledProgram(kernel) => {
1030                        if !pending_parallel.is_empty() {
1031                            let groups = self.partition_parallel_safe_group(&pending_parallel)?;
1032                            for group in groups {
1033                                let mut prof = self.execute_parallel_group_profiled(&group)?;
1034                                profiles.extend(prof.drain(..).map(|(_, p)| p));
1035                            }
1036                            pending_parallel.clear();
1037                        }
1038
1039                        let start = Instant::now();
1040                        Self::execute_kernel(kernel)?;
1041                        profiles.push(KernelProfile {
1042                            kernel: Arc::clone(&kernel.kernel),
1043                            device: kernel.device.clone(),
1044                            num_buffers: kernel.buffer_ptrs.len(),
1045                            elapsed: start.elapsed(),
1046                        });
1047                    }
1048                    PreparedOp::BufferCopy(copy) => {
1049                        if !pending_parallel.is_empty() {
1050                            let groups = self.partition_parallel_safe_group(&pending_parallel)?;
1051                            for group in groups {
1052                                let mut prof = self.execute_parallel_group_profiled(&group)?;
1053                                profiles.extend(prof.drain(..).map(|(_, p)| p));
1054                            }
1055                            pending_parallel.clear();
1056                        }
1057                        self.execute_copy(copy)?;
1058                    }
1059                    PreparedOp::BufferView(view) => {
1060                        if !pending_parallel.is_empty() {
1061                            let groups = self.partition_parallel_safe_group(&pending_parallel)?;
1062                            for group in groups {
1063                                let mut prof = self.execute_parallel_group_profiled(&group)?;
1064                                profiles.extend(prof.drain(..).map(|(_, p)| p));
1065                            }
1066                            pending_parallel.clear();
1067                        }
1068                        self.execute_buffer_view(view)?;
1069                    }
1070                    PreparedOp::CustomFunction(custom) => {
1071                        if !pending_parallel.is_empty() {
1072                            let groups = self.partition_parallel_safe_group(&pending_parallel)?;
1073                            for group in groups {
1074                                let mut prof = self.execute_parallel_group_profiled(&group)?;
1075                                profiles.extend(prof.drain(..).map(|(_, p)| p));
1076                            }
1077                            pending_parallel.clear();
1078                        }
1079                        self.execute_custom_function(custom)?;
1080                    }
1081                }
1082            }
1083
1084            if !pending_parallel.is_empty() {
1085                let groups = self.partition_parallel_safe_group(&pending_parallel)?;
1086                for group in groups {
1087                    let mut prof = self.execute_parallel_group_profiled(&group)?;
1088                    profiles.extend(prof.drain(..).map(|(_, p)| p));
1089                }
1090            }
1091        }
1092        Ok(profiles)
1093    }
1094
1095    /// Re-execute the plan with different variable bindings.
1096    ///
1097    /// The kernel code is NOT recompiled; only the `vals` passed to each kernel
1098    /// are updated. Buffers must be allocated to max variable values (which is
1099    /// the default when using `Variable::bind()`).
1100    ///
1101    /// # Safety contract
1102    ///
1103    /// Variable values **must** fall within `[min_val, max_val]` bounds defined
1104    /// at `Variable::new()` time. Exceeding `max_val` causes out-of-bounds buffer
1105    /// access (buffers are allocated to `max_val`). Use `Variable::bind()` to
1106    /// validate bounds before calling this method.
1107    ///
1108    /// Variables not present in `var_vals` keep their existing values from
1109    /// `prepare()` (or the previous `execute_with_vars` call). Internal
1110    /// variables like `core_id` are left untouched.
1111    pub fn execute_with_vars(&mut self, var_vals: &[(&str, i64)]) -> Result<()> {
1112        self.update_runtime_var_vals(var_vals)?;
1113        self.execute()
1114    }
1115
1116    /// Re-execute the plan with different variable bindings and per-kernel timing.
1117    ///
1118    /// Updates kernel `vals` the same way as [`Self::execute_with_vars`] and then
1119    /// executes via [`Self::execute_profiled`].
1120    pub fn execute_with_vars_profiled(&mut self, var_vals: &[(&str, i64)]) -> Result<Vec<KernelProfile>> {
1121        self.update_runtime_var_vals(var_vals)?;
1122        self.execute_profiled()
1123    }
1124
1125    /// Get the first output buffer index.
1126    pub fn output_buffer_idx(&self) -> usize {
1127        self.output_buffer_indices[0]
1128    }
1129
1130    /// Get the AST ID to buffer index mapping.
1131    pub fn ast_to_buffer_map(&self) -> &HashMap<u64, usize> {
1132        &self.ast_to_buffer
1133    }
1134
1135    /// Release intermediate buffers from the global buffer registry.
1136    ///
1137    /// Call this after you're done executing the plan to free intermediate
1138    /// buffers from the global registry. The output buffer is preserved.
1139    pub fn release_intermediate_buffers<F>(&self, remove_fn: F)
1140    where
1141        F: Fn(u64),
1142    {
1143        self.release_buffers_impl(remove_fn, true);
1144    }
1145
1146    /// Release ALL buffers from the global registry, including the output.
1147    pub fn release_all_buffers<F>(&self, remove_fn: F)
1148    where
1149        F: Fn(u64),
1150    {
1151        self.release_buffers_impl(remove_fn, false);
1152    }
1153
1154    fn release_buffers_impl<F>(&self, remove_fn: F, skip_output: bool)
1155    where
1156        F: Fn(u64),
1157    {
1158        let output_buf_ids: std::collections::HashSet<u64> = if skip_output {
1159            self.output_buffer_indices.iter().filter_map(|&idx| self.buffers.get(idx).map(|b| b.id().0)).collect()
1160        } else {
1161            std::collections::HashSet::new()
1162        };
1163
1164        for (&ast_id, &buf_idx) in &self.ast_to_buffer {
1165            if skip_output && output_buf_ids.contains(&self.buffers[buf_idx].id().0) {
1166                continue;
1167            }
1168            remove_fn(ast_id);
1169        }
1170
1171        for &alias_id in &self.alias_ids {
1172            remove_fn(alias_id);
1173        }
1174    }
1175}
1176
1177impl std::fmt::Debug for ExecutionPlan {
1178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1179        let kernel_count = self.ops.iter().filter(|op| matches!(op, PreparedOp::CompiledProgram(_))).count();
1180        f.debug_struct("ExecutionPlan")
1181            .field("ops", &self.ops.len())
1182            .field("op_instance_dependencies", &self.op_instance_dependencies.len())
1183            .field("op_order", &self.op_order.len())
1184            .field("kernels", &kernel_count)
1185            .field("buffers", &self.buffers.len())
1186            .field("device", &self.device)
1187            .finish()
1188    }
1189}
1190
1191impl std::fmt::Debug for PreparedKernel {
1192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1193        f.debug_struct("PreparedKernel")
1194            .field("id", &self.id)
1195            .field("device", &self.device)
1196            .field("buffer_indices", &self.buffer_indices)
1197            .field("output_indices", &self.output_indices)
1198            .field("vals", &self.vals)
1199            .field("fixedvars", &self.fixedvars)
1200            .field("dependencies", &self.dependencies)
1201            .finish()
1202    }
1203}
1204
1205// ============================================================================
1206// Builder for ExecutionPlan
1207// ============================================================================
1208
1209/// Builder for creating ExecutionPlan from schedule data.
1210pub struct ExecutionPlanBuilder {
1211    ops: Vec<PreparedOp>,
1212    op_instance_dependencies: Vec<Vec<usize>>,
1213    buffers: Vec<Buffer>,
1214    ast_to_buffer: HashMap<u64, usize>,
1215    output_buffer_indices: Vec<usize>,
1216    device: DeviceSpec,
1217    alias_ids: Vec<u64>,
1218}
1219
1220impl ExecutionPlanBuilder {
1221    /// Create a new builder.
1222    pub fn new(device: DeviceSpec) -> Self {
1223        Self {
1224            ops: Vec::new(),
1225            op_instance_dependencies: Vec::new(),
1226            buffers: Vec::new(),
1227            ast_to_buffer: HashMap::new(),
1228            output_buffer_indices: Vec::new(),
1229            device,
1230            alias_ids: Vec::new(),
1231        }
1232    }
1233
1234    /// Add alias IDs that need cleanup.
1235    pub fn add_alias_ids(&mut self, ids: impl IntoIterator<Item = u64>) {
1236        self.alias_ids.extend(ids);
1237    }
1238
1239    /// Add a buffer to the plan. Returns the buffer index.
1240    pub fn add_buffer(&mut self, ast_id: u64, buffer: Buffer) -> usize {
1241        let idx = self.buffers.len();
1242        self.buffers.push(buffer);
1243        self.ast_to_buffer.insert(ast_id, idx);
1244        idx
1245    }
1246
1247    /// Map an additional AST/buffer UOp ID to an existing buffer index.
1248    pub fn map_buffer(&mut self, ast_id: u64, idx: usize) {
1249        self.ast_to_buffer.insert(ast_id, idx);
1250    }
1251
1252    /// Replace a buffer at the given index (for BUFFER_VIEW sub-buffer views).
1253    pub fn replace_buffer(&mut self, idx: usize, buffer: Buffer) {
1254        self.buffers[idx] = buffer;
1255    }
1256
1257    /// Set single output buffer index.
1258    pub fn set_output_buffer(&mut self, idx: usize) {
1259        self.output_buffer_indices = vec![idx];
1260    }
1261
1262    /// Set multiple output buffer indices (batch scheduling).
1263    pub fn set_output_buffers(&mut self, indices: Vec<usize>) {
1264        self.output_buffer_indices = indices;
1265    }
1266
1267    /// Compatibility helper: add a compiled kernel as a prepared operation.
1268    ///
1269    /// The canonical builder path is `add_op(PreparedOp::...)`.
1270    pub fn add_kernel(&mut self, kernel: PreparedKernel) {
1271        self.add_op(PreparedOp::CompiledProgram(kernel));
1272    }
1273
1274    /// Add a prepared operation in schedule order.
1275    pub fn add_op(&mut self, op: PreparedOp) {
1276        self.add_op_with_instance_dependencies(op, Vec::new());
1277    }
1278
1279    /// Add a prepared operation with concrete op-index dependencies.
1280    pub fn add_op_with_instance_dependencies(&mut self, op: PreparedOp, instance_dependencies: Vec<usize>) {
1281        self.ops.push(op);
1282        self.op_instance_dependencies.push(instance_dependencies);
1283    }
1284
1285    /// Build the ExecutionPlan.
1286    ///
1287    /// Finalizes by computing pre-allocated buffer pointers and buffer IDs
1288    /// for zero-allocation execution.
1289    pub fn build(mut self) -> Result<ExecutionPlan> {
1290        for op in &mut self.ops {
1291            let PreparedOp::CompiledProgram(kernel) = op else {
1292                continue;
1293            };
1294
1295            if kernel.output_indices.is_empty() {
1296                return Err(crate::error::Error::Execution {
1297                    reason: format!("CompiledProgram {} has no output indices", kernel.id),
1298                });
1299            }
1300            for &out_idx in &kernel.output_indices {
1301                if out_idx >= kernel.buffer_indices.len() {
1302                    return Err(crate::error::Error::Execution {
1303                        reason: format!(
1304                            "CompiledProgram {} output index out of range: output_idx={}, kernel_buffers={}",
1305                            kernel.id,
1306                            out_idx,
1307                            kernel.buffer_indices.len()
1308                        ),
1309                    });
1310                }
1311            }
1312
1313            let mut buffer_ptrs = Vec::with_capacity(kernel.buffer_indices.len());
1314            let mut buffer_ids = Vec::with_capacity(kernel.buffer_indices.len());
1315
1316            for &idx in &kernel.buffer_indices {
1317                let Some(buffer) = self.buffers.get(idx) else {
1318                    return Err(crate::error::Error::Execution {
1319                        reason: format!(
1320                            "CompiledProgram {} buffer index out of range: idx={}, total_buffers={}",
1321                            kernel.id,
1322                            idx,
1323                            self.buffers.len()
1324                        ),
1325                    });
1326                };
1327                buffer_ptrs.push(unsafe { buffer.as_raw_ptr() } as usize);
1328                buffer_ids.push(buffer.id());
1329            }
1330
1331            kernel.buffer_ptrs = buffer_ptrs;
1332            kernel.buffer_ids = buffer_ids;
1333        }
1334
1335        if self.output_buffer_indices.is_empty() && !self.buffers.is_empty() {
1336            return Err(crate::error::Error::Execution {
1337                reason: "execution plan output buffers must be set explicitly".to_string(),
1338            });
1339        }
1340
1341        let op_order = compute_mixed_op_order_with_instance_dependencies(&self.ops, &self.op_instance_dependencies)?;
1342        let op_levels = compute_execution_levels_with_instance_dependencies(&self.ops, &self.op_instance_dependencies)?;
1343
1344        Ok(ExecutionPlan {
1345            ops: self.ops,
1346            op_instance_dependencies: self.op_instance_dependencies,
1347            op_order,
1348            op_levels,
1349            buffers: self.buffers,
1350            ast_to_buffer: self.ast_to_buffer,
1351            output_buffer_indices: self.output_buffer_indices,
1352            device: self.device,
1353            runtime_var_vals: HashMap::new(),
1354            alias_ids: self.alias_ids,
1355        })
1356    }
1357}
1358
1359#[cfg(test)]
1360#[path = "test/unit/execution_plan.rs"]
1361mod tests;