ptx_parser/parser/
variable.rs

1use crate::{
2    alt, cclosure, err, mapc, ok,
3    parser::{
4        ParseErrorKind, PtxParseError, PtxParser, PtxTokenStream, Span,
5        util::{
6            between, comma_p, directive_exact_p, directive_p, equals_p, integer_p, lbrace_p,
7            lbracket_p, many, map, optional, parse_u32_literal, parse_u64_literal, rbrace_p,
8            rbracket_p, semicolon_p, sep_by, seq, skip_first, string_literal_p,
9            try_map,
10        },
11    },
12    seq_n,
13    r#type::{
14        AttributeDirective, DataType, FunctionSymbol, GlobalInitializer, Immediate,
15        InitializerValue, ModuleVariableDirective, ParamStateSpace, ParameterDirective,
16        VariableDirective, VariableModifier, VariableSymbol,
17    },
18};
19
20impl PtxParser for ModuleVariableDirective {
21    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
22        let tex = mapc!(
23            skip_first(directive_exact_p("tex"), VariableDirective::parse()),
24            ModuleVariableDirective::Tex { directive }
25        );
26        let shared = mapc!(
27            skip_first(directive_exact_p("shared"), VariableDirective::parse()),
28            ModuleVariableDirective::Shared { directive }
29        );
30        let global = mapc!(
31            skip_first(directive_exact_p("global"), VariableDirective::parse()),
32            ModuleVariableDirective::Global { directive }
33        );
34        let konst = mapc!(
35            skip_first(directive_exact_p("const"), VariableDirective::parse()),
36            ModuleVariableDirective::Const { directive }
37        );
38        alt!(tex, shared, global, konst)
39    }
40}
41
42impl PtxParser for VariableDirective {
43    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
44        mapc!(
45            seq_n!(
46                many(AttributeDirective::parse()),
47                many(VariableModifier::parse()),
48                DataType::parse(),
49                VariableSymbol::parse(),
50                array_dimensions_parser(),
51                optional(initializer_assignment()),
52                semicolon_p()
53            ),
54            VariableDirective {
55                attributes,
56                modifiers,
57                ty,
58                name,
59                array_dims,
60                initializer,
61                _
62            }
63        )
64    }
65}
66
67impl PtxParser for VariableModifier {
68    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
69        let alignment = try_map(
70            skip_first(directive_exact_p("align"), integer_p()),
71            |value, span| {
72                let value = parse_u32_literal(&value, span)?;
73                ok!(VariableModifier::Alignment { value })
74            },
75        );
76        let ptr = map(
77            directive_exact_p("ptr"),
78            cclosure!(VariableModifier::Ptr {}),
79        );
80        let vector = try_map(directive_p(), |name, span| {
81            if let Some(width) = name.strip_prefix('v') {
82                if width.is_empty() {
83                    return err!(ParseErrorKind::InvalidLiteral(
84                        "vector modifier requires width (e.g. .v4)".into(),
85                    ));
86                }
87                let value = width.parse::<u32>().map_err(|_| PtxParseError {
88                    kind: ParseErrorKind::InvalidLiteral(format!("invalid vector width: {width}")),
89                    span,
90                })?;
91                ok!(VariableModifier::Vector { value })
92            } else {
93                err!(ParseErrorKind::InvalidLiteral(format!(
94                    "unknown variable modifier: .{name}"
95                )))
96            }
97        });
98
99        alt!(alignment, ptr, vector)
100    }
101}
102
103impl PtxParser for ParameterDirective {
104    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
105        let register = mapc!(
106            seq_n!(
107                directive_exact_p("reg"),
108                DataType::parse(),
109                VariableSymbol::parse()
110            ),
111            ParameterDirective::Register { _, ty, name }
112        );
113        let param = mapc!(
114            skip_first(directive_exact_p("param"), parameter_spec_parser()),
115            ParameterDirective::Parameter {
116                align,
117                ptr,
118                space,
119                ty,
120                name,
121                array,
122            }
123        );
124
125        alt!(register, param)
126    }
127}
128
129impl PtxParser for InitializerValue {
130    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
131        alt!(
132            mapc!(
133                Immediate::parse(),
134                InitializerValue::NumericLiteral { value }
135            ),
136            mapc!(
137                FunctionSymbol::parse(),
138                InitializerValue::FunctionSymbol { name }
139            ),
140            mapc!(
141                string_literal_p(),
142                InitializerValue::StringLiteral { value }
143            ),
144        )
145    }
146}
147
148impl PtxParser for GlobalInitializer {
149    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
150        let aggregate = move |stream: &mut PtxTokenStream| {
151            let inner = GlobalInitializer::parse();
152            between(lbrace_p(), rbrace_p(), sep_by(inner, comma_p()))(stream)
153        };
154        alt!(
155            mapc!(aggregate, GlobalInitializer::Aggregate { values }),
156            mapc!(
157                InitializerValue::parse(),
158                GlobalInitializer::Scalar { value }
159            ),
160        )
161    }
162}
163
164fn initializer_assignment()
165-> impl Fn(&mut PtxTokenStream) -> Result<(GlobalInitializer, Span), PtxParseError> {
166    skip_first(equals_p(), GlobalInitializer::parse())
167}
168
169fn array_dimensions_parser()
170-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<Option<u64>>, Span), PtxParseError> {
171    many(array_dimension_parser())
172}
173
174fn array_dimension_parser()
175-> impl Fn(&mut PtxTokenStream) -> Result<(Option<u64>, Span), PtxParseError> {
176    try_map(
177        between(lbracket_p(), rbracket_p(), optional(integer_p())),
178        |maybe_value, span| {
179            if let Some(value) = maybe_value {
180                let parsed = parse_u64_literal(&value, span)?;
181                Ok(Some(parsed))
182            } else {
183                Ok(None)
184            }
185        },
186    )
187}
188
189fn parse_alignment_modifier() -> impl Fn(&mut PtxTokenStream) -> Result<(u32, Span), PtxParseError>
190{
191    try_map(
192        seq(directive_exact_p("align"), integer_p()),
193        |(_, value), span| {
194            let amount = parse_u32_literal(&value, span)?;
195            Ok(amount)
196        },
197    )
198}
199
200fn param_modifier() -> impl Fn(
201    &mut PtxTokenStream,
202) -> Result<
203    ((Option<u32>, bool, Option<ParamStateSpace>), Span),
204    PtxParseError,
205> {
206    alt!(
207        map(parse_alignment_modifier(), |value, _| (
208            Some(value),
209            false,
210            None
211        )),
212        map(directive_exact_p("ptr"), |_, _| (None, true, None)),
213        map(ParamStateSpace::parse(), |space, _| (
214            None,
215            false,
216            Some(space)
217        ))
218    )
219}
220
221fn apply_param_mods(
222    mods: impl IntoIterator<Item = (Option<u32>, bool, Option<ParamStateSpace>)>,
223    align: &mut Option<u32>,
224    ptr: &mut bool,
225    space: &mut Option<ParamStateSpace>,
226) {
227    for (a, p, s) in mods {
228        if let Some(v) = a {
229            *align = Some(v);
230        }
231        if p {
232            *ptr = true;
233        }
234        if let Some(ss) = s {
235            *space = Some(ss);
236        }
237    }
238}
239
240fn parameter_spec_parser() -> impl Fn(
241    &mut PtxTokenStream,
242) -> Result<
243    (
244        (
245            Option<u32>,
246            bool,
247            Option<ParamStateSpace>,
248            DataType,
249            VariableSymbol,
250            Vec<Option<u64>>,
251        ),
252        Span,
253    ),
254    PtxParseError,
255> {
256    move |stream| {
257        stream.try_with_span(|stream| {
258            let mut align = None;
259            let mut ptr = false;
260            let mut space = None;
261
262            let (mods_before, _) = many(param_modifier())(stream)?;
263            apply_param_mods(mods_before, &mut align, &mut ptr, &mut space);
264
265            let (ty, _) = DataType::parse()(stream)?;
266
267            let (mods_after, _) = many(param_modifier())(stream)?;
268            apply_param_mods(mods_after, &mut align, &mut ptr, &mut space);
269
270            let (name, _) = VariableSymbol::parse()(stream)?;
271            let (array_dims, _) = map(array_dimensions_parser(), |dims, _| dims)(stream)?;
272
273            Ok((align, ptr, space, ty, name, array_dims))
274        })
275    }
276}