1use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15
16use indexmap::IndexMap;
17use morok_ir::{AxisType, Op, SInt, UOp, UOpKey};
18use smallvec::SmallVec;
19use tracing::{debug, trace};
20
21#[derive(Debug, Clone)]
27pub struct PcontigConfig {
28 pub level: u8,
30 pub max_buffers_threshold: usize,
32 pub out_in_ratio_threshold: f64,
34}
35
36impl Default for PcontigConfig {
37 fn default() -> Self {
38 Self { level: 2, max_buffers_threshold: 3, out_in_ratio_threshold: 10.0 }
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub struct SplitReduceOpConfig {
45 pub split_threshold: usize,
47 pub output_size_bits: u32,
49 pub max_divisor: usize,
51 pub min_divisor: usize,
53 pub enabled: bool,
55}
56
57impl Default for SplitReduceOpConfig {
58 fn default() -> Self {
59 Self { split_threshold: 32768, output_size_bits: 22, max_divisor: 256, min_divisor: 8, enabled: true }
60 }
61}
62
63impl SplitReduceOpConfig {
64 pub fn max_output_size(&self) -> usize {
65 2_usize.pow(self.output_size_bits)
66 }
67}
68
69#[derive(Clone)]
79pub struct KernelContext {
80 pub global_counter: usize,
81 pub local_counter: usize,
82 pub buffer_map: HashMap<UOpKey, Arc<UOp>>,
83 pub vars: HashMap<String, (Arc<UOp>, Option<i64>)>,
87 pub range_counter: usize,
88}
89
90impl KernelContext {
91 pub fn new() -> Self {
92 Self { global_counter: 0, local_counter: 0, buffer_map: HashMap::new(), vars: HashMap::new(), range_counter: 0 }
93 }
94
95 pub fn next_global(&mut self) -> usize {
96 let id = self.global_counter;
97 self.global_counter += 1;
98 id
99 }
100
101 pub fn next_local(&mut self) -> usize {
102 let id = self.local_counter;
103 self.local_counter += 1;
104 id
105 }
106
107 pub fn next_range(&mut self) -> usize {
108 let id = self.range_counter;
109 self.range_counter += 1;
110 id
111 }
112
113 pub fn has_buffer(&self, buf: &Arc<UOp>) -> bool {
114 self.buffer_map.contains_key(&UOpKey(buf.clone()))
115 }
116
117 pub fn get_buffer(&self, buf: &Arc<UOp>) -> Option<&Arc<UOp>> {
118 self.buffer_map.get(&UOpKey(buf.clone()))
119 }
120
121 pub fn map_buffer(&mut self, original: Arc<UOp>, replacement: Arc<UOp>) {
122 self.buffer_map.insert(UOpKey(original), replacement);
123 }
124
125 pub fn add_var(&mut self, var: Arc<UOp>, value: Option<i64>) {
127 if let Op::DefineVar { name, .. } = var.op() {
128 self.vars.insert(name.clone(), (var, value));
129 }
130 }
131}
132
133impl Default for KernelContext {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139#[derive(Default)]
153pub struct LocalAddBufferContext {
154 pub param_slot: usize,
156 pub map: IndexMap<UOpKey, Arc<UOp>>,
158 pub vars: HashMap<String, (Arc<UOp>, Option<i64>)>,
160 pub range: usize,
162 pub opts: Vec<morok_ir::ContiguousHint>,
164}
165
166impl LocalAddBufferContext {
167 pub fn new() -> Self {
168 Self::default()
169 }
170
171 pub fn next_param_slot(&mut self) -> usize {
173 let id = self.param_slot;
174 self.param_slot += 1;
175 id
176 }
177
178 pub fn next_range(&mut self) -> usize {
180 let id = self.range;
181 self.range += 1;
182 id
183 }
184
185 pub fn add_var(&mut self, var: Arc<UOp>, value: Option<i64>) {
187 if let Op::DefineVar { name, .. } = var.op() {
188 self.vars.insert(name.clone(), (var, value));
189 }
190 }
191
192 pub fn map_buffer(&mut self, buf: Arc<UOp>, after: Arc<UOp>) {
194 self.map.insert(UOpKey(buf), after);
195 }
196
197 pub fn has_buffer(&self, buf: &Arc<UOp>) -> bool {
199 self.map.contains_key(&UOpKey(buf.clone()))
200 }
201}
202
203#[derive(Debug, Clone)]
215pub struct KernelAstMarker;
216
217fn extract_stored_value(ret: &Arc<UOp>) -> &Arc<UOp> {
222 match ret.op() {
223 Op::Store { value, .. } => value,
224 Op::End { computation, .. } => match computation.op() {
225 Op::Store { value, .. } => value,
226 _ => ret,
227 },
228 _ => ret,
229 }
230}
231
232pub fn split_store(_ctx: &mut Vec<Arc<UOp>>, x: &Arc<UOp>) -> Option<Arc<UOp>> {
237 use super::patterns::{local_to_param_patterns, rangeify_codegen_patterns};
238 use crate::rewrite::graph_rewrite_bottom_up;
239
240 trace!(uop_id = x.id, op = ?std::mem::discriminant(x.op()), "split_store: entering");
241
242 #[allow(clippy::mutable_key_type)] let in_scope = x.in_scope_ranges();
245 let has_non_outer =
246 in_scope.iter().any(|r| matches!(r.0.op(), Op::Range { axis_type, .. } if *axis_type != AxisType::Outer));
247 if has_non_outer {
248 return None;
249 }
250
251 if let Op::End { ranges, .. } = x.op()
254 && let Some(r) = ranges.first()
255 && matches!(r.op(), Op::Range { axis_type: AxisType::Outer, .. })
256 {
257 return None;
258 }
259
260 let is_valid = match x.op() {
262 Op::Store { .. } => true,
263 Op::End { computation, .. } => matches!(computation.op(), Op::Store { .. }),
264 _ => false,
265 };
266 if !is_valid {
267 return None;
268 }
269
270 let mut lctx = LocalAddBufferContext::new();
272
273 let ret = {
280 use std::sync::LazyLock;
281 static PM_CTX_DEP: LazyLock<crate::TypedPatternMatcher<LocalAddBufferContext>> =
282 LazyLock::new(|| local_to_param_patterns() + rangeify_codegen_patterns());
283 graph_rewrite_bottom_up(&*PM_CTX_DEP, x.clone(), &mut lctx)
284 };
285
286 let stored = extract_stored_value(&ret);
289 let ast = if matches!(stored.op(), Op::Copy { .. } | Op::BufferView { .. }) {
290 stored.clone()
291 } else {
292 UOp::sink(vec![ret]).with_metadata(KernelAstMarker)
295 };
296
297 let sources: SmallVec<[Arc<UOp>; 4]> =
300 lctx.map.values().cloned().chain(lctx.vars.values().map(|(uop, _)| uop.clone())).collect();
301
302 let kernel = UOp::kernel(sources.clone(), ast.clone());
303 debug!(
304 kernel_id = kernel.id,
305 num_sources = sources.len(),
306 map_size = lctx.map.len(),
307 vars_size = lctx.vars.len(),
308 "split_store: created kernel"
309 );
310
311 Some(kernel)
312}
313
314fn fix_assign(root: &Arc<UOp>) -> Arc<UOp> {
322 let mut kernel_assign: HashMap<u64, Arc<UOp>> = HashMap::new();
324 #[allow(clippy::mutable_key_type)]
325 let mut assign_rep: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
326
327 for u in root.toposort() {
328 let Op::After { passthrough, deps } = u.op() else {
329 continue;
330 };
331
332 let buf_id = passthrough.buf_uop().id;
334 kernel_assign.insert(buf_id, u.clone());
335
336 let Some(kernel) = deps.iter().find(|d| matches!(d.op(), Op::Kernel { .. })).cloned() else {
338 continue;
339 };
340
341 let Op::Kernel { sources, .. } = kernel.op() else {
342 continue;
343 };
344
345 for s in sources {
346 if !matches!(s.op(), Op::Buffer { .. } | Op::Param { .. }) {
348 continue;
349 }
350 let s_buf_id = s.buf_uop().id;
351 if s_buf_id == buf_id {
352 continue;
353 }
354 let Some(a) = kernel_assign.get(&s_buf_id) else {
355 continue;
356 };
357
358 if let Op::After { deps: a_deps, .. } = a.op()
362 && a_deps.iter().any(|ad| deps.iter().any(|ud| Arc::ptr_eq(ad, ud)))
363 {
364 continue;
365 }
366
367 if u.any_in_subtree(|x| matches!(x.op(), Op::After { .. }) && x.buf_uop().id == s_buf_id) {
369 panic!(
370 "cycle detected in graph: kernel for buffer {} reads buffer {} which has AFTER in its tree",
371 buf_id, s_buf_id
372 );
373 }
374
375 if let Op::After { passthrough: a_passthrough, deps: a_deps } = a.op() {
377 let mut new_deps = a_deps.clone();
378 new_deps.push(u.clone());
379 let new_a = a_passthrough.after(new_deps);
380 assign_rep.insert(UOpKey(a.clone()), new_a.clone());
381 kernel_assign.insert(s_buf_id, new_a);
382 }
383 }
384 }
385
386 if assign_rep.is_empty() { root.clone() } else { root.substitute(&assign_rep) }
387}
388
389pub fn run_kernel_split_pipeline(root: Arc<UOp>) -> (Arc<UOp>, KernelContext) {
401 use super::transforms::pm_add_buffers_patterns;
402 use crate::rewrite::graph_rewrite_bottom_up;
403
404 let mut ctx = KernelContext::new();
405
406 let t_stage = std::time::Instant::now();
408 let after_buffers = {
409 use morok_ir::op::pattern_derived::OpKey;
410 use morok_ir::pattern::RewriteResult;
411 let mut matcher = pm_add_buffers_patterns();
412 matcher.add(&[OpKey::Sink], |node, _ctx| {
415 if node.metadata::<KernelAstMarker>().is_some() {
416 RewriteResult::Gate(node.clone())
417 } else {
418 RewriteResult::NoMatch
419 }
420 });
421 graph_rewrite_bottom_up(&matcher, root, &mut ctx)
422 };
423 tracing::debug!(elapsed_ms = t_stage.elapsed().as_millis() as u64, "kernel split: pm_add_buffers complete");
424
425 trace!(tree = %after_buffers.tree_full(), "after pm_add_buffers");
426
427 let t_stage = std::time::Instant::now();
433 let after_ctx_free = graph_rewrite_bottom_up(super::transforms::pm_flatten_range(), after_buffers, &mut ());
434 tracing::debug!(
435 elapsed_ms = t_stage.elapsed().as_millis() as u64,
436 "kernel split: pm_flatten_range pre-pass complete"
437 );
438
439 let t_stage = std::time::Instant::now();
440 let after_split = split_all_stores(&after_ctx_free);
441 tracing::debug!(elapsed_ms = t_stage.elapsed().as_millis() as u64, "kernel split: split_all_stores complete");
442
443 let t_stage = std::time::Instant::now();
444 let result = fix_assign(&after_split);
445 tracing::debug!(elapsed_ms = t_stage.elapsed().as_millis() as u64, "kernel split: fix_assign complete");
446
447 (result, ctx)
448}
449
450fn split_all_stores(root: &Arc<UOp>) -> Arc<UOp> {
459 use morok_ir::op::pattern_derived::OpKey;
460 use morok_ir::pattern::RewriteResult;
461 use morok_ir::rewrite::graph_rewrite_bottom_up;
462
463 let mut matcher = crate::patterns! {
465 @context Vec<Arc<UOp>>;
466 node @ Store { index: _, value: _ } => |node, ctx| split_store(ctx, node),
467 node @ End { computation, .. }
468 if matches!(computation.op(), Op::Store { .. } | Op::End { .. })
469 => |node, ctx| split_store(ctx, node),
470 };
471 matcher.add(&[OpKey::Sink], |node, _ctx| {
474 if node.metadata::<KernelAstMarker>().is_some() {
475 RewriteResult::Gate(node.clone())
476 } else {
477 RewriteResult::NoMatch
478 }
479 });
480
481 let mut ctx = Vec::new();
482 graph_rewrite_bottom_up(&matcher, root.clone(), &mut ctx)
483}
484
485pub fn collect_range_ids(indexed: &Arc<UOp>) -> Vec<usize> {
491 let mut range_ids: Vec<usize> = indexed
492 .toposort()
493 .into_iter()
494 .filter_map(|node| if let Op::Range { axis_id, .. } = node.op() { Some(axis_id.value()) } else { None })
495 .collect();
496
497 range_ids.sort_unstable();
498 range_ids.dedup();
499 range_ids
500}
501
502#[derive(Debug, Clone)]
503struct SplitCandidate {
504 dimension: usize,
505 divisor: usize,
506 #[allow(dead_code)]
507 output_size: usize,
508}
509
510fn detect_expanded_dimensions(source: &Arc<UOp>, input_shape: &[SInt]) -> Vec<bool> {
511 let ranges: Vec<Arc<UOp>> = input_shape
512 .iter()
513 .enumerate()
514 .map(|(axis_id, dim)| match dim {
515 SInt::Const(n) if *n > 1 => {
516 let end = UOp::index_const(*n as i64);
517 UOp::range_axis(end, morok_ir::AxisId::Unrenumbered(axis_id), morok_ir::AxisType::Loop)
518 }
519 _ => UOp::index_const(0),
520 })
521 .collect();
522
523 let indexed = match UOp::index().buffer(Arc::clone(source)).indices(ranges).call() {
524 Ok(idx) => idx,
525 Err(_) => return vec![false; input_shape.len()],
526 };
527
528 let base = source.base();
529 let noop = UOp::noop();
530 #[allow(clippy::mutable_key_type)]
531 let mut substitutions = HashMap::new();
532 substitutions.insert(UOpKey(base), noop);
533
534 let substituted = indexed.substitute(&substitutions);
535
536 use super::patterns::{movement_op_patterns, pm_syntactic_sugar};
537 use crate::rewrite::graph_rewrite_bottom_up;
538
539 use std::sync::LazyLock;
541 static PM_MOPS: LazyLock<crate::TypedPatternMatcher> =
542 LazyLock::new(|| movement_op_patterns() + pm_syntactic_sugar());
543 let transformed = graph_rewrite_bottom_up(&*PM_MOPS, substituted, &mut ());
544
545 let surviving_range_ids = collect_range_ids(&transformed);
546 let surviving_set: HashSet<usize> = surviving_range_ids.into_iter().collect();
547
548 input_shape.iter().enumerate().map(|(axis_id, _)| !surviving_set.contains(&axis_id)).collect()
549}
550
551fn find_split_candidates(
552 reduce: &Arc<UOp>,
553 input_shape: &[SInt],
554 is_expanded: &[bool],
555 config: &SplitReduceOpConfig,
556) -> Vec<SplitCandidate> {
557 let Op::ReduceAxis { axes: reduce_axes, .. } = reduce.op() else {
558 return vec![];
559 };
560
561 let output_shape = match reduce.shape() {
562 Ok(Some(shape)) => shape,
563 _ => return vec![],
564 };
565
566 let output_size: usize = output_shape.iter().filter_map(|s| s.as_const()).product();
567
568 let mut candidates = Vec::new();
569
570 for &axis in reduce_axes {
571 if axis >= is_expanded.len() || is_expanded[axis] {
572 continue;
573 }
574
575 let dim_size = match &input_shape[axis] {
576 SInt::Const(n) => *n,
577 SInt::Symbolic(_) | SInt::Infer => continue,
578 };
579
580 for divisor in (config.min_divisor..=config.max_divisor).rev() {
581 if dim_size % divisor != 0 {
582 continue;
583 }
584
585 let new_output_size = output_size * divisor;
586
587 if new_output_size > config.max_output_size() {
588 continue;
589 }
590
591 candidates.push(SplitCandidate { dimension: axis, divisor, output_size: new_output_size });
592 }
593 }
594
595 candidates
596}
597
598fn apply_split_transformation(
599 source: &Arc<UOp>,
600 reduce: &Arc<UOp>,
601 candidate: &SplitCandidate,
602 input_shape: &[SInt],
603) -> Option<Arc<UOp>> {
604 let Op::ReduceAxis { reduce_op, axes: reduce_axes, .. } = reduce.op() else {
605 return None;
606 };
607
608 let dim_to_split = candidate.dimension;
609 let divisor = candidate.divisor;
610 let dim_size = input_shape[dim_to_split].as_const()?;
611 let remainder = dim_size / divisor;
612
613 let mut splitted_shape: SmallVec<[SInt; 4]> = SmallVec::new();
614 for (i, dim) in input_shape.iter().enumerate() {
615 if i == dim_to_split {
616 splitted_shape.push(SInt::Const(divisor));
617 splitted_shape.push(SInt::Const(remainder));
618 } else {
619 splitted_shape.push(dim.clone());
620 }
621 }
622
623 let reshaped = source.try_reshape(&splitted_shape).ok()?;
624
625 let mut permutation: Vec<usize> = (0..splitted_shape.len()).filter(|&i| i != dim_to_split).collect();
626 permutation.push(dim_to_split);
627
628 let permuted = reshaped.try_permute(permutation.clone()).ok()?;
629
630 let adjusted_axes: Vec<usize> = reduce_axes
631 .iter()
632 .map(|&axis| {
633 if axis < dim_to_split {
634 axis
635 } else if axis == dim_to_split {
636 dim_to_split + 1
637 } else {
638 axis + 1
639 }
640 })
641 .collect();
642
643 let permuted_axes: Vec<usize> =
644 adjusted_axes.iter().map(|&old_axis| permutation.iter().position(|&p| p == old_axis).unwrap()).collect();
645
646 let first_reduce = permuted.try_reduce_axis(*reduce_op, permuted_axes).ok()?;
647
648 let contiguous = first_reduce.contiguous();
649
650 let output_shape = contiguous.shape().ok()??;
651 let split_axis = output_shape.len() - 1;
652
653 let second_reduce = contiguous.try_reduce_axis(*reduce_op, vec![split_axis]).ok()?;
654
655 let final_shape = reduce.shape().ok()??;
656
657 second_reduce.try_reshape(final_shape).ok()
658}
659
660pub fn split_reduceop(reduce: &Arc<UOp>, config: &SplitReduceOpConfig) -> Option<Arc<UOp>> {
662 if !config.enabled {
663 return None;
664 }
665
666 let Op::ReduceAxis { src: source, .. } = reduce.op() else {
667 return None;
668 };
669
670 let input_shape = source.shape().ok()??;
671 let output_shape = reduce.shape().ok()??;
672
673 if !input_shape.iter().all(|s| s.is_const()) {
674 return None;
675 }
676
677 let input_size: usize = input_shape.iter().map(|s| s.as_const().unwrap()).product();
678 let output_size: usize = output_shape.iter().map(|s| s.as_const().unwrap()).product();
679
680 if output_size == 0 {
681 return None;
682 }
683
684 let ratio = input_size / output_size;
685 if ratio < config.split_threshold {
686 return None;
687 }
688
689 let is_expanded = detect_expanded_dimensions(source, input_shape);
690 let candidates = find_split_candidates(reduce, input_shape, &is_expanded, config);
691
692 if candidates.is_empty() {
693 return None;
694 }
695
696 apply_split_transformation(source, reduce, &candidates[0], input_shape)
697}