Skip to main content

ptx_parser/parser/
variable.rs

1use crate::{
2    alt, cclosure, err, mapc, ok,
3    parser::{
4        ParseErrorKind, PtxParseError, PtxParser, PtxTokenStream, Span,
5        util::{
6            between, comma_p, directive_exact_p, directive_p, equals_p, lbrace_p, lbracket_p, many,
7            map, optional, rbrace_p, rbracket_p, semicolon_p, sep_by, seq, skip_first,
8            string_literal_p, try_map, u32_p, u64_p,
9        },
10    },
11    seq_n,
12    r#type::{
13        AttributeDirective, DataType, FunctionSymbol, GlobalInitializer, Immediate,
14        InitializerValue, ModuleVariableDirective, ParamStateSpace, ParameterDirective,
15        VariableDirective, VariableModifier, VariableSymbol,
16    },
17};
18
19impl PtxParser for ModuleVariableDirective {
20    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
21        let tex = mapc!(
22            skip_first(directive_exact_p("tex"), VariableDirective::parse()),
23            ModuleVariableDirective::Tex { directive }
24        );
25        let shared = mapc!(
26            skip_first(directive_exact_p("shared"), VariableDirective::parse()),
27            ModuleVariableDirective::Shared { directive }
28        );
29        let global = mapc!(
30            skip_first(directive_exact_p("global"), VariableDirective::parse()),
31            ModuleVariableDirective::Global { directive }
32        );
33        let konst = mapc!(
34            skip_first(directive_exact_p("const"), VariableDirective::parse()),
35            ModuleVariableDirective::Const { directive }
36        );
37        alt!(tex, shared, global, konst)
38    }
39}
40
41impl PtxParser for VariableDirective {
42    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
43        mapc!(
44            seq_n!(
45                many(AttributeDirective::parse()),
46                many(VariableModifier::parse()),
47                DataType::parse(),
48                VariableSymbol::parse(),
49                array_dimensions_parser(),
50                optional(initializer_assignment()),
51                semicolon_p()
52            ),
53            VariableDirective {
54                attributes,
55                modifiers,
56                ty,
57                name,
58                array_dims,
59                initializer,
60                _
61            }
62        )
63    }
64}
65
66impl PtxParser for VariableModifier {
67    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
68        let alignment = try_map(
69            skip_first(directive_exact_p("align"), u32_p()),
70            |value, span| ok!(VariableModifier::Alignment { value }),
71        );
72        let ptr = map(
73            directive_exact_p("ptr"),
74            cclosure!(VariableModifier::Ptr {}),
75        );
76        let vector = try_map(directive_p(), |name, span| {
77            if let Some(width) = name.strip_prefix('v') {
78                if width.is_empty() {
79                    return err!(ParseErrorKind::InvalidLiteral(
80                        "vector modifier requires width (e.g. .v4)".into(),
81                    ));
82                }
83                let value = width.parse::<u32>().map_err(|_| PtxParseError {
84                    kind: ParseErrorKind::InvalidLiteral(format!("invalid vector width: {width}")),
85                    span,
86                })?;
87                ok!(VariableModifier::Vector { value })
88            } else {
89                err!(ParseErrorKind::InvalidLiteral(format!(
90                    "unknown variable modifier: .{name}"
91                )))
92            }
93        });
94
95        alt!(alignment, ptr, vector)
96    }
97}
98
99impl PtxParser for ParameterDirective {
100    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
101        let register = mapc!(
102            seq_n!(
103                directive_exact_p("reg"),
104                DataType::parse(),
105                VariableSymbol::parse()
106            ),
107            ParameterDirective::Register { _, ty, name }
108        );
109        let param = mapc!(
110            skip_first(directive_exact_p("param"), parameter_spec_parser()),
111            ParameterDirective::Parameter {
112                align,
113                ptr,
114                space,
115                ty,
116                name,
117                array,
118            }
119        );
120
121        alt!(register, param)
122    }
123}
124
125impl PtxParser for InitializerValue {
126    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
127        alt!(
128            mapc!(
129                Immediate::parse(),
130                InitializerValue::NumericLiteral { value }
131            ),
132            mapc!(
133                FunctionSymbol::parse(),
134                InitializerValue::FunctionSymbol { name }
135            ),
136            mapc!(
137                string_literal_p(),
138                InitializerValue::StringLiteral { value }
139            ),
140        )
141    }
142}
143
144impl PtxParser for GlobalInitializer {
145    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
146        let aggregate = move |stream: &mut PtxTokenStream| {
147            let inner = GlobalInitializer::parse();
148            between(lbrace_p(), rbrace_p(), sep_by(inner, comma_p()))(stream)
149        };
150        alt!(
151            mapc!(aggregate, GlobalInitializer::Aggregate { values }),
152            mapc!(
153                InitializerValue::parse(),
154                GlobalInitializer::Scalar { value }
155            ),
156        )
157    }
158}
159
160fn initializer_assignment()
161-> impl Fn(&mut PtxTokenStream) -> Result<(GlobalInitializer, Span), PtxParseError> {
162    skip_first(equals_p(), GlobalInitializer::parse())
163}
164
165fn array_dimensions_parser()
166-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<Option<u64>>, Span), PtxParseError> {
167    many(array_dimension_parser())
168}
169
170fn array_dimension_parser()
171-> impl Fn(&mut PtxTokenStream) -> Result<(Option<u64>, Span), PtxParseError> {
172    try_map(
173        between(lbracket_p(), rbracket_p(), optional(u64_p())),
174        |maybe_value, _span| Ok(maybe_value),
175    )
176}
177
178fn parse_alignment_modifier() -> impl Fn(&mut PtxTokenStream) -> Result<(u32, Span), PtxParseError>
179{
180    try_map(
181        seq(directive_exact_p("align"), u32_p()),
182        |(_, value), _span| Ok(value),
183    )
184}
185
186fn param_modifier() -> impl Fn(
187    &mut PtxTokenStream,
188) -> Result<
189    ((Option<u32>, bool, Option<ParamStateSpace>), Span),
190    PtxParseError,
191> {
192    alt!(
193        map(parse_alignment_modifier(), |value, _| (
194            Some(value),
195            false,
196            None
197        )),
198        map(directive_exact_p("ptr"), |_, _| (None, true, None)),
199        map(ParamStateSpace::parse(), |space, _| (
200            None,
201            false,
202            Some(space)
203        ))
204    )
205}
206
207fn apply_param_mods(
208    mods: impl IntoIterator<Item = (Option<u32>, bool, Option<ParamStateSpace>)>,
209    align: &mut Option<u32>,
210    ptr: &mut bool,
211    space: &mut Option<ParamStateSpace>,
212) {
213    for (a, p, s) in mods {
214        if let Some(v) = a {
215            *align = Some(v);
216        }
217        if p {
218            *ptr = true;
219        }
220        if let Some(ss) = s {
221            *space = Some(ss);
222        }
223    }
224}
225
226fn parameter_spec_parser() -> impl Fn(
227    &mut PtxTokenStream,
228) -> Result<
229    (
230        (
231            Option<u32>,
232            bool,
233            Option<ParamStateSpace>,
234            DataType,
235            VariableSymbol,
236            Vec<Option<u64>>,
237        ),
238        Span,
239    ),
240    PtxParseError,
241> {
242    move |stream| {
243        stream.try_with_span(|stream| {
244            let mut align = None;
245            let mut ptr = false;
246            let mut space = None;
247
248            let (mods_before, _) = many(param_modifier())(stream)?;
249            apply_param_mods(mods_before, &mut align, &mut ptr, &mut space);
250
251            let (ty, _) = DataType::parse()(stream)?;
252
253            let (mods_after, _) = many(param_modifier())(stream)?;
254            apply_param_mods(mods_after, &mut align, &mut ptr, &mut space);
255
256            let (name, _) = VariableSymbol::parse()(stream)?;
257            let (array_dims, _) = map(array_dimensions_parser(), |dims, _| dims)(stream)?;
258
259            Ok((align, ptr, space, ty, name, array_dims))
260        })
261    }
262}