1use std::collections::{HashMap, HashSet};
15use std::sync::Arc;
16
17use super::context::RangeifyContext;
18use super::indexing::IndexingContext;
19use super::kernel::KernelContext;
20use morok_ir::shape::Shape;
21use morok_ir::{AddrSpace, AxisType, BufferizeOpts, ConstValue, DType, Op, UOp, UOpKey};
22use smallvec::{SmallVec, smallvec};
23
24pub struct AddTagsCtx {
30 pub uop_list: Vec<Arc<UOp>>,
32 excluded: HashSet<UOpKey>,
34}
35
36impl Default for AddTagsCtx {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl AddTagsCtx {
43 pub fn new() -> Self {
44 Self { uop_list: Vec::new(), excluded: HashSet::new() }
45 }
46}
47
48fn should_skip_tag(op: &Op) -> bool {
52 matches!(
53 op,
54 Op::Param { .. }
55 | Op::Const(_)
56 | Op::Device(_)
57 | Op::Unique(_)
58 | Op::DefineVar { .. }
59 | Op::Bind { .. }
60 | Op::End { .. }
61 | Op::Range { .. }
62 ) || op.is_movement()
63}
64
65pub fn add_tags_patterns() -> crate::TypedPatternMatcher<AddTagsCtx> {
70 crate::patterns! {
71 @context AddTagsCtx;
72 x => {
74 if x.tag().is_some() || ctx.excluded.contains(&UOpKey(x.clone())) { return None; }
75 if let Op::Kernel { ast, .. } = x.op() {
77 for u in ast.toposort() {
78 ctx.excluded.insert(UOpKey(u));
79 }
80 }
81 if should_skip_tag(x.op()) { return None; }
82 if x.dtype().base() == morok_dtype::ScalarDType::Index { return None; }
84 if matches!(x.op(), Op::MStack { .. } | Op::MSelect { .. })
86 && x.op().sources().iter().all(|s| matches!(s.op(), Op::Param { .. }))
87 {
88 return None;
89 }
90 ctx.uop_list.push(x.clone());
91 Some(x.with_tag(smallvec![ctx.uop_list.len() - 1]))
92 },
93 }
94}
95
96pub fn rangeify(
105 sink: Arc<UOp>,
106 pcontig_config: Option<&super::kernel::PcontigConfig>,
107) -> morok_ir::Result<(Arc<UOp>, RangeifyContext)> {
108 let result = rangeify_with_map(sink, pcontig_config)?;
109 Ok((result.sink, result.context))
110}
111
112pub struct RangeifyResult {
114 pub sink: Arc<UOp>,
116 pub context: RangeifyContext,
118 pub uop_list: Vec<Arc<UOp>>,
121}
122
123#[allow(clippy::mutable_key_type)]
141#[tracing::instrument(skip_all)]
142pub fn rangeify_with_map(
143 sink: Arc<UOp>,
144 pcontig_config: Option<&super::kernel::PcontigConfig>,
145) -> morok_ir::Result<RangeifyResult> {
146 let t_stage = std::time::Instant::now();
149 let mut tag_ctx = AddTagsCtx::new();
150 let mut sink = crate::rewrite::graph_rewrite_bottom_up(&add_tags_patterns(), sink, &mut tag_ctx);
151 let uop_list = tag_ctx.uop_list;
152 tracing::debug!(
153 tagged_count = uop_list.len(),
154 node_count = sink.node_count(),
155 elapsed_ms = t_stage.elapsed().as_millis() as u64,
156 "add_tags complete"
157 );
158
159 let t_stage = std::time::Instant::now();
163 let early_combined = super::patterns::early_rewrites().with_context::<super::patterns::ReplaceContiguousCtx>()
164 + super::patterns::replace_contiguous();
165 let mut contig_ctx = super::patterns::ReplaceContiguousCtx::new();
166 sink = crate::rewrite::graph_rewrite_bottom_up(&early_combined, sink, &mut contig_ctx);
167 tracing::debug!(
168 uop.tree = sink.tree(),
169 node_count = sink.node_count(),
170 elapsed_ms = t_stage.elapsed().as_millis() as u64,
171 "early rewrites + replace contiguous complete"
172 );
173
174 let t_stage = std::time::Instant::now();
176 let mut split_config = super::kernel::SplitReduceOpConfig::default();
177 let split_matcher = super::patterns::split_reduceop_patterns();
178 sink = crate::rewrite::graph_rewrite(&split_matcher, sink, &mut split_config);
179 tracing::debug!(
180 uop.tree = sink.tree(),
181 node_count = sink.node_count(),
182 elapsed_ms = t_stage.elapsed().as_millis() as u64,
183 "split reduceops complete"
184 );
185
186 let t_stage = std::time::Instant::now();
192 let (rangeified, indexing_ctx) = super::indexing::run_rangeify(sink)?;
193 sink = rangeified;
194 tracing::debug!(
195 uop.tree = sink.tree(),
196 node_count = sink.node_count(),
197 elapsed_ms = t_stage.elapsed().as_millis() as u64,
198 "Stage 0: range assignment + apply rangeify complete"
199 );
200
201 {
210 use super::kernel::PcontigConfig;
211 let t_stage = std::time::Instant::now();
212 use std::sync::LazyLock;
213 static MEGA_PASS: LazyLock<crate::TypedPatternMatcher<PcontigConfig>> = LazyLock::new(|| {
214 crate::symbolic::symbolic().with_context::<PcontigConfig>()
215 + super::patterns::pm_reduce_simplify().with_context()
216 + super::patterns::buffer_folding().with_context()
217 + super::patterns::dead_axis_removal().with_context()
218 + super::patterns::movement_op_patterns().with_context()
220 + super::patterns::buffer_removal_with_pcontig()
221 });
222 let mega_pass = &*MEGA_PASS;
223 tracing::debug!(
224 total_patterns = mega_pass.len(),
225 wildcard_count = mega_pass.wildcard_count(),
226 indexed_buckets = mega_pass.indexed_count(),
227 "mega-pass pattern stats"
228 );
229 let mut pcontig = pcontig_config.cloned().unwrap_or_default();
230 sink = crate::rewrite::graph_rewrite(mega_pass, sink, &mut pcontig);
231 tracing::debug!(
232 node_count = sink.node_count(),
233 elapsed_ms = t_stage.elapsed().as_millis() as u64,
234 "mega-pass complete"
235 );
236 }
237
238 if let Op::Sink { sources } = sink.op() {
246 let filtered: Vec<Arc<UOp>> = sources
247 .iter()
248 .filter(|s| {
249 let valid_op = matches!(
250 s.base().op(),
251 Op::Bufferize { .. } | Op::MStack { .. } | Op::Const(_) | Op::Param { .. } | Op::After { .. }
252 );
253 valid_op
254 })
255 .cloned()
256 .collect();
257 if !filtered.is_empty() && filtered.len() != sources.len() {
258 tracing::debug!(
259 original = sources.len(),
260 filtered = filtered.len(),
261 "SINK cleanup: removed invalid-type sources after mega-pass"
262 );
263 sink = UOp::sink(filtered);
264 }
265 }
266
267 if let Some(device) = super::patterns::extract_device_from_graph(&sink)
269 && let Some(limit) = device.max_buffers()
270 {
271 let t_stage = std::time::Instant::now();
272 let limit_matcher = super::patterns::buffer_limit_patterns(limit);
273 sink = crate::rewrite::graph_rewrite(&limit_matcher, sink, &mut ());
274 tracing::debug!(
275 uop.tree = sink.tree(),
276 elapsed_ms = t_stage.elapsed().as_millis() as u64,
277 "Stage 7b: buffer limit enforcement complete"
278 );
279 }
280
281 let rangeify_ctx = RangeifyContext { range_counter: indexing_ctx.range_counter(), range_map: HashMap::new() };
287
288 Ok(RangeifyResult { sink, context: rangeify_ctx, uop_list })
289}
290
291pub fn pm_flatten_range() -> &'static crate::TypedPatternMatcher {
296 crate::cached_patterns! {
297 r @ End { computation: _, ranges } if !ranges.is_empty() => |r| flatten_range_impl(r),
298 r @ Reduce { src: _, ranges, reduce_op: _ } if !ranges.is_empty() => |r| flatten_range_impl(r),
299 r @ Store { index: _, value: _, ranges } if !ranges.is_empty() => |r| flatten_range_impl(r),
300 }
301}
302
303#[derive(Default)]
312pub struct SplitRangesContext {
313 pub marked_ranges: HashMap<u64, i64>,
315 protected_ranges: HashSet<u64>,
317}
318
319pub fn pm_split_ranges() -> crate::TypedPatternMatcher<SplitRangesContext> {
337 crate::patterns! {
338 @context SplitRangesContext;
339
340 _modop @ Mod(r @ Range { end, axis_id: _, axis_type: _ }, c @ Const(_))
342 if is_divisible_range_end(end, c) => |r, c| {
343 mark_range_mod(ctx, r, c);
344 None },
346
347 _store @ Store { index: idx @ Index { buffer: buf, indices: _, gate: _ }, value: _, ranges: _ }
349 if is_image_dtype(buf) => |idx| {
350 protect_ranges_for_image(ctx, idx);
351 None },
353
354 sink @ Sink { sources: _ } if !ctx.marked_ranges.is_empty() => |sink| {
356 do_split_ranges_substitute(ctx, sink)
357 },
358 }
359}
360
361fn is_image_dtype(buf: &Arc<UOp>) -> bool {
363 matches!(buf.dtype(), DType::Image { .. })
364}
365
366fn protect_ranges_for_image(ctx: &mut SplitRangesContext, idx: &Arc<UOp>) {
368 for node in idx.toposort() {
369 if matches!(node.op(), Op::Range { .. }) {
370 ctx.protected_ranges.insert(node.id);
371 ctx.marked_ranges.remove(&node.id);
373 }
374 }
375}
376
377fn const_uop_to_i64(c: &Arc<UOp>) -> Option<i64> {
379 match c.op() {
380 Op::Const(cv) => match cv.0 {
381 ConstValue::Int(v) => Some(v),
382 ConstValue::UInt(v) => Some(v as i64),
383 _ => None,
384 },
385 _ => None,
386 }
387}
388
389fn is_divisible_range_end(end: &Arc<UOp>, c: &Arc<UOp>) -> bool {
391 let Some(end_val) = const_uop_to_i64(end) else {
392 return false;
393 };
394 let Some(mod_val) = const_uop_to_i64(c) else {
395 return false;
396 };
397 mod_val > 0 && end_val % mod_val == 0
398}
399
400fn mark_range_mod(ctx: &mut SplitRangesContext, r: &Arc<UOp>, c: &Arc<UOp>) {
402 if ctx.marked_ranges.contains_key(&r.id) || ctx.protected_ranges.contains(&r.id) {
404 return;
405 }
406 if let Some(mod_val) = const_uop_to_i64(c) {
407 ctx.marked_ranges.insert(r.id, mod_val);
408 }
409}
410
411fn do_split_ranges_substitute(ctx: &mut SplitRangesContext, sink: &Arc<UOp>) -> Option<Arc<UOp>> {
418 use morok_ir::AxisId;
419 use morok_ir::rewrite::graph_rewrite_bottom_up;
420
421 if ctx.marked_ranges.is_empty() {
422 return None;
423 }
424
425 let mut subs: HashMap<u64, Arc<UOp>> = HashMap::new();
427
428 let topo = sink.toposort();
430
431 let mut max_axis_id: usize = 0;
433 for uop in &topo {
434 if let Op::Range { axis_id, .. } = uop.op() {
435 max_axis_id = max_axis_id.max(axis_id.value());
436 }
437 }
438 let mut next_id = max_axis_id + 1;
439
440 for uop in &topo {
441 if ctx.protected_ranges.contains(&uop.id) {
443 continue;
444 }
445 if let Some(&mod_val) = ctx.marked_ranges.get(&uop.id)
446 && let Op::Range { end, axis_type, .. } = uop.op()
447 {
448 let Some(end_val) = const_uop_to_i64(end) else {
449 continue;
450 };
451
452 let outer_end = end_val / mod_val;
454 let outer_range = UOp::range_axis(UOp::index_const(outer_end), AxisId::Renumbered(next_id), *axis_type);
455 next_id += 1;
456
457 let inner_range = UOp::range_axis(UOp::index_const(mod_val), AxisId::Renumbered(next_id), *axis_type);
459 next_id += 1;
460
461 let mod_const = UOp::index_const(mod_val);
463 let outer_scaled = outer_range.mul(&mod_const);
464 let combined = outer_scaled.add(&inner_range);
465
466 subs.insert(uop.id, combined);
467 }
468 }
469
470 if subs.is_empty() {
471 return None;
472 }
473
474 let substitute_pm = crate::patterns! {
476 r @ Range { end: _, axis_id: _, axis_type: _ } if subs.contains_key(&r.id) => {
477 subs.get(&r.id).cloned()
478 },
479 };
480
481 let result = graph_rewrite_bottom_up(&substitute_pm, sink.clone(), &mut ());
482
483 ctx.marked_ranges.clear();
485
486 Some(result)
487}
488
489pub fn transform_sources_with_bufferize(x: &Arc<UOp>, ctx: &mut IndexingContext) -> Option<Vec<Arc<UOp>>> {
495 if matches!(x.op(), Op::Bufferize { .. } | Op::Index { .. } | Op::After { .. }) {
496 return None;
497 }
498
499 let sources = x.op().sources();
500 if sources.is_empty() {
501 return None;
502 }
503
504 let input_ranges = if let Some((ranges, _)) = ctx.get_ranges(x) { ranges.clone() } else { Vec::new() };
507
508 let mut new_sources = Vec::with_capacity(sources.len());
509 let mut any_changed = false;
510
511 for src in sources.iter() {
512 let new_src = transform_single_source(x, src, &input_ranges, ctx);
513 if !Arc::ptr_eq(&new_src, src) {
514 any_changed = true;
515 }
516 new_sources.push(new_src);
517 }
518
519 if any_changed { Some(new_sources) } else { None }
520}
521
522fn flatten_bufferize(bufferize: &Arc<UOp>) -> Option<Arc<UOp>> {
532 let Op::Bufferize { compute, ranges, opts } = bufferize.op() else { return None };
533 if ranges.len() <= 1 {
534 return None;
535 }
536 let shape: Vec<morok_ir::SInt> = ranges
538 .iter()
539 .map(|r| match r.op() {
540 Op::Range { end, .. } => morok_ir::SInt::from(end.clone()),
541 _ => morok_ir::SInt::from(1usize),
542 })
543 .collect();
544
545 let flat_shape = vec![morok_ir::sint_prod(&shape)];
547 let ranges_vec: Vec<Arc<UOp>> = ranges.iter().cloned().collect();
548 let flat_indices = super::indexing::apply_reshape_ranges(&flat_shape, &shape, &ranges_vec);
549 assert_eq!(flat_indices.len(), 1, "flatten_bufferize: expected 1 flat index, got {}", flat_indices.len());
550 let flat_buf = UOp::bufferize(compute.clone(), vec![flat_indices[0].clone()], opts.clone());
552
553 let shape_smallvec: Shape = shape.iter().cloned().collect();
555 let reshaped = flat_buf.try_reshape(&shape_smallvec).expect("flatten_bufferize: try_reshape failed");
556
557 let has_symbolic =
560 ranges.iter().any(|r| matches!(r.op(), Op::Range { end, .. } if !matches!(end.op(), Op::Const(_))));
561
562 if has_symbolic {
563 let sym_ranges: Vec<(morok_ir::SInt, morok_ir::SInt)> = ranges
564 .iter()
565 .map(|r| match r.op() {
566 Op::Range { end, .. } => (morok_ir::SInt::from(0usize), morok_ir::SInt::from(end.clone())),
567 _ => (morok_ir::SInt::from(0usize), morok_ir::SInt::from(1usize)),
568 })
569 .collect();
570 Some(reshaped.try_shrink(&sym_ranges).expect("flatten_bufferize: try_shrink failed for symbolic ranges"))
571 } else {
572 Some(reshaped)
573 }
574}
575
576pub(crate) fn push_movement_through_after(mop: &Arc<UOp>, deps: &SmallVec<[Arc<UOp>; 4]>) -> Option<Arc<UOp>> {
582 let inner_src = &mop.op().sources()[0];
583 let new_after = inner_src.after(deps.clone());
584 let new_op = match mop.op() {
587 Op::Reshape { new_shape, .. } => Op::Reshape { src: new_after, new_shape: new_shape.clone() },
588 Op::Permute { axes, .. } => Op::Permute { src: new_after, axes: axes.clone() },
589 Op::Expand { new_shape, .. } => Op::Expand { src: new_after, new_shape: new_shape.clone() },
590 Op::Pad { begin_pads, end_pads, .. } => {
591 Op::Pad { src: new_after, begin_pads: begin_pads.clone(), end_pads: end_pads.clone() }
592 }
593 Op::Shrink { begins, ends, .. } => Op::Shrink { src: new_after, begins: begins.clone(), ends: ends.clone() },
594 Op::Flip { axes, .. } => Op::Flip { src: new_after, axes: axes.clone() },
595 _ => return None,
596 };
597 Some(UOp::new(new_op, mop.dtype()))
598}
599
600pub(crate) fn transform_single_source(
610 consumer: &Arc<UOp>,
611 src: &Arc<UOp>,
612 input_ranges: &[Arc<UOp>],
613 ctx: &mut IndexingContext,
614) -> Arc<UOp> {
615 if matches!(
621 src.op(),
622 Op::Buffer { .. }
623 | Op::Param { .. }
624 | Op::BufferView { .. }
625 | Op::MStack { .. }
626 | Op::MSelect { .. }
627 | Op::After { .. }
628 ) {
629 if !input_ranges.is_empty() {
630 return UOp::index()
631 .buffer(Arc::clone(src))
632 .indices(input_ranges.to_vec())
633 .call()
634 .expect("Failed to create INDEX for buffer source");
635 }
636 return Arc::clone(src);
637 }
638
639 let realize_axes_opt = ctx.get_realize_axes(src).cloned();
641 if let Some(ref realize_axes) = realize_axes_opt {
642 let (_, output_ranges) = ctx.get_ranges(src).expect("Realized op must have ranges");
643
644 let closed_ranges: Vec<_> = output_ranges
645 .iter()
646 .enumerate()
647 .filter(|(i, _)| realize_axes.contains(i))
648 .map(|(_, r)| Arc::clone(r))
649 .collect();
650
651 let is_copy_consumer = matches!(consumer.op(), Op::Copy { .. });
653 let is_always_contiguous_src = super::indexing::is_always_contiguous(src);
654 let removable = !is_copy_consumer && !is_always_contiguous_src;
655 let addrspace = if output_ranges.len() == realize_axes.len() { AddrSpace::Global } else { AddrSpace::Local };
656 tracing::debug!(
657 src_id = src.id,
658 src_op = src.op().as_ref(),
659 consumer_id = consumer.id,
660 consumer_op = consumer.op().as_ref(),
661 realize_axes = ?realize_axes,
662 output_ranges_len = output_ranges.len(),
663 addrspace = ?addrspace,
664 removable = removable,
665 "BUFFERIZE decision"
666 );
667 let device = src.device_spec();
669 let opts = BufferizeOpts { device, addrspace, removable };
670
671 let buf_tag = if addrspace == AddrSpace::Global { src.tag().clone() } else { None };
673 let bufferized = UOp::bufferize(Arc::clone(src), closed_ranges.clone(), opts);
674 let bufferized = if let Some(t) = buf_tag { bufferized.with_tag(t) } else { bufferized };
675
676 let index_ranges: Vec<_> = input_ranges
677 .iter()
678 .enumerate()
679 .filter(|(i, _)| realize_axes.contains(i))
680 .map(|(_, r)| Arc::clone(r))
681 .collect();
682
683 if !index_ranges.is_empty() {
684 return UOp::index()
687 .buffer(bufferized)
688 .indices(index_ranges)
689 .call()
690 .expect("Failed to create INDEX after BUFFERIZE");
691 } else {
692 return bufferized;
693 }
694 }
695
696 Arc::clone(src)
698}
699
700fn apply_movement_ops_chain(result: &Arc<UOp>, chain: &Arc<UOp>) -> Option<Arc<UOp>> {
712 let mut mops = Vec::new();
713 let mut walk = chain.clone();
714
715 while walk.op().is_movement() {
717 mops.push(walk.clone());
718 walk = match walk.op() {
720 Op::Reshape { src, .. }
721 | Op::Permute { src, .. }
722 | Op::Expand { src, .. }
723 | Op::Pad { src, .. }
724 | Op::Shrink { src, .. }
725 | Op::Flip { src, .. } => src.clone(),
726 _ => break,
727 };
728 }
729
730 let mut current = result.clone();
732 for mop in mops.into_iter().rev() {
733 current = apply_single_movement_op(¤t, mop.op())?;
734 }
735
736 Some(current)
737}
738
739fn apply_single_movement_op(uop: &Arc<UOp>, op: &Op) -> Option<Arc<UOp>> {
744 match op {
745 Op::Reshape { new_shape, .. } => {
746 let shape = extract_shape_from_uop(new_shape)?;
747 uop.try_reshape(&shape).ok()
748 }
749 Op::Permute { axes, .. } => uop.try_permute(axes.clone()).ok(),
750 Op::Expand { new_shape, .. } => {
751 let shape = extract_shape_from_uop(new_shape)?;
752 uop.try_expand(&shape).ok()
753 }
754 Op::Pad { begin_pads, end_pads, .. } => {
755 let begins = extract_shape_from_uop(begin_pads)?;
756 let ends = extract_shape_from_uop(end_pads)?;
757 let padding: Vec<_> = begins.into_iter().zip(ends).collect();
758 uop.try_pad(&padding).ok()
759 }
760 Op::Shrink { begins, ends, .. } => {
761 let begin_shape = extract_shape_from_uop(begins)?;
762 let end_shape = extract_shape_from_uop(ends)?;
763 let ranges: Vec<_> = begin_shape.into_iter().zip(end_shape).collect();
764 uop.try_shrink(&ranges).ok()
765 }
766 Op::Flip { axes, .. } => uop.try_flip(axes.clone()).ok(),
767 _ => None,
768 }
769}
770
771fn extract_shape_from_uop(shape_uop: &Arc<UOp>) -> Option<Shape> {
774 use morok_ir::SInt;
775 match shape_uop.op() {
776 Op::Vectorize { elements } => Some(elements.iter().cloned().map(SInt::from).collect()),
778 Op::Const(const_hash) => match const_hash.0 {
780 ConstValue::Int(v) if v >= 0 => Some(smallvec![SInt::from(v as usize)]),
781 ConstValue::UInt(v) => Some(smallvec![SInt::from(v as usize)]),
782 _ => None,
783 },
784 Op::VConst { values } => {
786 let mut dims = smallvec![];
787 for val in values {
788 match val {
789 ConstValue::Int(v) if *v >= 0 => dims.push(SInt::from(*v as usize)),
790 ConstValue::UInt(v) => dims.push(SInt::from(*v as usize)),
791 _ => return None,
792 }
793 }
794 Some(dims)
795 }
796 _ => None,
797 }
798}
799
800fn create_loop_range_from_outer(outer: &Arc<UOp>, size: usize) -> Option<Arc<UOp>> {
802 use morok_ir::AxisType;
803 let Op::Range { axis_id, .. } = outer.op() else {
804 return None;
805 };
806 Some(UOp::range_axis(UOp::index_const(size as i64), *axis_id, AxisType::Loop))
807}
808
809fn reduce_op_to_binary(op: morok_ir::ReduceOp, lhs: &Arc<UOp>, rhs: &Arc<UOp>) -> Option<Arc<UOp>> {
811 use morok_ir::types::{BinaryOp, ReduceOp};
812 let dtype = lhs.dtype();
813 Some(match op {
814 ReduceOp::Add => UOp::new(Op::Binary(BinaryOp::Add, lhs.clone(), rhs.clone()), dtype),
815 ReduceOp::Mul => UOp::new(Op::Binary(BinaryOp::Mul, lhs.clone(), rhs.clone()), dtype),
816 ReduceOp::Max => UOp::new(Op::Binary(BinaryOp::Max, lhs.clone(), rhs.clone()), dtype),
817 ReduceOp::Min => {
818 let cond = UOp::new(Op::Binary(BinaryOp::Lt, lhs.clone(), rhs.clone()), morok_dtype::DType::Bool);
820 UOp::try_where(cond, lhs.clone(), rhs.clone()).expect("reduce_op_to_binary: try_where failed for Min")
821 }
822 })
823}
824
825fn calculate_size_from_ranges(ranges: &SmallVec<[Arc<UOp>; 4]>) -> usize {
831 if ranges.is_empty() {
832 return 1;
833 }
834
835 ranges
836 .iter()
837 .map(|r| {
838 let vmax = r.vmax();
840 match vmax {
841 ConstValue::Int(v) if *v >= 0 => (*v + 1) as usize,
842 ConstValue::UInt(v) => (*v + 1) as usize,
843 other => panic!(
844 "Cannot allocate buffer: range vmax resolved to {:?}. \
845 Buffers require concrete sizes (Tinygrad: 'no symbolic sized buffers')",
846 other
847 ),
848 }
849 })
850 .product()
851}
852
853fn sort_ranges_by_axis_id(ranges: &SmallVec<[Arc<UOp>; 4]>) -> SmallVec<[Arc<UOp>; 4]> {
862 let mut sorted: Vec<_> = ranges.iter().cloned().collect();
863 sorted.sort_by_key(|r| {
864 if let Op::Range { axis_id, axis_type, .. } = r.op() {
865 (axis_id.value(), axis_type_ordinal(*axis_type))
867 } else {
868 (usize::MAX, u8::MAX)
869 }
870 });
871 sorted.into()
872}
873
874fn axis_type_ordinal(at: AxisType) -> u8 {
877 match at {
878 AxisType::Outer => 0,
879 AxisType::Global => 1,
880 AxisType::Warp => 2,
881 AxisType::Local => 3,
882 AxisType::Loop => 4,
883 AxisType::GroupReduce => 5,
884 AxisType::Reduce => 6,
885 AxisType::Upcast => 7,
886 AxisType::Unroll => 8,
887 AxisType::Thread => 9,
888 AxisType::Placeholder => 10,
889 }
890}
891
892fn collect_range_uops(ranges: &SmallVec<[Arc<UOp>; 4]>) -> SmallVec<[Arc<UOp>; 4]> {
901 let mut collected = SmallVec::new();
902 for r in ranges.iter() {
903 if matches!(r.op(), Op::Range { .. }) {
904 collected.push(r.clone());
905 } else if !matches!(r.op(), Op::Const(_)) {
906 for rng in r.ranges().iter() {
907 if !collected.iter().any(|c: &Arc<UOp>| c.id == rng.id) {
908 collected.push(rng.clone());
909 }
910 }
911 }
912 }
913 collected
914}
915
916pub fn bufferize_to_store(bufferize_op: &Arc<UOp>, ctx: &mut KernelContext, allow_locals: bool) -> Option<Arc<UOp>> {
925 let (compute, ranges, opts) = match bufferize_op.op() {
926 Op::Bufferize { compute, ranges, opts } => {
927 tracing::debug!(
928 bufferize_id = bufferize_op.id,
929 compute_id = compute.id,
930 ranges_len = ranges.len(),
931 allow_locals = allow_locals,
932 "bufferize_to_store: CONVERTING BUFFERIZE to STORE→AFTER"
933 );
934 (compute, ranges, opts)
935 }
936 _ => return None,
937 };
938
939 let size = calculate_size_from_ranges(ranges);
941 let base_dtype = match bufferize_op.dtype() {
942 DType::Ptr { base, .. } => (*base).clone(),
943 other => other,
944 };
945
946 let sdtype = base_dtype.clone().ptr(Some(size), opts.addrspace);
951
952 let end_ranges: SmallVec<[Arc<UOp>; 4]> = sort_ranges_by_axis_id(&collect_range_uops(ranges));
955
956 if let Op::Assign { target, value, movement_ops } = compute.op() {
961 let Op::Index { buffer, indices, gate } = target.op() else {
963 return None;
964 };
965
966 let store_target = UOp::index()
968 .buffer(buffer.clone())
969 .indices(indices.to_vec())
970 .maybe_gate(gate.clone())
971 .dtype(sdtype.clone())
972 .call()
973 .expect("bufferize_to_store: failed to create INDEX for ASSIGN target");
974
975 let store = store_target.store_value(value.clone());
977 let do_store = if end_ranges.is_empty() { store } else { store.end(end_ranges.clone()) };
978
979 let mut result = buffer.after(smallvec![do_store]);
981 if let Some(mops_chain) = movement_ops {
982 result = apply_movement_ops_chain(&result, mops_chain)?;
983 }
984
985 ctx.map_buffer(bufferize_op.clone(), result.clone());
986 return Some(result);
987 }
988
989 if let Op::Reduce { src: reduce_src, ranges: reduce_ranges, reduce_op } = compute.op() {
994 if reduce_ranges.len() == 1
997 && let Op::Range { axis_type, .. } = reduce_ranges[0].op()
998 && *axis_type == AxisType::Outer
999 {
1000 if opts.addrspace != AddrSpace::Global {
1002 return None;
1003 }
1004
1005 let outer_range = reduce_ranges[0].clone();
1006 let device = opts.device.clone().unwrap_or(morok_ir::DeviceSpec::Cpu);
1007
1008 let buf = UOp::new_buffer(device, size, base_dtype.clone());
1010
1011 let zero_range = create_loop_range_from_outer(&outer_range, size)?;
1013
1014 use crate::symbolic::dce::reduce_identity;
1016 let identity = reduce_identity(*reduce_op, base_dtype.clone());
1017
1018 let zero_idx = UOp::index()
1020 .buffer(buf.clone())
1021 .indices(vec![zero_range.clone()])
1022 .dtype(sdtype.clone())
1023 .call()
1024 .expect("bufferize_to_store: failed to create INDEX for OUTER REDUCE zero-init");
1025 let zero_store = zero_idx.store_value(identity).end(smallvec![zero_range.clone()]);
1026 let buf_zeroed = buf.after(smallvec![zero_store]);
1027
1028 debug_assert!(
1031 ranges.len() <= 1 || ranges.iter().all(|r| matches!(r.op(), Op::Const(_))),
1032 "bufferize_to_store: unexpected multi-range in OUTER REDUCE after flatten_bufferize"
1033 );
1034 let idx = if ranges.len() == 1 && !matches!(ranges[0].op(), Op::Const(_)) {
1035 ranges[0].clone()
1036 } else if !end_ranges.is_empty() {
1037 sort_ranges_by_axis_id(&end_ranges)[0].clone()
1038 } else {
1039 UOp::index_const(0)
1040 };
1041
1042 let sorted_end_ranges = sort_ranges_by_axis_id(&collect_range_uops(ranges));
1044
1045 let buf_idx = UOp::index()
1047 .buffer(buf_zeroed.clone())
1048 .indices(vec![idx])
1049 .dtype(sdtype.clone())
1050 .call()
1051 .expect("bufferize_to_store: failed to create INDEX for OUTER REDUCE accumulation");
1052 let loaded = UOp::load().buffer(buf_zeroed.clone()).index(buf_idx.clone()).call();
1053 let accumulated = reduce_op_to_binary(*reduce_op, &loaded, reduce_src)?;
1054
1055 let do_store = buf_idx.store_value(accumulated).end(sorted_end_ranges).end(smallvec![outer_range]);
1057
1058 let result = buf_zeroed.after(smallvec![do_store]);
1059 ctx.map_buffer(bufferize_op.clone(), result.clone());
1060 return Some(result);
1061 }
1062 }
1063
1064 if !allow_locals && opts.addrspace == AddrSpace::Local {
1073 return None;
1074 }
1075 let effective_addrspace = opts.addrspace;
1076
1077 let buffer = if let Some(existing_buffer) = ctx.get_buffer(bufferize_op) {
1078 existing_buffer.clone()
1079 } else if effective_addrspace == AddrSpace::Global {
1080 let device = opts.device.clone().unwrap_or(morok_ir::DeviceSpec::Cpu);
1083 UOp::new_buffer(device, size, base_dtype.clone())
1084 } else {
1085 let local_ptr_dtype = base_dtype.clone().ptr(Some(size), opts.addrspace);
1087 let local_id = ctx.next_local();
1088 UOp::define_local(local_id, local_ptr_dtype)
1089 };
1090
1091 let active_ranges: SmallVec<[Arc<UOp>; 4]> = collect_range_uops(ranges);
1099
1100 let sorted_ranges = sort_ranges_by_axis_id(&active_ranges);
1102
1103 let vcount = compute.dtype().vcount();
1106 let store_buffer = if vcount > 1 { buffer.broadcast(vcount) } else { buffer.clone() };
1107
1108 let store_target = if !sorted_ranges.is_empty() {
1109 assert!(
1113 ranges.len() <= 1 || ranges.iter().all(|r| matches!(r.op(), Op::Const(_))),
1114 "bufferize_to_store: unexpected multi-range in general path after flatten_bufferize"
1115 );
1116 let idx = if ranges.len() == 1 && !matches!(ranges[0].op(), Op::Const(_)) {
1117 ranges[0].clone()
1119 } else {
1120 sorted_ranges[0].clone()
1122 };
1123 UOp::index()
1124 .buffer(store_buffer)
1125 .indices(vec![idx])
1126 .dtype(sdtype.clone())
1127 .call()
1128 .expect("Failed to create INDEX for BUFFERIZE-to-STORE conversion")
1129 } else {
1130 UOp::index()
1132 .buffer(store_buffer)
1133 .indices(vec![UOp::index_const(0)])
1134 .dtype(sdtype.clone())
1135 .call()
1136 .expect("Failed to create INDEX for scalar STORE")
1137 };
1138
1139 let store = store_target.store_value(compute.clone());
1149
1150 let end_ranges: SmallVec<[Arc<UOp>; 4]> = sorted_ranges.clone();
1156
1157 let mut do_store = if !end_ranges.is_empty() { store.end(end_ranges) } else { store };
1158
1159 if opts.addrspace == AddrSpace::Local {
1160 do_store = do_store.barrier(SmallVec::new());
1161 }
1162
1163 let result = buffer.after(SmallVec::from_elem(do_store, 1));
1164 ctx.map_buffer(bufferize_op.clone(), result.clone());
1165
1166 Some(result)
1167}
1168
1169#[allow(clippy::mutable_key_type)]
1175pub(crate) fn partition_reduce_ranges(
1176 ranges: &SmallVec<[Arc<UOp>; 4]>,
1177 src_ranges: &HashSet<UOpKey>,
1178) -> (SmallVec<[Arc<UOp>; 4]>, Vec<Arc<UOp>>) {
1179 let mut parented = SmallVec::new();
1180 let mut unparented = Vec::new();
1181
1182 for range in ranges {
1183 let key = UOpKey(Arc::clone(range));
1184 if src_ranges.contains(&key) {
1185 parented.push(Arc::clone(range));
1186 } else {
1187 unparented.push(Arc::clone(range));
1188 }
1189 }
1190
1191 (parented, unparented)
1192}
1193
1194pub(crate) fn get_range_size(range: &Arc<UOp>) -> Option<Arc<UOp>> {
1195 if let Op::Range { end, .. } = range.op() { Some(Arc::clone(end)) } else { None }
1196}
1197
1198#[allow(clippy::mutable_key_type)]
1212fn reduce_collapse_with(src: &Arc<UOp>, ranges: &[Arc<UOp>], pm: &crate::TypedPatternMatcher<()>) -> Option<Arc<UOp>> {
1213 use morok_ir::ReduceOp;
1214
1215 if ranges.is_empty() {
1216 return None;
1217 }
1218
1219 let mut u = Arc::clone(src);
1220
1221 for range in ranges {
1222 let range_key = UOpKey(range.clone());
1224 let in_scope: HashSet<UOpKey> =
1225 u.toposort_filtered(|node| node.in_scope_ranges().contains(&range_key)).into_iter().map(UOpKey).collect();
1226
1227 if in_scope.iter().any(|k| matches!(k.0.op(), Op::Reduce { .. } | Op::Store { .. })) {
1229 return None;
1230 }
1231
1232 let mut replaces: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
1235 for node in &in_scope {
1236 node.0.op().map_child(|child| {
1237 let key = UOpKey(child.clone());
1238 if in_scope.contains(&key) || replaces.contains_key(&key) {
1239 return;
1240 }
1241 if matches!(
1242 child.op(),
1243 Op::Const(_)
1244 | Op::VConst { .. }
1245 | Op::DefineVar { .. }
1246 | Op::Param { device: None, .. }
1247 | Op::DefineLocal { .. }
1248 ) {
1249 return;
1250 }
1251 let vmin = match child.vmin() {
1252 ConstValue::Int(i) => *i,
1253 ConstValue::UInt(u) => *u as i64,
1254 ConstValue::Float(f) => *f as i64,
1255 ConstValue::Bool(b) => *b as i64,
1256 };
1257 let vmax = match child.vmax() {
1258 ConstValue::Int(i) => *i,
1259 ConstValue::UInt(u) => *u as i64,
1260 ConstValue::Float(f) => *f as i64,
1261 ConstValue::Bool(b) => *b as i64,
1262 };
1263 let var = UOp::define_var(format!("in{}", replaces.len()), vmin, vmax).with_dtype(child.dtype());
1264 replaces.insert(key, var);
1265 });
1266 }
1267
1268 let substituted = u.substitute(&replaces);
1270 let synthetic_reduce = substituted.reduce(smallvec![range.clone()], ReduceOp::Add);
1271
1272 let result = crate::rewrite::graph_rewrite(pm, synthetic_reduce, &mut ());
1274
1275 let has_range = result.toposort().iter().any(|x| matches!(x.op(), Op::Range { .. }));
1278 if has_range {
1279 return None;
1280 }
1281
1282 let reverse: HashMap<UOpKey, Arc<UOp>> = replaces.into_iter().map(|(k, v)| (UOpKey(v), k.0)).collect();
1284 u = result.substitute(&reverse);
1285 }
1286
1287 Some(u)
1288}
1289
1290pub fn reduce_collapse(src: &Arc<UOp>, ranges: &[Arc<UOp>]) -> Option<Arc<UOp>> {
1294 reduce_collapse_with(src, ranges, super::patterns::build_reduce_collapse_matcher())
1295}
1296
1297pub fn reduce_load_collapse(src: &Arc<UOp>, ranges: &[Arc<UOp>]) -> Option<Arc<UOp>> {
1303 reduce_collapse_with(src, ranges, super::patterns::build_reduce_load_collapse_matcher())
1304}
1305
1306pub(crate) fn cast_to_dtype(value: &Arc<UOp>, target_dtype: &morok_dtype::DType) -> Option<Arc<UOp>> {
1307 use morok_dtype::DType;
1308
1309 let scalar_type = match target_dtype {
1310 DType::Scalar(s) => DType::Scalar(*s),
1311 DType::Vector { scalar, .. } => DType::Scalar(*scalar),
1312 _ => return None,
1313 };
1314
1315 let casted = value.cast(scalar_type);
1316
1317 if target_dtype.is_vector() {
1318 let count = target_dtype.count();
1319 let elements: SmallVec<[Arc<UOp>; 4]> = (0..count).map(|_| casted.clone()).collect();
1320 Some(UOp::vectorize(elements))
1321 } else {
1322 Some(casted)
1323 }
1324}
1325
1326pub fn simplify_merge_adjacent(u: &Arc<UOp>) -> Option<Arc<UOp>> {
1347 use crate::passes::linearize_index::count_divmod;
1348
1349 let ended_ranges = match u.op() {
1351 Op::End { computation: _, ranges } => ranges.clone(),
1352 Op::Reduce { ranges, .. } => ranges.clone(),
1353 _ => return None,
1354 };
1355
1356 if ended_ranges.len() < 2 {
1357 return None;
1358 }
1359
1360 let reduce_ranges: Vec<SmallVec<[Arc<UOp>; 4]>> = u
1362 .toposort()
1363 .iter()
1364 .filter_map(|dep| match dep.op() {
1365 Op::Reduce { ranges, .. } => Some(ranges.clone()),
1366 _ => None,
1367 })
1368 .collect();
1369
1370 let mut current = Arc::clone(u);
1373 let mut changed = false;
1374
1375 let pairs: Vec<(usize, usize)> = if matches!(u.op(), Op::End { .. }) {
1377 (0..ended_ranges.len() - 1).map(|i| (i, i + 1)).collect()
1378 } else {
1379 let mut perms = Vec::new();
1380 for i in 0..ended_ranges.len() {
1381 for j in 0..ended_ranges.len() {
1382 if i != j {
1383 perms.push((i, j));
1384 }
1385 }
1386 }
1387 perms
1388 };
1389
1390 for (i0, i1) in pairs {
1391 let r0 = &ended_ranges[i0];
1392 let r1 = &ended_ranges[i1];
1393
1394 let (r0_axis_type, r0_end) = match r0.op() {
1395 Op::Range { end, axis_type, .. } => (axis_type, end),
1396 _ => continue,
1397 };
1398 let (r1_axis_type, r1_end) = match r1.op() {
1399 Op::Range { end, axis_type, .. } => (axis_type, end),
1400 _ => continue,
1401 };
1402
1403 if r0_axis_type != r1_axis_type {
1404 continue;
1405 }
1406
1407 let valid_reduce_scope = reduce_ranges.iter().all(|rngs| {
1409 let r0_in = rngs.iter().any(|rng| Arc::ptr_eq(rng, r0));
1410 let r1_in = rngs.iter().any(|rng| Arc::ptr_eq(rng, r1));
1411 r0_in == r1_in
1412 });
1413 if !valid_reduce_scope {
1414 continue;
1415 }
1416
1417 if let Some(v) = const_uop_to_i64(r0_end)
1418 && v <= 0
1419 {
1420 continue;
1421 }
1422 if let Some(v) = const_uop_to_i64(r1_end)
1423 && v <= 0
1424 {
1425 continue;
1426 }
1427 if let (Some(s0), Some(s1)) = (const_uop_to_i64(r0_end), const_uop_to_i64(r1_end))
1428 && s0.checked_mul(s1).is_none()
1429 {
1430 continue;
1431 }
1432
1433 let merged_size_uop = r0_end.mul(r1_end);
1434 let merged_range = r0.with_sources(vec![merged_size_uop]);
1435
1436 let new_r0 = merged_range.idiv(r1_end);
1437 let new_r1 = merged_range.mod_(r1_end);
1438
1439 #[allow(clippy::mutable_key_type)]
1440 let mut subs: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
1441 subs.insert(UOpKey(r0.clone()), new_r0);
1442 subs.insert(UOpKey(r1.clone()), new_r1);
1443
1444 let rewritten = current.substitute(&subs);
1446 static MERGE_SYM: std::sync::LazyLock<crate::TypedPatternMatcher> =
1447 std::sync::LazyLock::new(|| crate::symbolic::symbolic().clone() + pm_flatten_range().clone());
1448 let simplified = crate::rewrite::graph_rewrite(&*MERGE_SYM, rewritten, &mut ());
1449
1450 let original_divmod = count_divmod(¤t);
1452 let new_divmod = count_divmod(&simplified);
1453
1454 if new_divmod <= original_divmod {
1455 current = simplified;
1456 changed = true;
1457 }
1458 }
1459
1460 if changed { Some(current) } else { None }
1461}
1462
1463pub fn pm_simplify_ranges() -> &'static crate::TypedPatternMatcher {
1467 crate::cached_patterns! {
1468 u @ End { computation: _, ranges } if !ranges.is_empty() => |u| simplify_merge_adjacent(u),
1470 u @ Reduce { src: _, ranges, reduce_op: _ } if !ranges.is_empty() => |u| simplify_merge_adjacent(u),
1472 }
1473}
1474
1475pub fn flatten_range_impl(r: &Arc<UOp>) -> Option<Arc<UOp>> {
1481 let off = match r.op() {
1482 Op::Reduce { .. } => 1,
1483 Op::Store { .. } => 2, Op::End { .. } => 1,
1485 _ => return None,
1486 };
1487
1488 let original_sources = r.op().sources();
1489 let original_ranges: Vec<&Arc<UOp>> = original_sources.iter().skip(off).collect();
1490 let mut all_range_sources: Vec<Arc<UOp>> = original_ranges.iter().map(|r| (*r).clone()).collect();
1491
1492 let innermost_computation = if matches!(r.op(), Op::End { .. }) {
1493 let mut computation = Arc::clone(&original_sources[0]);
1494
1495 while matches!(computation.op(), Op::End { .. }) {
1496 all_range_sources.extend(computation.op().sources().iter().skip(1).cloned());
1497 computation = Arc::clone(&computation.op().sources()[0]);
1498 }
1499
1500 Some(computation)
1501 } else {
1502 None
1503 };
1504
1505 if all_range_sources.is_empty() {
1506 return None;
1507 }
1508
1509 let sink = UOp::sink(all_range_sources);
1510 let new_ranges: Vec<Arc<UOp>> =
1511 sink.toposort().into_iter().filter(|uop| matches!(uop.op(), Op::Range { .. })).collect();
1512
1513 if new_ranges.is_empty() {
1514 return None;
1515 }
1516
1517 if new_ranges.len() == original_ranges.len()
1519 && innermost_computation.as_ref().is_none_or(|c| Arc::ptr_eq(c, &original_sources[0]))
1520 && new_ranges.iter().zip(original_ranges.iter()).all(|(a, b)| Arc::ptr_eq(a, *b))
1521 {
1522 return None; }
1524
1525 let mut new_sources: Vec<Arc<UOp>> =
1526 if let Some(inner_comp) = innermost_computation { vec![inner_comp] } else { original_sources[..off].to_vec() };
1527 new_sources.extend(new_ranges);
1528
1529 Some(r.with_sources(new_sources))
1530}
1531
1532#[allow(clippy::mutable_key_type)]
1534pub fn flatten_ranges(root: &Arc<UOp>) -> Arc<UOp> {
1535 let mut replacements: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
1536
1537 for node in root.toposort() {
1538 if let Some(flattened) = flatten_range_impl(&node) {
1539 replacements.insert(UOpKey(node.clone()), flattened);
1540 }
1541 }
1542
1543 root.substitute(&replacements)
1544}
1545
1546#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1552pub enum OpAccessType {
1553 Load,
1554 Store,
1555}
1556
1557pub fn as_buf(uop: &Arc<UOp>) -> Arc<UOp> {
1559 match uop.op() {
1560 Op::MSelect { buffer, .. } => buffer.clone(),
1561 Op::MStack { buffers } if !buffers.is_empty() => buffers[0].clone(),
1562 Op::After { passthrough, .. } => passthrough.clone(),
1563 _ => uop.clone(),
1564 }
1565}
1566
1567#[allow(clippy::mutable_key_type)]
1569pub fn find_bufs(store: &Arc<UOp>) -> HashMap<UOpKey, OpAccessType> {
1570 let mut ret: HashMap<UOpKey, OpAccessType> = HashMap::new();
1571
1572 let nodes = store.toposort_filtered(|uop| !matches!(uop.op(), Op::After { .. }));
1573
1574 for node in nodes {
1575 if let Op::Load { buffer, .. } = node.op() {
1576 let buf = as_buf(buffer);
1577 let buf_key = UOpKey(buf.clone());
1578
1579 if let Some(&existing_access) = ret.get(&buf_key)
1580 && existing_access != OpAccessType::Load
1581 {
1582 panic!(
1583 "buffer accessed with conflicting ops: {:?} (existing: {:?}, new: {:?})",
1584 buf,
1585 existing_access,
1586 OpAccessType::Load
1587 );
1588 }
1589
1590 ret.insert(buf_key, OpAccessType::Load);
1591 }
1592
1593 if let Some(buffer) = node.store_buffer() {
1594 let buf = as_buf(buffer);
1595 let buf_key = UOpKey(buf.clone());
1596
1597 if let Some(&existing_access) = ret.get(&buf_key)
1598 && existing_access != OpAccessType::Store
1599 {
1600 panic!(
1601 "buffer accessed with conflicting ops: {:?} (existing: {:?}, new: {:?})",
1602 buf,
1603 existing_access,
1604 OpAccessType::Store
1605 );
1606 }
1607
1608 ret.insert(buf_key, OpAccessType::Store);
1609 }
1610 }
1611
1612 ret
1613}
1614
1615fn late_buffer_view(compute: &Arc<UOp>, bufferize: &Arc<UOp>) -> Option<Arc<UOp>> {
1623 use morok_ir::uop::cached_property::CachedProperty;
1624 use morok_ir::uop::properties::VminVmaxProperty;
1625
1626 let Op::Bufferize { opts, ranges, .. } = bufferize.op() else { return None };
1627
1628 if !matches!(&opts.device, Some(d) if d.is_disk()) {
1630 return None;
1631 }
1632
1633 let size: usize = ranges
1635 .iter()
1636 .map(|r| {
1637 if let Op::Range { end, .. } = r.op()
1638 && let (_, morok_ir::ConstValue::Int(v)) = VminVmaxProperty::get(end)
1639 {
1640 return *v as usize;
1641 }
1642 if let Op::Const(_) = r.op() {
1643 return 1; }
1645 1
1646 })
1647 .product();
1648
1649 let mut x = compute.clone();
1654 loop {
1655 if x.op().sources().iter().any(|s| matches!(s.op(), Op::Index { .. })) {
1657 break;
1658 }
1659 if matches!(x.op(), Op::BitCast { .. } | Op::Contiguous { .. }) {
1661 x = x.op().sources().first()?.clone();
1662 continue;
1663 }
1664 if matches!(x.op(), Op::Unary(..) | Op::Binary(..) | Op::Ternary(..) | Op::Cast { .. }) {
1666 return None;
1667 }
1668 x = x.op().sources().first()?.clone();
1669 }
1670 let index = x.op().sources().iter().find(|s| matches!(s.op(), Op::Index { .. }))?.clone();
1671
1672 let offset: usize = if let Op::Index { indices, .. } = index.op() {
1674 if indices.is_empty() {
1675 0
1677 } else {
1678 let mut total: i64 = 0;
1680 for idx in indices.iter() {
1681 let (vmin, _) = VminVmaxProperty::get(idx);
1682 if let morok_ir::ConstValue::Int(v) = vmin {
1683 total += v;
1684 }
1685 }
1686 total.max(0) as usize
1687 }
1688 } else {
1689 0
1690 };
1691
1692 let base = index.base();
1694
1695 let buffer_view = UOp::new(Op::BufferView { buffer: base, size, offset }, compute.dtype());
1697
1698 let new_sources: Vec<Arc<UOp>> = std::iter::once(buffer_view).chain(ranges.iter().cloned()).collect();
1700 Some(UOp::bufferize(new_sources[0].clone(), new_sources[1..].to_vec(), opts.clone()))
1701}
1702
1703pub fn pm_add_buffers_patterns() -> crate::TypedPatternMatcher<super::kernel::KernelContext> {
1709 crate::patterns! {
1710 @context super::kernel::KernelContext;
1711 buf @ Bufferize { compute: _ } if matches!(buf.op(), Op::Bufferize { ranges, .. } if ranges.len() > 1)
1713 => |buf, _ctx| { flatten_bufferize(buf) },
1714 Index { buffer: mop, indices, gate } if mop.op().is_movement()
1716 => |mop, indices, gate, _ctx| {
1717 super::patterns::transform_movement_through_index(mop, indices, gate)
1718 },
1719 After { passthrough: mop, deps } if mop.op().is_movement()
1722 => |mop, deps, _ctx| {
1723 push_movement_through_after(mop, deps)
1724 },
1725 End { computation: mop, ranges } if mop.op().is_movement()
1728 => |mop, ranges, _ctx| {
1729 let src = &mop.op().sources()[0];
1730 Some(src.end(ranges.clone()))
1731 },
1732 buf @ Bufferize { compute }
1734 if matches!(compute.op(), Op::BitCast { .. } | Op::Contiguous { .. })
1735 => |buf, compute, _ctx| late_buffer_view(compute, buf),
1736 buf @ Bufferize { compute: _ } => |buf, ctx| {
1738 bufferize_to_store(buf, ctx, false)
1739 },
1740 }
1741}
1742
1743pub fn pm_add_buffers_local_patterns() -> crate::TypedPatternMatcher<super::kernel::KernelContext> {
1748 crate::patterns! {
1749 @context super::kernel::KernelContext;
1750 buf @ Bufferize { compute: _ } if matches!(buf.op(), Op::Bufferize { ranges, .. } if ranges.len() > 1)
1752 => |buf, _ctx| { flatten_bufferize(buf) },
1753 Index { buffer: mop, indices, gate } if mop.op().is_movement()
1755 => |mop, indices, gate, _ctx| {
1756 super::patterns::transform_movement_through_index(mop, indices, gate)
1757 },
1758 After { passthrough: mop, deps } if mop.op().is_movement()
1760 => |mop, deps, _ctx| {
1761 push_movement_through_after(mop, deps)
1762 },
1763 End { computation: mop, ranges } if mop.op().is_movement()
1765 => |mop, ranges, _ctx| {
1766 let src = &mop.op().sources()[0];
1767 Some(src.end(ranges.clone()))
1768 },
1769 buf @ Bufferize { compute }
1771 if matches!(compute.op(), Op::BitCast { .. } | Op::Contiguous { .. })
1772 => |buf, compute, _ctx| late_buffer_view(compute, buf),
1773 buf @ Bufferize { compute: _ } => |buf, ctx| {
1775 bufferize_to_store(buf, ctx, true)
1776 },
1777 }
1778}