ptx_parser/parser/
variable.rs

1use crate::{
2    lexer::PtxToken,
3    parser::{
4        PtxParseError, PtxParser, PtxTokenStream, Span, common::parse_u64_literal, invalid_literal,
5        peek_directive, unexpected_value,
6    },
7    r#type::{
8        common::{AddressSpace, AttributeDirective, DataLinkage, DataType},
9        variable::{
10            GlobalInitializer, InitializerValue, ModuleVariableDirective, NumericLiteral,
11            VariableDirective, VariableModifier,
12        },
13    },
14};
15
16const DATA_TYPE_NAMES: &[&str] = &[
17    "u8", "u16", "u32", "u64", "s8", "s16", "s32", "s64", "f16", "f16x2", "f32", "f64", "b8",
18    "b16", "b32", "b64", "b128", "pred",
19];
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
22enum VariableDirectiveKind {
23    Tex,
24    Shared,
25    Global,
26    Const,
27    Other,
28}
29
30struct ParsedVariableDirective {
31    directive: VariableDirective,
32    kind: VariableDirectiveKind,
33    leading_span: Option<Span>,
34}
35
36fn is_data_type_directive(name: &str) -> bool {
37    DATA_TYPE_NAMES.iter().any(|candidate| candidate == &name)
38}
39
40fn is_vector_modifier(name: &str) -> bool {
41    let mut chars = name.chars();
42    match (chars.next(), chars.next()) {
43        (Some('v'), Some(digit)) if digit.is_ascii_digit() => chars.all(|ch| ch.is_ascii_digit()),
44        _ => false,
45    }
46}
47
48fn parse_alignment_value(stream: &mut PtxTokenStream) -> Result<u32, PtxParseError> {
49    let (value, value_span) = parse_u64_literal(stream)?;
50    if value > u32::MAX as u64 {
51        return Err(invalid_literal(
52            value_span,
53            "alignment value exceeds u32 range",
54        ));
55    }
56    Ok(value as u32)
57}
58
59fn parse_numeric_string(text: &str, span: Span) -> Result<u128, PtxParseError> {
60    text.parse::<u128>()
61        .map_err(|_| invalid_literal(span, "invalid integer literal"))
62}
63
64impl PtxParser for NumericLiteral {
65    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
66        let negative = stream
67            .consume_if(|token| matches!(token, PtxToken::Minus))
68            .is_some();
69        let positive = stream
70            .consume_if(|token| matches!(token, PtxToken::Plus))
71            .is_some();
72
73        if negative && positive {
74            let (_, span) = stream.peek()?;
75            return Err(invalid_literal(
76                span.clone(),
77                "cannot have both '+' and '-' signs",
78            ));
79        }
80
81        let (token, span_ref) = stream.consume()?;
82        let span = span_ref.clone();
83        match token {
84            PtxToken::DecimalInteger(text) => {
85                let value = parse_numeric_string(text.as_str(), span.clone())?;
86                if negative {
87                    if value > (i64::MAX as u128) + 1 {
88                        return Err(invalid_literal(span.clone(), "signed integer underflow"));
89                    }
90                    let signed = -(value as i128);
91                    Ok(NumericLiteral::Signed(signed as i64))
92                } else {
93                    if value > u64::MAX as u128 {
94                        return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
95                    }
96                    Ok(NumericLiteral::Unsigned(value as u64))
97                }
98            }
99            PtxToken::HexInteger(text) => {
100                let stripped = text
101                    .strip_prefix("0x")
102                    .or_else(|| text.strip_prefix("0X"))
103                    .unwrap_or(text.as_str());
104                let value = u128::from_str_radix(stripped, 16)
105                    .map_err(|_| invalid_literal(span.clone(), "invalid hex literal"))?;
106                if negative {
107                    if value > (i64::MAX as u128) + 1 {
108                        return Err(invalid_literal(span.clone(), "signed integer underflow"));
109                    }
110                    let signed = -(value as i128);
111                    Ok(NumericLiteral::Signed(signed as i64))
112                } else {
113                    if value > u64::MAX as u128 {
114                        return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
115                    }
116                    Ok(NumericLiteral::Unsigned(value as u64))
117                }
118            }
119            PtxToken::BinaryInteger(text) => {
120                let stripped = text
121                    .strip_prefix("0b")
122                    .or_else(|| text.strip_prefix("0B"))
123                    .unwrap_or(text.as_str());
124                let value = u128::from_str_radix(stripped, 2)
125                    .map_err(|_| invalid_literal(span.clone(), "invalid binary literal"))?;
126                if negative {
127                    if value > (i64::MAX as u128) + 1 {
128                        return Err(invalid_literal(span.clone(), "signed integer underflow"));
129                    }
130                    let signed = -(value as i128);
131                    Ok(NumericLiteral::Signed(signed as i64))
132                } else {
133                    if value > u64::MAX as u128 {
134                        return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
135                    }
136                    Ok(NumericLiteral::Unsigned(value as u64))
137                }
138            }
139            PtxToken::OctalInteger(text) => {
140                let stripped = &text.as_str()[1..];
141                let value = u128::from_str_radix(stripped, 8)
142                    .map_err(|_| invalid_literal(span.clone(), "invalid octal literal"))?;
143                if negative {
144                    if value > (i64::MAX as u128) + 1 {
145                        return Err(invalid_literal(span.clone(), "signed integer underflow"));
146                    }
147                    let signed = -(value as i128);
148                    Ok(NumericLiteral::Signed(signed as i64))
149                } else {
150                    if value > u64::MAX as u128 {
151                        return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
152                    }
153                    Ok(NumericLiteral::Unsigned(value as u64))
154                }
155            }
156            PtxToken::Float(text) | PtxToken::FloatExponent(text) => {
157                let mut value = text
158                    .parse::<f64>()
159                    .map_err(|_| invalid_literal(span.clone(), "invalid floating-point literal"))?;
160                if negative {
161                    value = -value;
162                }
163                Ok(NumericLiteral::Float64(value.to_bits()))
164            }
165            PtxToken::HexFloat(text) => {
166                if text.len() < 3 {
167                    return Err(invalid_literal(
168                        span.clone(),
169                        "invalid hexadecimal float literal",
170                    ));
171                }
172                let (prefix, digits) = text.split_at(2);
173                match prefix.to_ascii_lowercase().as_str() {
174                    "0f" => {
175                        let mut bits = u32::from_str_radix(digits, 16)
176                            .map_err(|_| invalid_literal(span.clone(), "invalid float literal"))?;
177                        if negative {
178                            bits ^= 0x8000_0000;
179                        }
180                        Ok(NumericLiteral::Float32(bits))
181                    }
182                    "0d" => {
183                        let mut bits = u64::from_str_radix(digits, 16)
184                            .map_err(|_| invalid_literal(span.clone(), "invalid float literal"))?;
185                        if negative {
186                            bits ^= 0x8000_0000_0000_0000;
187                        }
188                        Ok(NumericLiteral::Float64(bits))
189                    }
190                    _ => Err(invalid_literal(
191                        span.clone(),
192                        "hexadecimal float must start with 0f or 0d",
193                    )),
194                }
195            }
196            _ => Err(unexpected_value(
197                span.clone(),
198                &["numeric literal"],
199                format!("{token:?}"),
200            )),
201        }
202    }
203}
204
205impl PtxParser for InitializerValue {
206    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
207        if let Some((token, span)) = stream.peek().ok() {
208            match token {
209                PtxToken::StringLiteral(value) => {
210                    let value = value.clone();
211                    stream.consume()?;
212                    return Ok(InitializerValue::StringLiteral(value));
213                }
214                PtxToken::Identifier(_) => {
215                    let (symbol, _) = stream.expect_identifier()?;
216                    return Ok(InitializerValue::Symbol(symbol));
217                }
218                PtxToken::Plus | PtxToken::Minus => {
219                    let literal = NumericLiteral::parse(stream)?;
220                    return Ok(InitializerValue::Numeric(literal));
221                }
222                PtxToken::DecimalInteger(_)
223                | PtxToken::HexInteger(_)
224                | PtxToken::BinaryInteger(_)
225                | PtxToken::OctalInteger(_)
226                | PtxToken::Float(_)
227                | PtxToken::FloatExponent(_)
228                | PtxToken::HexFloat(_) => {
229                    let literal = NumericLiteral::parse(stream)?;
230                    return Ok(InitializerValue::Numeric(literal));
231                }
232                _ => {
233                    return Err(unexpected_value(
234                        span.clone(),
235                        &["numeric literal", "symbol", "string literal"],
236                        format!("{token:?}"),
237                    ));
238                }
239            }
240        }
241        let span = stream.peek()?.1.clone();
242        Err(unexpected_value(
243            span,
244            &["numeric literal", "symbol", "string literal"],
245            "end of input".to_string(),
246        ))
247    }
248}
249
250impl PtxParser for GlobalInitializer {
251    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
252        if stream
253            .consume_if(|token| matches!(token, PtxToken::LBrace))
254            .is_some()
255        {
256            let mut children = Vec::new();
257            if !stream.check(|token| matches!(token, PtxToken::RBrace)) {
258                loop {
259                    let initializer = GlobalInitializer::parse(stream)?;
260                    children.push(initializer);
261                    if !(stream
262                        .consume_if(|token| matches!(token, PtxToken::Comma))
263                        .is_some())
264                    {
265                        break;
266                    }
267                }
268            }
269            stream.expect(&PtxToken::RBrace)?;
270            Ok(GlobalInitializer::Aggregate(children))
271        } else {
272            let value = InitializerValue::parse(stream)?;
273            Ok(GlobalInitializer::Scalar(value))
274        }
275    }
276}
277
278impl PtxParser for VariableModifier {
279    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
280        let (directive, span_ref) = stream.expect_directive()?;
281        let span = span_ref.clone();
282        match directive.as_str() {
283            "align" => {
284                let value = parse_alignment_value(stream)?;
285                Ok(VariableModifier::Alignment(value))
286            }
287            "ptr" => Ok(VariableModifier::Ptr),
288            "visible" => Ok(VariableModifier::Linkage(DataLinkage::Visible)),
289            "extern" => Ok(VariableModifier::Linkage(DataLinkage::Extern)),
290            "weak" => Ok(VariableModifier::Linkage(DataLinkage::Weak)),
291            "common" => Ok(VariableModifier::Linkage(DataLinkage::Common)),
292            other if is_vector_modifier(other) => {
293                let digits = &other[1..];
294                let value = digits
295                    .parse::<u32>()
296                    .map_err(|_| invalid_literal(span.clone(), "invalid vector width"))?;
297                Ok(VariableModifier::Vector(value))
298            }
299            other => Err(unexpected_value(
300                span.clone(),
301                &[
302                    ".align", ".ptr", ".visible", ".extern", ".weak", ".common", ".vN",
303                ],
304                format!(".{other}"),
305            )),
306        }
307    }
308}
309
310fn parse_variable_directive_internal(
311    stream: &mut PtxTokenStream,
312) -> Result<ParsedVariableDirective, PtxParseError> {
313    let first_span = stream.peek().ok().map(|(_, span)| span.clone());
314
315    let mut address_space: Option<AddressSpace> = None;
316    let mut attributes = Vec::new();
317    let mut modifiers = Vec::new();
318    let mut ty: Option<DataType> = None;
319    let mut array = Vec::new();
320    let mut initializer = None;
321    let mut seen_tex = false;
322    let mut kind = VariableDirectiveKind::Other;
323    let mut kind_span = None;
324
325    loop {
326        let Some((directive, directive_span)) = peek_directive(stream)? else {
327            break;
328        };
329        match directive.as_str() {
330            "tex" => {
331                stream.expect_directive()?;
332                if !seen_tex {
333                    seen_tex = true;
334                    kind = VariableDirectiveKind::Tex;
335                    kind_span = Some(directive_span);
336                }
337            }
338            "global" | "const" | "shared" | "local" | "param" | "reg" => {
339                if address_space.is_some() {
340                    return Err(unexpected_value(
341                        directive_span.clone(),
342                        &["single address space qualifier"],
343                        format!(".{directive}"),
344                    ));
345                }
346                let space = AddressSpace::parse(stream)?;
347                address_space = Some(space);
348                match space {
349                    AddressSpace::Global => {
350                        kind = VariableDirectiveKind::Global;
351                        kind_span = Some(directive_span);
352                    }
353                    AddressSpace::Const => {
354                        kind = VariableDirectiveKind::Const;
355                        kind_span = Some(directive_span);
356                    }
357                    AddressSpace::Shared => {
358                        kind = VariableDirectiveKind::Shared;
359                        kind_span = Some(directive_span);
360                    }
361                    _ => {}
362                }
363            }
364            "managed" | "unified" => {
365                attributes.push(AttributeDirective::parse(stream)?);
366            }
367            "align" | "ptr" | "visible" | "extern" | "weak" | "common" => {
368                modifiers.push(VariableModifier::parse(stream)?);
369            }
370            other if is_vector_modifier(other) => {
371                modifiers.push(VariableModifier::parse(stream)?);
372            }
373            other if is_data_type_directive(other) => {
374                if ty.is_some() {
375                    return Err(unexpected_value(
376                        directive_span.clone(),
377                        &["single data type qualifier"],
378                        format!(".{other}"),
379                    ));
380                }
381                ty = Some(DataType::parse(stream)?);
382            }
383            _ => break,
384        }
385    }
386
387    let (name, _) = stream.expect_identifier()?;
388
389    loop {
390        if stream
391            .consume_if(|token| matches!(token, PtxToken::LBracket))
392            .is_none()
393        {
394            break;
395        }
396
397        if stream
398            .consume_if(|token| matches!(token, PtxToken::RBracket))
399            .is_some()
400        {
401            array.push(None);
402            continue;
403        }
404
405        let size_span = stream.peek()?.1.clone();
406        let literal = NumericLiteral::parse(stream)?;
407        let size = match literal {
408            NumericLiteral::Unsigned(value) => value,
409            NumericLiteral::Signed(value) if value >= 0 => value as u64,
410            _ => {
411                return Err(invalid_literal(
412                    size_span.clone(),
413                    "array size must be a non-negative integer",
414                ));
415            }
416        };
417
418        stream.expect(&PtxToken::RBracket)?;
419        array.push(Some(size));
420    }
421
422    if stream
423        .consume_if(|token| matches!(token, PtxToken::Equals))
424        .is_some()
425    {
426        initializer = Some(GlobalInitializer::parse(stream)?);
427    }
428
429    stream.expect(&PtxToken::Semicolon)?;
430
431    let mut final_kind = kind;
432    if seen_tex {
433        final_kind = VariableDirectiveKind::Tex;
434    } else if matches!(final_kind, VariableDirectiveKind::Other) {
435        final_kind = match address_space {
436            Some(AddressSpace::Shared) => VariableDirectiveKind::Shared,
437            Some(AddressSpace::Global) => VariableDirectiveKind::Global,
438            Some(AddressSpace::Const) => VariableDirectiveKind::Const,
439            _ => VariableDirectiveKind::Other,
440        };
441    }
442
443    let directive = VariableDirective {
444        address_space,
445        attributes,
446        ty,
447        modifiers,
448        name,
449        array,
450        initializer,
451        raw: String::new(),
452    };
453
454    Ok(ParsedVariableDirective {
455        directive,
456        kind: final_kind,
457        leading_span: kind_span.or(first_span),
458    })
459}
460
461impl VariableDirective {
462    fn parse_with_kind(
463        stream: &mut PtxTokenStream,
464    ) -> Result<(VariableDirective, VariableDirectiveKind, Option<Span>), PtxParseError> {
465        let parsed = parse_variable_directive_internal(stream)?;
466        Ok((parsed.directive, parsed.kind, parsed.leading_span))
467    }
468}
469
470impl PtxParser for VariableDirective {
471    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
472        let parsed = parse_variable_directive_internal(stream)?;
473        Ok(parsed.directive)
474    }
475}
476
477impl PtxParser for ModuleVariableDirective {
478    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
479        let (directive, kind, span) = VariableDirective::parse_with_kind(stream)?;
480        match kind {
481            VariableDirectiveKind::Tex => Ok(ModuleVariableDirective::Tex(directive)),
482            VariableDirectiveKind::Shared => Ok(ModuleVariableDirective::Shared(directive)),
483            VariableDirectiveKind::Global => Ok(ModuleVariableDirective::Global(directive)),
484            VariableDirectiveKind::Const => Ok(ModuleVariableDirective::Const(directive)),
485            VariableDirectiveKind::Other => Err(unexpected_value(
486                span.unwrap_or(0..0),
487                &[".tex", ".shared", ".global", ".const"],
488                "variable directive".to_string(),
489            )),
490        }
491    }
492}