1use core::fmt::Display;
2
3use super::{Branch, CoopMma, NonSemantic, Plane, Synchronization, Type, Variable};
4use crate::{
5 Arithmetic, AtomicOp, Bitwise, Metadata, OperationArgs, OperationReflect, Operator, TmaOps,
6 comparison::Comparison,
7};
8use crate::{BarrierOps, SourceLoc, TypeHash};
9use alloc::{
10 format,
11 string::{String, ToString},
12 vec::Vec,
13};
14use derive_more::derive::From;
15
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, From, OperationReflect)]
25#[operation(opcode_name = OpCode)]
26#[allow(dead_code, missing_docs, clippy::large_enum_variant)] pub enum Operation {
28 #[operation(pure)]
29 #[from(ignore)]
30 Copy(Variable),
31 #[operation(nested)]
32 Arithmetic(Arithmetic),
33 #[operation(nested)]
34 Comparison(Comparison),
35 #[operation(nested)]
36 Bitwise(Bitwise),
37 #[operation(nested)]
38 Operator(Operator),
39 #[operation(nested)]
40 Atomic(AtomicOp),
41 #[operation(nested)]
42 Metadata(Metadata),
43 #[operation(nested)]
44 Branch(Branch),
45 #[operation(nested)]
46 Synchronization(Synchronization),
47 #[operation(nested)]
48 Plane(Plane),
49 #[operation(nested)]
50 CoopMma(CoopMma),
51 #[operation(nested)]
52 Barrier(BarrierOps),
53 #[operation(nested)]
54 Tma(TmaOps),
55 #[operation(nested)]
57 NonSemantic(NonSemantic),
58 Free(Variable),
61}
62
63#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
65#[derive(Debug, Clone, PartialEq, Eq, Hash, TypeHash)]
66pub struct Instruction {
67 pub out: Option<Variable>,
68 pub source_loc: Option<SourceLoc>,
69 pub operation: Operation,
70}
71
72impl Instruction {
73 pub fn new(operation: impl Into<Operation>, out: Variable) -> Self {
74 Instruction {
75 out: Some(out),
76 operation: operation.into(),
77 source_loc: None,
78 }
79 }
80
81 pub fn no_out(operation: impl Into<Operation>) -> Self {
82 Instruction {
83 out: None,
84 operation: operation.into(),
85 source_loc: None,
86 }
87 }
88
89 pub fn out(&self) -> Variable {
90 self.out.unwrap()
91 }
92
93 pub fn ty(&self) -> Type {
94 self.out().ty
95 }
96}
97
98impl Display for Instruction {
99 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100 match &self.operation {
101 Operation::Operator(Operator::CopyMemory(op)) => write!(
102 f,
103 "copy_mem({}[{}], {}[{}])",
104 self.out(),
105 op.out_index,
106 op.input,
107 op.in_index
108 ),
109 Operation::Operator(Operator::CopyMemoryBulk(op)) => write!(
110 f,
111 "copy_mem_bulk({}[{}], {}[{}], {})",
112 self.out(),
113 op.out_index,
114 op.input,
115 op.in_index,
116 op.len
117 ),
118 Operation::Operator(Operator::IndexAssign(op)) => {
119 write!(
120 f,
121 "{}[{}] = {} : ({}, {}) -> ({})",
122 self.out(),
123 op.index,
124 op.value,
125 op.index.ty,
126 op.value.ty,
127 self.out().ty,
128 )
129 }
130 Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
131 write!(
132 f,
133 "unchecked {}[{}] = {} : ({}, {}) -> ({})",
134 self.out(),
135 op.index,
136 op.value,
137 op.index.ty,
138 op.value.ty,
139 self.out().ty,
140 )
141 }
142 Operation::Operator(Operator::Cast(op)) => {
143 write!(
144 f,
145 "{} = cast<{}>({}) : ({}) -> ({})",
146 self.out(),
147 self.ty(),
148 op.input,
149 op.input.ty,
150 self.out().ty,
151 )
152 }
153 Operation::Operator(Operator::Reinterpret(op)) => {
154 write!(f, "{} = bitcast<{}>({})", self.out(), self.ty(), op.input)
155 }
156 _ => {
157 if let Some(out) = self.out {
158 let mut vars_str = String::new();
159 for (i, var) in self.operation.args().unwrap_or_default().iter().enumerate() {
160 if i != 0 {
161 vars_str.push_str(", ");
162 }
163 vars_str.push_str(&var.ty.to_string());
164 }
165 write!(
166 f,
167 "{out} = {} : ({}) -> ({})",
168 self.operation, vars_str, out.ty
169 )
170 } else {
171 write!(f, "{}", self.operation)
172 }
173 }
174 }
175 }
176}
177
178impl Display for Operation {
179 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
180 match self {
181 Operation::Arithmetic(arithmetic) => write!(f, "{arithmetic}"),
182 Operation::Comparison(comparison) => write!(f, "{comparison}"),
183 Operation::Bitwise(bitwise) => write!(f, "{bitwise}"),
184 Operation::Operator(operator) => write!(f, "{operator}"),
185 Operation::Atomic(atomic) => write!(f, "{atomic}"),
186 Operation::Metadata(metadata) => write!(f, "{metadata}"),
187 Operation::Branch(branch) => write!(f, "{branch}"),
188 Operation::Synchronization(synchronization) => write!(f, "{synchronization}"),
189 Operation::Plane(plane) => write!(f, "{plane}"),
190 Operation::CoopMma(coop_mma) => write!(f, "{coop_mma}"),
191 Operation::Copy(variable) => write!(f, "{variable}"),
192 Operation::NonSemantic(non_semantic) => write!(f, "{non_semantic}"),
193 Operation::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
194 Operation::Tma(tma_ops) => write!(f, "{tma_ops}"),
195 Operation::Free(var) => write!(f, "free({var})"),
196 }
197 }
198}
199
200pub fn fmt_vararg(args: &[impl Display]) -> String {
201 if args.is_empty() {
202 "".to_string()
203 } else {
204 let str = args
205 .iter()
206 .map(|it| it.to_string())
207 .collect::<Vec<_>>()
208 .join(", ");
209 format!(", {str}")
210 }
211}
212
213#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
214#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
215#[allow(missing_docs)]
216pub struct IndexOperator {
217 pub list: Variable,
218 pub index: Variable,
219 pub line_size: u32, pub unroll_factor: u32, }
222
223#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
224#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
225#[allow(missing_docs)]
226pub struct IndexAssignOperator {
227 pub index: Variable,
229 pub value: Variable,
230 pub line_size: u32, pub unroll_factor: u32, }
233
234#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
235#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
236#[allow(missing_docs)]
237pub struct BinaryOperator {
238 pub lhs: Variable,
239 pub rhs: Variable,
240}
241
242#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
243#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
244#[allow(missing_docs)]
245pub struct UnaryOperator {
246 pub input: Variable,
247}
248
249impl From<Branch> for Instruction {
250 fn from(value: Branch) -> Self {
251 Instruction::no_out(value)
252 }
253}
254
255impl From<Synchronization> for Instruction {
256 fn from(value: Synchronization) -> Self {
257 Instruction::no_out(value)
258 }
259}
260
261impl From<NonSemantic> for Instruction {
262 fn from(value: NonSemantic) -> Self {
263 Instruction::no_out(value)
264 }
265}