1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
//! Graph organization: sink, group, assign, contiguous.
//!
//! This module contains graph organization and optimization operations:
//! - Graph structure: sink, group
//! - Assignment: assign
//! - Dependencies: after
//! - Materialization: detach, contiguous, contiguous_backward
//! - Optimization hints: precast
//! - Custom code: custom, customi
use std::sync::Arc;
use morok_dtype::DType;
use smallvec::SmallVec;
use crate::op::Op;
use crate::uop::UOp;
impl UOp {
// =========================================================================
// Graph Structure
// =========================================================================
/// Create a sink operation (graph termination).
///
/// Sink marks outputs that must be evaluated. All sources are dependencies.
pub fn sink(sources: Vec<Arc<Self>>) -> Arc<Self> {
Self::new(Op::Sink { sources: SmallVec::from_vec(sources) }, DType::Void)
}
/// Create a group operation (merging/organizing related ops).
///
/// Group is a NOOP that helps organize related operations together.
/// It passes through the first source while ensuring all sources are dependencies.
pub fn group(sources: Vec<Arc<Self>>) -> Arc<Self> {
let dtype = if sources.is_empty() { DType::Void } else { sources[0].dtype.clone() };
Self::new(Op::Group { sources: SmallVec::from_vec(sources) }, dtype)
}
// =========================================================================
// Assignment
// =========================================================================
/// In-place assignment.
///
/// # Arguments
/// * `target` - The INDEX operation for the assignment destination
/// * `value` - The value to assign
pub fn assign(target: Arc<Self>, value: Arc<Self>) -> Arc<Self> {
Self::assign_with_mops(target, value, None)
}
/// In-place assignment with movement ops chain.
///
/// The `movement_ops` parameter captures shape transformations from the
/// original target, used during bufferize_to_store to apply the same
/// transformations to the result buffer.
pub fn assign_with_mops(target: Arc<Self>, value: Arc<Self>, movement_ops: Option<Arc<Self>>) -> Arc<Self> {
let dtype = target.dtype();
Self::new(Op::Assign { target, value, movement_ops }, dtype)
}
// =========================================================================
// Dependencies
// =========================================================================
/// Ordering constraint: self depends on deps.
///
/// # Arguments
/// * `deps` - Dependencies that must complete before this value is used
///
/// # Panics (debug only)
/// Panics if self is a control flow node (Range, End)
pub fn after(self: &Arc<Self>, deps: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
#[cfg(debug_assertions)]
debug_assert!(
!matches!(self.op(), Op::Range { .. } | Op::End { .. }),
"AFTER passthrough must be data-producing node, got {:?} (id={})",
self.op(),
self.id
);
let dtype = self.dtype();
Self::new(Op::After { passthrough: self.clone(), deps }, dtype)
}
// =========================================================================
// Materialization
// =========================================================================
/// Detach from gradient flow / force materialization.
pub fn detach(self: &Arc<Self>) -> Arc<Self> {
let dtype = self.dtype();
Self::new(Op::Detach { src: self.clone() }, dtype)
}
/// Ensure contiguous memory layout.
///
/// Elides the CONTIGUOUS wrapper when the source is already contiguous:
/// - Already a CONTIGUOUS node (no double wrapping)
/// - Has buffer identity (BUFFER, or RESHAPE/MULTI chain to BUFFER)
///
/// Based on Tinygrad's `UOp.contiguous()` (ops.py:463-466).
pub fn contiguous(self: &Arc<Self>) -> Arc<Self> {
if matches!(self.op(), Op::Contiguous { .. }) {
return self.clone();
}
if self.has_buffer_identity() {
return self.clone();
}
let dtype = self.dtype();
Self::new(Op::Contiguous { src: self.clone(), opts: smallvec::SmallVec::new() }, dtype)
}
/// Ensure contiguous memory layout with optimization hints.
///
/// The hints are extracted during rangeify and passed to the optimizer.
/// Based on Tinygrad's CONTIGUOUS.arg which carries Opt tuples.
pub fn contiguous_with_opts(
self: &Arc<Self>,
opts: smallvec::SmallVec<[crate::types::ContiguousHint; 4]>,
) -> Arc<Self> {
let dtype = self.dtype();
Self::new(Op::Contiguous { src: self.clone(), opts }, dtype)
}
/// Contiguous backward pass.
pub fn contiguous_backward(self: &Arc<Self>) -> Arc<Self> {
let dtype = self.dtype();
Self::new(Op::ContiguousBackward { src: self.clone() }, dtype)
}
// =========================================================================
// Optimization Hints
// =========================================================================
/// Optimizer hint to force materialization before type conversion.
///
/// Inserted before BITCAST to ensure the source is rendered separately
/// in codegen (prevents invalid cast fusion).
pub fn precast(self: &Arc<Self>) -> Arc<Self> {
let dtype = self.dtype();
Self::new(Op::Precast { src: self.clone() }, dtype)
}
// =========================================================================
// Custom Code
// =========================================================================
/// Inject custom code as a statement in the generated kernel.
///
/// `deps` are UOps whose rendered names can be referenced in `code`.
/// `dtype` specifies the result type (often Void for statements).
pub fn custom(deps: SmallVec<[Arc<Self>; 4]>, code: String, dtype: DType) -> Arc<Self> {
Self::new(Op::Custom { deps, code }, dtype)
}
/// Inject custom code as an inline expression.
///
/// Unlike `custom` (statement), this is substituted directly into expressions.
/// `deps` provide values to reference; result has specified `dtype`.
pub fn customi(deps: SmallVec<[Arc<Self>; 4]>, code: String, dtype: DType) -> Arc<Self> {
Self::new(Op::CustomI { deps, code }, dtype)
}
}