ptx_parser/parser/instruction/
wmma_store.rs

1//! Original PTX specification:
2//!
3//! wmma.store.d.sync.aligned.layout.shape{.ss}.type [p], r {, stride};
4//! .layout = {.row, .col};
5//! .shape  = {.m16n16k16, .m8n32k16, .m32n8k16};
6//! .ss     = {.global, .shared, .shared::cta};
7//! .type   = {.f16, .f32, .s32};
8//! ----------------------------------------------------------------
9//! wmma.store.d.sync.aligned.layout.shape{.ss}.type [p], r {, stride};
10//! .layout = {.row, .col};
11//! .shape  = {.m8n8k32, .m8n8k128};
12//! .ss     = {.global, .shared, .shared::cta};
13//! .type   = {.s32};
14//! ----------------------------------------------------------------
15//! wmma.store.d.sync.aligned.layout.shape{.ss}.type [p], r {, stride};
16//! .layout = {.row, .col};
17//! .shape  = {.m16n16k8};
18//! .ss     = {.global, .shared, .shared::cta};
19//! .type   = {.f32};
20//! ----------------------------------------------------------------
21//! wmma.store.d.sync.aligned.layout.shape{.ss}.type [p], r {, stride};
22//! .layout = {.row, .col};
23//! .shape  = {.m8n8k4 };
24//! .ss     = {.global, .shared, .shared::cta};
25//! .type   = {.f64};
26
27#![allow(unused)]
28
29use crate::parser::{
30    PtxParseError, PtxParser, PtxTokenStream, Span,
31    util::{
32        between, comma_p, directive_p, exclamation_p, lbracket_p, lparen_p, map, minus_p, optional,
33        pipe_p, rbracket_p, rparen_p, semicolon_p, sep_by, string_p, try_map,
34    },
35};
36use crate::r#type::common::*;
37use crate::{alt, ok, seq_n};
38
39pub mod section_0 {
40    use super::*;
41    use crate::r#type::instruction::wmma_store::section_0::*;
42
43    // ============================================================================
44    // Generated enum parsers
45    // ============================================================================
46
47    impl PtxParser for Layout {
48        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
49            alt!(
50                map(string_p(".row"), |_, _span| Layout::Row),
51                map(string_p(".col"), |_, _span| Layout::Col)
52            )
53        }
54    }
55
56    impl PtxParser for Shape {
57        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
58            alt!(
59                map(string_p(".m16n16k16"), |_, _span| Shape::M16n16k16),
60                map(string_p(".m8n32k16"), |_, _span| Shape::M8n32k16),
61                map(string_p(".m32n8k16"), |_, _span| Shape::M32n8k16)
62            )
63        }
64    }
65
66    impl PtxParser for Ss {
67        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
68            alt!(
69                map(string_p(".shared::cta"), |_, _span| Ss::SharedCta),
70                map(string_p(".global"), |_, _span| Ss::Global),
71                map(string_p(".shared"), |_, _span| Ss::Shared)
72            )
73        }
74    }
75
76    impl PtxParser for Type {
77        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
78            alt!(
79                map(string_p(".f16"), |_, _span| Type::F16),
80                map(string_p(".f32"), |_, _span| Type::F32),
81                map(string_p(".s32"), |_, _span| Type::S32)
82            )
83        }
84    }
85
86    impl PtxParser for WmmaStoreDSyncAlignedLayoutShapeSsType {
87        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
88            try_map(
89                seq_n!(
90                    string_p("wmma"),
91                    string_p(".store"),
92                    string_p(".d"),
93                    string_p(".sync"),
94                    string_p(".aligned"),
95                    Layout::parse(),
96                    Shape::parse(),
97                    optional(Ss::parse()),
98                    Type::parse(),
99                    AddressOperand::parse(),
100                    comma_p(),
101                    GeneralOperand::parse(),
102                    map(
103                        optional(seq_n!(comma_p(), GeneralOperand::parse())),
104                        |value, _| value.map(|(_, operand)| operand)
105                    ),
106                    semicolon_p()
107                ),
108                |(_, store, d, sync, aligned, layout, shape, ss, type_, p, _, r, stride, _),
109                 span| {
110                    ok!(WmmaStoreDSyncAlignedLayoutShapeSsType {
111                        store = store,
112                        d = d,
113                        sync = sync,
114                        aligned = aligned,
115                        layout = layout,
116                        shape = shape,
117                        ss = ss,
118                        type_ = type_,
119                        p = p,
120                        r = r,
121                        stride = stride,
122
123                    })
124                },
125            )
126        }
127    }
128}
129
130pub mod section_1 {
131    use super::*;
132    use crate::r#type::instruction::wmma_store::section_1::*;
133
134    // ============================================================================
135    // Generated enum parsers
136    // ============================================================================
137
138    impl PtxParser for Layout {
139        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
140            alt!(
141                map(string_p(".row"), |_, _span| Layout::Row),
142                map(string_p(".col"), |_, _span| Layout::Col)
143            )
144        }
145    }
146
147    impl PtxParser for Shape {
148        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
149            alt!(
150                map(string_p(".m8n8k128"), |_, _span| Shape::M8n8k128),
151                map(string_p(".m8n8k32"), |_, _span| Shape::M8n8k32)
152            )
153        }
154    }
155
156    impl PtxParser for Ss {
157        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
158            alt!(
159                map(string_p(".shared::cta"), |_, _span| Ss::SharedCta),
160                map(string_p(".global"), |_, _span| Ss::Global),
161                map(string_p(".shared"), |_, _span| Ss::Shared)
162            )
163        }
164    }
165
166    impl PtxParser for Type {
167        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
168            alt!(map(string_p(".s32"), |_, _span| Type::S32))
169        }
170    }
171
172    impl PtxParser for WmmaStoreDSyncAlignedLayoutShapeSsType1 {
173        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
174            try_map(
175                seq_n!(
176                    string_p("wmma"),
177                    string_p(".store"),
178                    string_p(".d"),
179                    string_p(".sync"),
180                    string_p(".aligned"),
181                    Layout::parse(),
182                    Shape::parse(),
183                    optional(Ss::parse()),
184                    Type::parse(),
185                    AddressOperand::parse(),
186                    comma_p(),
187                    GeneralOperand::parse(),
188                    map(
189                        optional(seq_n!(comma_p(), GeneralOperand::parse())),
190                        |value, _| value.map(|(_, operand)| operand)
191                    ),
192                    semicolon_p()
193                ),
194                |(_, store, d, sync, aligned, layout, shape, ss, type_, p, _, r, stride, _),
195                 span| {
196                    ok!(WmmaStoreDSyncAlignedLayoutShapeSsType1 {
197                        store = store,
198                        d = d,
199                        sync = sync,
200                        aligned = aligned,
201                        layout = layout,
202                        shape = shape,
203                        ss = ss,
204                        type_ = type_,
205                        p = p,
206                        r = r,
207                        stride = stride,
208
209                    })
210                },
211            )
212        }
213    }
214}
215
216pub mod section_2 {
217    use super::*;
218    use crate::r#type::instruction::wmma_store::section_2::*;
219
220    // ============================================================================
221    // Generated enum parsers
222    // ============================================================================
223
224    impl PtxParser for Layout {
225        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
226            alt!(
227                map(string_p(".row"), |_, _span| Layout::Row),
228                map(string_p(".col"), |_, _span| Layout::Col)
229            )
230        }
231    }
232
233    impl PtxParser for Shape {
234        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
235            alt!(map(string_p(".m16n16k8"), |_, _span| Shape::M16n16k8))
236        }
237    }
238
239    impl PtxParser for Ss {
240        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
241            alt!(
242                map(string_p(".shared::cta"), |_, _span| Ss::SharedCta),
243                map(string_p(".global"), |_, _span| Ss::Global),
244                map(string_p(".shared"), |_, _span| Ss::Shared)
245            )
246        }
247    }
248
249    impl PtxParser for Type {
250        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
251            alt!(map(string_p(".f32"), |_, _span| Type::F32))
252        }
253    }
254
255    impl PtxParser for WmmaStoreDSyncAlignedLayoutShapeSsType2 {
256        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
257            try_map(
258                seq_n!(
259                    string_p("wmma"),
260                    string_p(".store"),
261                    string_p(".d"),
262                    string_p(".sync"),
263                    string_p(".aligned"),
264                    Layout::parse(),
265                    Shape::parse(),
266                    optional(Ss::parse()),
267                    Type::parse(),
268                    AddressOperand::parse(),
269                    comma_p(),
270                    GeneralOperand::parse(),
271                    map(
272                        optional(seq_n!(comma_p(), GeneralOperand::parse())),
273                        |value, _| value.map(|(_, operand)| operand)
274                    ),
275                    semicolon_p()
276                ),
277                |(_, store, d, sync, aligned, layout, shape, ss, type_, p, _, r, stride, _),
278                 span| {
279                    ok!(WmmaStoreDSyncAlignedLayoutShapeSsType2 {
280                        store = store,
281                        d = d,
282                        sync = sync,
283                        aligned = aligned,
284                        layout = layout,
285                        shape = shape,
286                        ss = ss,
287                        type_ = type_,
288                        p = p,
289                        r = r,
290                        stride = stride,
291
292                    })
293                },
294            )
295        }
296    }
297}
298
299pub mod section_3 {
300    use super::*;
301    use crate::r#type::instruction::wmma_store::section_3::*;
302
303    // ============================================================================
304    // Generated enum parsers
305    // ============================================================================
306
307    impl PtxParser for Layout {
308        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
309            alt!(
310                map(string_p(".row"), |_, _span| Layout::Row),
311                map(string_p(".col"), |_, _span| Layout::Col)
312            )
313        }
314    }
315
316    impl PtxParser for Shape {
317        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
318            alt!(map(string_p(".m8n8k4"), |_, _span| Shape::M8n8k4))
319        }
320    }
321
322    impl PtxParser for Ss {
323        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
324            alt!(
325                map(string_p(".shared::cta"), |_, _span| Ss::SharedCta),
326                map(string_p(".global"), |_, _span| Ss::Global),
327                map(string_p(".shared"), |_, _span| Ss::Shared)
328            )
329        }
330    }
331
332    impl PtxParser for Type {
333        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
334            alt!(map(string_p(".f64"), |_, _span| Type::F64))
335        }
336    }
337
338    impl PtxParser for WmmaStoreDSyncAlignedLayoutShapeSsType3 {
339        fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
340            try_map(
341                seq_n!(
342                    string_p("wmma"),
343                    string_p(".store"),
344                    string_p(".d"),
345                    string_p(".sync"),
346                    string_p(".aligned"),
347                    Layout::parse(),
348                    Shape::parse(),
349                    optional(Ss::parse()),
350                    Type::parse(),
351                    AddressOperand::parse(),
352                    comma_p(),
353                    GeneralOperand::parse(),
354                    map(
355                        optional(seq_n!(comma_p(), GeneralOperand::parse())),
356                        |value, _| value.map(|(_, operand)| operand)
357                    ),
358                    semicolon_p()
359                ),
360                |(_, store, d, sync, aligned, layout, shape, ss, type_, p, _, r, stride, _),
361                 span| {
362                    ok!(WmmaStoreDSyncAlignedLayoutShapeSsType3 {
363                        store = store,
364                        d = d,
365                        sync = sync,
366                        aligned = aligned,
367                        layout = layout,
368                        shape = shape,
369                        ss = ss,
370                        type_ = type_,
371                        p = p,
372                        r = r,
373                        stride = stride,
374
375                    })
376                },
377            )
378        }
379    }
380}