1use core::fmt::Display;
2
3use super::{Branch, CoopMma, NonSemantic, Plane, Synchronization, Type, Variable};
4use crate::{
5    Arithmetic, AtomicOp, Bitwise, Metadata, OperationArgs, OperationReflect, Operator, TmaOps,
6    comparison::Comparison,
7};
8use crate::{BarrierOps, SourceLoc, TypeHash};
9use alloc::{
10    format,
11    string::{String, ToString},
12    vec::Vec,
13};
14use derive_more::derive::From;
15
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, From, OperationReflect)]
25#[operation(opcode_name = OpCode)]
26#[allow(dead_code, missing_docs, clippy::large_enum_variant)] pub enum Operation {
28    #[operation(pure)]
29    #[from(ignore)]
30    Copy(Variable),
31    #[operation(nested)]
32    Arithmetic(Arithmetic),
33    #[operation(nested)]
34    Comparison(Comparison),
35    #[operation(nested)]
36    Bitwise(Bitwise),
37    #[operation(nested)]
38    Operator(Operator),
39    #[operation(nested)]
40    Atomic(AtomicOp),
41    #[operation(nested)]
42    Metadata(Metadata),
43    #[operation(nested)]
44    Branch(Branch),
45    #[operation(nested)]
46    Synchronization(Synchronization),
47    #[operation(nested)]
48    Plane(Plane),
49    #[operation(nested)]
50    CoopMma(CoopMma),
51    #[operation(nested)]
52    Barrier(BarrierOps),
53    #[operation(nested)]
54    Tma(TmaOps),
55    #[operation(nested)]
57    NonSemantic(NonSemantic),
58    Free(Variable),
61}
62
63#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
65#[derive(Debug, Clone, PartialEq, Eq, Hash, TypeHash)]
66pub struct Instruction {
67    pub out: Option<Variable>,
68    pub source_loc: Option<SourceLoc>,
69    pub operation: Operation,
70}
71
72impl Instruction {
73    pub fn new(operation: impl Into<Operation>, out: Variable) -> Self {
74        Instruction {
75            out: Some(out),
76            operation: operation.into(),
77            source_loc: None,
78        }
79    }
80
81    pub fn no_out(operation: impl Into<Operation>) -> Self {
82        Instruction {
83            out: None,
84            operation: operation.into(),
85            source_loc: None,
86        }
87    }
88
89    pub fn out(&self) -> Variable {
90        self.out.unwrap()
91    }
92
93    pub fn ty(&self) -> Type {
94        self.out().ty
95    }
96}
97
98impl Display for Instruction {
99    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100        match &self.operation {
101            Operation::Operator(Operator::CopyMemory(op)) => write!(
102                f,
103                "copy_mem({}[{}], {}[{}])",
104                self.out(),
105                op.out_index,
106                op.input,
107                op.in_index
108            ),
109            Operation::Operator(Operator::CopyMemoryBulk(op)) => write!(
110                f,
111                "copy_mem_bulk({}[{}], {}[{}], {})",
112                self.out(),
113                op.out_index,
114                op.input,
115                op.in_index,
116                op.len
117            ),
118            Operation::Operator(Operator::IndexAssign(op)) => {
119                write!(
120                    f,
121                    "{}[{}] = {}  : ({}, {}) -> ({})",
122                    self.out(),
123                    op.index,
124                    op.value,
125                    op.index.ty,
126                    op.value.ty,
127                    self.out().ty,
128                )
129            }
130            Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
131                write!(
132                    f,
133                    "unchecked {}[{}] = {} : ({}, {}) -> ({})",
134                    self.out(),
135                    op.index,
136                    op.value,
137                    op.index.ty,
138                    op.value.ty,
139                    self.out().ty,
140                )
141            }
142            Operation::Operator(Operator::Cast(op)) => {
143                write!(
144                    f,
145                    "{} = cast<{}>({}) : ({}) -> ({})",
146                    self.out(),
147                    self.ty(),
148                    op.input,
149                    op.input.ty,
150                    self.out().ty,
151                )
152            }
153            Operation::Operator(Operator::Reinterpret(op)) => {
154                write!(f, "{} = bitcast<{}>({})", self.out(), self.ty(), op.input)
155            }
156            _ => {
157                if let Some(out) = self.out {
158                    let mut vars_str = String::new();
159                    for (i, var) in self.operation.args().unwrap_or_default().iter().enumerate() {
160                        if i != 0 {
161                            vars_str.push_str(", ");
162                        }
163                        vars_str.push_str(&var.ty.to_string());
164                    }
165                    write!(
166                        f,
167                        "{out} = {} : ({}) -> ({})",
168                        self.operation, vars_str, out.ty
169                    )
170                } else {
171                    write!(f, "{}", self.operation)
172                }
173            }
174        }
175    }
176}
177
178impl Display for Operation {
179    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
180        match self {
181            Operation::Arithmetic(arithmetic) => write!(f, "{arithmetic}"),
182            Operation::Comparison(comparison) => write!(f, "{comparison}"),
183            Operation::Bitwise(bitwise) => write!(f, "{bitwise}"),
184            Operation::Operator(operator) => write!(f, "{operator}"),
185            Operation::Atomic(atomic) => write!(f, "{atomic}"),
186            Operation::Metadata(metadata) => write!(f, "{metadata}"),
187            Operation::Branch(branch) => write!(f, "{branch}"),
188            Operation::Synchronization(synchronization) => write!(f, "{synchronization}"),
189            Operation::Plane(plane) => write!(f, "{plane}"),
190            Operation::CoopMma(coop_mma) => write!(f, "{coop_mma}"),
191            Operation::Copy(variable) => write!(f, "{variable}"),
192            Operation::NonSemantic(non_semantic) => write!(f, "{non_semantic}"),
193            Operation::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
194            Operation::Tma(tma_ops) => write!(f, "{tma_ops}"),
195            Operation::Free(var) => write!(f, "free({var})"),
196        }
197    }
198}
199
200pub fn fmt_vararg(args: &[impl Display]) -> String {
201    if args.is_empty() {
202        "".to_string()
203    } else {
204        let str = args
205            .iter()
206            .map(|it| it.to_string())
207            .collect::<Vec<_>>()
208            .join(", ");
209        format!(", {str}")
210    }
211}
212
213#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
214#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
215#[allow(missing_docs)]
216pub struct IndexOperator {
217    pub list: Variable,
218    pub index: Variable,
219    pub line_size: u32,     pub unroll_factor: u32, }
222
223#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
224#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
225#[allow(missing_docs)]
226pub struct IndexAssignOperator {
227    pub index: Variable,
229    pub value: Variable,
230    pub line_size: u32,     pub unroll_factor: u32, }
233
234#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
235#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
236#[allow(missing_docs)]
237pub struct BinaryOperator {
238    pub lhs: Variable,
239    pub rhs: Variable,
240}
241
242#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
243#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
244#[allow(missing_docs)]
245pub struct UnaryOperator {
246    pub input: Variable,
247}
248
249impl From<Branch> for Instruction {
250    fn from(value: Branch) -> Self {
251        Instruction::no_out(value)
252    }
253}
254
255impl From<Synchronization> for Instruction {
256    fn from(value: Synchronization) -> Self {
257        Instruction::no_out(value)
258    }
259}
260
261impl From<NonSemantic> for Instruction {
262    fn from(value: NonSemantic) -> Self {
263        Instruction::no_out(value)
264    }
265}