Skip to main content

ptx_parser/parser/
module.rs

1use crate::{
2    alt, err, func,
3    lexer::PtxToken,
4    mapc, ok,
5    parser::{
6        ParseErrorKind, PtxParseError, PtxParser, PtxTokenStream, Span,
7        util::{
8            comma_p, directive_exact_p, identifier_p, many, optional, parse_u32_literal, sep_by,
9            seq, skip_first, skip_semicolon, string_literal_p, try_map, u32_p, u64_p,
10        },
11    },
12    seq_n,
13    r#type::{
14        AliasFunctionDirective, CodeLinkage, DataLinkage, DwarfDirective, EntryFunctionDirective,
15        FuncFunctionDirective, SectionDirective, module::*, variable::ModuleVariableDirective,
16    },
17};
18
19impl PtxParser for Module {
20    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
21        mapc!(many(ModuleDirective::parse()), Module { directives })
22    }
23}
24
25impl PtxParser for ModuleDirective {
26    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
27        alt!(
28            parse_module_variable(),
29            parse_entry_function(),
30            parse_func_function(),
31            parse_alias_function(),
32            parse_module_info(),
33            parse_module_debug()
34        )
35    }
36}
37
38fn parse_module_variable()
39-> impl Fn(&mut PtxTokenStream) -> Result<(ModuleDirective, Span), PtxParseError> {
40    mapc!(
41        seq(
42            optional(DataLinkage::parse()),
43            ModuleVariableDirective::parse(),
44        ),
45        ModuleDirective::ModuleVariable { linkage, directive }
46    )
47}
48
49fn parse_entry_function()
50-> impl Fn(&mut PtxTokenStream) -> Result<(ModuleDirective, Span), PtxParseError> {
51    mapc!(
52        seq(
53            optional(CodeLinkage::parse()),
54            EntryFunctionDirective::parse(),
55        ),
56        ModuleDirective::EntryFunction { linkage, directive }
57    )
58}
59
60fn parse_func_function()
61-> impl Fn(&mut PtxTokenStream) -> Result<(ModuleDirective, Span), PtxParseError> {
62    mapc!(
63        seq(
64            optional(CodeLinkage::parse()),
65            FuncFunctionDirective::parse(),
66        ),
67        ModuleDirective::FuncFunction { linkage, directive }
68    )
69}
70
71fn parse_alias_function()
72-> impl Fn(&mut PtxTokenStream) -> Result<(ModuleDirective, Span), PtxParseError> {
73    mapc!(
74        AliasFunctionDirective::parse(),
75        ModuleDirective::AliasFunction { directive }
76    )
77}
78
79fn parse_module_info()
80-> impl Fn(&mut PtxTokenStream) -> Result<(ModuleDirective, Span), PtxParseError> {
81    mapc!(
82        ModuleInfoDirectiveKind::parse(),
83        ModuleDirective::ModuleInfo { directive }
84    )
85}
86
87fn parse_module_debug()
88-> impl Fn(&mut PtxTokenStream) -> Result<(ModuleDirective, Span), PtxParseError> {
89    mapc!(
90        ModuleDebugDirective::parse(),
91        ModuleDirective::Debug { directive }
92    )
93}
94
95impl PtxParser for ModuleInfoDirectiveKind {
96    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
97        alt!(
98            mapc!(
99                VersionDirective::parse(),
100                ModuleInfoDirectiveKind::Version { directive }
101            ),
102            mapc!(
103                TargetDirective::parse(),
104                ModuleInfoDirectiveKind::Target { directive }
105            ),
106            mapc!(
107                AddressSizeDirective::parse(),
108                ModuleInfoDirectiveKind::AddressSize { directive }
109            )
110        )
111    }
112}
113
114impl PtxParser for VersionDirective {
115    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
116        try_map(
117            skip_first(directive_exact_p("version"), version_number_p()),
118            func!(|(major, minor)| { ok!(VersionDirective { major, minor }) }),
119        )
120    }
121}
122
123impl PtxParser for TargetDirective {
124    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
125        mapc!(
126            skip_first(
127                directive_exact_p("target"),
128                sep_by(TargetString::parse(), comma_p()),
129            ),
130            TargetDirective { entries }
131        )
132    }
133}
134
135impl PtxParser for TargetString {
136    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
137        // Parse target specifiers like "sm_80", "texmode_unified", etc.
138        try_map(
139            identifier_p(),
140            func!(|name| match name.as_str() {
141                "sm_120a" => ok!(TargetString::Sm120a),
142                "sm_120f" => ok!(TargetString::Sm120f),
143                "sm_120" => ok!(TargetString::Sm120),
144                "sm_121a" => ok!(TargetString::Sm121a),
145                "sm_121f" => ok!(TargetString::Sm121f),
146                "sm_121" => ok!(TargetString::Sm121),
147                "sm_110a" => ok!(TargetString::Sm110a),
148                "sm_110f" => ok!(TargetString::Sm110f),
149                "sm_110" => ok!(TargetString::Sm110),
150                "sm_100a" => ok!(TargetString::Sm100a),
151                "sm_100f" => ok!(TargetString::Sm100f),
152                "sm_100" => ok!(TargetString::Sm100),
153                "sm_101a" => ok!(TargetString::Sm101a),
154                "sm_101f" => ok!(TargetString::Sm101f),
155                "sm_101" => ok!(TargetString::Sm101),
156                "sm_103a" => ok!(TargetString::Sm103a),
157                "sm_103f" => ok!(TargetString::Sm103f),
158                "sm_103" => ok!(TargetString::Sm103),
159                "sm_90a" => ok!(TargetString::Sm90a),
160                "sm_90" => ok!(TargetString::Sm90),
161                "sm_80" => ok!(TargetString::Sm80),
162                "sm_86" => ok!(TargetString::Sm86),
163                "sm_87" => ok!(TargetString::Sm87),
164                "sm_88" => ok!(TargetString::Sm88),
165                "sm_89" => ok!(TargetString::Sm89),
166                "sm_70" => ok!(TargetString::Sm70),
167                "sm_72" => ok!(TargetString::Sm72),
168                "sm_75" => ok!(TargetString::Sm75),
169                "sm_60" => ok!(TargetString::Sm60),
170                "sm_61" => ok!(TargetString::Sm61),
171                "sm_62" => ok!(TargetString::Sm62),
172                "sm_50" => ok!(TargetString::Sm50),
173                "sm_52" => ok!(TargetString::Sm52),
174                "sm_53" => ok!(TargetString::Sm53),
175                "sm_30" => ok!(TargetString::Sm30),
176                "sm_32" => ok!(TargetString::Sm32),
177                "sm_35" => ok!(TargetString::Sm35),
178                "sm_37" => ok!(TargetString::Sm37),
179                "sm_20" => ok!(TargetString::Sm20),
180                "sm_10" => ok!(TargetString::Sm10),
181                "sm_11" => ok!(TargetString::Sm11),
182                "sm_12" => ok!(TargetString::Sm12),
183                "sm_13" => ok!(TargetString::Sm13),
184                "texmode_unified" => ok!(TargetString::TexmodeUnified),
185                "texmode_independent" => ok!(TargetString::TexmodeIndependent),
186                "debug" => ok!(TargetString::Debug),
187                "map_f64_to_f32" => ok!(TargetString::MapF64ToF32),
188                _ => err!(ParseErrorKind::InvalidLiteral(format!(
189                    "unknown target specifier: {}",
190                    name
191                ))),
192            }),
193        )
194    }
195}
196
197impl PtxParser for AddressSizeDirective {
198    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
199        mapc!(
200            skip_first(directive_exact_p("address_size"), AddressSize::parse()),
201            AddressSizeDirective { size }
202        )
203    }
204}
205
206impl PtxParser for ModuleDebugDirective {
207    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
208        alt!(
209            mapc!(
210                FileDirective::parse(),
211                ModuleDebugDirective::File { directive }
212            ),
213            mapc!(
214                SectionDirective::parse(),
215                ModuleDebugDirective::Section { directive }
216            ),
217            mapc!(
218                skip_semicolon(DwarfDirective::parse()),
219                ModuleDebugDirective::Dwarf { directive }
220            )
221        )
222    }
223}
224
225impl PtxParser for FileDirective {
226    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
227        try_map(
228            skip_first(
229                directive_exact_p("file"),
230                seq_n!(
231                    u32_p(),
232                    string_literal_p(),
233                    optional(skip_first(
234                        comma_p(),
235                        seq(u64_p(), skip_first(comma_p(), u64_p())),
236                    )),
237                ),
238            ),
239            |(index, path, maybe_timestamps), span| {
240                let (timestamp, file_size) = if let Some((ts, size)) = maybe_timestamps {
241                    (Some(ts), Some(size))
242                } else {
243                    (None, None)
244                };
245                ok!(FileDirective {
246                    index,
247                    path,
248                    timestamp,
249                    file_size,
250                })
251            },
252        )
253    }
254}
255
256impl PtxParser for AddressSize {
257    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
258        try_map(
259            u32_p(),
260            func!(|value| match value {
261                32 => ok!(AddressSize::Size32),
262                64 => ok!(AddressSize::Size64),
263                other => err!(ParseErrorKind::InvalidLiteral(format!(
264                    "invalid address size: {} (expected 32 or 64)",
265                    other
266                ))),
267            }),
268        )
269    }
270}
271
272/// Parser for version numbers - handles both Float("8.5") and separate tokens (8 . 5)
273fn version_number_p() -> impl Fn(&mut PtxTokenStream) -> Result<((u32, u32), Span), PtxParseError> {
274    |stream| {
275        let start_pos = stream.position().0;
276
277        // Try to parse as float first
278        if let Ok((token, span)) = stream.peek() {
279            if let PtxToken::Float(f) = token {
280                let version_str = f.clone();
281                stream.consume()?;
282                let end_pos = stream.position().0;
283                let full_span = Span::new(start_pos, end_pos);
284                let parts: Vec<&str> = version_str.split('.').collect();
285                let span = span.clone();
286                if parts.len() != 2 {
287                    return err!(ParseErrorKind::InvalidLiteral(format!(
288                        "expected version in format X.Y, got {}",
289                        version_str
290                    )));
291                }
292                let major = parse_u32_literal(parts[0], span)?;
293                let minor = parse_u32_literal(parts[1], span)?;
294                return Ok(((major, minor), full_span));
295            }
296        }
297
298        // Otherwise parse as integer.integer
299        let (major, _) = u32_p()(stream)?;
300        stream.expect(&PtxToken::Dot)?;
301        let (minor, _) = u32_p()(stream)?;
302
303        let end_pos = stream.position().0;
304        let span = Span::new(start_pos, end_pos);
305        Ok(((major, minor), span))
306    }
307}