1use std::fmt::Display;
2
3use super::{Branch, CoopMma, Item, NonSemantic, Plane, Scope, Select, Synchronization, Variable};
4use crate::{
5 cpa,
6 ir::{Elem, UIntKind},
7 prelude::AtomicOp,
8};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19#[allow(dead_code, missing_docs, clippy::large_enum_variant)] pub enum Operation {
21 Copy(Variable),
22 Operator(Operator),
23 Atomic(AtomicOp),
24 Metadata(Metadata),
25 Branch(Branch),
26 Synchronization(Synchronization),
27 Plane(Plane),
28 CoopMma(CoopMma),
29 NonSemantic(NonSemantic),
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
35pub struct Instruction {
36 pub out: Option<Variable>,
37 pub operation: Operation,
38}
39
40impl Instruction {
41 pub fn new(operation: impl Into<Operation>, out: Variable) -> Self {
42 Instruction {
43 out: Some(out),
44 operation: operation.into(),
45 }
46 }
47
48 pub fn out(&self) -> Variable {
49 self.out.unwrap()
50 }
51
52 pub fn item(&self) -> Item {
53 self.out().item
54 }
55}
56
57impl Operation {
58 pub fn is_pure(&self) -> bool {
64 match self {
65 Operation::Copy(_) => true,
66 Operation::Operator(_) => true,
67 Operation::Atomic(_) => false,
68 Operation::Metadata(_) => true,
69 Operation::Branch(_) => false,
70 Operation::Synchronization(_) => false,
71 Operation::Plane(_) => false,
72 Operation::CoopMma(_) => false,
73 Operation::NonSemantic(_) => false,
74 }
75 }
76}
77
78impl Display for Instruction {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match &self.operation {
81 Operation::Operator(Operator::CopyMemory(op)) => write!(
82 f,
83 "copy_mem({}[{}], {}[{}])",
84 self.out(),
85 op.out_index,
86 op.input,
87 op.in_index
88 ),
89 Operation::Operator(Operator::CopyMemoryBulk(op)) => write!(
90 f,
91 "copy_mem_bulk({}[{}], {}[{}], {})",
92 self.out(),
93 op.out_index,
94 op.input,
95 op.in_index,
96 op.len
97 ),
98 Operation::Operator(Operator::IndexAssign(op)) => {
99 write!(f, "{}[{}] = {}", self.out(), op.lhs, op.rhs)
100 }
101 Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
102 write!(f, "unchecked {}[{}] = {}", self.out(), op.lhs, op.rhs)
103 }
104 Operation::Operator(Operator::Cast(op)) => {
105 write!(f, "{} = cast<{}>({})", self.out(), self.item(), op.input)
106 }
107 Operation::Operator(Operator::Bitcast(op)) => {
108 write!(f, "{} = bitcast<{}>({})", self.out(), self.item(), op.input)
109 }
110 _ => {
111 if let Some(out) = self.out {
112 write!(f, "{out} = {}", self.operation)
113 } else {
114 write!(f, "{}", self.operation)
115 }
116 }
117 }
118 }
119}
120
121impl Display for Operation {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 match self {
124 Operation::Operator(operator) => write!(f, "{operator}"),
125 Operation::Atomic(atomic) => write!(f, "{atomic}"),
126 Operation::Metadata(metadata) => write!(f, "{metadata}"),
127 Operation::Branch(branch) => write!(f, "{branch}"),
128 Operation::Synchronization(synchronization) => write!(f, "{synchronization}"),
129 Operation::Plane(plane) => write!(f, "{plane}"),
130 Operation::CoopMma(coop_mma) => write!(f, "{coop_mma}"),
131 Operation::Copy(variable) => write!(f, "{variable}"),
132 Operation::NonSemantic(non_semantic) => write!(f, "{non_semantic}"),
133 }
134 }
135}
136
137pub fn fmt_vararg(args: &[impl Display]) -> String {
138 if args.is_empty() {
139 "".to_string()
140 } else {
141 let str = args
142 .iter()
143 .map(|it| it.to_string())
144 .collect::<Vec<_>>()
145 .join(", ");
146 format!(", {str}")
147 }
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
152#[allow(dead_code, missing_docs)] pub enum Operator {
154 Add(BinaryOperator),
155 Fma(FmaOperator),
156 Sub(BinaryOperator),
157 Mul(BinaryOperator),
158 Div(BinaryOperator),
159 Abs(UnaryOperator),
160 Exp(UnaryOperator),
161 Log(UnaryOperator),
162 Log1p(UnaryOperator),
163 Cos(UnaryOperator),
164 Sin(UnaryOperator),
165 Tanh(UnaryOperator),
166 Powf(BinaryOperator),
167 Sqrt(UnaryOperator),
168 Round(UnaryOperator),
169 Floor(UnaryOperator),
170 Ceil(UnaryOperator),
171 Erf(UnaryOperator),
172 Recip(UnaryOperator),
173 Equal(BinaryOperator),
174 NotEqual(BinaryOperator),
175 Lower(BinaryOperator),
176 Clamp(ClampOperator),
177 Greater(BinaryOperator),
178 LowerEqual(BinaryOperator),
179 GreaterEqual(BinaryOperator),
180 Cast(UnaryOperator),
181 Modulo(BinaryOperator),
182 Index(BinaryOperator),
183 CopyMemory(CopyMemoryOperator),
184 CopyMemoryBulk(CopyMemoryBulkOperator),
185 Slice(SliceOperator),
186 UncheckedIndex(BinaryOperator),
187 IndexAssign(BinaryOperator),
188 InitLine(LineInitOperator),
189 UncheckedIndexAssign(BinaryOperator),
190 And(BinaryOperator),
191 Or(BinaryOperator),
192 Not(UnaryOperator),
193 Neg(UnaryOperator),
194 Max(BinaryOperator),
195 Min(BinaryOperator),
196 BitwiseAnd(BinaryOperator),
197 BitwiseOr(BinaryOperator),
198 BitwiseXor(BinaryOperator),
199 ShiftLeft(BinaryOperator),
200 ShiftRight(BinaryOperator),
201 CountOnes(UnaryOperator),
202 ReverseBits(UnaryOperator),
203 Remainder(BinaryOperator),
204 Bitcast(UnaryOperator),
205 Magnitude(UnaryOperator),
206 Normalize(UnaryOperator),
207 Dot(BinaryOperator),
208 Select(Select),
210}
211
212impl Display for Operator {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 match self {
215 Operator::Add(op) => write!(f, "{} + {}", op.lhs, op.rhs),
216 Operator::Fma(op) => write!(f, "{} * {} + {}", op.a, op.b, op.c),
217 Operator::Sub(op) => write!(f, "{} - {}", op.lhs, op.rhs),
218 Operator::Mul(op) => write!(f, "{} * {}", op.lhs, op.rhs),
219 Operator::Div(op) => write!(f, "{} / {}", op.lhs, op.rhs),
220 Operator::Abs(op) => write!(f, "{}.abs()", op.input),
221 Operator::Exp(op) => write!(f, "{}.exp()", op.input),
222 Operator::Log(op) => write!(f, "{}.log()", op.input),
223 Operator::Log1p(op) => write!(f, "{}.log_1p()", op.input),
224 Operator::Cos(op) => write!(f, "{}.cos()", op.input),
225 Operator::Sin(op) => write!(f, "{}.sin()", op.input),
226 Operator::Tanh(op) => write!(f, "{}.tanh()", op.input),
227 Operator::Powf(op) => write!(f, "{}.pow({})", op.lhs, op.rhs),
228 Operator::Sqrt(op) => write!(f, "{}.sqrt()", op.input),
229 Operator::Round(op) => write!(f, "{}.round()", op.input),
230 Operator::Floor(op) => write!(f, "{}.floor()", op.input),
231 Operator::Ceil(op) => write!(f, "{}.ceil()", op.input),
232 Operator::Erf(op) => write!(f, "{}.erf()", op.input),
233 Operator::Recip(op) => write!(f, "{}.recip()", op.input),
234 Operator::Equal(op) => write!(f, "{} == {}", op.lhs, op.rhs),
235 Operator::NotEqual(op) => write!(f, "{} != {}", op.lhs, op.rhs),
236 Operator::Lower(op) => write!(f, "{} < {}", op.lhs, op.rhs),
237 Operator::Clamp(op) => {
238 write!(f, "{}.clamp({}, {})", op.input, op.min_value, op.max_value)
239 }
240 Operator::Greater(op) => write!(f, "{} > {}", op.lhs, op.rhs),
241 Operator::LowerEqual(op) => write!(f, "{} <= {}", op.lhs, op.rhs),
242 Operator::GreaterEqual(op) => write!(f, "{} >= {}", op.lhs, op.rhs),
243 Operator::Modulo(op) => write!(f, "{} % {}", op.lhs, op.rhs),
244 Operator::Index(op) => write!(f, "{}[{}]", op.lhs, op.rhs),
245 Operator::CopyMemory(op) => {
246 write!(f, "[{}] = {}[{}]", op.out_index, op.input, op.in_index)
247 }
248 Operator::CopyMemoryBulk(op) => write!(
249 f,
250 "memcpy([{}], {}[{}], {})",
251 op.input, op.in_index, op.out_index, op.len
252 ),
253 Operator::Slice(op) => write!(f, "{}[{}..{}]", op.input, op.start, op.end),
254 Operator::UncheckedIndex(op) => {
255 write!(f, "unchecked {}[{}]", op.lhs, op.rhs)
256 }
257 Operator::IndexAssign(op) => write!(f, "[{}] = {}", op.lhs, op.rhs),
258 Operator::UncheckedIndexAssign(op) => {
259 write!(f, "unchecked [{}] = {}", op.lhs, op.rhs)
260 }
261 Operator::And(op) => write!(f, "{} && {}", op.lhs, op.rhs),
262 Operator::Or(op) => write!(f, "{} || {}", op.lhs, op.rhs),
263 Operator::Not(op) => write!(f, "!{}", op.input),
264 Operator::Neg(op) => write!(f, "-{}", op.input),
265 Operator::Max(op) => write!(f, "{}.max({})", op.lhs, op.rhs),
266 Operator::Min(op) => write!(f, "{}.min({})", op.lhs, op.rhs),
267 Operator::BitwiseAnd(op) => write!(f, "{} & {}", op.lhs, op.rhs),
268 Operator::BitwiseOr(op) => write!(f, "{} | {}", op.lhs, op.rhs),
269 Operator::BitwiseXor(op) => write!(f, "{} ^ {}", op.lhs, op.rhs),
270 Operator::CountOnes(op) => write!(f, "{}.count_bits()", op.input),
271 Operator::ReverseBits(op) => write!(f, "{}.reverse_bits()", op.input),
272 Operator::ShiftLeft(op) => write!(f, "{} << {}", op.lhs, op.rhs),
273 Operator::ShiftRight(op) => write!(f, "{} >> {}", op.lhs, op.rhs),
274 Operator::Remainder(op) => write!(f, "{} rem {}", op.lhs, op.rhs),
275 Operator::Magnitude(op) => write!(f, "{}.length()", op.input),
276 Operator::Normalize(op) => write!(f, "{}.normalize()", op.input),
277 Operator::Dot(op) => write!(f, "{}.dot({})", op.lhs, op.rhs),
278 Operator::InitLine(init) => {
279 let inits = init
280 .inputs
281 .iter()
282 .map(|input| format!("{input}"))
283 .collect::<Vec<_>>();
284 write!(f, "vec({})", inits.join(", "))
285 }
286 Operator::Select(op) => {
287 write!(f, "{} ? {} : {}", op.cond, op.then, op.or_else)
288 }
289 Operator::Cast(op) => write!(f, "cast({})", op.input),
290 Operator::Bitcast(op) => write!(f, "bitcast({})", op.input),
291 }
292 }
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
297#[allow(missing_docs)]
298pub enum Metadata {
299 Rank { var: Variable },
301 Stride { dim: Variable, var: Variable },
303 Shape { dim: Variable, var: Variable },
305 Length { var: Variable },
307 BufferLength { var: Variable },
309}
310
311impl Display for Metadata {
312 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 match self {
314 Metadata::Rank { var } => write!(f, "rank({})", var),
315 Metadata::Stride { dim, var } => write!(f, "{}.strides[{}]", var, dim),
316 Metadata::Shape { dim, var } => write!(f, "{}.shape[{}]", var, dim),
317 Metadata::Length { var } => write!(f, "{}.len()", var),
318 Metadata::BufferLength { var } => write!(f, "buffer_len({})", var),
319 }
320 }
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
324#[allow(missing_docs)]
325pub struct BinaryOperator {
326 pub lhs: Variable,
327 pub rhs: Variable,
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
331#[allow(missing_docs)]
332pub struct UnaryOperator {
333 pub input: Variable,
334}
335
336#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
337#[allow(missing_docs)]
338pub struct LineInitOperator {
339 pub inputs: Vec<Variable>,
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
343#[allow(missing_docs)]
344pub struct CopyMemoryOperator {
345 pub out_index: Variable,
346 pub input: Variable,
347 pub in_index: Variable,
348}
349
350#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
351#[allow(missing_docs)]
352pub struct CopyMemoryBulkOperator {
353 pub out_index: Variable,
354 pub input: Variable,
355 pub in_index: Variable,
356 pub len: u32,
357}
358
359#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
360#[allow(missing_docs)]
361pub struct ClampOperator {
362 pub input: Variable,
363 pub min_value: Variable,
364 pub max_value: Variable,
365}
366
367#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
368#[allow(missing_docs)]
369pub struct SliceOperator {
370 pub input: Variable,
371 pub start: Variable,
372 pub end: Variable,
373}
374
375#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
376#[allow(missing_docs)]
377pub struct CompareAndSwapOperator {
378 pub input: Variable,
379 pub cmp: Variable,
380 pub val: Variable,
381}
382
383#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
384#[allow(missing_docs)]
385pub struct ReadGlobalOperator {
386 pub variable: Variable,
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
390#[allow(missing_docs)]
391pub struct ReadGlobalWithLayoutOperator {
392 pub variable: Variable,
393 pub tensor_read_pos: usize,
394 pub tensor_layout_pos: usize,
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
398#[allow(missing_docs)]
399pub struct FmaOperator {
400 pub a: Variable,
401 pub b: Variable,
402 pub c: Variable,
403}
404
405#[allow(missing_docs)]
406pub fn expand_checked_index_assign(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
407 let array_len = scope.create_local(Item::new(Elem::UInt(UIntKind::U32)));
408 let inside_bound = scope.create_local(Item::new(Elem::Bool));
409
410 if out.has_buffer_length() {
411 cpa!(scope, array_len = buffer_len(out));
412 } else {
413 cpa!(scope, array_len = len(out));
414 }
415
416 cpa!(scope, inside_bound = lhs < array_len);
417 cpa!(scope, if(inside_bound).then(|scope| {
418 cpa!(scope, unchecked(out[lhs]) = rhs);
419 }));
420}
421
422impl From<Operator> for Operation {
423 fn from(val: Operator) -> Self {
424 Operation::Operator(val)
425 }
426}
427
428impl From<Branch> for Operation {
429 fn from(value: Branch) -> Self {
430 Self::Branch(value)
431 }
432}
433
434impl From<Branch> for Instruction {
435 fn from(value: Branch) -> Self {
436 Instruction {
437 out: None,
438 operation: value.into(),
439 }
440 }
441}
442
443impl From<Synchronization> for Operation {
444 fn from(value: Synchronization) -> Self {
445 Self::Synchronization(value)
446 }
447}
448
449impl From<Synchronization> for Instruction {
450 fn from(value: Synchronization) -> Self {
451 Instruction {
452 out: None,
453 operation: value.into(),
454 }
455 }
456}
457
458impl From<Metadata> for Operation {
459 fn from(val: Metadata) -> Self {
460 Operation::Metadata(val)
461 }
462}
463
464impl From<NonSemantic> for Operation {
465 fn from(val: NonSemantic) -> Self {
466 Operation::NonSemantic(val)
467 }
468}
469
470impl From<NonSemantic> for Instruction {
471 fn from(value: NonSemantic) -> Self {
472 Instruction {
473 out: None,
474 operation: value.into(),
475 }
476 }
477}