Skip to main content

morok_schedule/rangeify/
kernel.rs

1//! Consolidated kernel splitting and pipeline orchestration.
2//!
3//! This module contains:
4//! - KernelContext for tracking state during kernel splitting
5//! - split_store for splitting computation at STORE boundaries
6//! - run_kernel_split_pipeline for full pipeline orchestration
7//! - PcontigConfig for partial contiguous buffer removal
8//! - Two-stage reduction splitting (split_reduceop)
9//!
10//! Consolidated from: kernel_context.rs, split_kernel.rs, pipeline.rs,
11//! buffer_cost.rs, split_reduceop.rs
12
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15
16use indexmap::IndexMap;
17use morok_ir::{AxisType, Op, SInt, UOp, UOpKey};
18use smallvec::SmallVec;
19use tracing::{debug, trace};
20
21// ============================================================================
22// CONFIGURATION
23// ============================================================================
24
25/// Configuration for partial contiguous optimization.
26#[derive(Debug, Clone)]
27pub struct PcontigConfig {
28    /// 0=disabled, 1=basic, 2=enabled (default), 3+=aggressive
29    pub level: u8,
30    /// Max buffers before keeping BUFFERIZE (default: 3)
31    pub max_buffers_threshold: usize,
32    /// Max output/input ratio for partial contiguous (default: 10.0)
33    pub out_in_ratio_threshold: f64,
34}
35
36impl Default for PcontigConfig {
37    fn default() -> Self {
38        Self { level: 2, max_buffers_threshold: 3, out_in_ratio_threshold: 10.0 }
39    }
40}
41
42/// Configuration for split_reduceop optimization.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub struct SplitReduceOpConfig {
45    /// Minimum input/output ratio to trigger splitting (default: 32768)
46    pub split_threshold: usize,
47    /// Max output buffer size as 2^N elements (default: 22 = 4M elements)
48    pub output_size_bits: u32,
49    /// Max split divisor (default: 256)
50    pub max_divisor: usize,
51    /// Min split divisor (default: 8)
52    pub min_divisor: usize,
53    /// Enable/disable the optimization (default: true)
54    pub enabled: bool,
55}
56
57impl Default for SplitReduceOpConfig {
58    fn default() -> Self {
59        Self { split_threshold: 32768, output_size_bits: 22, max_divisor: 256, min_divisor: 8, enabled: true }
60    }
61}
62
63impl SplitReduceOpConfig {
64    pub fn max_output_size(&self) -> usize {
65        2_usize.pow(self.output_size_bits)
66    }
67}
68
69// ============================================================================
70// KERNEL CONTEXT
71// ============================================================================
72
73/// Context for tracking state during kernel splitting.
74///
75/// Simplified from original 8 fields to 6 fields, removing Morok-specific
76/// `kernel_deps` and `buffer_id_mapping` that are no longer needed after
77/// aligning with fix_assign approach.
78#[derive(Clone)]
79pub struct KernelContext {
80    pub global_counter: usize,
81    pub local_counter: usize,
82    pub buffer_map: HashMap<UOpKey, Arc<UOp>>,
83    /// Bound variables: maps variable name → (DEFINE_VAR UOp, optional bound value).
84    /// Populated when BIND(DEFINE_VAR, CONST) is stripped during kernel splitting.
85    /// The UOp is kept for kernel sources; the i64 is the concrete bound value (None for OUTER ranges).
86    pub vars: HashMap<String, (Arc<UOp>, Option<i64>)>,
87    pub range_counter: usize,
88}
89
90impl KernelContext {
91    pub fn new() -> Self {
92        Self { global_counter: 0, local_counter: 0, buffer_map: HashMap::new(), vars: HashMap::new(), range_counter: 0 }
93    }
94
95    pub fn next_global(&mut self) -> usize {
96        let id = self.global_counter;
97        self.global_counter += 1;
98        id
99    }
100
101    pub fn next_local(&mut self) -> usize {
102        let id = self.local_counter;
103        self.local_counter += 1;
104        id
105    }
106
107    pub fn next_range(&mut self) -> usize {
108        let id = self.range_counter;
109        self.range_counter += 1;
110        id
111    }
112
113    pub fn has_buffer(&self, buf: &Arc<UOp>) -> bool {
114        self.buffer_map.contains_key(&UOpKey(buf.clone()))
115    }
116
117    pub fn get_buffer(&self, buf: &Arc<UOp>) -> Option<&Arc<UOp>> {
118        self.buffer_map.get(&UOpKey(buf.clone()))
119    }
120
121    pub fn map_buffer(&mut self, original: Arc<UOp>, replacement: Arc<UOp>) {
122        self.buffer_map.insert(UOpKey(original), replacement);
123    }
124
125    /// Track a bound variable with its DEFINE_VAR UOp and concrete value.
126    pub fn add_var(&mut self, var: Arc<UOp>, value: Option<i64>) {
127        if let Op::DefineVar { name, .. } = var.op() {
128            self.vars.insert(name.clone(), (var, value));
129        }
130    }
131}
132
133impl Default for KernelContext {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139// ============================================================================
140// LOCAL ADD BUFFER CONTEXT (Per-Kernel Context)
141// ============================================================================
142
143/// Per-kernel context for tracking state during kernel splitting.
144///
145/// Based on `LocalAddBufferContext`.
146/// This is used within `split_store` for each individual kernel being created.
147///
148/// IMPORTANT: Uses IndexMap for `map` to maintain insertion order.
149/// This is critical because PARAM slot indices are assigned in the order
150/// patterns match, and kernel sources must be in the same order for correct
151/// buffer indexing during execution.
152#[derive(Default)]
153pub struct LocalAddBufferContext {
154    /// PARAM slot counter (`dg`)
155    pub param_slot: usize,
156    /// Buffer → AFTER mapping (IndexMap maintains insertion order)
157    pub map: IndexMap<UOpKey, Arc<UOp>>,
158    /// Bound variables: name → (DEFINE_VAR UOp, optional bound value).
159    pub vars: HashMap<String, (Arc<UOp>, Option<i64>)>,
160    /// Range renumber counter
161    pub range: usize,
162    /// Optimization hints extracted from CONTIGUOUS.opts (ctx.opts)
163    pub opts: Vec<morok_ir::ContiguousHint>,
164}
165
166impl LocalAddBufferContext {
167    pub fn new() -> Self {
168        Self::default()
169    }
170
171    /// Get next PARAM slot index (`ctx.dg`).
172    pub fn next_param_slot(&mut self) -> usize {
173        let id = self.param_slot;
174        self.param_slot += 1;
175        id
176    }
177
178    /// Get next range renumber index.
179    pub fn next_range(&mut self) -> usize {
180        let id = self.range;
181        self.range += 1;
182        id
183    }
184
185    /// Track a bound variable with its DEFINE_VAR UOp and concrete value.
186    pub fn add_var(&mut self, var: Arc<UOp>, value: Option<i64>) {
187        if let Op::DefineVar { name, .. } = var.op() {
188            self.vars.insert(name.clone(), (var, value));
189        }
190    }
191
192    /// Map a buffer to its AFTER wrapper.
193    pub fn map_buffer(&mut self, buf: Arc<UOp>, after: Arc<UOp>) {
194        self.map.insert(UOpKey(buf), after);
195    }
196
197    /// Check if buffer is already mapped.
198    pub fn has_buffer(&self, buf: &Arc<UOp>) -> bool {
199        self.map.contains_key(&UOpKey(buf.clone()))
200    }
201}
202
203// ============================================================================
204// SPLIT KERNEL
205// ============================================================================
206
207/// Marker metadata for kernel AST SINKs.
208///
209/// Matches `KernelInfo` arg on SINK nodes:
210///   `ret = ret.sink(arg=KernelInfo(...))`
211///
212/// The gate (`pm_gate_kernel_sink`) checks for this marker to skip
213/// already-formed kernel ASTs during bottom-up traversal.
214#[derive(Debug, Clone)]
215pub struct KernelAstMarker;
216
217/// Extract the stored value from a STORE/END(STORE) structure.
218///
219/// Used to check if the stored value is COPY/BUFFER_VIEW without traversing
220/// the entire subgraph.
221fn extract_stored_value(ret: &Arc<UOp>) -> &Arc<UOp> {
222    match ret.op() {
223        Op::Store { value, .. } => value,
224        Op::End { computation, .. } => match computation.op() {
225            Op::Store { value, .. } => value,
226            _ => ret,
227        },
228        _ => ret,
229    }
230}
231
232/// Split STORE and END operations into individual kernels.
233///
234/// Based on split_store.
235/// Simplified from 280 lines to ~80 lines using LocalAddBufferContext.
236pub fn split_store(_ctx: &mut Vec<Arc<UOp>>, x: &Arc<UOp>) -> Option<Arc<UOp>> {
237    use super::patterns::{local_to_param_patterns, rangeify_codegen_patterns};
238    use crate::rewrite::graph_rewrite_bottom_up;
239
240    trace!(uop_id = x.id, op = ?std::mem::discriminant(x.op()), "split_store: entering");
241
242    // Guard 1: Skip if has non-OUTER ranges
243    #[allow(clippy::mutable_key_type)] // UOp uses Arc<OnceLock> for caching, but keys hash by ID
244    let in_scope = x.in_scope_ranges();
245    let has_non_outer =
246        in_scope.iter().any(|r| matches!(r.0.op(), Op::Range { axis_type, .. } if *axis_type != AxisType::Outer));
247    if has_non_outer {
248        return None;
249    }
250
251    // Guard 2: Skip END where FIRST range is OUTER
252    // `if x.op is Ops.END and x.src[1].arg[0] == AxisType.OUTER: return None`
253    if let Op::End { ranges, .. } = x.op()
254        && let Some(r) = ranges.first()
255        && matches!(r.op(), Op::Range { axis_type: AxisType::Outer, .. })
256    {
257        return None;
258    }
259
260    // Verify operation type (only STORE and END(STORE) are valid)
261    let is_valid = match x.op() {
262        Op::Store { .. } => true,
263        Op::End { computation, .. } => matches!(computation.op(), Op::Store { .. }),
264        _ => false,
265    };
266    if !is_valid {
267        return None;
268    }
269
270    // Per-kernel context (LocalAddBufferContext)
271    let mut lctx = LocalAddBufferContext::new();
272
273    // Context-dependent rewrite per kernel.
274    //
275    // Context-free patterns (movement_op, syntactic_sugar, flatten_range) were already
276    // applied in run_kernel_split_pipeline's pre-pass. Here we only run patterns that
277    // need LocalAddBufferContext (Buffer/Param→codegen PARAM, Bind, After, Range renumber,
278    // NOOP→zero, Contiguous→extract opts).
279    let ret = {
280        use std::sync::LazyLock;
281        static PM_CTX_DEP: LazyLock<crate::TypedPatternMatcher<LocalAddBufferContext>> =
282            LazyLock::new(|| local_to_param_patterns() + rangeify_codegen_patterns());
283        graph_rewrite_bottom_up(&*PM_CTX_DEP, x.clone(), &mut lctx)
284    };
285
286    // Check for COPY/BUFFER_VIEW directly on the stored value.
287    // No graph traversal needed — just walk the STORE/END structure.
288    let stored = extract_stored_value(&ret);
289    let ast = if matches!(stored.op(), Op::Copy { .. } | Op::BufferView { .. }) {
290        stored.clone()
291    } else {
292        // Mark AST SINK with KernelAstMarker — matches `ret.sink(arg=KernelInfo(...))`
293        // The gate (`pm_gate_kernel_sink`) checks for this marker to skip the kernel AST subtree.
294        UOp::sink(vec![ret]).with_metadata(KernelAstMarker)
295    };
296
297    // Build KERNEL from context
298    // Sources: lctx.map.values() (buffer → AFTER mappings) + DEFINE_VAR UOps from vars
299    let sources: SmallVec<[Arc<UOp>; 4]> =
300        lctx.map.values().cloned().chain(lctx.vars.values().map(|(uop, _)| uop.clone())).collect();
301
302    let kernel = UOp::kernel(sources.clone(), ast.clone());
303    debug!(
304        kernel_id = kernel.id,
305        num_sources = sources.len(),
306        map_size = lctx.map.len(),
307        vars_size = lctx.vars.len(),
308        "split_store: created kernel"
309    );
310
311    Some(kernel)
312}
313
314/// Fix inter-kernel dependencies (like fix_assign).
315///
316/// Based on upstream.
317/// When kernel B reads from a buffer that kernel A writes to, this function
318/// ensures kernel B's AFTER node depends on kernel A's AFTER node.
319///
320/// Uses buf_uop() to walk through AFTER chains and get underlying buffer IDs.
321fn fix_assign(root: &Arc<UOp>) -> Arc<UOp> {
322    // Map buf_uop().id -> AFTER node that produces it
323    let mut kernel_assign: HashMap<u64, Arc<UOp>> = HashMap::new();
324    #[allow(clippy::mutable_key_type)]
325    let mut assign_rep: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
326
327    for u in root.toposort() {
328        let Op::After { passthrough, deps } = u.op() else {
329            continue;
330        };
331
332        // Use buf_uop() to get underlying buffer ID (handles AFTER chains)
333        let buf_id = passthrough.buf_uop().id;
334        kernel_assign.insert(buf_id, u.clone());
335
336        // Get kernel from deps
337        let Some(kernel) = deps.iter().find(|d| matches!(d.op(), Op::Kernel { .. })).cloned() else {
338            continue;
339        };
340
341        let Op::Kernel { sources, .. } = kernel.op() else {
342            continue;
343        };
344
345        for s in sources {
346            // Check kernel sources for buffer dependencies
347            if !matches!(s.op(), Op::Buffer { .. } | Op::Param { .. }) {
348                continue;
349            }
350            let s_buf_id = s.buf_uop().id;
351            if s_buf_id == buf_id {
352                continue;
353            }
354            let Some(a) = kernel_assign.get(&s_buf_id) else {
355                continue;
356            };
357
358            // Same-kernel check (a.src[1] is u.src[1])
359            // Skip if both AFTERs belong to the same kernel — avoids spurious WAR deps
360            // between outputs of the same multi-output kernel.
361            if let Op::After { deps: a_deps, .. } = a.op()
362                && a_deps.iter().any(|ad| deps.iter().any(|ud| Arc::ptr_eq(ad, ud)))
363            {
364                continue;
365            }
366
367            // Cycle detection
368            if u.any_in_subtree(|x| matches!(x.op(), Op::After { .. }) && x.buf_uop().id == s_buf_id) {
369                panic!(
370                    "cycle detected in graph: kernel for buffer {} reads buffer {} which has AFTER in its tree",
371                    buf_id, s_buf_id
372                );
373            }
374
375            // Add dependency: a.replace(src=a.src+(u,))
376            if let Op::After { passthrough: a_passthrough, deps: a_deps } = a.op() {
377                let mut new_deps = a_deps.clone();
378                new_deps.push(u.clone());
379                let new_a = a_passthrough.after(new_deps);
380                assign_rep.insert(UOpKey(a.clone()), new_a.clone());
381                kernel_assign.insert(s_buf_id, new_a);
382            }
383        }
384    }
385
386    if assign_rep.is_empty() { root.clone() } else { root.substitute(&assign_rep) }
387}
388
389// ============================================================================
390// PIPELINE
391// ============================================================================
392
393/// Run the kernel splitting pipeline.
394///
395/// Based on get_rangeify_map.
396/// Simplified from ~200 lines to ~40 lines.
397///
398/// # Returns
399/// Returns `(result, KernelContext)` tuple for backward compatibility with 30+ callers.
400pub fn run_kernel_split_pipeline(root: Arc<UOp>) -> (Arc<UOp>, KernelContext) {
401    use super::transforms::pm_add_buffers_patterns;
402    use crate::rewrite::graph_rewrite_bottom_up;
403
404    let mut ctx = KernelContext::new();
405
406    // PASS 1: bufferize → store (pm_gate_kernel_sink + pm_add_buffers + pm_add_range_tags, bottom_up=True)
407    let t_stage = std::time::Instant::now();
408    let after_buffers = {
409        use morok_ir::op::pattern_derived::OpKey;
410        use morok_ir::pattern::RewriteResult;
411        let mut matcher = pm_add_buffers_patterns();
412        // Gate on SINK with KernelAstMarker to skip already-formed kernel ASTs
413        // (pm_gate_kernel_sink gates on SINK with KernelInfo arg)
414        matcher.add(&[OpKey::Sink], |node, _ctx| {
415            if node.metadata::<KernelAstMarker>().is_some() {
416                RewriteResult::Gate(node.clone())
417            } else {
418                RewriteResult::NoMatch
419            }
420        });
421        graph_rewrite_bottom_up(&matcher, root, &mut ctx)
422    };
423    tracing::debug!(elapsed_ms = t_stage.elapsed().as_millis() as u64, "kernel split: pm_add_buffers complete");
424
425    trace!(tree = %after_buffers.tree_full(), "after pm_add_buffers");
426
427    // Pre-run pm_flatten_range on the FULL graph ONCE before kernel splitting.
428    //
429    // split_store includes pm_flatten_range but NOT pm_mops/pm_syntactic_sugar
430    // (those were already applied in earlier pipeline stages). Running flatten_range once
431    // on the full graph avoids redundant per-kernel traversals on overlapping subgraphs.
432    let t_stage = std::time::Instant::now();
433    let after_ctx_free = graph_rewrite_bottom_up(super::transforms::pm_flatten_range(), after_buffers, &mut ());
434    tracing::debug!(
435        elapsed_ms = t_stage.elapsed().as_millis() as u64,
436        "kernel split: pm_flatten_range pre-pass complete"
437    );
438
439    let t_stage = std::time::Instant::now();
440    let after_split = split_all_stores(&after_ctx_free);
441    tracing::debug!(elapsed_ms = t_stage.elapsed().as_millis() as u64, "kernel split: split_all_stores complete");
442
443    let t_stage = std::time::Instant::now();
444    let result = fix_assign(&after_split);
445    tracing::debug!(elapsed_ms = t_stage.elapsed().as_millis() as u64, "kernel split: fix_assign complete");
446
447    (result, ctx)
448}
449
450/// Split all STORE/END operations into KERNELs.
451///
452/// Matches upstream:
453///   `graph_rewrite(tsink, pm_gate_kernel_sink + split_kernels, bottom_up=True)`
454///
455/// All patterns run in bpm (Stage 0, see ORIGINAL children):
456/// - Gate on KERNEL nodes to prevent descending into already-split subtrees
457/// - STORE/END → KERNEL via split_store
458fn split_all_stores(root: &Arc<UOp>) -> Arc<UOp> {
459    use morok_ir::op::pattern_derived::OpKey;
460    use morok_ir::pattern::RewriteResult;
461    use morok_ir::rewrite::graph_rewrite_bottom_up;
462
463    // Combined gate + split in bpm (pm_gate_kernel_sink + split_kernels, bottom_up=True)
464    let mut matcher = crate::patterns! {
465        @context Vec<Arc<UOp>>;
466        node @ Store { index: _, value: _ } => |node, ctx| split_store(ctx, node),
467        node @ End { computation, .. }
468            if matches!(computation.op(), Op::Store { .. } | Op::End { .. })
469            => |node, ctx| split_store(ctx, node),
470    };
471    // Gate on SINK with KernelAstMarker to skip already-formed kernel ASTs
472    // (pm_gate_kernel_sink gates on SINK with KernelInfo arg)
473    matcher.add(&[OpKey::Sink], |node, _ctx| {
474        if node.metadata::<KernelAstMarker>().is_some() {
475            RewriteResult::Gate(node.clone())
476        } else {
477            RewriteResult::NoMatch
478        }
479    });
480
481    let mut ctx = Vec::new();
482    graph_rewrite_bottom_up(&matcher, root.clone(), &mut ctx)
483}
484
485// ============================================================================
486// SPLIT REDUCEOP
487// ============================================================================
488
489/// Extract all RANGE axis IDs from a UOp tree.
490pub fn collect_range_ids(indexed: &Arc<UOp>) -> Vec<usize> {
491    let mut range_ids: Vec<usize> = indexed
492        .toposort()
493        .into_iter()
494        .filter_map(|node| if let Op::Range { axis_id, .. } = node.op() { Some(axis_id.value()) } else { None })
495        .collect();
496
497    range_ids.sort_unstable();
498    range_ids.dedup();
499    range_ids
500}
501
502#[derive(Debug, Clone)]
503struct SplitCandidate {
504    dimension: usize,
505    divisor: usize,
506    #[allow(dead_code)]
507    output_size: usize,
508}
509
510fn detect_expanded_dimensions(source: &Arc<UOp>, input_shape: &[SInt]) -> Vec<bool> {
511    let ranges: Vec<Arc<UOp>> = input_shape
512        .iter()
513        .enumerate()
514        .map(|(axis_id, dim)| match dim {
515            SInt::Const(n) if *n > 1 => {
516                let end = UOp::index_const(*n as i64);
517                UOp::range_axis(end, morok_ir::AxisId::Unrenumbered(axis_id), morok_ir::AxisType::Loop)
518            }
519            _ => UOp::index_const(0),
520        })
521        .collect();
522
523    let indexed = match UOp::index().buffer(Arc::clone(source)).indices(ranges).call() {
524        Ok(idx) => idx,
525        Err(_) => return vec![false; input_shape.len()],
526    };
527
528    let base = source.base();
529    let noop = UOp::noop();
530    #[allow(clippy::mutable_key_type)]
531    let mut substitutions = HashMap::new();
532    substitutions.insert(UOpKey(base), noop);
533
534    let substituted = indexed.substitute(&substitutions);
535
536    use super::patterns::{movement_op_patterns, pm_syntactic_sugar};
537    use crate::rewrite::graph_rewrite_bottom_up;
538
539    // pm_mops + pm_syntactic_sugar (early movement ops, bottom_up=True)
540    use std::sync::LazyLock;
541    static PM_MOPS: LazyLock<crate::TypedPatternMatcher> =
542        LazyLock::new(|| movement_op_patterns() + pm_syntactic_sugar());
543    let transformed = graph_rewrite_bottom_up(&*PM_MOPS, substituted, &mut ());
544
545    let surviving_range_ids = collect_range_ids(&transformed);
546    let surviving_set: HashSet<usize> = surviving_range_ids.into_iter().collect();
547
548    input_shape.iter().enumerate().map(|(axis_id, _)| !surviving_set.contains(&axis_id)).collect()
549}
550
551fn find_split_candidates(
552    reduce: &Arc<UOp>,
553    input_shape: &[SInt],
554    is_expanded: &[bool],
555    config: &SplitReduceOpConfig,
556) -> Vec<SplitCandidate> {
557    let Op::ReduceAxis { axes: reduce_axes, .. } = reduce.op() else {
558        return vec![];
559    };
560
561    let output_shape = match reduce.shape() {
562        Ok(Some(shape)) => shape,
563        _ => return vec![],
564    };
565
566    let output_size: usize = output_shape.iter().filter_map(|s| s.as_const()).product();
567
568    let mut candidates = Vec::new();
569
570    for &axis in reduce_axes {
571        if axis >= is_expanded.len() || is_expanded[axis] {
572            continue;
573        }
574
575        let dim_size = match &input_shape[axis] {
576            SInt::Const(n) => *n,
577            SInt::Symbolic(_) | SInt::Infer => continue,
578        };
579
580        for divisor in (config.min_divisor..=config.max_divisor).rev() {
581            if dim_size % divisor != 0 {
582                continue;
583            }
584
585            let new_output_size = output_size * divisor;
586
587            if new_output_size > config.max_output_size() {
588                continue;
589            }
590
591            candidates.push(SplitCandidate { dimension: axis, divisor, output_size: new_output_size });
592        }
593    }
594
595    candidates
596}
597
598fn apply_split_transformation(
599    source: &Arc<UOp>,
600    reduce: &Arc<UOp>,
601    candidate: &SplitCandidate,
602    input_shape: &[SInt],
603) -> Option<Arc<UOp>> {
604    let Op::ReduceAxis { reduce_op, axes: reduce_axes, .. } = reduce.op() else {
605        return None;
606    };
607
608    let dim_to_split = candidate.dimension;
609    let divisor = candidate.divisor;
610    let dim_size = input_shape[dim_to_split].as_const()?;
611    let remainder = dim_size / divisor;
612
613    let mut splitted_shape: SmallVec<[SInt; 4]> = SmallVec::new();
614    for (i, dim) in input_shape.iter().enumerate() {
615        if i == dim_to_split {
616            splitted_shape.push(SInt::Const(divisor));
617            splitted_shape.push(SInt::Const(remainder));
618        } else {
619            splitted_shape.push(dim.clone());
620        }
621    }
622
623    let reshaped = source.try_reshape(&splitted_shape).ok()?;
624
625    let mut permutation: Vec<usize> = (0..splitted_shape.len()).filter(|&i| i != dim_to_split).collect();
626    permutation.push(dim_to_split);
627
628    let permuted = reshaped.try_permute(permutation.clone()).ok()?;
629
630    let adjusted_axes: Vec<usize> = reduce_axes
631        .iter()
632        .map(|&axis| {
633            if axis < dim_to_split {
634                axis
635            } else if axis == dim_to_split {
636                dim_to_split + 1
637            } else {
638                axis + 1
639            }
640        })
641        .collect();
642
643    let permuted_axes: Vec<usize> =
644        adjusted_axes.iter().map(|&old_axis| permutation.iter().position(|&p| p == old_axis).unwrap()).collect();
645
646    let first_reduce = permuted.try_reduce_axis(*reduce_op, permuted_axes).ok()?;
647
648    let contiguous = first_reduce.contiguous();
649
650    let output_shape = contiguous.shape().ok()??;
651    let split_axis = output_shape.len() - 1;
652
653    let second_reduce = contiguous.try_reduce_axis(*reduce_op, vec![split_axis]).ok()?;
654
655    let final_shape = reduce.shape().ok()??;
656
657    second_reduce.try_reshape(final_shape).ok()
658}
659
660/// Split large REDUCE_AXIS into two stages.
661pub fn split_reduceop(reduce: &Arc<UOp>, config: &SplitReduceOpConfig) -> Option<Arc<UOp>> {
662    if !config.enabled {
663        return None;
664    }
665
666    let Op::ReduceAxis { src: source, .. } = reduce.op() else {
667        return None;
668    };
669
670    let input_shape = source.shape().ok()??;
671    let output_shape = reduce.shape().ok()??;
672
673    if !input_shape.iter().all(|s| s.is_const()) {
674        return None;
675    }
676
677    let input_size: usize = input_shape.iter().map(|s| s.as_const().unwrap()).product();
678    let output_size: usize = output_shape.iter().map(|s| s.as_const().unwrap()).product();
679
680    if output_size == 0 {
681        return None;
682    }
683
684    let ratio = input_size / output_size;
685    if ratio < config.split_threshold {
686        return None;
687    }
688
689    let is_expanded = detect_expanded_dimensions(source, input_shape);
690    let candidates = find_split_candidates(reduce, input_shape, &is_expanded, config);
691
692    if candidates.is_empty() {
693        return None;
694    }
695
696    apply_split_transformation(source, reduce, &candidates[0], input_shape)
697}