ptx_parser/parser/
common.rs

1use std::borrow::Cow;
2
3use crate::{
4    lexer::PtxToken,
5    parser::{ParseErrorKind, PtxParseError, PtxParser, PtxTokenStream, Span},
6    r#type::common::*,
7    r#type::instruction::Inst,
8};
9
10pub(crate) fn unexpected_value(
11    span: Span,
12    expected: &[&str],
13    found: impl Into<Cow<'static, str>>,
14) -> PtxParseError {
15    PtxParseError {
16        kind: ParseErrorKind::UnexpectedToken {
17            expected: expected.iter().map(|s| s.to_string()).collect(),
18            found: found.into().to_string(),
19        },
20        span,
21    }
22}
23
24pub(crate) fn invalid_literal(span: Span, literal: impl Into<Cow<'static, str>>) -> PtxParseError {
25    PtxParseError {
26        kind: ParseErrorKind::InvalidLiteral(literal.into().to_string()),
27        span,
28    }
29}
30
31pub(crate) fn parse_register_name(
32    stream: &mut PtxTokenStream,
33) -> Result<(String, Span), PtxParseError> {
34    let (mut name, mut span) = stream.expect_register()?;
35
36    loop {
37        // Peek to decide whether the next token should be treated as a component.
38        let next = match stream.peek() {
39            Ok((token, _)) => token,
40            Err(_) => break,
41        };
42
43        match next {
44            PtxToken::Dot => {
45                // Peek ahead to see if this is a valid register component
46                if let Some((PtxToken::Identifier(component_name), _)) =
47                    stream.tokens.get(stream.index + 1)
48                {
49                    // Only consume if it's a valid single-character register component
50                    // Exclude multi-character .b* patterns (e.g., .b0, .b3210) which are instruction-specific modifiers
51                    if matches!(component_name.as_str(), "x" | "y" | "z" | "w" | "r" | "g" | "b" | "a") {
52                        // consume the dot and identifier
53                        stream.consume()?;
54                        let (component, component_span) = stream.expect_identifier()?;
55
56                        name.push('.');
57                        name.push_str(&component);
58
59                        span.end = component_span.end;
60                    } else {
61                        // Not a valid register component, stop parsing
62                        break;
63                    }
64                } else {
65                    break;
66                }
67            }
68            _ => break,
69        }
70    }
71
72    Ok((name, span))
73}
74
75pub(crate) fn numeric_literal(token: &PtxToken) -> Option<&String> {
76    match token {
77        PtxToken::DecimalInteger(value)
78        | PtxToken::HexInteger(value)
79        | PtxToken::BinaryInteger(value)
80        | PtxToken::OctalInteger(value)
81        | PtxToken::FloatExponent(value)
82        | PtxToken::Float(value)
83        | PtxToken::HexFloat(value) => Some(value),
84        _ => None,
85    }
86}
87
88pub(crate) fn is_numeric_token(token: &PtxToken) -> bool {
89    numeric_literal(token).is_some()
90}
91
92pub(crate) fn parse_u64_literal(stream: &mut PtxTokenStream) -> Result<(u64, Span), PtxParseError> {
93    let (token, span) = stream.consume()?;
94    let span = span.clone();
95
96    let value = match token {
97        PtxToken::DecimalInteger(text) => text
98            .parse::<u64>()
99            .map_err(|_| invalid_literal(span.clone(), text.clone()))?,
100        PtxToken::HexInteger(text) => {
101            let stripped = text
102                .strip_prefix("0x")
103                .or_else(|| text.strip_prefix("0X"))
104                .ok_or_else(|| invalid_literal(span.clone(), text.clone()))?;
105            u64::from_str_radix(stripped, 16)
106                .map_err(|_| invalid_literal(span.clone(), text.clone()))?
107        }
108        PtxToken::BinaryInteger(text) => {
109            let stripped = text
110                .strip_prefix("0b")
111                .or_else(|| text.strip_prefix("0B"))
112                .ok_or_else(|| invalid_literal(span.clone(), text.clone()))?;
113            u64::from_str_radix(stripped, 2)
114                .map_err(|_| invalid_literal(span.clone(), text.clone()))?
115        }
116        PtxToken::OctalInteger(text) => {
117            let stripped = &text[1..];
118            u64::from_str_radix(stripped, 8)
119                .map_err(|_| invalid_literal(span.clone(), text.clone()))?
120        }
121        _ => {
122            return Err(unexpected_value(
123                span,
124                &["unsigned integer literal"],
125                format!("{token:?}"),
126            ));
127        }
128    };
129
130    Ok((value, span))
131}
132
133impl PtxParser for CodeLinkage {
134    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
135        let (directive, span) = stream.expect_directive()?;
136        match directive.as_str() {
137            "visible" => Ok(CodeLinkage::Visible),
138            "extern" => Ok(CodeLinkage::Extern),
139            "weak" => Ok(CodeLinkage::Weak),
140            other => Err(unexpected_value(
141                span,
142                &[".visible", ".extern", ".weak"],
143                format!(".{other}"),
144            )),
145        }
146    }
147}
148
149impl PtxParser for DataLinkage {
150    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
151        let (directive, span) = stream.expect_directive()?;
152        match directive.as_str() {
153            "visible" => Ok(DataLinkage::Visible),
154            "extern" => Ok(DataLinkage::Extern),
155            "weak" => Ok(DataLinkage::Weak),
156            "common" => Ok(DataLinkage::Common),
157            other => Err(unexpected_value(
158                span,
159                &[".visible", ".extern", ".weak", ".common"],
160                format!(".{other}"),
161            )),
162        }
163    }
164}
165
166impl PtxParser for CodeOrDataLinkage {
167    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
168        let (directive, span) = stream.expect_directive()?;
169        match directive.as_str() {
170            "visible" => Ok(CodeOrDataLinkage::Visible),
171            "extern" => Ok(CodeOrDataLinkage::Extern),
172            "weak" => Ok(CodeOrDataLinkage::Weak),
173            "common" => Ok(CodeOrDataLinkage::Common),
174            other => Err(unexpected_value(
175                span,
176                &[".visible", ".extern", ".weak", ".common"],
177                format!(".{other}"),
178            )),
179        }
180    }
181}
182
183impl PtxParser for TexType {
184    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
185        let (directive, span) = stream.expect_directive()?;
186        match directive.as_str() {
187            "texref" => Ok(TexType::TexRef),
188            "samplerref" => Ok(TexType::SamplerRef),
189            "surfref" => Ok(TexType::SurfRef),
190            other => Err(unexpected_value(
191                span,
192                &[".texref", ".samplerref", ".surfref"],
193                format!(".{other}"),
194            )),
195        }
196    }
197}
198
199impl PtxParser for AddressSpace {
200    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
201        let (directive, span) = stream.expect_directive()?;
202        match directive.as_str() {
203            "global" => Ok(AddressSpace::Global),
204            "const" => Ok(AddressSpace::Const),
205            "shared" => Ok(AddressSpace::Shared),
206            "local" => Ok(AddressSpace::Local),
207            "param" => Ok(AddressSpace::Param),
208            "reg" => Ok(AddressSpace::Reg),
209            other => Err(unexpected_value(
210                span,
211                &[".global", ".const", ".shared", ".local", ".param", ".reg"],
212                format!(".{other}"),
213            )),
214        }
215    }
216}
217
218impl PtxParser for AttributeDirective {
219    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
220        let (directive, span) = stream.expect_directive()?;
221        match directive.as_str() {
222            "unified" => {
223                stream.expect(&PtxToken::LParen)?;
224                let (uuid1, _) = parse_u64_literal(stream)?;
225                stream.expect(&PtxToken::Comma)?;
226                let (uuid2, _) = parse_u64_literal(stream)?;
227                stream.expect(&PtxToken::RParen)?;
228                Ok(AttributeDirective::Unified(uuid1, uuid2))
229            }
230            "managed" => Ok(AttributeDirective::Managed),
231            other => Err(unexpected_value(
232                span,
233                &[".unified", ".managed"],
234                format!(".{other}"),
235            )),
236        }
237    }
238}
239
240impl PtxParser for DataType {
241    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
242        let (directive, span) = stream.expect_directive()?;
243        match directive.as_str() {
244            "u8" => Ok(DataType::U8),
245            "u16" => Ok(DataType::U16),
246            "u32" => Ok(DataType::U32),
247            "u64" => Ok(DataType::U64),
248            "s8" => Ok(DataType::S8),
249            "s16" => Ok(DataType::S16),
250            "s32" => Ok(DataType::S32),
251            "s64" => Ok(DataType::S64),
252            "f16" => Ok(DataType::F16),
253            "f16x2" => Ok(DataType::F16x2),
254            "f32" => Ok(DataType::F32),
255            "f64" => Ok(DataType::F64),
256            "b8" => Ok(DataType::B8),
257            "b16" => Ok(DataType::B16),
258            "b32" => Ok(DataType::B32),
259            "b64" => Ok(DataType::B64),
260            "b128" => Ok(DataType::B128),
261            "pred" => Ok(DataType::Pred),
262            other => Err(unexpected_value(
263                span,
264                &[
265                    ".u8", ".u16", ".u32", ".u64", ".s8", ".s16", ".s32", ".s64", ".f16", ".f16x2",
266                    ".f32", ".f64", ".b8", ".b16", ".b32", ".b64", ".b128", ".pred",
267                ],
268                format!(".{other}"),
269            )),
270        }
271    }
272}
273
274impl PtxParser for Sign {
275    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
276        if stream
277            .consume_if(|token| matches!(token, PtxToken::Plus))
278            .is_some()
279        {
280            return Ok(Sign::Positive);
281        }
282        if stream
283            .consume_if(|token| matches!(token, PtxToken::Minus))
284            .is_some()
285        {
286            return Ok(Sign::Negative);
287        }
288
289        let (token, span) = stream.peek()?;
290        Err(unexpected_value(
291            span.clone(),
292            &["+", "-"],
293            format!("{token:?}"),
294        ))
295    }
296}
297
298impl PtxParser for Immediate {
299    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
300        // Check for optional minus sign
301        let has_minus = stream
302            .consume_if(|token| matches!(token, PtxToken::Minus))
303            .is_some();
304
305        let (token, span) = stream.peek()?;
306        let value = numeric_literal(token).cloned();
307        match value {
308            Some(value) => {
309                let literal = if has_minus {
310                    format!("-{}", value)
311                } else {
312                    value.clone()
313                };
314                stream.consume()?;
315                Ok(Immediate(literal))
316            }
317            None => {
318                // If we consumed a minus, we need to restore position by going back one token
319                if has_minus {
320                    let mut current_pos = stream.position();
321                    if current_pos.index > 0 {
322                        current_pos.index -= 1;
323                        current_pos.char_offset = 0;
324                        stream.set_position(current_pos);
325                    }
326                }
327                Err(unexpected_value(
328                    span.clone(),
329                    &["numeric literal"],
330                    format!("{token:?}"),
331                ))
332            }
333        }
334    }
335}
336
337impl PtxParser for RegisterOperand {
338    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
339        if !stream.check(|token| matches!(token, PtxToken::Register(_))) {
340            let (token, span) = stream.peek()?;
341            return Err(unexpected_value(
342                span.clone(),
343                &["register"],
344                format!("{token:?}"),
345            ));
346        }
347        let (name, _) = parse_register_name(stream)?;
348        Ok(RegisterOperand(name))
349    }
350}
351
352impl PtxParser for PredicateRegister {
353    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
354        let (name, span) = parse_register_name(stream)?;
355        if name.starts_with("%p") {
356            Ok(PredicateRegister(name))
357        } else {
358            Err(invalid_literal(
359                span,
360                format!("expected predicate register starting with %p, found {name}"),
361            ))
362        }
363    }
364}
365
366impl PtxParser for Label {
367    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
368        let (name, _) = stream.expect_identifier()?;
369        Ok(Label(name))
370    }
371}
372
373impl PtxParser for SpecialRegister {
374    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
375        let (name, span) = parse_register_name(stream)?;
376        // Preserve component information (.x/.y/.z) for certain special registers.
377        // If a component is present, return the axis-aware variant; otherwise fall through
378        // to the general match below.
379        let name_str = name.as_str();
380        if let Some(rest) = name_str.strip_prefix("%cluster_ctaid") {
381            if rest.is_empty() {
382                return Ok(SpecialRegister::ClusterCtaid(Axis::None));
383            } else if rest == ".x" {
384                return Ok(SpecialRegister::ClusterCtaid(Axis::X));
385            } else if rest == ".y" {
386                return Ok(SpecialRegister::ClusterCtaid(Axis::Y));
387            } else if rest == ".z" {
388                return Ok(SpecialRegister::ClusterCtaid(Axis::Z));
389            }
390        }
391        if let Some(rest) = name_str.strip_prefix("%cluster_ctarank") {
392            if rest.is_empty() {
393                return Ok(SpecialRegister::ClusterCtarank(Axis::None));
394            } else if rest == ".x" {
395                return Ok(SpecialRegister::ClusterCtarank(Axis::X));
396            } else if rest == ".y" {
397                return Ok(SpecialRegister::ClusterCtarank(Axis::Y));
398            } else if rest == ".z" {
399                return Ok(SpecialRegister::ClusterCtarank(Axis::Z));
400            }
401        }
402        if let Some(rest) = name_str.strip_prefix("%nctaid") {
403            if rest.is_empty() {
404                return Ok(SpecialRegister::Nctaid(Axis::None));
405            } else if rest == ".x" {
406                return Ok(SpecialRegister::Nctaid(Axis::X));
407            } else if rest == ".y" {
408                return Ok(SpecialRegister::Nctaid(Axis::Y));
409            } else if rest == ".z" {
410                return Ok(SpecialRegister::Nctaid(Axis::Z));
411            }
412        }
413        if let Some(rest) = name_str.strip_prefix("%tid") {
414            if rest.is_empty() {
415                return Ok(SpecialRegister::Tid(Axis::None));
416            } else if rest == ".x" {
417                return Ok(SpecialRegister::Tid(Axis::X));
418            } else if rest == ".y" {
419                return Ok(SpecialRegister::Tid(Axis::Y));
420            } else if rest == ".z" {
421                return Ok(SpecialRegister::Tid(Axis::Z));
422            }
423        }
424        if let Some(rest) = name_str.strip_prefix("%cluster_nctaid") {
425            if rest.is_empty() {
426                return Ok(SpecialRegister::ClusterNctaid(Axis::None));
427            } else if rest == ".x" {
428                return Ok(SpecialRegister::ClusterNctaid(Axis::X));
429            } else if rest == ".y" {
430                return Ok(SpecialRegister::ClusterNctaid(Axis::Y));
431            } else if rest == ".z" {
432                return Ok(SpecialRegister::ClusterNctaid(Axis::Z));
433            }
434        }
435        if let Some(rest) = name_str.strip_prefix("%cluster_nctarank") {
436            if rest.is_empty() {
437                return Ok(SpecialRegister::ClusterNctarank(Axis::None));
438            } else if rest == ".x" {
439                return Ok(SpecialRegister::ClusterNctarank(Axis::X));
440            } else if rest == ".y" {
441                return Ok(SpecialRegister::ClusterNctarank(Axis::Y));
442            } else if rest == ".z" {
443                return Ok(SpecialRegister::ClusterNctarank(Axis::Z));
444            }
445        }
446        if let Some(rest) = name_str.strip_prefix("%ntid") {
447            if rest.is_empty() {
448                return Ok(SpecialRegister::Ntid(Axis::None));
449            } else if rest == ".x" {
450                return Ok(SpecialRegister::Ntid(Axis::X));
451            } else if rest == ".y" {
452                return Ok(SpecialRegister::Ntid(Axis::Y));
453            } else if rest == ".z" {
454                return Ok(SpecialRegister::Ntid(Axis::Z));
455            }
456        }
457        if let Some(rest) = name_str.strip_prefix("%ctaid") {
458            if rest.is_empty() {
459                return Ok(SpecialRegister::Ctaid(Axis::None));
460            } else if rest == ".x" {
461                return Ok(SpecialRegister::Ctaid(Axis::X));
462            } else if rest == ".y" {
463                return Ok(SpecialRegister::Ctaid(Axis::Y));
464            } else if rest == ".z" {
465                return Ok(SpecialRegister::Ctaid(Axis::Z));
466            }
467        }
468
469        match name.as_str() {
470            "%aggr_smem_size" => Ok(SpecialRegister::AggrSmemSize),
471            "%dynamic_smem_size" => Ok(SpecialRegister::DynamicSmemSize),
472            "%lanemask_gt" => Ok(SpecialRegister::LanemaskGt),
473            "%reserved_smem_offset_begin" => Ok(SpecialRegister::ReservedSmemOffsetBegin),
474            "%clock" => Ok(SpecialRegister::Clock),
475            "%lanemask_le" => Ok(SpecialRegister::LanemaskLe),
476            "%reserved_smem_offset_cap" => Ok(SpecialRegister::ReservedSmemOffsetCap),
477            "%clock64" => Ok(SpecialRegister::Clock64),
478            "%globaltimer" => Ok(SpecialRegister::Globaltimer),
479            "%lanemask_lt" => Ok(SpecialRegister::LanemaskLt),
480            "%reserved_smem_offset_end" => Ok(SpecialRegister::ReservedSmemOffsetEnd),
481            "%cluster_ctaid" | "%cluster_ctaid.x" | "%cluster_ctaid.y" | "%cluster_ctaid.z" => {
482                Ok(SpecialRegister::ClusterCtaid(Axis::None))
483            }
484            "%globaltimer_hi" => Ok(SpecialRegister::GlobaltimerHi),
485            "%nclusterid" => Ok(SpecialRegister::Nclusterid),
486            "%smid" => Ok(SpecialRegister::Smid),
487            "%cluster_ctarank" | "%cluster_ctarank.x" | "%cluster_ctarank.y"
488            | "%cluster_ctarank.z" => Ok(SpecialRegister::ClusterCtarank(Axis::None)),
489            "%globaltimer_lo" => Ok(SpecialRegister::GlobaltimerLo),
490            "%nctaid" | "%nctaid.x" | "%nctaid.y" | "%nctaid.z" => {
491                Ok(SpecialRegister::Nctaid(Axis::None))
492            }
493            "%tid" | "%tid.x" | "%tid.y" | "%tid.z" => Ok(SpecialRegister::Tid(Axis::None)),
494            "%cluster_nctaid" | "%cluster_nctaid.x" | "%cluster_nctaid.y" | "%cluster_nctaid.z" => {
495                Ok(SpecialRegister::ClusterNctaid(Axis::None))
496            }
497            "%gridid" => Ok(SpecialRegister::Gridid),
498            "%nsmid" => Ok(SpecialRegister::Nsmid),
499            "%total_smem_size" => Ok(SpecialRegister::TotalSmemSize),
500            "%cluster_nctarank"
501            | "%cluster_nctarank.x"
502            | "%cluster_nctarank.y"
503            | "%cluster_nctarank.z" => Ok(SpecialRegister::ClusterNctarank(Axis::None)),
504            "%is_explicit_cluster" => Ok(SpecialRegister::IsExplicitCluster),
505            "%ntid" | "%ntid.x" | "%ntid.y" | "%ntid.z" => Ok(SpecialRegister::Ntid(Axis::None)),
506            "%warpid" => Ok(SpecialRegister::Warpid),
507            "%clusterid" => Ok(SpecialRegister::Clusterid),
508            "%laneid" => Ok(SpecialRegister::Laneid),
509            "%nwarpid" => Ok(SpecialRegister::Nwarpid),
510            "%WARPSZ" => Ok(SpecialRegister::WARPSZ),
511            "%ctaid" | "%ctaid.x" | "%ctaid.y" | "%ctaid.z" => {
512                Ok(SpecialRegister::Ctaid(Axis::None))
513            }
514            "%lanemask_eq" => Ok(SpecialRegister::LanemaskEq),
515            "%current_graph_exec" => Ok(SpecialRegister::CurrentGraphExec),
516            "%lanemask_ge" => Ok(SpecialRegister::LanemaskGe),
517            other => {
518                if let Some(num) = other.strip_prefix("%envreg") {
519                    let value = num
520                        .parse::<u8>()
521                        .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
522                    if value <= 31 {
523                        return Ok(SpecialRegister::Envreg(value));
524                    }
525                    return Err(invalid_literal(
526                        span,
527                        format!("envreg index out of range: {value}"),
528                    ));
529                }
530
531                if let Some(num) = other.strip_prefix("%pm") {
532                    if let Some(rest) = num.strip_suffix("_64") {
533                        let value = rest
534                            .parse::<u8>()
535                            .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
536                        if value <= 7 {
537                            return Ok(SpecialRegister::Pm64(value));
538                        }
539                        return Err(invalid_literal(
540                            span,
541                            format!("pm index out of range: {value}"),
542                        ));
543                    }
544
545                    let value = num
546                        .parse::<u8>()
547                        .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
548                    if value <= 7 {
549                        return Ok(SpecialRegister::Pm(value));
550                    }
551                    return Err(invalid_literal(
552                        span,
553                        format!("pm index out of range: {value}"),
554                    ));
555                }
556
557                if let Some(num) = other.strip_prefix("%reserved_smem_offset_") {
558                    let value = num
559                        .parse::<u8>()
560                        .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
561                    if value <= 1 {
562                        return Ok(SpecialRegister::ReservedSmemOffset(value));
563                    }
564                    return Err(invalid_literal(
565                        span,
566                        format!("reserved_smem_offset index out of range: {value}"),
567                    ));
568                }
569
570                Err(invalid_literal(
571                    span,
572                    format!("unknown special register {name}"),
573                ))
574            }
575        }
576    }
577}
578
579impl PtxParser for Operand {
580    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
581        let saved_pos = stream.position();
582        if let Ok(immediate) = Immediate::parse(stream) {
583            return Ok(Operand::Immediate(immediate));
584        }
585        stream.set_position(saved_pos);
586
587        if stream.check(|token| matches!(token, PtxToken::Register(_))) {
588            return Ok(Operand::Register(RegisterOperand::parse(stream)?));
589        }
590
591        if stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
592            let (identifier, _) = stream.expect_identifier()?;
593            
594            // Check for arithmetic expression: identifier + immediate
595            let saved_pos_after_ident = stream.position();
596            if stream.expect(&PtxToken::Plus).is_ok() {
597                if let Ok(offset) = Immediate::parse(stream) {
598                    return Ok(Operand::SymbolOffset(identifier, offset));
599                }
600                // If parsing offset failed, backtrack
601                stream.set_position(saved_pos_after_ident);
602            }
603            
604            return Ok(Operand::Symbol(identifier));
605        }
606
607        let (token, span) = stream.peek()?;
608        Err(unexpected_value(
609            span.clone(),
610            &["operand"],
611            format!("{token:?}"),
612        ))
613    }
614}
615
616impl PtxParser for VectorOperand {
617    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
618        let (_, brace_span) = stream.expect(&PtxToken::LBrace)?;
619        let mut operands = Vec::new();
620
621        loop {
622            operands.push(Operand::parse(stream)?);
623            if stream
624                .consume_if(|token| matches!(token, PtxToken::Comma))
625                .is_some()
626            {
627                continue;
628            }
629            break;
630        }
631
632        stream.expect(&PtxToken::RBrace)?;
633
634        match operands.len() {
635            1 => Ok(VectorOperand::Vector1(operands.remove(0))),
636            2 => Ok(VectorOperand::Vector2([
637                operands.remove(0),
638                operands.remove(0),
639            ])),
640            3 => Ok(VectorOperand::Vector3([
641                operands.remove(0),
642                operands.remove(0),
643                operands.remove(0),
644            ])),
645            4 => Ok(VectorOperand::Vector4([
646                operands.remove(0),
647                operands.remove(0),
648                operands.remove(0),
649                operands.remove(0),
650            ])),
651            8 => Ok(VectorOperand::Vector8([
652                operands.remove(0),
653                operands.remove(0),
654                operands.remove(0),
655                operands.remove(0),
656                operands.remove(0),
657                operands.remove(0),
658                operands.remove(0),
659                operands.remove(0),
660            ])),
661            other => Err(invalid_literal(
662                brace_span.clone(),
663                format!("expected operand vector of length 1..=4 or 8, found {other}"),
664            )),
665        }
666    }
667}
668
669impl PtxParser for GeneralOperand {
670    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
671        if stream.check(|token| matches!(token, PtxToken::LBrace)) {
672            Ok(GeneralOperand::Vec(VectorOperand::parse(stream)?))
673        } else {
674            Ok(GeneralOperand::Single(Operand::parse(stream)?))
675        }
676    }
677}
678
679impl PtxParser for TexHandler2 {
680    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
681        stream.expect(&PtxToken::LBracket)?;
682        let first = GeneralOperand::parse(stream)?;
683        stream.expect(&PtxToken::Comma)?;
684        let second = GeneralOperand::parse(stream)?;
685        stream.expect(&PtxToken::RBracket)?;
686        Ok(TexHandler2([first, second]))
687    }
688}
689
690impl PtxParser for TexHandler3 {
691    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
692        stream.expect(&PtxToken::LBracket)?;
693        let handle = GeneralOperand::parse(stream)?;
694        stream.expect(&PtxToken::Comma)?;
695        let sampler = GeneralOperand::parse(stream)?;
696        stream.expect(&PtxToken::Comma)?;
697        let coords = GeneralOperand::parse(stream)?;
698        stream.expect(&PtxToken::RBracket)?;
699
700        Ok(TexHandler3 {
701            handle,
702            sampler,
703            coords,
704        })
705    }
706}
707
708impl PtxParser for TexHandler3Optional {
709    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
710        stream.expect(&PtxToken::LBracket)?;
711        let handle = GeneralOperand::parse(stream)?;
712        stream.expect(&PtxToken::Comma)?;
713        let second = GeneralOperand::parse(stream)?;
714
715        let (sampler, coords) = if stream
716            .consume_if(|token| matches!(token, PtxToken::Comma))
717            .is_some()
718        {
719            let coords = GeneralOperand::parse(stream)?;
720            (Some(second), coords)
721        } else {
722            (None, second)
723        };
724
725        stream.expect(&PtxToken::RBracket)?;
726
727        Ok(TexHandler3Optional {
728            handle,
729            sampler,
730            coords,
731        })
732    }
733}
734
735impl PtxParser for AddressBase {
736    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
737        if stream.check(|token| matches!(token, PtxToken::Register(_))) {
738            Ok(AddressBase::Register(RegisterOperand::parse(stream)?))
739        } else if stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
740            Ok(AddressBase::Variable(VariableSymbol::parse(stream)?))
741        } else {
742            let (token, span) = stream.peek()?;
743            Err(unexpected_value(
744                span.clone(),
745                &["register", "identifier"],
746                format!("{token:?}"),
747            ))
748        }
749    }
750}
751
752impl PtxParser for AddressOffset {
753    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
754        if stream
755            .consume_if(|token| matches!(token, PtxToken::Plus))
756            .is_some()
757        {
758            if stream.check(|token| matches!(token, PtxToken::Register(_))) {
759                Ok(AddressOffset::Register(RegisterOperand::parse(stream)?))
760            } else {
761                Ok(AddressOffset::Immediate(
762                    Sign::Positive,
763                    Immediate::parse(stream)?,
764                ))
765            }
766        } else if stream
767            .consume_if(|token| matches!(token, PtxToken::Minus))
768            .is_some()
769        {
770            Ok(AddressOffset::Immediate(
771                Sign::Negative,
772                Immediate::parse(stream)?,
773            ))
774        } else {
775            let (token, span) = stream.peek()?;
776            Err(unexpected_value(
777                span.clone(),
778                &["+", "-"],
779                format!("{token:?}"),
780            ))
781        }
782    }
783}
784
785impl PtxParser for AddressOperand {
786    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
787        if stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
788            let saved = stream.position();
789            let (identifier, _) = stream.expect_identifier()?;
790            if stream
791                .consume_if(|token| matches!(token, PtxToken::LBracket))
792                .is_some()
793            {
794                let immediate = Immediate::parse(stream)?;
795                stream.expect(&PtxToken::RBracket)?;
796                return Ok(AddressOperand::Array(VariableSymbol(identifier), immediate));
797            } else {
798                stream.set_position(saved);
799            }
800        }
801
802        stream.expect(&PtxToken::LBracket)?;
803
804        if stream.check(|token| matches!(token, PtxToken::Minus)) {
805            let pos = stream.position();
806            stream.consume()?;
807            if stream.check(|token| is_numeric_token(token)) {
808                let mut immediate = Immediate::parse(stream)?;
809                immediate.0.insert(0, '-');
810                stream.expect(&PtxToken::RBracket)?;
811                return Ok(AddressOperand::ImmediateAddress(immediate));
812            } else {
813                stream.set_position(pos);
814            }
815        }
816
817        if stream.check(|token| is_numeric_token(token)) {
818            let immediate = Immediate::parse(stream)?;
819            stream.expect(&PtxToken::RBracket)?;
820            return Ok(AddressOperand::ImmediateAddress(immediate));
821        }
822
823        let base = AddressBase::parse(stream)?;
824        let offset = if stream.check(|token| matches!(token, PtxToken::Plus | PtxToken::Minus)) {
825            Some(AddressOffset::parse(stream)?)
826        } else {
827            None
828        };
829        stream.expect(&PtxToken::RBracket)?;
830
831        Ok(AddressOperand::Offset(base, offset))
832    }
833}
834
835impl PtxParser for FunctionSymbol {
836    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
837        let (name, _) = stream.expect_identifier()?;
838        Ok(FunctionSymbol(name))
839    }
840}
841
842impl PtxParser for VariableSymbol {
843    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
844        let (name, _) = stream.expect_identifier()?;
845        Ok(VariableSymbol(name))
846    }
847}
848
849/// Try to parse an optional label (identifier followed by colon).
850/// Returns `Ok(Some(label))` if a label is found, `Ok(None)` if not,
851/// or `Err` if parsing fails.
852pub(crate) fn try_parse_label(
853    stream: &mut PtxTokenStream,
854) -> Result<Option<String>, PtxParseError> {
855    if !stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
856        return Ok(None);
857    }
858
859    let position = stream.position();
860    let (name, _) = stream.expect_identifier()?;
861    if stream
862        .consume_if(|token| matches!(token, PtxToken::Colon))
863        .is_some()
864    {
865        Ok(Some(name))
866    } else {
867        stream.set_position(position);
868        Ok(None)
869    }
870}
871
872impl PtxParser for Instruction {
873    /// Parse a PTX instruction with optional label and predicate
874    ///
875    /// Format: [label:] [@{!}pred] instruction
876    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
877        // Optional label (ends with colon)
878        let label = try_parse_label(stream)?;
879        
880        // Optional predicate: @{!}pred or @!pred
881        let predicate = if stream.check(|t| matches!(t, PtxToken::At)) {
882            stream.consume()?; // consume @
883            
884            // Optional negation
885            let negated = stream.consume_if(|t| matches!(t, PtxToken::Exclaim)).is_some();
886
887            // Predicate operand (can be register %p1 or identifier p)
888            let operand = Operand::parse(stream)?;
889
890            Some(Predicate { negated, operand })
891        } else {
892            None
893        };
894        
895        // Parse the actual instruction using the module-level dispatcher
896        let inst = crate::parser::instruction::parse_instruction_inner(stream)?;
897        
898        Ok(Instruction { label, predicate, inst })
899    }
900}
901
902// Backwards compatibility: Inst can still be parsed directly
903impl PtxParser for Inst {
904    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
905        Ok(Instruction::parse(stream)?.inst)
906    }
907}