morok_ir/uop/constructors/graph.rs
1//! Graph organization: sink, group, assign, contiguous.
2//!
3//! This module contains graph organization and optimization operations:
4//! - Graph structure: sink, group
5//! - Assignment: assign
6//! - Dependencies: after
7//! - Materialization: detach, contiguous, contiguous_backward
8//! - Optimization hints: precast
9//! - Custom code: custom, customi
10
11use std::sync::Arc;
12
13use morok_dtype::DType;
14use smallvec::SmallVec;
15
16use crate::op::Op;
17use crate::uop::UOp;
18
19impl UOp {
20 // =========================================================================
21 // Graph Structure
22 // =========================================================================
23
24 /// Create a sink operation (graph termination).
25 ///
26 /// Sink marks outputs that must be evaluated. All sources are dependencies.
27 pub fn sink(sources: Vec<Arc<Self>>) -> Arc<Self> {
28 Self::new(Op::Sink { sources: SmallVec::from_vec(sources) }, DType::Void)
29 }
30
31 /// Create a group operation (merging/organizing related ops).
32 ///
33 /// Group is a NOOP that helps organize related operations together.
34 /// It passes through the first source while ensuring all sources are dependencies.
35 pub fn group(sources: Vec<Arc<Self>>) -> Arc<Self> {
36 let dtype = if sources.is_empty() { DType::Void } else { sources[0].dtype.clone() };
37 Self::new(Op::Group { sources: SmallVec::from_vec(sources) }, dtype)
38 }
39
40 // =========================================================================
41 // Assignment
42 // =========================================================================
43
44 /// In-place assignment.
45 ///
46 /// # Arguments
47 /// * `target` - The INDEX operation for the assignment destination
48 /// * `value` - The value to assign
49 pub fn assign(target: Arc<Self>, value: Arc<Self>) -> Arc<Self> {
50 Self::assign_with_mops(target, value, None)
51 }
52
53 /// In-place assignment with movement ops chain.
54 ///
55 /// The `movement_ops` parameter captures shape transformations from the
56 /// original target, used during bufferize_to_store to apply the same
57 /// transformations to the result buffer.
58 pub fn assign_with_mops(target: Arc<Self>, value: Arc<Self>, movement_ops: Option<Arc<Self>>) -> Arc<Self> {
59 let dtype = target.dtype();
60 Self::new(Op::Assign { target, value, movement_ops }, dtype)
61 }
62
63 // =========================================================================
64 // Dependencies
65 // =========================================================================
66
67 /// Ordering constraint: self depends on deps.
68 ///
69 /// # Arguments
70 /// * `deps` - Dependencies that must complete before this value is used
71 ///
72 /// # Panics (debug only)
73 /// Panics if self is a control flow node (Range, End)
74 pub fn after(self: &Arc<Self>, deps: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
75 #[cfg(debug_assertions)]
76 debug_assert!(
77 !matches!(self.op(), Op::Range { .. } | Op::End { .. }),
78 "AFTER passthrough must be data-producing node, got {:?} (id={})",
79 self.op(),
80 self.id
81 );
82
83 let dtype = self.dtype();
84 Self::new(Op::After { passthrough: self.clone(), deps }, dtype)
85 }
86
87 // =========================================================================
88 // Materialization
89 // =========================================================================
90
91 /// Detach from gradient flow / force materialization.
92 pub fn detach(self: &Arc<Self>) -> Arc<Self> {
93 let dtype = self.dtype();
94 Self::new(Op::Detach { src: self.clone() }, dtype)
95 }
96
97 /// Ensure contiguous memory layout.
98 ///
99 /// Elides the CONTIGUOUS wrapper when the source is already contiguous:
100 /// - Already a CONTIGUOUS node (no double wrapping)
101 /// - Has buffer identity (BUFFER, or RESHAPE/MULTI chain to BUFFER)
102 ///
103 /// Based on Tinygrad's `UOp.contiguous()` (ops.py:463-466).
104 pub fn contiguous(self: &Arc<Self>) -> Arc<Self> {
105 if matches!(self.op(), Op::Contiguous { .. }) {
106 return self.clone();
107 }
108 if self.has_buffer_identity() {
109 return self.clone();
110 }
111 let dtype = self.dtype();
112 Self::new(Op::Contiguous { src: self.clone(), opts: smallvec::SmallVec::new() }, dtype)
113 }
114
115 /// Ensure contiguous memory layout with optimization hints.
116 ///
117 /// The hints are extracted during rangeify and passed to the optimizer.
118 /// Based on Tinygrad's CONTIGUOUS.arg which carries Opt tuples.
119 pub fn contiguous_with_opts(
120 self: &Arc<Self>,
121 opts: smallvec::SmallVec<[crate::types::ContiguousHint; 4]>,
122 ) -> Arc<Self> {
123 let dtype = self.dtype();
124 Self::new(Op::Contiguous { src: self.clone(), opts }, dtype)
125 }
126
127 /// Contiguous backward pass.
128 pub fn contiguous_backward(self: &Arc<Self>) -> Arc<Self> {
129 let dtype = self.dtype();
130 Self::new(Op::ContiguousBackward { src: self.clone() }, dtype)
131 }
132
133 // =========================================================================
134 // Optimization Hints
135 // =========================================================================
136
137 /// Optimizer hint to force materialization before type conversion.
138 ///
139 /// Inserted before BITCAST to ensure the source is rendered separately
140 /// in codegen (prevents invalid cast fusion).
141 pub fn precast(self: &Arc<Self>) -> Arc<Self> {
142 let dtype = self.dtype();
143 Self::new(Op::Precast { src: self.clone() }, dtype)
144 }
145
146 // =========================================================================
147 // Custom Code
148 // =========================================================================
149
150 /// Inject custom code as a statement in the generated kernel.
151 ///
152 /// `deps` are UOps whose rendered names can be referenced in `code`.
153 /// `dtype` specifies the result type (often Void for statements).
154 pub fn custom(deps: SmallVec<[Arc<Self>; 4]>, code: String, dtype: DType) -> Arc<Self> {
155 Self::new(Op::Custom { deps, code }, dtype)
156 }
157
158 /// Inject custom code as an inline expression.
159 ///
160 /// Unlike `custom` (statement), this is substituted directly into expressions.
161 /// `deps` provide values to reference; result has specified `dtype`.
162 pub fn customi(deps: SmallVec<[Arc<Self>; 4]>, code: String, dtype: DType) -> Arc<Self> {
163 Self::new(Op::CustomI { deps, code }, dtype)
164 }
165}