ptx_parser/parser/
variable.rs

1use crate::{
2    lexer::PtxToken,
3    parser::{
4        PtxParseError, PtxParser, PtxTokenStream, Span, common::parse_u64_literal, invalid_literal,
5        peek_directive, unexpected_value,
6    },
7    r#type::{
8        common::{AddressSpace, AttributeDirective, DataLinkage, DataType},
9        variable::{
10            GlobalInitializer, InitializerValue, ModuleVariableDirective, NumericLiteral,
11            VariableDirective, VariableModifier,
12        },
13    },
14};
15
16const DATA_TYPE_NAMES: &[&str] = &[
17    "u8", "u16", "u32", "u64", "s8", "s16", "s32", "s64", "f16", "f16x2", "f32", "f64", "b8",
18    "b16", "b32", "b64", "b128", "pred",
19];
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
22enum VariableDirectiveKind {
23    Tex,
24    Shared,
25    Global,
26    Const,
27    Other,
28}
29
30fn is_data_type_directive(name: &str) -> bool {
31    DATA_TYPE_NAMES.iter().any(|candidate| candidate == &name)
32}
33
34fn is_vector_modifier(name: &str) -> bool {
35    let mut chars = name.chars();
36    match (chars.next(), chars.next()) {
37        (Some('v'), Some(digit)) if digit.is_ascii_digit() => chars.all(|ch| ch.is_ascii_digit()),
38        _ => false,
39    }
40}
41
42fn parse_alignment_value(stream: &mut PtxTokenStream) -> Result<u32, PtxParseError> {
43    let (value, value_span) = parse_u64_literal(stream)?;
44    if value > u32::MAX as u64 {
45        return Err(invalid_literal(
46            value_span,
47            "alignment value exceeds u32 range",
48        ));
49    }
50    Ok(value as u32)
51}
52
53fn parse_numeric_string(text: &str, span: Span) -> Result<u128, PtxParseError> {
54    text.parse::<u128>()
55        .map_err(|_| invalid_literal(span, "invalid integer literal"))
56}
57
58impl PtxParser for NumericLiteral {
59    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
60        let negative = stream
61            .consume_if(|token| matches!(token, PtxToken::Minus))
62            .is_some();
63        let positive = stream
64            .consume_if(|token| matches!(token, PtxToken::Plus))
65            .is_some();
66
67        if negative && positive {
68            let (_, span) = stream.peek()?;
69            return Err(invalid_literal(
70                span.clone(),
71                "cannot have both '+' and '-' signs",
72            ));
73        }
74
75        let (token, span_ref) = stream.consume()?;
76        let span = span_ref.clone();
77        match token {
78            PtxToken::DecimalInteger(text) => {
79                let value = parse_numeric_string(text.as_str(), span.clone())?;
80                if negative {
81                    if value > (i64::MAX as u128) + 1 {
82                        return Err(invalid_literal(span.clone(), "signed integer underflow"));
83                    }
84                    let signed = -(value as i128);
85                    Ok(NumericLiteral::Signed { value: signed as i64, span })
86                } else {
87                    if value > u64::MAX as u128 {
88                        return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
89                    }
90                    Ok(NumericLiteral::Unsigned { value: value as u64, span })
91                }
92            }
93            PtxToken::HexInteger(text) => {
94                let stripped = text
95                    .strip_prefix("0x")
96                    .or_else(|| text.strip_prefix("0X"))
97                    .unwrap_or(text.as_str());
98                let value = u128::from_str_radix(stripped, 16)
99                    .map_err(|_| invalid_literal(span.clone(), "invalid hex literal"))?;
100                if negative {
101                    if value > (i64::MAX as u128) + 1 {
102                        return Err(invalid_literal(span.clone(), "signed integer underflow"));
103                    }
104                    let signed = -(value as i128);
105                    Ok(NumericLiteral::Signed { value: signed as i64, span })
106                } else {
107                    if value > u64::MAX as u128 {
108                        return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
109                    }
110                    Ok(NumericLiteral::Unsigned { value: value as u64, span })
111                }
112            }
113            PtxToken::BinaryInteger(text) => {
114                let stripped = text
115                    .strip_prefix("0b")
116                    .or_else(|| text.strip_prefix("0B"))
117                    .unwrap_or(text.as_str());
118                let value = u128::from_str_radix(stripped, 2)
119                    .map_err(|_| invalid_literal(span.clone(), "invalid binary literal"))?;
120                if negative {
121                    if value > (i64::MAX as u128) + 1 {
122                        return Err(invalid_literal(span.clone(), "signed integer underflow"));
123                    }
124                    let signed = -(value as i128);
125                    Ok(NumericLiteral::Signed { value: signed as i64, span })
126                } else {
127                    if value > u64::MAX as u128 {
128                        return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
129                    }
130                    Ok(NumericLiteral::Unsigned { value: value as u64, span })
131                }
132            }
133            PtxToken::OctalInteger(text) => {
134                let stripped = &text.as_str()[1..];
135                let value = u128::from_str_radix(stripped, 8)
136                    .map_err(|_| invalid_literal(span.clone(), "invalid octal literal"))?;
137                if negative {
138                    if value > (i64::MAX as u128) + 1 {
139                        return Err(invalid_literal(span.clone(), "signed integer underflow"));
140                    }
141                    let signed = -(value as i128);
142                    Ok(NumericLiteral::Signed { value: signed as i64, span })
143                } else {
144                    if value > u64::MAX as u128 {
145                        return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
146                    }
147                    Ok(NumericLiteral::Unsigned { value: value as u64, span })
148                }
149            }
150            PtxToken::Float(text) | PtxToken::FloatExponent(text) => {
151                let mut value = text
152                    .parse::<f64>()
153                    .map_err(|_| invalid_literal(span.clone(), "invalid floating-point literal"))?;
154                if negative {
155                    value = -value;
156                }
157                Ok(NumericLiteral::Float64 { value: value.to_bits(), span })
158            }
159            PtxToken::HexFloat(text) => {
160                if text.len() < 3 {
161                    return Err(invalid_literal(
162                        span.clone(),
163                        "invalid hexadecimal float literal",
164                    ));
165                }
166                let (prefix, digits) = text.split_at(2);
167                match prefix.to_ascii_lowercase().as_str() {
168                    "0f" => {
169                        let mut bits = u32::from_str_radix(digits, 16)
170                            .map_err(|_| invalid_literal(span.clone(), "invalid float literal"))?;
171                        if negative {
172                            bits ^= 0x8000_0000;
173                        }
174                        Ok(NumericLiteral::Float32 { value: bits, span })
175                    }
176                    "0d" => {
177                        let mut bits = u64::from_str_radix(digits, 16)
178                            .map_err(|_| invalid_literal(span.clone(), "invalid float literal"))?;
179                        if negative {
180                            bits ^= 0x8000_0000_0000_0000;
181                        }
182                        Ok(NumericLiteral::Float64 { value: bits, span })
183                    }
184                    _ => Err(invalid_literal(
185                        span.clone(),
186                        "hexadecimal float must start with 0f or 0d",
187                    )),
188                }
189            }
190            _ => Err(unexpected_value(
191                span.clone(),
192                &["numeric literal"],
193                format!("{token:?}"),
194            )),
195        }
196    }
197}
198
199impl PtxParser for InitializerValue {
200    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
201        if let Some((token, span)) = stream.peek().ok() {
202            let span = span.clone();
203            match token {
204                PtxToken::StringLiteral(value) => {
205                    let value = value.clone();
206                    stream.consume()?;
207                    return Ok(InitializerValue::StringLiteral { value, span });
208                }
209                PtxToken::Identifier(_) => {
210                    let (name, span) = stream.expect_identifier()?;
211                    return Ok(InitializerValue::Symbol { name, span: span.clone() });
212                }
213                PtxToken::Plus | PtxToken::Minus => {
214                    let literal = NumericLiteral::parse(stream)?;
215                    let span = literal.span();
216                    return Ok(InitializerValue::Numeric { value: literal, span });
217                }
218                PtxToken::DecimalInteger(_)
219                | PtxToken::HexInteger(_)
220                | PtxToken::BinaryInteger(_)
221                | PtxToken::OctalInteger(_)
222                | PtxToken::Float(_)
223                | PtxToken::FloatExponent(_)
224                | PtxToken::HexFloat(_) => {
225                    let literal = NumericLiteral::parse(stream)?;
226                    let span = literal.span();
227                    return Ok(InitializerValue::Numeric { value: literal, span });
228                }
229                _ => {
230                    return Err(unexpected_value(
231                        span.clone(),
232                        &["numeric literal", "symbol", "string literal"],
233                        format!("{token:?}"),
234                    ));
235                }
236            }
237        }
238        let span = stream.peek()?.1.clone();
239        Err(unexpected_value(
240            span,
241            &["numeric literal", "symbol", "string literal"],
242            "end of input".to_string(),
243        ))
244    }
245}
246
247impl PtxParser for GlobalInitializer {
248    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
249        let start_span = stream.peek()?.1.clone();
250        if stream
251            .consume_if(|token| matches!(token, PtxToken::LBrace))
252            .is_some()
253        {
254            let mut children = Vec::new();
255            if !stream.check(|token| matches!(token, PtxToken::RBrace)) {
256                loop {
257                    let initializer = GlobalInitializer::parse(stream)?;
258                    children.push(initializer);
259                    if !(stream
260                        .consume_if(|token| matches!(token, PtxToken::Comma))
261                        .is_some())
262                    {
263                        break;
264                    }
265                }
266            }
267            let (_, end_span) = stream.expect(&PtxToken::RBrace)?;
268            let span = start_span.start..end_span.end;
269            Ok(GlobalInitializer::Aggregate { values: children, span })
270        } else {
271            let value = InitializerValue::parse(stream)?;
272            let span = value.span();
273            Ok(GlobalInitializer::Scalar { value, span })
274        }
275    }
276}
277
278impl PtxParser for VariableModifier {
279    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
280        let (directive, span_ref) = stream.expect_directive()?;
281        let span = span_ref.clone();
282        match directive.as_str() {
283            "align" => {
284                let value = parse_alignment_value(stream)?;
285                Ok(VariableModifier::Alignment { value, span })
286            }
287            "ptr" => Ok(VariableModifier::Ptr { span }),
288            "visible" => Ok(VariableModifier::Linkage {
289                linkage: DataLinkage::Visible { span: span.clone() },
290                span
291            }),
292            "extern" => Ok(VariableModifier::Linkage {
293                linkage: DataLinkage::Extern { span: span.clone() },
294                span
295            }),
296            "weak" => Ok(VariableModifier::Linkage {
297                linkage: DataLinkage::Weak { span: span.clone() },
298                span
299            }),
300            "common" => Ok(VariableModifier::Linkage {
301                linkage: DataLinkage::Common { span: span.clone() },
302                span
303            }),
304            other if is_vector_modifier(other) => {
305                let digits = &other[1..];
306                let value = digits
307                    .parse::<u32>()
308                    .map_err(|_| invalid_literal(span.clone(), "invalid vector width"))?;
309                Ok(VariableModifier::Vector { value, span })
310            }
311            other => Err(unexpected_value(
312                span.clone(),
313                &[
314                    ".align", ".ptr", ".visible", ".extern", ".weak", ".common", ".vN",
315                ],
316                format!(".{other}"),
317            )),
318        }
319    }
320}
321
322impl VariableDirective {
323    fn parse_with_kind(
324        stream: &mut PtxTokenStream,
325    ) -> Result<(VariableDirective, VariableDirectiveKind, Option<Span>), PtxParseError> {
326        let first_span = stream.peek().ok().map(|(_, span)| span.clone());
327
328        let mut address_space: Option<AddressSpace> = None;
329        let mut attributes = Vec::new();
330        let mut modifiers = Vec::new();
331        let mut ty: Option<DataType> = None;
332        let mut array = Vec::new();
333        let mut initializer = None;
334        let mut seen_tex = false;
335        let mut kind = VariableDirectiveKind::Other;
336        let mut kind_span = None;
337
338        loop {
339            let Some((directive, directive_span)) = peek_directive(stream)? else {
340                break;
341            };
342            match directive.as_str() {
343                "tex" => {
344                    stream.expect_directive()?;
345                    if !seen_tex {
346                        seen_tex = true;
347                        kind = VariableDirectiveKind::Tex;
348                        kind_span = Some(directive_span);
349                    }
350                }
351                "global" | "const" | "shared" | "local" | "param" | "reg" => {
352                    if address_space.is_some() {
353                        return Err(unexpected_value(
354                            directive_span.clone(),
355                            &["single address space qualifier"],
356                            format!(".{directive}"),
357                        ));
358                    }
359                    let space = AddressSpace::parse(stream)?;
360                    match space {
361                        AddressSpace::Global { .. } => {
362                            kind = VariableDirectiveKind::Global;
363                            kind_span = Some(directive_span.clone());
364                        }
365                        AddressSpace::Const { .. } => {
366                            kind = VariableDirectiveKind::Const;
367                            kind_span = Some(directive_span.clone());
368                        }
369                        AddressSpace::Shared { .. } => {
370                            kind = VariableDirectiveKind::Shared;
371                            kind_span = Some(directive_span.clone());
372                        }
373                        _ => {}
374                    }
375                    address_space = Some(space);
376                }
377                "managed" | "unified" => {
378                    attributes.push(AttributeDirective::parse(stream)?);
379                }
380                "align" | "ptr" | "visible" | "extern" | "weak" | "common" => {
381                    modifiers.push(VariableModifier::parse(stream)?);
382                }
383                other if is_vector_modifier(other) => {
384                    modifiers.push(VariableModifier::parse(stream)?);
385                }
386                other if is_data_type_directive(other) => {
387                    if ty.is_some() {
388                        return Err(unexpected_value(
389                            directive_span.clone(),
390                            &["single data type qualifier"],
391                            format!(".{other}"),
392                        ));
393                    }
394                    ty = Some(DataType::parse(stream)?);
395                }
396                _ => break,
397            }
398        }
399
400        let (name, _) = stream.expect_identifier()?;
401
402        loop {
403            if stream
404                .consume_if(|token| matches!(token, PtxToken::LBracket))
405                .is_none()
406            {
407                break;
408            }
409
410            if stream
411                .consume_if(|token| matches!(token, PtxToken::RBracket))
412                .is_some()
413            {
414                array.push(None);
415                continue;
416            }
417
418            let size_span = stream.peek()?.1.clone();
419            let literal = NumericLiteral::parse(stream)?;
420            let size = match literal {
421                NumericLiteral::Unsigned { value, .. } => value,
422                NumericLiteral::Signed { value, .. } if value >= 0 => value as u64,
423                _ => {
424                    return Err(invalid_literal(
425                        size_span.clone(),
426                        "array size must be a non-negative integer",
427                    ));
428                }
429            };
430
431            stream.expect(&PtxToken::RBracket)?;
432            array.push(Some(size));
433        }
434
435        if stream
436            .consume_if(|token| matches!(token, PtxToken::Equals))
437            .is_some()
438        {
439            initializer = Some(GlobalInitializer::parse(stream)?);
440        }
441
442        stream.expect(&PtxToken::Semicolon)?;
443
444        let mut final_kind = kind;
445        if seen_tex {
446            final_kind = VariableDirectiveKind::Tex;
447        } else if matches!(final_kind, VariableDirectiveKind::Other) {
448            final_kind = match address_space {
449                Some(AddressSpace::Shared { .. }) => VariableDirectiveKind::Shared,
450                Some(AddressSpace::Global { .. }) => VariableDirectiveKind::Global,
451                Some(AddressSpace::Const { .. }) => VariableDirectiveKind::Const,
452                _ => VariableDirectiveKind::Other,
453            };
454        }
455
456        let end_span = stream.peek().ok().map(|(_, s)| s.clone()).unwrap_or(0..0);
457        let span = first_span.map(|s| s.start..end_span.end).unwrap_or(0..0);
458
459        let directive = VariableDirective {
460            address_space,
461            attributes,
462            ty,
463            modifiers,
464            name,
465            array,
466            initializer,
467            span: span.clone(),
468        };
469
470        Ok((directive, final_kind, kind_span.or(Some(span))))
471    }
472}
473
474impl PtxParser for VariableDirective {
475    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
476        let (directive, _, _) = VariableDirective::parse_with_kind(stream)?;
477        Ok(directive)
478    }
479}
480
481impl PtxParser for ModuleVariableDirective {
482    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
483        let (directive, kind, span_opt) = VariableDirective::parse_with_kind(stream)?;
484        let span = span_opt.unwrap_or(0..0);
485        match kind {
486            VariableDirectiveKind::Tex => Ok(ModuleVariableDirective::Tex {
487                directive,
488                span
489            }),
490            VariableDirectiveKind::Shared => Ok(ModuleVariableDirective::Shared {
491                directive,
492                span
493            }),
494            VariableDirectiveKind::Global => Ok(ModuleVariableDirective::Global {
495                directive,
496                span
497            }),
498            VariableDirectiveKind::Const => Ok(ModuleVariableDirective::Const {
499                directive,
500                span: span.clone()
501            }),
502            VariableDirectiveKind::Other => Err(unexpected_value(
503                span,
504                &[".tex", ".shared", ".global", ".const"],
505                "variable directive".to_string(),
506            )),
507        }
508    }
509}