1use core::fmt::Display;
2
3use super::{Branch, CoopMma, Item, NonSemantic, Plane, Synchronization, Variable};
4use crate::{
5 Arithmetic, AtomicOp, Bitwise, Metadata, OperationArgs, OperationReflect, Operator, TmaOps,
6 comparison::Comparison,
7};
8use crate::{BarrierOps, SourceLoc, TypeHash};
9use alloc::{
10 format,
11 string::{String, ToString},
12 vec::Vec,
13};
14use derive_more::derive::From;
15
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, From, OperationReflect)]
25#[operation(opcode_name = OpCode)]
26#[allow(dead_code, missing_docs, clippy::large_enum_variant)] pub enum Operation {
28 #[operation(pure)]
29 #[from(ignore)]
30 Copy(Variable),
31 #[operation(nested)]
32 Arithmetic(Arithmetic),
33 #[operation(nested)]
34 Comparison(Comparison),
35 #[operation(nested)]
36 Bitwise(Bitwise),
37 #[operation(nested)]
38 Operator(Operator),
39 #[operation(nested)]
40 Atomic(AtomicOp),
41 #[operation(nested)]
42 Metadata(Metadata),
43 #[operation(nested)]
44 Branch(Branch),
45 #[operation(nested)]
46 Synchronization(Synchronization),
47 #[operation(nested)]
48 Plane(Plane),
49 #[operation(nested)]
50 CoopMma(CoopMma),
51 #[operation(nested)]
52 Barrier(BarrierOps),
53 #[operation(nested)]
54 Tma(TmaOps),
55 #[operation(nested)]
57 NonSemantic(NonSemantic),
58}
59
60#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
62#[derive(Debug, Clone, PartialEq, Eq, Hash, TypeHash)]
63pub struct Instruction {
64 pub out: Option<Variable>,
65 pub source_loc: Option<SourceLoc>,
66 pub operation: Operation,
67}
68
69impl Instruction {
70 pub fn new(operation: impl Into<Operation>, out: Variable) -> Self {
71 Instruction {
72 out: Some(out),
73 operation: operation.into(),
74 source_loc: None,
75 }
76 }
77
78 pub fn no_out(operation: impl Into<Operation>) -> Self {
79 Instruction {
80 out: None,
81 operation: operation.into(),
82 source_loc: None,
83 }
84 }
85
86 pub fn out(&self) -> Variable {
87 self.out.unwrap()
88 }
89
90 pub fn item(&self) -> Item {
91 self.out().item
92 }
93}
94
95impl Display for Instruction {
96 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
97 match &self.operation {
98 Operation::Operator(Operator::CopyMemory(op)) => write!(
99 f,
100 "copy_mem({}[{}], {}[{}])",
101 self.out(),
102 op.out_index,
103 op.input,
104 op.in_index
105 ),
106 Operation::Operator(Operator::CopyMemoryBulk(op)) => write!(
107 f,
108 "copy_mem_bulk({}[{}], {}[{}], {})",
109 self.out(),
110 op.out_index,
111 op.input,
112 op.in_index,
113 op.len
114 ),
115 Operation::Operator(Operator::IndexAssign(op)) => {
116 write!(f, "{}[{}] = {}", self.out(), op.index, op.value)
117 }
118 Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
119 write!(f, "unchecked {}[{}] = {}", self.out(), op.index, op.value)
120 }
121 Operation::Operator(Operator::Cast(op)) => {
122 write!(f, "{} = cast<{}>({})", self.out(), self.item(), op.input)
123 }
124 Operation::Operator(Operator::Reinterpret(op)) => {
125 write!(f, "{} = bitcast<{}>({})", self.out(), self.item(), op.input)
126 }
127 _ => {
128 if let Some(out) = self.out {
129 write!(f, "{out} = {}", self.operation)
130 } else {
131 write!(f, "{}", self.operation)
132 }
133 }
134 }
135 }
136}
137
138impl Display for Operation {
139 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
140 match self {
141 Operation::Arithmetic(arithmetic) => write!(f, "{arithmetic}"),
142 Operation::Comparison(comparison) => write!(f, "{comparison}"),
143 Operation::Bitwise(bitwise) => write!(f, "{bitwise}"),
144 Operation::Operator(operator) => write!(f, "{operator}"),
145 Operation::Atomic(atomic) => write!(f, "{atomic}"),
146 Operation::Metadata(metadata) => write!(f, "{metadata}"),
147 Operation::Branch(branch) => write!(f, "{branch}"),
148 Operation::Synchronization(synchronization) => write!(f, "{synchronization}"),
149 Operation::Plane(plane) => write!(f, "{plane}"),
150 Operation::CoopMma(coop_mma) => write!(f, "{coop_mma}"),
151 Operation::Copy(variable) => write!(f, "{variable}"),
152 Operation::NonSemantic(non_semantic) => write!(f, "{non_semantic}"),
153 Operation::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
154 Operation::Tma(tma_ops) => write!(f, "{tma_ops}"),
155 }
156 }
157}
158
159pub fn fmt_vararg(args: &[impl Display]) -> String {
160 if args.is_empty() {
161 "".to_string()
162 } else {
163 let str = args
164 .iter()
165 .map(|it| it.to_string())
166 .collect::<Vec<_>>()
167 .join(", ");
168 format!(", {str}")
169 }
170}
171
172#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
173#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
174#[allow(missing_docs)]
175pub struct IndexOperator {
176 pub list: Variable,
177 pub index: Variable,
178 pub line_size: u32, }
180
181#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
182#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
183#[allow(missing_docs)]
184pub struct IndexAssignOperator {
185 pub index: Variable,
187 pub value: Variable,
188 pub line_size: u32, }
190
191#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
192#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
193#[allow(missing_docs)]
194pub struct BinaryOperator {
195 pub lhs: Variable,
196 pub rhs: Variable,
197}
198
199#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
200#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
201#[allow(missing_docs)]
202pub struct UnaryOperator {
203 pub input: Variable,
204}
205
206impl From<Branch> for Instruction {
207 fn from(value: Branch) -> Self {
208 Instruction::no_out(value)
209 }
210}
211
212impl From<Synchronization> for Instruction {
213 fn from(value: Synchronization) -> Self {
214 Instruction::no_out(value)
215 }
216}
217
218impl From<NonSemantic> for Instruction {
219 fn from(value: NonSemantic) -> Self {
220 Instruction::no_out(value)
221 }
222}