1use core::fmt::Display;
2
3use super::{Branch, CoopMma, NonSemantic, Plane, Synchronization, Type, Variable};
4use crate::{
5 Arithmetic, AtomicOp, Bitwise, InstructionModes, Metadata, OperationArgs, OperationReflect,
6 Operator, TmaOps, comparison::Comparison, marker::Marker,
7};
8use crate::{BarrierOps, SourceLoc, TypeHash};
9use alloc::{
10 format,
11 string::{String, ToString},
12 vec::Vec,
13};
14use derive_more::derive::From;
15
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, From, OperationReflect)]
25#[operation(opcode_name = OpCode)]
26#[allow(dead_code, missing_docs, clippy::large_enum_variant)] pub enum Operation {
28 #[operation(pure)]
29 #[from(ignore)]
30 Copy(Variable),
31 #[operation(nested)]
32 Arithmetic(Arithmetic),
33 #[operation(nested)]
34 Comparison(Comparison),
35 #[operation(nested)]
36 Bitwise(Bitwise),
37 #[operation(nested)]
38 Operator(Operator),
39 #[operation(nested)]
40 Atomic(AtomicOp),
41 #[operation(nested)]
42 Metadata(Metadata),
43 #[operation(nested)]
44 Branch(Branch),
45 #[operation(nested)]
46 Synchronization(Synchronization),
47 #[operation(nested)]
48 Plane(Plane),
49 #[operation(nested)]
50 CoopMma(CoopMma),
51 #[operation(nested)]
52 Barrier(BarrierOps),
53 #[operation(nested)]
54 Tma(TmaOps),
55 #[operation(nested)]
57 NonSemantic(NonSemantic),
58 #[operation(nested)]
60 Marker(Marker),
61}
62
63#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
65#[derive(Debug, Clone, PartialEq, Eq, Hash, TypeHash)]
66pub struct Instruction {
67 pub out: Option<Variable>,
68 pub source_loc: Option<SourceLoc>,
69 pub modes: InstructionModes,
70 pub operation: Operation,
71}
72
73impl Instruction {
74 pub fn new(operation: impl Into<Operation>, out: Variable) -> Self {
75 Instruction {
76 out: Some(out),
77 operation: operation.into(),
78 source_loc: None,
79 modes: Default::default(),
80 }
81 }
82
83 pub fn no_out(operation: impl Into<Operation>) -> Self {
84 Instruction {
85 out: None,
86 operation: operation.into(),
87 source_loc: None,
88 modes: Default::default(),
89 }
90 }
91
92 pub fn out(&self) -> Variable {
93 self.out.unwrap()
94 }
95
96 pub fn ty(&self) -> Type {
97 self.out().ty
98 }
99}
100
101impl Display for Instruction {
102 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
103 match &self.operation {
104 Operation::Operator(Operator::CopyMemory(op)) => write!(
105 f,
106 "copy_mem({}[{}], {}[{}])",
107 self.out(),
108 op.out_index,
109 op.input,
110 op.in_index
111 ),
112 Operation::Operator(Operator::CopyMemoryBulk(op)) => write!(
113 f,
114 "copy_mem_bulk({}[{}], {}[{}], {})",
115 self.out(),
116 op.out_index,
117 op.input,
118 op.in_index,
119 op.len
120 ),
121 Operation::Operator(Operator::IndexAssign(op)) => {
122 write!(
123 f,
124 "{}[{}] = {} : ({}, {}) -> ({})",
125 self.out(),
126 op.index,
127 op.value,
128 op.index.ty,
129 op.value.ty,
130 self.out().ty,
131 )
132 }
133 Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
134 write!(
135 f,
136 "unchecked {}[{}] = {} : ({}, {}) -> ({})",
137 self.out(),
138 op.index,
139 op.value,
140 op.index.ty,
141 op.value.ty,
142 self.out().ty,
143 )
144 }
145 Operation::Operator(Operator::Cast(op)) => {
146 write!(
147 f,
148 "{} = cast<{}>({}) : ({}) -> ({})",
149 self.out(),
150 self.ty(),
151 op.input,
152 op.input.ty,
153 self.out().ty,
154 )
155 }
156 Operation::Operator(Operator::Reinterpret(op)) => {
157 write!(f, "{} = bitcast<{}>({})", self.out(), self.ty(), op.input)
158 }
159 _ => {
160 if let Some(out) = self.out {
161 let mut vars_str = String::new();
162 for (i, var) in self.operation.args().unwrap_or_default().iter().enumerate() {
163 if i != 0 {
164 vars_str.push_str(", ");
165 }
166 vars_str.push_str(&var.ty.to_string());
167 }
168 write!(
169 f,
170 "{out} = {} : ({}) -> ({})",
171 self.operation, vars_str, out.ty
172 )
173 } else {
174 write!(f, "{}", self.operation)
175 }
176 }
177 }
178 }
179}
180
181impl Display for Operation {
182 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
183 match self {
184 Operation::Arithmetic(arithmetic) => write!(f, "{arithmetic}"),
185 Operation::Comparison(comparison) => write!(f, "{comparison}"),
186 Operation::Bitwise(bitwise) => write!(f, "{bitwise}"),
187 Operation::Operator(operator) => write!(f, "{operator}"),
188 Operation::Atomic(atomic) => write!(f, "{atomic}"),
189 Operation::Metadata(metadata) => write!(f, "{metadata}"),
190 Operation::Branch(branch) => write!(f, "{branch}"),
191 Operation::Synchronization(synchronization) => write!(f, "{synchronization}"),
192 Operation::Plane(plane) => write!(f, "{plane}"),
193 Operation::CoopMma(coop_mma) => write!(f, "{coop_mma}"),
194 Operation::Copy(variable) => write!(f, "{variable}"),
195 Operation::NonSemantic(non_semantic) => write!(f, "{non_semantic}"),
196 Operation::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
197 Operation::Tma(tma_ops) => write!(f, "{tma_ops}"),
198 Operation::Marker(marker) => write!(f, "{marker}"),
199 }
200 }
201}
202
203pub fn fmt_vararg(args: &[impl Display]) -> String {
204 if args.is_empty() {
205 "".to_string()
206 } else {
207 let str = args
208 .iter()
209 .map(|it| it.to_string())
210 .collect::<Vec<_>>()
211 .join(", ");
212 format!(", {str}")
213 }
214}
215
216#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
217#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
218#[allow(missing_docs)]
219pub struct IndexOperator {
220 pub list: Variable,
221 pub index: Variable,
222 pub line_size: u32, pub unroll_factor: u32, }
225
226#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
227#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
228#[allow(missing_docs)]
229pub struct IndexAssignOperator {
230 pub index: Variable,
232 pub value: Variable,
233 pub line_size: u32, pub unroll_factor: u32, }
236
237#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
238#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
239#[allow(missing_docs)]
240pub struct BinaryOperator {
241 pub lhs: Variable,
242 pub rhs: Variable,
243}
244
245#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
246#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
247#[allow(missing_docs)]
248pub struct UnaryOperator {
249 pub input: Variable,
250}
251
252impl From<Branch> for Instruction {
253 fn from(value: Branch) -> Self {
254 Instruction::no_out(value)
255 }
256}
257
258impl From<Synchronization> for Instruction {
259 fn from(value: Synchronization) -> Self {
260 Instruction::no_out(value)
261 }
262}
263
264impl From<NonSemantic> for Instruction {
265 fn from(value: NonSemantic) -> Self {
266 Instruction::no_out(value)
267 }
268}