ptx_parser/parser/
common.rs

1use std::borrow::Cow;
2
3use crate::{
4    lexer::PtxToken,
5    parser::{ParseErrorKind, PtxParseError, PtxParser, PtxTokenStream, Span},
6    r#type::common::*,
7    r#type::instruction::Inst,
8};
9
10pub(crate) fn unexpected_value(
11    span: Span,
12    expected: &[&str],
13    found: impl Into<Cow<'static, str>>,
14) -> PtxParseError {
15    PtxParseError {
16        kind: ParseErrorKind::UnexpectedToken {
17            expected: expected.iter().map(|s| s.to_string()).collect(),
18            found: found.into().to_string(),
19        },
20        span,
21    }
22}
23
24pub(crate) fn invalid_literal(span: Span, literal: impl Into<Cow<'static, str>>) -> PtxParseError {
25    PtxParseError {
26        kind: ParseErrorKind::InvalidLiteral(literal.into().to_string()),
27        span,
28    }
29}
30
31pub(crate) fn parse_register_name(
32    stream: &mut PtxTokenStream,
33) -> Result<(String, Span), PtxParseError> {
34    let (mut name, mut span) = stream.expect_register()?;
35
36    loop {
37        // Peek to decide whether the next token should be treated as a component.
38        let next = match stream.peek() {
39            Ok((token, _)) => token,
40            Err(_) => break,
41        };
42
43        match next {
44            PtxToken::Dot => {
45                // Peek ahead to see if this is a valid register component
46                if let Some((PtxToken::Identifier(component_name), _)) =
47                    stream.tokens.get(stream.index + 1)
48                {
49                    // Only consume if it's a valid single-character register component
50                    // Exclude multi-character .b* patterns (e.g., .b0, .b3210) which are instruction-specific modifiers
51                    if matches!(
52                        component_name.as_str(),
53                        "x" | "y" | "z" | "w" | "r" | "g" | "b" | "a"
54                    ) {
55                        // consume the dot and identifier
56                        stream.consume()?;
57                        let (component, component_span) = stream.expect_identifier()?;
58
59                        name.push('.');
60                        name.push_str(&component);
61
62                        span.end = component_span.end;
63                    } else {
64                        // Not a valid register component, stop parsing
65                        break;
66                    }
67                } else {
68                    break;
69                }
70            }
71            _ => break,
72        }
73    }
74
75    Ok((name, span))
76}
77
78pub(crate) fn numeric_literal(token: &PtxToken) -> Option<&String> {
79    match token {
80        PtxToken::DecimalInteger(value)
81        | PtxToken::HexInteger(value)
82        | PtxToken::BinaryInteger(value)
83        | PtxToken::OctalInteger(value)
84        | PtxToken::FloatExponent(value)
85        | PtxToken::Float(value)
86        | PtxToken::HexFloat(value) => Some(value),
87        _ => None,
88    }
89}
90
91pub(crate) fn is_numeric_token(token: &PtxToken) -> bool {
92    numeric_literal(token).is_some()
93}
94
95pub(crate) fn parse_u64_literal(stream: &mut PtxTokenStream) -> Result<(u64, Span), PtxParseError> {
96    let (token, span) = stream.consume()?;
97    let span = span.clone();
98
99    let value = match token {
100        PtxToken::DecimalInteger(text) => text
101            .parse::<u64>()
102            .map_err(|_| invalid_literal(span.clone(), text.clone()))?,
103        PtxToken::HexInteger(text) => {
104            let stripped = text
105                .strip_prefix("0x")
106                .or_else(|| text.strip_prefix("0X"))
107                .ok_or_else(|| invalid_literal(span.clone(), text.clone()))?;
108            u64::from_str_radix(stripped, 16)
109                .map_err(|_| invalid_literal(span.clone(), text.clone()))?
110        }
111        PtxToken::BinaryInteger(text) => {
112            let stripped = text
113                .strip_prefix("0b")
114                .or_else(|| text.strip_prefix("0B"))
115                .ok_or_else(|| invalid_literal(span.clone(), text.clone()))?;
116            u64::from_str_radix(stripped, 2)
117                .map_err(|_| invalid_literal(span.clone(), text.clone()))?
118        }
119        PtxToken::OctalInteger(text) => {
120            let stripped = &text[1..];
121            u64::from_str_radix(stripped, 8)
122                .map_err(|_| invalid_literal(span.clone(), text.clone()))?
123        }
124        _ => {
125            return Err(unexpected_value(
126                span,
127                &["unsigned integer literal"],
128                format!("{token:?}"),
129            ));
130        }
131    };
132
133    Ok((value, span))
134}
135
136impl PtxParser for CodeLinkage {
137    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
138        let (directive, span) = stream.expect_directive()?;
139        match directive.as_str() {
140            "visible" => Ok(CodeLinkage::Visible),
141            "extern" => Ok(CodeLinkage::Extern),
142            "weak" => Ok(CodeLinkage::Weak),
143            other => Err(unexpected_value(
144                span,
145                &[".visible", ".extern", ".weak"],
146                format!(".{other}"),
147            )),
148        }
149    }
150}
151
152impl PtxParser for DataLinkage {
153    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
154        let (directive, span) = stream.expect_directive()?;
155        match directive.as_str() {
156            "visible" => Ok(DataLinkage::Visible),
157            "extern" => Ok(DataLinkage::Extern),
158            "weak" => Ok(DataLinkage::Weak),
159            "common" => Ok(DataLinkage::Common),
160            other => Err(unexpected_value(
161                span,
162                &[".visible", ".extern", ".weak", ".common"],
163                format!(".{other}"),
164            )),
165        }
166    }
167}
168
169impl PtxParser for CodeOrDataLinkage {
170    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
171        let (directive, span) = stream.expect_directive()?;
172        match directive.as_str() {
173            "visible" => Ok(CodeOrDataLinkage::Visible),
174            "extern" => Ok(CodeOrDataLinkage::Extern),
175            "weak" => Ok(CodeOrDataLinkage::Weak),
176            "common" => Ok(CodeOrDataLinkage::Common),
177            other => Err(unexpected_value(
178                span,
179                &[".visible", ".extern", ".weak", ".common"],
180                format!(".{other}"),
181            )),
182        }
183    }
184}
185
186impl PtxParser for TexType {
187    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
188        let (directive, span) = stream.expect_directive()?;
189        match directive.as_str() {
190            "texref" => Ok(TexType::TexRef),
191            "samplerref" => Ok(TexType::SamplerRef),
192            "surfref" => Ok(TexType::SurfRef),
193            other => Err(unexpected_value(
194                span,
195                &[".texref", ".samplerref", ".surfref"],
196                format!(".{other}"),
197            )),
198        }
199    }
200}
201
202impl PtxParser for AddressSpace {
203    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
204        let (directive, span) = stream.expect_directive()?;
205        match directive.as_str() {
206            "global" => Ok(AddressSpace::Global),
207            "const" => Ok(AddressSpace::Const),
208            "shared" => Ok(AddressSpace::Shared),
209            "local" => Ok(AddressSpace::Local),
210            "param" => Ok(AddressSpace::Param),
211            "reg" => Ok(AddressSpace::Reg),
212            other => Err(unexpected_value(
213                span,
214                &[".global", ".const", ".shared", ".local", ".param", ".reg"],
215                format!(".{other}"),
216            )),
217        }
218    }
219}
220
221impl PtxParser for AttributeDirective {
222    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
223        let (directive, span) = stream.expect_directive()?;
224        match directive.as_str() {
225            "unified" => {
226                stream.expect(&PtxToken::LParen)?;
227                let (uuid1, _) = parse_u64_literal(stream)?;
228                stream.expect(&PtxToken::Comma)?;
229                let (uuid2, _) = parse_u64_literal(stream)?;
230                stream.expect(&PtxToken::RParen)?;
231                Ok(AttributeDirective::Unified(uuid1, uuid2))
232            }
233            "managed" => Ok(AttributeDirective::Managed),
234            other => Err(unexpected_value(
235                span,
236                &[".unified", ".managed"],
237                format!(".{other}"),
238            )),
239        }
240    }
241}
242
243impl PtxParser for DataType {
244    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
245        let (directive, span) = stream.expect_directive()?;
246        match directive.as_str() {
247            "u8" => Ok(DataType::U8),
248            "u16" => Ok(DataType::U16),
249            "u32" => Ok(DataType::U32),
250            "u64" => Ok(DataType::U64),
251            "s8" => Ok(DataType::S8),
252            "s16" => Ok(DataType::S16),
253            "s32" => Ok(DataType::S32),
254            "s64" => Ok(DataType::S64),
255            "f16" => Ok(DataType::F16),
256            "f16x2" => Ok(DataType::F16x2),
257            "f32" => Ok(DataType::F32),
258            "f64" => Ok(DataType::F64),
259            "b8" => Ok(DataType::B8),
260            "b16" => Ok(DataType::B16),
261            "b32" => Ok(DataType::B32),
262            "b64" => Ok(DataType::B64),
263            "b128" => Ok(DataType::B128),
264            "pred" => Ok(DataType::Pred),
265            other => Err(unexpected_value(
266                span,
267                &[
268                    ".u8", ".u16", ".u32", ".u64", ".s8", ".s16", ".s32", ".s64", ".f16", ".f16x2",
269                    ".f32", ".f64", ".b8", ".b16", ".b32", ".b64", ".b128", ".pred",
270                ],
271                format!(".{other}"),
272            )),
273        }
274    }
275}
276
277impl PtxParser for Sign {
278    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
279        if stream
280            .consume_if(|token| matches!(token, PtxToken::Plus))
281            .is_some()
282        {
283            return Ok(Sign::Positive);
284        }
285        if stream
286            .consume_if(|token| matches!(token, PtxToken::Minus))
287            .is_some()
288        {
289            return Ok(Sign::Negative);
290        }
291
292        let (token, span) = stream.peek()?;
293        Err(unexpected_value(
294            span.clone(),
295            &["+", "-"],
296            format!("{token:?}"),
297        ))
298    }
299}
300
301impl PtxParser for Immediate {
302    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
303        // Check for optional minus sign
304        let has_minus = stream
305            .consume_if(|token| matches!(token, PtxToken::Minus))
306            .is_some();
307
308        let (token, span) = stream.peek()?;
309        let value = numeric_literal(token).cloned();
310        match value {
311            Some(value) => {
312                let literal = if has_minus {
313                    format!("-{}", value)
314                } else {
315                    value.clone()
316                };
317                stream.consume()?;
318                Ok(Immediate(literal))
319            }
320            None => {
321                // If we consumed a minus, we need to restore position by going back one token
322                if has_minus {
323                    let mut current_pos = stream.position();
324                    if current_pos.index > 0 {
325                        current_pos.index -= 1;
326                        current_pos.char_offset = 0;
327                        stream.set_position(current_pos);
328                    }
329                }
330                Err(unexpected_value(
331                    span.clone(),
332                    &["numeric literal"],
333                    format!("{token:?}"),
334                ))
335            }
336        }
337    }
338}
339
340impl PtxParser for RegisterOperand {
341    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
342        if !stream.check(|token| matches!(token, PtxToken::Register(_))) {
343            let (token, span) = stream.peek()?;
344            return Err(unexpected_value(
345                span.clone(),
346                &["register"],
347                format!("{token:?}"),
348            ));
349        }
350        let (name, _) = parse_register_name(stream)?;
351        Ok(RegisterOperand(name))
352    }
353}
354
355impl PtxParser for PredicateRegister {
356    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
357        let (name, span) = parse_register_name(stream)?;
358        if name.starts_with("%p") {
359            Ok(PredicateRegister(name))
360        } else {
361            Err(invalid_literal(
362                span,
363                format!("expected predicate register starting with %p, found {name}"),
364            ))
365        }
366    }
367}
368
369impl PtxParser for Label {
370    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
371        let (name, _) = stream.expect_identifier()?;
372        Ok(Label(name))
373    }
374}
375
376impl PtxParser for SpecialRegister {
377    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
378        let (name, span) = parse_register_name(stream)?;
379        // Preserve component information (.x/.y/.z) for certain special registers.
380        // If a component is present, return the axis-aware variant; otherwise fall through
381        // to the general match below.
382        let name_str = name.as_str();
383        if let Some(rest) = name_str.strip_prefix("%cluster_ctaid") {
384            if rest.is_empty() {
385                return Ok(SpecialRegister::ClusterCtaid(Axis::None));
386            } else if rest == ".x" {
387                return Ok(SpecialRegister::ClusterCtaid(Axis::X));
388            } else if rest == ".y" {
389                return Ok(SpecialRegister::ClusterCtaid(Axis::Y));
390            } else if rest == ".z" {
391                return Ok(SpecialRegister::ClusterCtaid(Axis::Z));
392            }
393        }
394        if let Some(rest) = name_str.strip_prefix("%cluster_ctarank") {
395            if rest.is_empty() {
396                return Ok(SpecialRegister::ClusterCtarank(Axis::None));
397            } else if rest == ".x" {
398                return Ok(SpecialRegister::ClusterCtarank(Axis::X));
399            } else if rest == ".y" {
400                return Ok(SpecialRegister::ClusterCtarank(Axis::Y));
401            } else if rest == ".z" {
402                return Ok(SpecialRegister::ClusterCtarank(Axis::Z));
403            }
404        }
405        if let Some(rest) = name_str.strip_prefix("%nctaid") {
406            if rest.is_empty() {
407                return Ok(SpecialRegister::Nctaid(Axis::None));
408            } else if rest == ".x" {
409                return Ok(SpecialRegister::Nctaid(Axis::X));
410            } else if rest == ".y" {
411                return Ok(SpecialRegister::Nctaid(Axis::Y));
412            } else if rest == ".z" {
413                return Ok(SpecialRegister::Nctaid(Axis::Z));
414            }
415        }
416        if let Some(rest) = name_str.strip_prefix("%tid") {
417            if rest.is_empty() {
418                return Ok(SpecialRegister::Tid(Axis::None));
419            } else if rest == ".x" {
420                return Ok(SpecialRegister::Tid(Axis::X));
421            } else if rest == ".y" {
422                return Ok(SpecialRegister::Tid(Axis::Y));
423            } else if rest == ".z" {
424                return Ok(SpecialRegister::Tid(Axis::Z));
425            }
426        }
427        if let Some(rest) = name_str.strip_prefix("%cluster_nctaid") {
428            if rest.is_empty() {
429                return Ok(SpecialRegister::ClusterNctaid(Axis::None));
430            } else if rest == ".x" {
431                return Ok(SpecialRegister::ClusterNctaid(Axis::X));
432            } else if rest == ".y" {
433                return Ok(SpecialRegister::ClusterNctaid(Axis::Y));
434            } else if rest == ".z" {
435                return Ok(SpecialRegister::ClusterNctaid(Axis::Z));
436            }
437        }
438        if let Some(rest) = name_str.strip_prefix("%cluster_nctarank") {
439            if rest.is_empty() {
440                return Ok(SpecialRegister::ClusterNctarank(Axis::None));
441            } else if rest == ".x" {
442                return Ok(SpecialRegister::ClusterNctarank(Axis::X));
443            } else if rest == ".y" {
444                return Ok(SpecialRegister::ClusterNctarank(Axis::Y));
445            } else if rest == ".z" {
446                return Ok(SpecialRegister::ClusterNctarank(Axis::Z));
447            }
448        }
449        if let Some(rest) = name_str.strip_prefix("%ntid") {
450            if rest.is_empty() {
451                return Ok(SpecialRegister::Ntid(Axis::None));
452            } else if rest == ".x" {
453                return Ok(SpecialRegister::Ntid(Axis::X));
454            } else if rest == ".y" {
455                return Ok(SpecialRegister::Ntid(Axis::Y));
456            } else if rest == ".z" {
457                return Ok(SpecialRegister::Ntid(Axis::Z));
458            }
459        }
460        if let Some(rest) = name_str.strip_prefix("%ctaid") {
461            if rest.is_empty() {
462                return Ok(SpecialRegister::Ctaid(Axis::None));
463            } else if rest == ".x" {
464                return Ok(SpecialRegister::Ctaid(Axis::X));
465            } else if rest == ".y" {
466                return Ok(SpecialRegister::Ctaid(Axis::Y));
467            } else if rest == ".z" {
468                return Ok(SpecialRegister::Ctaid(Axis::Z));
469            }
470        }
471
472        match name.as_str() {
473            "%aggr_smem_size" => Ok(SpecialRegister::AggrSmemSize),
474            "%dynamic_smem_size" => Ok(SpecialRegister::DynamicSmemSize),
475            "%lanemask_gt" => Ok(SpecialRegister::LanemaskGt),
476            "%reserved_smem_offset_begin" => Ok(SpecialRegister::ReservedSmemOffsetBegin),
477            "%clock" => Ok(SpecialRegister::Clock),
478            "%lanemask_le" => Ok(SpecialRegister::LanemaskLe),
479            "%reserved_smem_offset_cap" => Ok(SpecialRegister::ReservedSmemOffsetCap),
480            "%clock64" => Ok(SpecialRegister::Clock64),
481            "%globaltimer" => Ok(SpecialRegister::Globaltimer),
482            "%lanemask_lt" => Ok(SpecialRegister::LanemaskLt),
483            "%reserved_smem_offset_end" => Ok(SpecialRegister::ReservedSmemOffsetEnd),
484            "%cluster_ctaid" | "%cluster_ctaid.x" | "%cluster_ctaid.y" | "%cluster_ctaid.z" => {
485                Ok(SpecialRegister::ClusterCtaid(Axis::None))
486            }
487            "%globaltimer_hi" => Ok(SpecialRegister::GlobaltimerHi),
488            "%nclusterid" => Ok(SpecialRegister::Nclusterid),
489            "%smid" => Ok(SpecialRegister::Smid),
490            "%cluster_ctarank" | "%cluster_ctarank.x" | "%cluster_ctarank.y"
491            | "%cluster_ctarank.z" => Ok(SpecialRegister::ClusterCtarank(Axis::None)),
492            "%globaltimer_lo" => Ok(SpecialRegister::GlobaltimerLo),
493            "%nctaid" | "%nctaid.x" | "%nctaid.y" | "%nctaid.z" => {
494                Ok(SpecialRegister::Nctaid(Axis::None))
495            }
496            "%tid" | "%tid.x" | "%tid.y" | "%tid.z" => Ok(SpecialRegister::Tid(Axis::None)),
497            "%cluster_nctaid" | "%cluster_nctaid.x" | "%cluster_nctaid.y" | "%cluster_nctaid.z" => {
498                Ok(SpecialRegister::ClusterNctaid(Axis::None))
499            }
500            "%gridid" => Ok(SpecialRegister::Gridid),
501            "%nsmid" => Ok(SpecialRegister::Nsmid),
502            "%total_smem_size" => Ok(SpecialRegister::TotalSmemSize),
503            "%cluster_nctarank"
504            | "%cluster_nctarank.x"
505            | "%cluster_nctarank.y"
506            | "%cluster_nctarank.z" => Ok(SpecialRegister::ClusterNctarank(Axis::None)),
507            "%is_explicit_cluster" => Ok(SpecialRegister::IsExplicitCluster),
508            "%ntid" | "%ntid.x" | "%ntid.y" | "%ntid.z" => Ok(SpecialRegister::Ntid(Axis::None)),
509            "%warpid" => Ok(SpecialRegister::Warpid),
510            "%clusterid" => Ok(SpecialRegister::Clusterid),
511            "%laneid" => Ok(SpecialRegister::Laneid),
512            "%nwarpid" => Ok(SpecialRegister::Nwarpid),
513            "%WARPSZ" => Ok(SpecialRegister::WARPSZ),
514            "%ctaid" | "%ctaid.x" | "%ctaid.y" | "%ctaid.z" => {
515                Ok(SpecialRegister::Ctaid(Axis::None))
516            }
517            "%lanemask_eq" => Ok(SpecialRegister::LanemaskEq),
518            "%current_graph_exec" => Ok(SpecialRegister::CurrentGraphExec),
519            "%lanemask_ge" => Ok(SpecialRegister::LanemaskGe),
520            other => {
521                if let Some(num) = other.strip_prefix("%envreg") {
522                    let value = num
523                        .parse::<u8>()
524                        .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
525                    if value <= 31 {
526                        return Ok(SpecialRegister::Envreg(value));
527                    }
528                    return Err(invalid_literal(
529                        span,
530                        format!("envreg index out of range: {value}"),
531                    ));
532                }
533
534                if let Some(num) = other.strip_prefix("%pm") {
535                    if let Some(rest) = num.strip_suffix("_64") {
536                        let value = rest
537                            .parse::<u8>()
538                            .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
539                        if value <= 7 {
540                            return Ok(SpecialRegister::Pm64(value));
541                        }
542                        return Err(invalid_literal(
543                            span,
544                            format!("pm index out of range: {value}"),
545                        ));
546                    }
547
548                    let value = num
549                        .parse::<u8>()
550                        .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
551                    if value <= 7 {
552                        return Ok(SpecialRegister::Pm(value));
553                    }
554                    return Err(invalid_literal(
555                        span,
556                        format!("pm index out of range: {value}"),
557                    ));
558                }
559
560                if let Some(num) = other.strip_prefix("%reserved_smem_offset_") {
561                    let value = num
562                        .parse::<u8>()
563                        .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
564                    if value <= 1 {
565                        return Ok(SpecialRegister::ReservedSmemOffset(value));
566                    }
567                    return Err(invalid_literal(
568                        span,
569                        format!("reserved_smem_offset index out of range: {value}"),
570                    ));
571                }
572
573                Err(invalid_literal(
574                    span,
575                    format!("unknown special register {name}"),
576                ))
577            }
578        }
579    }
580}
581
582impl PtxParser for Operand {
583    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
584        let saved_pos = stream.position();
585        if let Ok(immediate) = Immediate::parse(stream) {
586            return Ok(Operand::Immediate(immediate));
587        }
588        stream.set_position(saved_pos);
589
590        if stream.check(|token| matches!(token, PtxToken::Register(_))) {
591            return Ok(Operand::Register(RegisterOperand::parse(stream)?));
592        }
593
594        if stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
595            let (identifier, _) = stream.expect_identifier()?;
596
597            // Check for arithmetic expression: identifier + immediate
598            let saved_pos_after_ident = stream.position();
599            if stream.expect(&PtxToken::Plus).is_ok() {
600                if let Ok(offset) = Immediate::parse(stream) {
601                    return Ok(Operand::SymbolOffset(identifier, offset));
602                }
603                // If parsing offset failed, backtrack
604                stream.set_position(saved_pos_after_ident);
605            }
606
607            return Ok(Operand::Symbol(identifier));
608        }
609
610        let (token, span) = stream.peek()?;
611        Err(unexpected_value(
612            span.clone(),
613            &["operand"],
614            format!("{token:?}"),
615        ))
616    }
617}
618
619impl PtxParser for VectorOperand {
620    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
621        let (_, brace_span) = stream.expect(&PtxToken::LBrace)?;
622        let mut operands = Vec::new();
623
624        loop {
625            operands.push(Operand::parse(stream)?);
626            if stream
627                .consume_if(|token| matches!(token, PtxToken::Comma))
628                .is_some()
629            {
630                continue;
631            }
632            break;
633        }
634
635        stream.expect(&PtxToken::RBrace)?;
636
637        match operands.len() {
638            1 => Ok(VectorOperand::Vector1(operands.remove(0))),
639            2 => Ok(VectorOperand::Vector2([
640                operands.remove(0),
641                operands.remove(0),
642            ])),
643            3 => Ok(VectorOperand::Vector3([
644                operands.remove(0),
645                operands.remove(0),
646                operands.remove(0),
647            ])),
648            4 => Ok(VectorOperand::Vector4([
649                operands.remove(0),
650                operands.remove(0),
651                operands.remove(0),
652                operands.remove(0),
653            ])),
654            8 => Ok(VectorOperand::Vector8([
655                operands.remove(0),
656                operands.remove(0),
657                operands.remove(0),
658                operands.remove(0),
659                operands.remove(0),
660                operands.remove(0),
661                operands.remove(0),
662                operands.remove(0),
663            ])),
664            other => Err(invalid_literal(
665                brace_span.clone(),
666                format!("expected operand vector of length 1..=4 or 8, found {other}"),
667            )),
668        }
669    }
670}
671
672impl PtxParser for GeneralOperand {
673    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
674        if stream.check(|token| matches!(token, PtxToken::LBrace)) {
675            Ok(GeneralOperand::Vec(VectorOperand::parse(stream)?))
676        } else {
677            Ok(GeneralOperand::Single(Operand::parse(stream)?))
678        }
679    }
680}
681
682impl PtxParser for TexHandler2 {
683    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
684        stream.expect(&PtxToken::LBracket)?;
685        let first = GeneralOperand::parse(stream)?;
686        stream.expect(&PtxToken::Comma)?;
687        let second = GeneralOperand::parse(stream)?;
688        stream.expect(&PtxToken::RBracket)?;
689        Ok(TexHandler2([first, second]))
690    }
691}
692
693impl PtxParser for TexHandler3 {
694    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
695        stream.expect(&PtxToken::LBracket)?;
696        let handle = GeneralOperand::parse(stream)?;
697        stream.expect(&PtxToken::Comma)?;
698        let sampler = GeneralOperand::parse(stream)?;
699        stream.expect(&PtxToken::Comma)?;
700        let coords = GeneralOperand::parse(stream)?;
701        stream.expect(&PtxToken::RBracket)?;
702
703        Ok(TexHandler3 {
704            handle,
705            sampler,
706            coords,
707        })
708    }
709}
710
711impl PtxParser for TexHandler3Optional {
712    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
713        stream.expect(&PtxToken::LBracket)?;
714        let handle = GeneralOperand::parse(stream)?;
715        stream.expect(&PtxToken::Comma)?;
716        let second = GeneralOperand::parse(stream)?;
717
718        let (sampler, coords) = if stream
719            .consume_if(|token| matches!(token, PtxToken::Comma))
720            .is_some()
721        {
722            let coords = GeneralOperand::parse(stream)?;
723            (Some(second), coords)
724        } else {
725            (None, second)
726        };
727
728        stream.expect(&PtxToken::RBracket)?;
729
730        Ok(TexHandler3Optional {
731            handle,
732            sampler,
733            coords,
734        })
735    }
736}
737
738impl PtxParser for AddressBase {
739    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
740        if stream.check(|token| matches!(token, PtxToken::Register(_))) {
741            Ok(AddressBase::Register(RegisterOperand::parse(stream)?))
742        } else if stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
743            Ok(AddressBase::Variable(VariableSymbol::parse(stream)?))
744        } else {
745            let (token, span) = stream.peek()?;
746            Err(unexpected_value(
747                span.clone(),
748                &["register", "identifier"],
749                format!("{token:?}"),
750            ))
751        }
752    }
753}
754
755impl PtxParser for AddressOffset {
756    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
757        if stream
758            .consume_if(|token| matches!(token, PtxToken::Plus))
759            .is_some()
760        {
761            if stream.check(|token| matches!(token, PtxToken::Register(_))) {
762                Ok(AddressOffset::Register(RegisterOperand::parse(stream)?))
763            } else {
764                Ok(AddressOffset::Immediate(
765                    Sign::Positive,
766                    Immediate::parse(stream)?,
767                ))
768            }
769        } else if stream
770            .consume_if(|token| matches!(token, PtxToken::Minus))
771            .is_some()
772        {
773            Ok(AddressOffset::Immediate(
774                Sign::Negative,
775                Immediate::parse(stream)?,
776            ))
777        } else {
778            let (token, span) = stream.peek()?;
779            Err(unexpected_value(
780                span.clone(),
781                &["+", "-"],
782                format!("{token:?}"),
783            ))
784        }
785    }
786}
787
788impl PtxParser for AddressOperand {
789    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
790        if stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
791            let saved = stream.position();
792            let (identifier, _) = stream.expect_identifier()?;
793            if stream
794                .consume_if(|token| matches!(token, PtxToken::LBracket))
795                .is_some()
796            {
797                let immediate = Immediate::parse(stream)?;
798                stream.expect(&PtxToken::RBracket)?;
799                return Ok(AddressOperand::Array(VariableSymbol(identifier), immediate));
800            } else {
801                stream.set_position(saved);
802            }
803        }
804
805        stream.expect(&PtxToken::LBracket)?;
806
807        if stream.check(|token| matches!(token, PtxToken::Minus)) {
808            let pos = stream.position();
809            stream.consume()?;
810            if stream.check(|token| is_numeric_token(token)) {
811                let mut immediate = Immediate::parse(stream)?;
812                immediate.0.insert(0, '-');
813                stream.expect(&PtxToken::RBracket)?;
814                return Ok(AddressOperand::ImmediateAddress(immediate));
815            } else {
816                stream.set_position(pos);
817            }
818        }
819
820        if stream.check(|token| is_numeric_token(token)) {
821            let immediate = Immediate::parse(stream)?;
822            stream.expect(&PtxToken::RBracket)?;
823            return Ok(AddressOperand::ImmediateAddress(immediate));
824        }
825
826        let base = AddressBase::parse(stream)?;
827        let offset = if stream.check(|token| matches!(token, PtxToken::Plus | PtxToken::Minus)) {
828            Some(AddressOffset::parse(stream)?)
829        } else {
830            None
831        };
832        stream.expect(&PtxToken::RBracket)?;
833
834        Ok(AddressOperand::Offset(base, offset))
835    }
836}
837
838impl PtxParser for FunctionSymbol {
839    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
840        let (name, _) = stream.expect_identifier()?;
841        Ok(FunctionSymbol(name))
842    }
843}
844
845impl PtxParser for VariableSymbol {
846    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
847        let (name, _) = stream.expect_identifier()?;
848        Ok(VariableSymbol(name))
849    }
850}
851
852/// Try to parse an optional label (identifier followed by colon).
853/// Returns `Ok(Some(label))` if a label is found, `Ok(None)` if not,
854/// or `Err` if parsing fails.
855pub(crate) fn try_parse_label(
856    stream: &mut PtxTokenStream,
857) -> Result<Option<String>, PtxParseError> {
858    if !stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
859        return Ok(None);
860    }
861
862    let position = stream.position();
863    let (name, _) = stream.expect_identifier()?;
864    if stream
865        .consume_if(|token| matches!(token, PtxToken::Colon))
866        .is_some()
867    {
868        Ok(Some(name))
869    } else {
870        stream.set_position(position);
871        Ok(None)
872    }
873}
874
875impl PtxParser for Instruction {
876    /// Parse a PTX instruction with optional label and predicate
877    ///
878    /// Format: [label:] [@{!}pred] instruction
879    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
880        // Optional label (ends with colon)
881        let label = try_parse_label(stream)?;
882
883        // Optional predicate: @{!}pred or @!pred
884        let predicate = if stream.check(|t| matches!(t, PtxToken::At)) {
885            stream.consume()?; // consume @
886
887            // Optional negation
888            let negated = stream
889                .consume_if(|t| matches!(t, PtxToken::Exclaim))
890                .is_some();
891
892            // Predicate operand (can be register %p1 or identifier p)
893            let operand = Operand::parse(stream)?;
894
895            Some(Predicate { negated, operand })
896        } else {
897            None
898        };
899
900        // Parse the actual instruction using the module-level dispatcher
901        let inst = crate::parser::instruction::parse_instruction_inner(stream)?;
902
903        Ok(Instruction {
904            label,
905            predicate,
906            inst,
907        })
908    }
909}
910
911// Backwards compatibility: Inst can still be parsed directly
912impl PtxParser for Inst {
913    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
914        Ok(Instruction::parse(stream)?.inst)
915    }
916}