Skip to main content

morok_ir/uop/constructors/
graph.rs

1//! Graph organization: sink, group, assign, contiguous.
2//!
3//! This module contains graph organization and optimization operations:
4//! - Graph structure: sink, group
5//! - Assignment: assign
6//! - Dependencies: after
7//! - Materialization: detach, contiguous, contiguous_backward
8//! - Optimization hints: precast
9//! - Custom code: custom, customi
10
11use std::sync::Arc;
12
13use morok_dtype::DType;
14use smallvec::SmallVec;
15
16use crate::op::Op;
17use crate::uop::UOp;
18
19impl UOp {
20    // =========================================================================
21    // Graph Structure
22    // =========================================================================
23
24    /// Create a sink operation (graph termination).
25    ///
26    /// Sink marks outputs that must be evaluated. All sources are dependencies.
27    pub fn sink(sources: Vec<Arc<Self>>) -> Arc<Self> {
28        Self::new(Op::Sink { sources: SmallVec::from_vec(sources) }, DType::Void)
29    }
30
31    /// Create a group operation (merging/organizing related ops).
32    ///
33    /// Group is a NOOP that helps organize related operations together.
34    /// It passes through the first source while ensuring all sources are dependencies.
35    pub fn group(sources: Vec<Arc<Self>>) -> Arc<Self> {
36        let dtype = if sources.is_empty() { DType::Void } else { sources[0].dtype.clone() };
37        Self::new(Op::Group { sources: SmallVec::from_vec(sources) }, dtype)
38    }
39
40    // =========================================================================
41    // Assignment
42    // =========================================================================
43
44    /// In-place assignment.
45    ///
46    /// # Arguments
47    /// * `target` - The INDEX operation for the assignment destination
48    /// * `value` - The value to assign
49    pub fn assign(target: Arc<Self>, value: Arc<Self>) -> Arc<Self> {
50        Self::assign_with_mops(target, value, None)
51    }
52
53    /// In-place assignment with movement ops chain.
54    ///
55    /// The `movement_ops` parameter captures shape transformations from the
56    /// original target, used during bufferize_to_store to apply the same
57    /// transformations to the result buffer.
58    pub fn assign_with_mops(target: Arc<Self>, value: Arc<Self>, movement_ops: Option<Arc<Self>>) -> Arc<Self> {
59        let dtype = target.dtype();
60        Self::new(Op::Assign { target, value, movement_ops }, dtype)
61    }
62
63    // =========================================================================
64    // Dependencies
65    // =========================================================================
66
67    /// Ordering constraint: self depends on deps.
68    ///
69    /// # Arguments
70    /// * `deps` - Dependencies that must complete before this value is used
71    ///
72    /// # Panics (debug only)
73    /// Panics if self is a control flow node (Range, End)
74    pub fn after(self: &Arc<Self>, deps: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
75        #[cfg(debug_assertions)]
76        debug_assert!(
77            !matches!(self.op(), Op::Range { .. } | Op::End { .. }),
78            "AFTER passthrough must be data-producing node, got {:?} (id={})",
79            self.op(),
80            self.id
81        );
82
83        let dtype = self.dtype();
84        Self::new(Op::After { passthrough: self.clone(), deps }, dtype)
85    }
86
87    // =========================================================================
88    // Materialization
89    // =========================================================================
90
91    /// Detach from gradient flow / force materialization.
92    pub fn detach(self: &Arc<Self>) -> Arc<Self> {
93        let dtype = self.dtype();
94        Self::new(Op::Detach { src: self.clone() }, dtype)
95    }
96
97    /// Ensure contiguous memory layout.
98    ///
99    /// Elides the CONTIGUOUS wrapper when the source is already contiguous:
100    /// - Already a CONTIGUOUS node (no double wrapping)
101    /// - Has buffer identity (BUFFER, or RESHAPE/MULTI chain to BUFFER)
102    ///
103    /// Based on Tinygrad's `UOp.contiguous()` (ops.py:463-466).
104    pub fn contiguous(self: &Arc<Self>) -> Arc<Self> {
105        if matches!(self.op(), Op::Contiguous { .. }) {
106            return self.clone();
107        }
108        if self.has_buffer_identity() {
109            return self.clone();
110        }
111        let dtype = self.dtype();
112        Self::new(Op::Contiguous { src: self.clone(), opts: smallvec::SmallVec::new() }, dtype)
113    }
114
115    /// Ensure contiguous memory layout with optimization hints.
116    ///
117    /// The hints are extracted during rangeify and passed to the optimizer.
118    /// Based on Tinygrad's CONTIGUOUS.arg which carries Opt tuples.
119    pub fn contiguous_with_opts(
120        self: &Arc<Self>,
121        opts: smallvec::SmallVec<[crate::types::ContiguousHint; 4]>,
122    ) -> Arc<Self> {
123        let dtype = self.dtype();
124        Self::new(Op::Contiguous { src: self.clone(), opts }, dtype)
125    }
126
127    /// Contiguous backward pass.
128    pub fn contiguous_backward(self: &Arc<Self>) -> Arc<Self> {
129        let dtype = self.dtype();
130        Self::new(Op::ContiguousBackward { src: self.clone() }, dtype)
131    }
132
133    // =========================================================================
134    // Optimization Hints
135    // =========================================================================
136
137    /// Optimizer hint to force materialization before type conversion.
138    ///
139    /// Inserted before BITCAST to ensure the source is rendered separately
140    /// in codegen (prevents invalid cast fusion).
141    pub fn precast(self: &Arc<Self>) -> Arc<Self> {
142        let dtype = self.dtype();
143        Self::new(Op::Precast { src: self.clone() }, dtype)
144    }
145
146    // =========================================================================
147    // Custom Code
148    // =========================================================================
149
150    /// Inject custom code as a statement in the generated kernel.
151    ///
152    /// `deps` are UOps whose rendered names can be referenced in `code`.
153    /// `dtype` specifies the result type (often Void for statements).
154    pub fn custom(deps: SmallVec<[Arc<Self>; 4]>, code: String, dtype: DType) -> Arc<Self> {
155        Self::new(Op::Custom { deps, code }, dtype)
156    }
157
158    /// Inject custom code as an inline expression.
159    ///
160    /// Unlike `custom` (statement), this is substituted directly into expressions.
161    /// `deps` provide values to reference; result has specified `dtype`.
162    pub fn customi(deps: SmallVec<[Arc<Self>; 4]>, code: String, dtype: DType) -> Arc<Self> {
163        Self::new(Op::CustomI { deps, code }, dtype)
164    }
165}