ptx_parser/parser/
common.rs

1use std::borrow::Cow;
2
3use crate::{
4    lexer::PtxToken,
5    parser::{ParseErrorKind, PtxParseError, PtxParser, PtxTokenStream, Span},
6    r#type::common::*,
7    r#type::instruction::Inst,
8};
9
10pub(crate) fn unexpected_value(
11    span: Span,
12    expected: &[&str],
13    found: impl Into<Cow<'static, str>>,
14) -> PtxParseError {
15    PtxParseError {
16        kind: ParseErrorKind::UnexpectedToken {
17            expected: expected.iter().map(|s| s.to_string()).collect(),
18            found: found.into().to_string(),
19        },
20        span,
21    }
22}
23
24pub(crate) fn invalid_literal(span: Span, literal: impl Into<Cow<'static, str>>) -> PtxParseError {
25    PtxParseError {
26        kind: ParseErrorKind::InvalidLiteral(literal.into().to_string()),
27        span,
28    }
29}
30
31pub(crate) fn parse_register_name(
32    stream: &mut PtxTokenStream,
33) -> Result<(String, Span), PtxParseError> {
34    let (mut name, mut span) = stream.expect_register()?;
35
36    loop {
37        // Peek to decide whether the next token should be treated as a component.
38        let next = match stream.peek() {
39            Ok((token, _)) => token,
40            Err(_) => break,
41        };
42
43        match next {
44            PtxToken::Dot => {
45                // Peek ahead to see if this is a valid register component
46                if let Some((PtxToken::Identifier(component_name), _)) =
47                    stream.tokens.get(stream.index + 1)
48                {
49                    // Only consume if it's a valid single-character register component
50                    // Exclude multi-character .b* patterns (e.g., .b0, .b3210) which are instruction-specific modifiers
51                    if matches!(
52                        component_name.as_str(),
53                        "x" | "y" | "z" | "w" | "r" | "g" | "b" | "a"
54                    ) {
55                        // consume the dot and identifier
56                        stream.consume()?;
57                        let (component, component_span) = stream.expect_identifier()?;
58
59                        name.push('.');
60                        name.push_str(&component);
61
62                        span.end = component_span.end;
63                    } else {
64                        // Not a valid register component, stop parsing
65                        break;
66                    }
67                } else {
68                    break;
69                }
70            }
71            _ => break,
72        }
73    }
74
75    Ok((name, span))
76}
77
78pub(crate) fn numeric_literal(token: &PtxToken) -> Option<&String> {
79    match token {
80        PtxToken::DecimalInteger(value)
81        | PtxToken::HexInteger(value)
82        | PtxToken::BinaryInteger(value)
83        | PtxToken::OctalInteger(value)
84        | PtxToken::FloatExponent(value)
85        | PtxToken::Float(value)
86        | PtxToken::HexFloat(value) => Some(value),
87        _ => None,
88    }
89}
90
91pub(crate) fn is_numeric_token(token: &PtxToken) -> bool {
92    numeric_literal(token).is_some()
93}
94
95pub(crate) fn parse_u64_literal(stream: &mut PtxTokenStream) -> Result<(u64, Span), PtxParseError> {
96    let (token, span) = stream.consume()?;
97    let span = span.clone();
98
99    let value = match token {
100        PtxToken::DecimalInteger(text) => text
101            .parse::<u64>()
102            .map_err(|_| invalid_literal(span.clone(), text.clone()))?,
103        PtxToken::HexInteger(text) => {
104            let stripped = text
105                .strip_prefix("0x")
106                .or_else(|| text.strip_prefix("0X"))
107                .ok_or_else(|| invalid_literal(span.clone(), text.clone()))?;
108            u64::from_str_radix(stripped, 16)
109                .map_err(|_| invalid_literal(span.clone(), text.clone()))?
110        }
111        PtxToken::BinaryInteger(text) => {
112            let stripped = text
113                .strip_prefix("0b")
114                .or_else(|| text.strip_prefix("0B"))
115                .ok_or_else(|| invalid_literal(span.clone(), text.clone()))?;
116            u64::from_str_radix(stripped, 2)
117                .map_err(|_| invalid_literal(span.clone(), text.clone()))?
118        }
119        PtxToken::OctalInteger(text) => {
120            let stripped = &text[1..];
121            u64::from_str_radix(stripped, 8)
122                .map_err(|_| invalid_literal(span.clone(), text.clone()))?
123        }
124        _ => {
125            return Err(unexpected_value(
126                span,
127                &["unsigned integer literal"],
128                format!("{token:?}"),
129            ));
130        }
131    };
132
133    Ok((value, span))
134}
135
136impl PtxParser for CodeLinkage {
137    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
138        let (directive, span) = stream.expect_directive()?;
139        match directive.as_str() {
140            "visible" => Ok(CodeLinkage::Visible { span }),
141            "extern" => Ok(CodeLinkage::Extern { span }),
142            "weak" => Ok(CodeLinkage::Weak { span }),
143            other => Err(unexpected_value(
144                span,
145                &[".visible", ".extern", ".weak"],
146                format!(".{other}"),
147            )),
148        }
149    }
150}
151
152impl PtxParser for DataLinkage {
153    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
154        let (directive, span) = stream.expect_directive()?;
155        match directive.as_str() {
156            "visible" => Ok(DataLinkage::Visible { span }),
157            "extern" => Ok(DataLinkage::Extern { span }),
158            "weak" => Ok(DataLinkage::Weak { span }),
159            "common" => Ok(DataLinkage::Common { span }),
160            other => Err(unexpected_value(
161                span,
162                &[".visible", ".extern", ".weak", ".common"],
163                format!(".{other}"),
164            )),
165        }
166    }
167}
168
169impl PtxParser for CodeOrDataLinkage {
170    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
171        let (directive, span) = stream.expect_directive()?;
172        match directive.as_str() {
173            "visible" => Ok(CodeOrDataLinkage::Visible { span }),
174            "extern" => Ok(CodeOrDataLinkage::Extern { span }),
175            "weak" => Ok(CodeOrDataLinkage::Weak { span }),
176            "common" => Ok(CodeOrDataLinkage::Common { span }),
177            other => Err(unexpected_value(
178                span,
179                &[".visible", ".extern", ".weak", ".common"],
180                format!(".{other}"),
181            )),
182        }
183    }
184}
185
186impl PtxParser for TexType {
187    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
188        let (directive, span) = stream.expect_directive()?;
189        match directive.as_str() {
190            "texref" => Ok(TexType::TexRef { span }),
191            "samplerref" => Ok(TexType::SamplerRef { span }),
192            "surfref" => Ok(TexType::SurfRef { span }),
193            other => Err(unexpected_value(
194                span,
195                &[".texref", ".samplerref", ".surfref"],
196                format!(".{other}"),
197            )),
198        }
199    }
200}
201
202impl PtxParser for AddressSpace {
203    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
204        let (directive, span) = stream.expect_directive()?;
205        match directive.as_str() {
206            "global" => Ok(AddressSpace::Global { span }),
207            "const" => Ok(AddressSpace::Const { span }),
208            "shared" => Ok(AddressSpace::Shared { span }),
209            "local" => Ok(AddressSpace::Local { span }),
210            "param" => Ok(AddressSpace::Param { span }),
211            "reg" => Ok(AddressSpace::Reg { span }),
212            other => Err(unexpected_value(
213                span,
214                &[".global", ".const", ".shared", ".local", ".param", ".reg"],
215                format!(".{other}"),
216            )),
217        }
218    }
219}
220
221impl PtxParser for AttributeDirective {
222    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
223        let (directive, span) = stream.expect_directive()?;
224        match directive.as_str() {
225            "unified" => {
226                stream.expect(&PtxToken::LParen)?;
227                let (uuid1, _) = parse_u64_literal(stream)?;
228                stream.expect(&PtxToken::Comma)?;
229                let (uuid2, _) = parse_u64_literal(stream)?;
230                stream.expect(&PtxToken::RParen)?;
231                Ok(AttributeDirective::Unified { uuid1, uuid2, span })
232            }
233            "managed" => Ok(AttributeDirective::Managed { span }),
234            other => Err(unexpected_value(
235                span,
236                &[".unified", ".managed"],
237                format!(".{other}"),
238            )),
239        }
240    }
241}
242
243impl PtxParser for DataType {
244    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
245        let (directive, span) = stream.expect_directive()?;
246        match directive.as_str() {
247            "u8" => Ok(DataType::U8 { span }),
248            "u16" => Ok(DataType::U16 { span }),
249            "u32" => Ok(DataType::U32 { span }),
250            "u64" => Ok(DataType::U64 { span }),
251            "s8" => Ok(DataType::S8 { span }),
252            "s16" => Ok(DataType::S16 { span }),
253            "s32" => Ok(DataType::S32 { span }),
254            "s64" => Ok(DataType::S64 { span }),
255            "f16" => Ok(DataType::F16 { span }),
256            "f16x2" => Ok(DataType::F16x2 { span }),
257            "f32" => Ok(DataType::F32 { span }),
258            "f64" => Ok(DataType::F64 { span }),
259            "b8" => Ok(DataType::B8 { span }),
260            "b16" => Ok(DataType::B16 { span }),
261            "b32" => Ok(DataType::B32 { span }),
262            "b64" => Ok(DataType::B64 { span }),
263            "b128" => Ok(DataType::B128 { span }),
264            "pred" => Ok(DataType::Pred { span }),
265            other => Err(unexpected_value(
266                span,
267                &[
268                    ".u8", ".u16", ".u32", ".u64", ".s8", ".s16", ".s32", ".s64", ".f16", ".f16x2",
269                    ".f32", ".f64", ".b8", ".b16", ".b32", ".b64", ".b128", ".pred",
270                ],
271                format!(".{other}"),
272            )),
273        }
274    }
275}
276
277impl PtxParser for Sign {
278    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
279        if let Some((_, span)) = stream
280            .consume_if(|token| matches!(token, PtxToken::Plus))
281        {
282            return Ok(Sign::Positive { span: span.clone() });
283        }
284        if let Some((_, span)) = stream
285            .consume_if(|token| matches!(token, PtxToken::Minus))
286        {
287            return Ok(Sign::Negative { span: span.clone() });
288        }
289
290        let (token, span) = stream.peek()?;
291        Err(unexpected_value(
292            span.clone(),
293            &["+", "-"],
294            format!("{token:?}"),
295        ))
296    }
297}
298
299impl PtxParser for Immediate {
300    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
301        // Check for optional minus sign
302        let minus_span = stream
303            .consume_if(|token| matches!(token, PtxToken::Minus))
304            .map(|(_, span)| span.clone());
305
306        let (token, span) = stream.peek()?;
307        let value = numeric_literal(token).cloned();
308        match value {
309            Some(value) => {
310                let literal = if minus_span.is_some() {
311                    format!("-{}", value)
312                } else {
313                    value.clone()
314                };
315                let (_, value_span) = stream.consume()?;
316                let full_span = if let Some(ref ms) = minus_span {
317                    Span { start: ms.start, end: value_span.end }
318                } else {
319                    value_span.clone()
320                };
321                Ok(Immediate { value: literal, span: full_span })
322            }
323            None => {
324                // If we consumed a minus, we need to restore position by going back one token
325                if minus_span.is_some() {
326                    let mut current_pos = stream.position();
327                    if current_pos.index > 0 {
328                        current_pos.index -= 1;
329                        current_pos.char_offset = 0;
330                        stream.set_position(current_pos);
331                    }
332                }
333                Err(unexpected_value(
334                    span.clone(),
335                    &["numeric literal"],
336                    format!("{token:?}"),
337                ))
338            }
339        }
340    }
341}
342
343impl PtxParser for RegisterOperand {
344    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
345        if !stream.check(|token| matches!(token, PtxToken::Register(_))) {
346            let (token, span) = stream.peek()?;
347            return Err(unexpected_value(
348                span.clone(),
349                &["register"],
350                format!("{token:?}"),
351            ));
352        }
353        let (name, span) = parse_register_name(stream)?;
354        Ok(RegisterOperand { name, span })
355    }
356}
357
358impl PtxParser for PredicateRegister {
359    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
360        let (name, span) = parse_register_name(stream)?;
361        if name.starts_with("%p") {
362            Ok(PredicateRegister { name, span })
363        } else {
364            Err(invalid_literal(
365                span,
366                format!("expected predicate register starting with %p, found {name}"),
367            ))
368        }
369    }
370}
371
372impl PtxParser for Label {
373    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
374        let (name, span) = stream.expect_identifier()?;
375        Ok(Label { name, span })
376    }
377}
378
379impl PtxParser for SpecialRegister {
380    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
381        let (name, span) = parse_register_name(stream)?;
382        // Preserve component information (.x/.y/.z) for certain special registers.
383        // If a component is present, return the axis-aware variant; otherwise fall through
384        // to the general match below.
385        let name_str = name.as_str();
386        if let Some(rest) = name_str.strip_prefix("%cluster_ctaid") {
387            if rest.is_empty() {
388                return Ok(SpecialRegister::ClusterCtaid { axis: Axis::None { span: span.clone() }, span });
389            } else if rest == ".x" {
390                return Ok(SpecialRegister::ClusterCtaid { axis: Axis::X { span: span.clone() }, span });
391            } else if rest == ".y" {
392                return Ok(SpecialRegister::ClusterCtaid { axis: Axis::Y { span: span.clone() }, span });
393            } else if rest == ".z" {
394                return Ok(SpecialRegister::ClusterCtaid { axis: Axis::Z { span: span.clone() }, span });
395            }
396        }
397        if let Some(rest) = name_str.strip_prefix("%cluster_ctarank") {
398            if rest.is_empty() {
399                return Ok(SpecialRegister::ClusterCtarank { axis: Axis::None { span: span.clone() }, span });
400            } else if rest == ".x" {
401                return Ok(SpecialRegister::ClusterCtarank { axis: Axis::X { span: span.clone() }, span });
402            } else if rest == ".y" {
403                return Ok(SpecialRegister::ClusterCtarank { axis: Axis::Y { span: span.clone() }, span });
404            } else if rest == ".z" {
405                return Ok(SpecialRegister::ClusterCtarank { axis: Axis::Z { span: span.clone() }, span });
406            }
407        }
408        if let Some(rest) = name_str.strip_prefix("%nctaid") {
409            if rest.is_empty() {
410                return Ok(SpecialRegister::Nctaid { axis: Axis::None { span: span.clone() }, span });
411            } else if rest == ".x" {
412                return Ok(SpecialRegister::Nctaid { axis: Axis::X { span: span.clone() }, span });
413            } else if rest == ".y" {
414                return Ok(SpecialRegister::Nctaid { axis: Axis::Y { span: span.clone() }, span });
415            } else if rest == ".z" {
416                return Ok(SpecialRegister::Nctaid { axis: Axis::Z { span: span.clone() }, span });
417            }
418        }
419        if let Some(rest) = name_str.strip_prefix("%tid") {
420            if rest.is_empty() {
421                return Ok(SpecialRegister::Tid { axis: Axis::None { span: span.clone() }, span });
422            } else if rest == ".x" {
423                return Ok(SpecialRegister::Tid { axis: Axis::X { span: span.clone() }, span });
424            } else if rest == ".y" {
425                return Ok(SpecialRegister::Tid { axis: Axis::Y { span: span.clone() }, span });
426            } else if rest == ".z" {
427                return Ok(SpecialRegister::Tid { axis: Axis::Z { span: span.clone() }, span });
428            }
429        }
430        if let Some(rest) = name_str.strip_prefix("%cluster_nctaid") {
431            if rest.is_empty() {
432                return Ok(SpecialRegister::ClusterNctaid { axis: Axis::None { span: span.clone() }, span });
433            } else if rest == ".x" {
434                return Ok(SpecialRegister::ClusterNctaid { axis: Axis::X { span: span.clone() }, span });
435            } else if rest == ".y" {
436                return Ok(SpecialRegister::ClusterNctaid { axis: Axis::Y { span: span.clone() }, span });
437            } else if rest == ".z" {
438                return Ok(SpecialRegister::ClusterNctaid { axis: Axis::Z { span: span.clone() }, span });
439            }
440        }
441        if let Some(rest) = name_str.strip_prefix("%cluster_nctarank") {
442            if rest.is_empty() {
443                return Ok(SpecialRegister::ClusterNctarank { axis: Axis::None { span: span.clone() }, span });
444            } else if rest == ".x" {
445                return Ok(SpecialRegister::ClusterNctarank { axis: Axis::X { span: span.clone() }, span });
446            } else if rest == ".y" {
447                return Ok(SpecialRegister::ClusterNctarank { axis: Axis::Y { span: span.clone() }, span });
448            } else if rest == ".z" {
449                return Ok(SpecialRegister::ClusterNctarank { axis: Axis::Z { span: span.clone() }, span });
450            }
451        }
452        if let Some(rest) = name_str.strip_prefix("%ntid") {
453            if rest.is_empty() {
454                return Ok(SpecialRegister::Ntid { axis: Axis::None { span: span.clone() }, span });
455            } else if rest == ".x" {
456                return Ok(SpecialRegister::Ntid { axis: Axis::X { span: span.clone() }, span });
457            } else if rest == ".y" {
458                return Ok(SpecialRegister::Ntid { axis: Axis::Y { span: span.clone() }, span });
459            } else if rest == ".z" {
460                return Ok(SpecialRegister::Ntid { axis: Axis::Z { span: span.clone() }, span });
461            }
462        }
463        if let Some(rest) = name_str.strip_prefix("%ctaid") {
464            if rest.is_empty() {
465                return Ok(SpecialRegister::Ctaid { axis: Axis::None { span: span.clone() }, span });
466            } else if rest == ".x" {
467                return Ok(SpecialRegister::Ctaid { axis: Axis::X { span: span.clone() }, span });
468            } else if rest == ".y" {
469                return Ok(SpecialRegister::Ctaid { axis: Axis::Y { span: span.clone() }, span });
470            } else if rest == ".z" {
471                return Ok(SpecialRegister::Ctaid { axis: Axis::Z { span: span.clone() }, span });
472            }
473        }
474
475        match name.as_str() {
476            "%aggr_smem_size" => Ok(SpecialRegister::AggrSmemSize { span }),
477            "%dynamic_smem_size" => Ok(SpecialRegister::DynamicSmemSize { span }),
478            "%lanemask_gt" => Ok(SpecialRegister::LanemaskGt { span }),
479            "%reserved_smem_offset_begin" => Ok(SpecialRegister::ReservedSmemOffsetBegin { span }),
480            "%clock" => Ok(SpecialRegister::Clock { span }),
481            "%lanemask_le" => Ok(SpecialRegister::LanemaskLe { span }),
482            "%reserved_smem_offset_cap" => Ok(SpecialRegister::ReservedSmemOffsetCap { span }),
483            "%clock64" => Ok(SpecialRegister::Clock64 { span }),
484            "%globaltimer" => Ok(SpecialRegister::Globaltimer { span }),
485            "%lanemask_lt" => Ok(SpecialRegister::LanemaskLt { span }),
486            "%reserved_smem_offset_end" => Ok(SpecialRegister::ReservedSmemOffsetEnd { span }),
487            "%cluster_ctaid" | "%cluster_ctaid.x" | "%cluster_ctaid.y" | "%cluster_ctaid.z" => {
488                Ok(SpecialRegister::ClusterCtaid { axis: Axis::None { span: span.clone() }, span })
489            }
490            "%globaltimer_hi" => Ok(SpecialRegister::GlobaltimerHi { span }),
491            "%nclusterid" => Ok(SpecialRegister::Nclusterid { span }),
492            "%smid" => Ok(SpecialRegister::Smid { span }),
493            "%cluster_ctarank" | "%cluster_ctarank.x" | "%cluster_ctarank.y"
494            | "%cluster_ctarank.z" => Ok(SpecialRegister::ClusterCtarank { axis: Axis::None { span: span.clone() }, span }),
495            "%globaltimer_lo" => Ok(SpecialRegister::GlobaltimerLo { span }),
496            "%nctaid" | "%nctaid.x" | "%nctaid.y" | "%nctaid.z" => {
497                Ok(SpecialRegister::Nctaid { axis: Axis::None { span: span.clone() }, span })
498            }
499            "%tid" | "%tid.x" | "%tid.y" | "%tid.z" => Ok(SpecialRegister::Tid { axis: Axis::None { span: span.clone() }, span }),
500            "%cluster_nctaid" | "%cluster_nctaid.x" | "%cluster_nctaid.y" | "%cluster_nctaid.z" => {
501                Ok(SpecialRegister::ClusterNctaid { axis: Axis::None { span: span.clone() }, span })
502            }
503            "%gridid" => Ok(SpecialRegister::Gridid { span }),
504            "%nsmid" => Ok(SpecialRegister::Nsmid { span }),
505            "%total_smem_size" => Ok(SpecialRegister::TotalSmemSize { span }),
506            "%cluster_nctarank"
507            | "%cluster_nctarank.x"
508            | "%cluster_nctarank.y"
509            | "%cluster_nctarank.z" => Ok(SpecialRegister::ClusterNctarank { axis: Axis::None { span: span.clone() }, span }),
510            "%is_explicit_cluster" => Ok(SpecialRegister::IsExplicitCluster { span }),
511            "%ntid" | "%ntid.x" | "%ntid.y" | "%ntid.z" => Ok(SpecialRegister::Ntid { axis: Axis::None { span: span.clone() }, span }),
512            "%warpid" => Ok(SpecialRegister::Warpid { span }),
513            "%clusterid" => Ok(SpecialRegister::Clusterid { span }),
514            "%laneid" => Ok(SpecialRegister::Laneid { span }),
515            "%nwarpid" => Ok(SpecialRegister::Nwarpid { span }),
516            "%WARPSZ" => Ok(SpecialRegister::WARPSZ { span }),
517            "%ctaid" | "%ctaid.x" | "%ctaid.y" | "%ctaid.z" => {
518                Ok(SpecialRegister::Ctaid { axis: Axis::None { span: span.clone() }, span })
519            }
520            "%lanemask_eq" => Ok(SpecialRegister::LanemaskEq { span }),
521            "%current_graph_exec" => Ok(SpecialRegister::CurrentGraphExec { span }),
522            "%lanemask_ge" => Ok(SpecialRegister::LanemaskGe { span }),
523            other => {
524                if let Some(num) = other.strip_prefix("%envreg") {
525                    let value = num
526                        .parse::<u8>()
527                        .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
528                    if value <= 31 {
529                        return Ok(SpecialRegister::Envreg { index: value, span });
530                    }
531                    return Err(invalid_literal(
532                        span,
533                        format!("envreg index out of range: {value}"),
534                    ));
535                }
536
537                if let Some(num) = other.strip_prefix("%pm") {
538                    if let Some(rest) = num.strip_suffix("_64") {
539                        let value = rest
540                            .parse::<u8>()
541                            .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
542                        if value <= 7 {
543                            return Ok(SpecialRegister::Pm64 { index: value, span });
544                        }
545                        return Err(invalid_literal(
546                            span,
547                            format!("pm index out of range: {value}"),
548                        ));
549                    }
550
551                    let value = num
552                        .parse::<u8>()
553                        .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
554                    if value <= 7 {
555                        return Ok(SpecialRegister::Pm { index: value, span });
556                    }
557                    return Err(invalid_literal(
558                        span,
559                        format!("pm index out of range: {value}"),
560                    ));
561                }
562
563                if let Some(num) = other.strip_prefix("%reserved_smem_offset_") {
564                    let value = num
565                        .parse::<u8>()
566                        .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
567                    if value <= 1 {
568                        return Ok(SpecialRegister::ReservedSmemOffset { index: value, span });
569                    }
570                    return Err(invalid_literal(
571                        span,
572                        format!("reserved_smem_offset index out of range: {value}"),
573                    ));
574                }
575
576                Err(invalid_literal(
577                    span,
578                    format!("unknown special register {name}"),
579                ))
580            }
581        }
582    }
583}
584
585impl PtxParser for Operand {
586    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
587        let saved_pos = stream.position();
588        if let Ok(immediate) = Immediate::parse(stream) {
589            let span = immediate.span.clone();
590            return Ok(Operand::Immediate { operand: immediate, span });
591        }
592        stream.set_position(saved_pos);
593
594        if stream.check(|token| matches!(token, PtxToken::Register(_))) {
595            let register = RegisterOperand::parse(stream)?;
596            let span = register.span.clone();
597            return Ok(Operand::Register { operand: register, span });
598        }
599
600        if stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
601            let (identifier, ident_span) = stream.expect_identifier()?;
602
603            // Check for arithmetic expression: identifier + immediate
604            let saved_pos_after_ident = stream.position();
605            if stream.expect(&PtxToken::Plus).is_ok() {
606                if let Ok(offset) = Immediate::parse(stream) {
607                    let span = Span { start: ident_span.start, end: offset.span.end };
608                    return Ok(Operand::SymbolOffset { symbol: identifier, offset, span });
609                }
610                // If parsing offset failed, backtrack
611                stream.set_position(saved_pos_after_ident);
612            }
613
614            return Ok(Operand::Symbol { name: identifier, span: ident_span });
615        }
616
617        let (token, span) = stream.peek()?;
618        Err(unexpected_value(
619            span.clone(),
620            &["operand"],
621            format!("{token:?}"),
622        ))
623    }
624}
625
626impl PtxParser for VectorOperand {
627    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
628        let (_, brace_span) = stream.expect(&PtxToken::LBrace)?;
629        let mut operands = Vec::new();
630
631        loop {
632            operands.push(Operand::parse(stream)?);
633            if stream
634                .consume_if(|token| matches!(token, PtxToken::Comma))
635                .is_some()
636            {
637                continue;
638            }
639            break;
640        }
641
642        let (_, end_span) = stream.expect(&PtxToken::RBrace)?;
643        let span = Span { start: brace_span.start, end: end_span.end };
644
645        match operands.len() {
646            1 => Ok(VectorOperand::Vector1 { operand: operands.remove(0), span }),
647            2 => Ok(VectorOperand::Vector2 { operands: [
648                operands.remove(0),
649                operands.remove(0),
650            ], span }),
651            3 => Ok(VectorOperand::Vector3 { operands: [
652                operands.remove(0),
653                operands.remove(0),
654                operands.remove(0),
655            ], span }),
656            4 => Ok(VectorOperand::Vector4 { operands: [
657                operands.remove(0),
658                operands.remove(0),
659                operands.remove(0),
660                operands.remove(0),
661            ], span }),
662            8 => Ok(VectorOperand::Vector8 { operands: [
663                operands.remove(0),
664                operands.remove(0),
665                operands.remove(0),
666                operands.remove(0),
667                operands.remove(0),
668                operands.remove(0),
669                operands.remove(0),
670                operands.remove(0),
671            ], span }),
672            other => Err(invalid_literal(
673                brace_span.clone(),
674                format!("expected operand vector of length 1..=4 or 8, found {other}"),
675            )),
676        }
677    }
678}
679
680impl PtxParser for GeneralOperand {
681    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
682        if stream.check(|token| matches!(token, PtxToken::LBrace)) {
683            let vec_operand = VectorOperand::parse(stream)?;
684            let span = vec_operand.span();
685            Ok(GeneralOperand::Vec { operand: vec_operand, span })
686        } else {
687            let operand = Operand::parse(stream)?;
688            let span = operand.span();
689            Ok(GeneralOperand::Single { operand, span })
690        }
691    }
692}
693
694impl PtxParser for TexHandler2 {
695    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
696        let (_, start_span) = stream.expect(&PtxToken::LBracket)?;
697        let first = GeneralOperand::parse(stream)?;
698        stream.expect(&PtxToken::Comma)?;
699        let second = GeneralOperand::parse(stream)?;
700        let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
701        let span = Span { start: start_span.start, end: end_span.end };
702        Ok(TexHandler2 { operands: [first, second], span })
703    }
704}
705
706impl PtxParser for TexHandler3 {
707    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
708        let (_, start_span) = stream.expect(&PtxToken::LBracket)?;
709        let handle = GeneralOperand::parse(stream)?;
710        stream.expect(&PtxToken::Comma)?;
711        let sampler = GeneralOperand::parse(stream)?;
712        stream.expect(&PtxToken::Comma)?;
713        let coords = GeneralOperand::parse(stream)?;
714        let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
715        let span = Span { start: start_span.start, end: end_span.end };
716
717        Ok(TexHandler3 {
718            handle,
719            sampler,
720            coords,
721            span,
722        })
723    }
724}
725
726impl PtxParser for TexHandler3Optional {
727    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
728        let (_, start_span) = stream.expect(&PtxToken::LBracket)?;
729        let handle = GeneralOperand::parse(stream)?;
730        stream.expect(&PtxToken::Comma)?;
731        let second = GeneralOperand::parse(stream)?;
732
733        let (sampler, coords) = if stream
734            .consume_if(|token| matches!(token, PtxToken::Comma))
735            .is_some()
736        {
737            let coords = GeneralOperand::parse(stream)?;
738            (Some(second), coords)
739        } else {
740            (None, second)
741        };
742
743        let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
744        let span = Span { start: start_span.start, end: end_span.end };
745
746        Ok(TexHandler3Optional {
747            handle,
748            sampler,
749            coords,
750            span,
751        })
752    }
753}
754
755impl PtxParser for AddressBase {
756    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
757        if stream.check(|token| matches!(token, PtxToken::Register(_))) {
758            let register = RegisterOperand::parse(stream)?;
759            let span = register.span.clone();
760            Ok(AddressBase::Register { operand: register, span })
761        } else if stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
762            let variable = VariableSymbol::parse(stream)?;
763            let span = variable.span.clone();
764            Ok(AddressBase::Variable { symbol: variable, span })
765        } else {
766            let (token, span) = stream.peek()?;
767            Err(unexpected_value(
768                span.clone(),
769                &["register", "identifier"],
770                format!("{token:?}"),
771            ))
772        }
773    }
774}
775
776impl PtxParser for AddressOffset {
777    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
778        if let Some((_, plus_span)) = stream
779            .consume_if(|token| matches!(token, PtxToken::Plus))
780        {
781            if stream.check(|token| matches!(token, PtxToken::Register(_))) {
782                let register = RegisterOperand::parse(stream)?;
783                let span = Span { start: plus_span.start, end: register.span.end };
784                Ok(AddressOffset::Register { operand: register, span })
785            } else {
786                let sign = Sign::Positive { span: plus_span.clone() };
787                let value = Immediate::parse(stream)?;
788                let span = Span { start: plus_span.start, end: value.span.end };
789                Ok(AddressOffset::Immediate { sign, value, span })
790            }
791        } else if let Some((_, minus_span)) = stream
792            .consume_if(|token| matches!(token, PtxToken::Minus))
793        {
794            let sign = Sign::Negative { span: minus_span.clone() };
795            let value = Immediate::parse(stream)?;
796            let span = Span { start: minus_span.start, end: value.span.end };
797            Ok(AddressOffset::Immediate { sign, value, span })
798        } else {
799            let (token, span) = stream.peek()?;
800            Err(unexpected_value(
801                span.clone(),
802                &["+", "-"],
803                format!("{token:?}"),
804            ))
805        }
806    }
807}
808
809impl PtxParser for AddressOperand {
810    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
811        if stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
812            let saved = stream.position();
813            let (identifier, ident_span) = stream.expect_identifier()?;
814            if stream
815                .consume_if(|token| matches!(token, PtxToken::LBracket))
816                .is_some()
817            {
818                let immediate = Immediate::parse(stream)?;
819                let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
820                let span = Span { start: ident_span.start, end: end_span.end };
821                return Ok(AddressOperand::Array { base: VariableSymbol { name: identifier, span: ident_span }, index: immediate, span });
822            } else {
823                stream.set_position(saved);
824            }
825        }
826
827        let (_, start_span) = stream.expect(&PtxToken::LBracket)?;
828
829        if stream.check(|token| matches!(token, PtxToken::Minus)) {
830            let pos = stream.position();
831            stream.consume()?;
832            if stream.check(|token| is_numeric_token(token)) {
833                let mut immediate = Immediate::parse(stream)?;
834                immediate.value.insert(0, '-');
835                let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
836                let span = Span { start: start_span.start, end: end_span.end };
837                return Ok(AddressOperand::ImmediateAddress { addr: immediate, span });
838            } else {
839                stream.set_position(pos);
840            }
841        }
842
843        if stream.check(|token| is_numeric_token(token)) {
844            let immediate = Immediate::parse(stream)?;
845            let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
846            let span = Span { start: start_span.start, end: end_span.end };
847            return Ok(AddressOperand::ImmediateAddress { addr: immediate, span });
848        }
849
850        let base = AddressBase::parse(stream)?;
851        let offset = if stream.check(|token| matches!(token, PtxToken::Plus | PtxToken::Minus)) {
852            Some(AddressOffset::parse(stream)?)
853        } else {
854            None
855        };
856        let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
857        let span = Span { start: start_span.start, end: end_span.end };
858
859        Ok(AddressOperand::Offset { base, offset, span })
860    }
861}
862
863impl PtxParser for FunctionSymbol {
864    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
865        let (name, span) = stream.expect_identifier()?;
866        Ok(FunctionSymbol { name, span })
867    }
868}
869
870impl PtxParser for VariableSymbol {
871    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
872        let (name, span) = stream.expect_identifier()?;
873        Ok(VariableSymbol { name, span })
874    }
875}
876
877/// Try to parse an optional label (identifier followed by colon).
878/// Returns `Ok(Some(label))` if a label is found, `Ok(None)` if not,
879/// or `Err` if parsing fails.
880pub(crate) fn try_parse_label(
881    stream: &mut PtxTokenStream,
882) -> Result<Option<String>, PtxParseError> {
883    if !stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
884        return Ok(None);
885    }
886
887    let position = stream.position();
888    let (name, _) = stream.expect_identifier()?;
889    if stream
890        .consume_if(|token| matches!(token, PtxToken::Colon))
891        .is_some()
892    {
893        Ok(Some(name))
894    } else {
895        stream.set_position(position);
896        Ok(None)
897    }
898}
899
900impl PtxParser for Instruction {
901    /// Parse a PTX instruction with optional predicate
902    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
903        let start_pos = stream.position();
904
905        // Optional predicate: @{!}pred or @!pred
906        let predicate = if stream.check(|t| matches!(t, PtxToken::At)) {
907            let (_, at_span) = stream.consume()?; // consume @
908
909            // Optional negation
910            let negated = stream
911                .consume_if(|t| matches!(t, PtxToken::Exclaim))
912                .is_some();
913
914            // Predicate operand (can be register %p1 or identifier p)
915            let operand = Operand::parse(stream)?;
916            let pred_span = Span { start: at_span.start, end: operand.span().end };
917
918            Some(Predicate { negated, operand, span: pred_span })
919        } else {
920            None
921        };
922
923        // Parse the actual instruction using the module-level dispatcher
924        let inst = crate::parser::instruction::parse_instruction_inner(stream)?;
925
926        // Calculate span from the start to the end of the instruction
927        let end_pos = stream.position();
928        let span = if let Some(ref pred) = predicate {
929            Span { start: pred.span.start, end: end_pos.char_offset as usize }
930        } else {
931            Span { start: start_pos.char_offset as usize, end: end_pos.char_offset as usize }
932        };
933
934        Ok(Instruction { predicate, inst, span })
935    }
936}
937
938// Backwards compatibility: Inst can still be parsed directly
939impl PtxParser for Inst {
940    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
941        Ok(Instruction::parse(stream)?.inst)
942    }
943}