1use std::borrow::Cow;
2
3use crate::{
4 lexer::PtxToken,
5 parser::{ParseErrorKind, PtxParseError, PtxParser, PtxTokenStream, Span},
6 r#type::common::*,
7 r#type::instruction::Inst,
8};
9
10pub(crate) fn unexpected_value(
11 span: Span,
12 expected: &[&str],
13 found: impl Into<Cow<'static, str>>,
14) -> PtxParseError {
15 PtxParseError {
16 kind: ParseErrorKind::UnexpectedToken {
17 expected: expected.iter().map(|s| s.to_string()).collect(),
18 found: found.into().to_string(),
19 },
20 span,
21 }
22}
23
24pub(crate) fn invalid_literal(span: Span, literal: impl Into<Cow<'static, str>>) -> PtxParseError {
25 PtxParseError {
26 kind: ParseErrorKind::InvalidLiteral(literal.into().to_string()),
27 span,
28 }
29}
30
31pub(crate) fn parse_register_name(
32 stream: &mut PtxTokenStream,
33) -> Result<(String, Span), PtxParseError> {
34 let (mut name, mut span) = stream.expect_register()?;
35
36 loop {
37 let next = match stream.peek() {
39 Ok((token, _)) => token,
40 Err(_) => break,
41 };
42
43 match next {
44 PtxToken::Dot => {
45 if let Some((PtxToken::Identifier(component_name), _)) =
47 stream.tokens.get(stream.index + 1)
48 {
49 if matches!(
52 component_name.as_str(),
53 "x" | "y" | "z" | "w" | "r" | "g" | "b" | "a"
54 ) {
55 stream.consume()?;
57 let (component, component_span) = stream.expect_identifier()?;
58
59 name.push('.');
60 name.push_str(&component);
61
62 span.end = component_span.end;
63 } else {
64 break;
66 }
67 } else {
68 break;
69 }
70 }
71 _ => break,
72 }
73 }
74
75 Ok((name, span))
76}
77
78pub(crate) fn numeric_literal(token: &PtxToken) -> Option<&String> {
79 match token {
80 PtxToken::DecimalInteger(value)
81 | PtxToken::HexInteger(value)
82 | PtxToken::BinaryInteger(value)
83 | PtxToken::OctalInteger(value)
84 | PtxToken::FloatExponent(value)
85 | PtxToken::Float(value)
86 | PtxToken::HexFloat(value) => Some(value),
87 _ => None,
88 }
89}
90
91pub(crate) fn is_numeric_token(token: &PtxToken) -> bool {
92 numeric_literal(token).is_some()
93}
94
95pub(crate) fn parse_u64_literal(stream: &mut PtxTokenStream) -> Result<(u64, Span), PtxParseError> {
96 let (token, span) = stream.consume()?;
97 let span = span.clone();
98
99 let value = match token {
100 PtxToken::DecimalInteger(text) => text
101 .parse::<u64>()
102 .map_err(|_| invalid_literal(span.clone(), text.clone()))?,
103 PtxToken::HexInteger(text) => {
104 let stripped = text
105 .strip_prefix("0x")
106 .or_else(|| text.strip_prefix("0X"))
107 .ok_or_else(|| invalid_literal(span.clone(), text.clone()))?;
108 u64::from_str_radix(stripped, 16)
109 .map_err(|_| invalid_literal(span.clone(), text.clone()))?
110 }
111 PtxToken::BinaryInteger(text) => {
112 let stripped = text
113 .strip_prefix("0b")
114 .or_else(|| text.strip_prefix("0B"))
115 .ok_or_else(|| invalid_literal(span.clone(), text.clone()))?;
116 u64::from_str_radix(stripped, 2)
117 .map_err(|_| invalid_literal(span.clone(), text.clone()))?
118 }
119 PtxToken::OctalInteger(text) => {
120 let stripped = &text[1..];
121 u64::from_str_radix(stripped, 8)
122 .map_err(|_| invalid_literal(span.clone(), text.clone()))?
123 }
124 _ => {
125 return Err(unexpected_value(
126 span,
127 &["unsigned integer literal"],
128 format!("{token:?}"),
129 ));
130 }
131 };
132
133 Ok((value, span))
134}
135
136impl PtxParser for CodeLinkage {
137 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
138 let (directive, span) = stream.expect_directive()?;
139 match directive.as_str() {
140 "visible" => Ok(CodeLinkage::Visible { span }),
141 "extern" => Ok(CodeLinkage::Extern { span }),
142 "weak" => Ok(CodeLinkage::Weak { span }),
143 other => Err(unexpected_value(
144 span,
145 &[".visible", ".extern", ".weak"],
146 format!(".{other}"),
147 )),
148 }
149 }
150}
151
152impl PtxParser for DataLinkage {
153 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
154 let (directive, span) = stream.expect_directive()?;
155 match directive.as_str() {
156 "visible" => Ok(DataLinkage::Visible { span }),
157 "extern" => Ok(DataLinkage::Extern { span }),
158 "weak" => Ok(DataLinkage::Weak { span }),
159 "common" => Ok(DataLinkage::Common { span }),
160 other => Err(unexpected_value(
161 span,
162 &[".visible", ".extern", ".weak", ".common"],
163 format!(".{other}"),
164 )),
165 }
166 }
167}
168
169impl PtxParser for CodeOrDataLinkage {
170 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
171 let (directive, span) = stream.expect_directive()?;
172 match directive.as_str() {
173 "visible" => Ok(CodeOrDataLinkage::Visible { span }),
174 "extern" => Ok(CodeOrDataLinkage::Extern { span }),
175 "weak" => Ok(CodeOrDataLinkage::Weak { span }),
176 "common" => Ok(CodeOrDataLinkage::Common { span }),
177 other => Err(unexpected_value(
178 span,
179 &[".visible", ".extern", ".weak", ".common"],
180 format!(".{other}"),
181 )),
182 }
183 }
184}
185
186impl PtxParser for TexType {
187 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
188 let (directive, span) = stream.expect_directive()?;
189 match directive.as_str() {
190 "texref" => Ok(TexType::TexRef { span }),
191 "samplerref" => Ok(TexType::SamplerRef { span }),
192 "surfref" => Ok(TexType::SurfRef { span }),
193 other => Err(unexpected_value(
194 span,
195 &[".texref", ".samplerref", ".surfref"],
196 format!(".{other}"),
197 )),
198 }
199 }
200}
201
202impl PtxParser for AddressSpace {
203 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
204 let (directive, span) = stream.expect_directive()?;
205 match directive.as_str() {
206 "global" => Ok(AddressSpace::Global { span }),
207 "const" => Ok(AddressSpace::Const { span }),
208 "shared" => Ok(AddressSpace::Shared { span }),
209 "local" => Ok(AddressSpace::Local { span }),
210 "param" => Ok(AddressSpace::Param { span }),
211 "reg" => Ok(AddressSpace::Reg { span }),
212 other => Err(unexpected_value(
213 span,
214 &[".global", ".const", ".shared", ".local", ".param", ".reg"],
215 format!(".{other}"),
216 )),
217 }
218 }
219}
220
221impl PtxParser for AttributeDirective {
222 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
223 let (directive, span) = stream.expect_directive()?;
224 match directive.as_str() {
225 "unified" => {
226 stream.expect(&PtxToken::LParen)?;
227 let (uuid1, _) = parse_u64_literal(stream)?;
228 stream.expect(&PtxToken::Comma)?;
229 let (uuid2, _) = parse_u64_literal(stream)?;
230 stream.expect(&PtxToken::RParen)?;
231 Ok(AttributeDirective::Unified { uuid1, uuid2, span })
232 }
233 "managed" => Ok(AttributeDirective::Managed { span }),
234 other => Err(unexpected_value(
235 span,
236 &[".unified", ".managed"],
237 format!(".{other}"),
238 )),
239 }
240 }
241}
242
243impl PtxParser for DataType {
244 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
245 let (directive, span) = stream.expect_directive()?;
246 match directive.as_str() {
247 "u8" => Ok(DataType::U8 { span }),
248 "u16" => Ok(DataType::U16 { span }),
249 "u32" => Ok(DataType::U32 { span }),
250 "u64" => Ok(DataType::U64 { span }),
251 "s8" => Ok(DataType::S8 { span }),
252 "s16" => Ok(DataType::S16 { span }),
253 "s32" => Ok(DataType::S32 { span }),
254 "s64" => Ok(DataType::S64 { span }),
255 "f16" => Ok(DataType::F16 { span }),
256 "f16x2" => Ok(DataType::F16x2 { span }),
257 "f32" => Ok(DataType::F32 { span }),
258 "f64" => Ok(DataType::F64 { span }),
259 "b8" => Ok(DataType::B8 { span }),
260 "b16" => Ok(DataType::B16 { span }),
261 "b32" => Ok(DataType::B32 { span }),
262 "b64" => Ok(DataType::B64 { span }),
263 "b128" => Ok(DataType::B128 { span }),
264 "pred" => Ok(DataType::Pred { span }),
265 other => Err(unexpected_value(
266 span,
267 &[
268 ".u8", ".u16", ".u32", ".u64", ".s8", ".s16", ".s32", ".s64", ".f16", ".f16x2",
269 ".f32", ".f64", ".b8", ".b16", ".b32", ".b64", ".b128", ".pred",
270 ],
271 format!(".{other}"),
272 )),
273 }
274 }
275}
276
277impl PtxParser for Sign {
278 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
279 if let Some((_, span)) = stream
280 .consume_if(|token| matches!(token, PtxToken::Plus))
281 {
282 return Ok(Sign::Positive { span: span.clone() });
283 }
284 if let Some((_, span)) = stream
285 .consume_if(|token| matches!(token, PtxToken::Minus))
286 {
287 return Ok(Sign::Negative { span: span.clone() });
288 }
289
290 let (token, span) = stream.peek()?;
291 Err(unexpected_value(
292 span.clone(),
293 &["+", "-"],
294 format!("{token:?}"),
295 ))
296 }
297}
298
299impl PtxParser for Immediate {
300 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
301 let minus_span = stream
303 .consume_if(|token| matches!(token, PtxToken::Minus))
304 .map(|(_, span)| span.clone());
305
306 let (token, span) = stream.peek()?;
307 let value = numeric_literal(token).cloned();
308 match value {
309 Some(value) => {
310 let literal = if minus_span.is_some() {
311 format!("-{}", value)
312 } else {
313 value.clone()
314 };
315 let (_, value_span) = stream.consume()?;
316 let full_span = if let Some(ref ms) = minus_span {
317 Span { start: ms.start, end: value_span.end }
318 } else {
319 value_span.clone()
320 };
321 Ok(Immediate { value: literal, span: full_span })
322 }
323 None => {
324 if minus_span.is_some() {
326 let mut current_pos = stream.position();
327 if current_pos.index > 0 {
328 current_pos.index -= 1;
329 current_pos.char_offset = 0;
330 stream.set_position(current_pos);
331 }
332 }
333 Err(unexpected_value(
334 span.clone(),
335 &["numeric literal"],
336 format!("{token:?}"),
337 ))
338 }
339 }
340 }
341}
342
343impl PtxParser for RegisterOperand {
344 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
345 if !stream.check(|token| matches!(token, PtxToken::Register(_))) {
346 let (token, span) = stream.peek()?;
347 return Err(unexpected_value(
348 span.clone(),
349 &["register"],
350 format!("{token:?}"),
351 ));
352 }
353 let (name, span) = parse_register_name(stream)?;
354 Ok(RegisterOperand { name, span })
355 }
356}
357
358impl PtxParser for PredicateRegister {
359 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
360 let (name, span) = parse_register_name(stream)?;
361 if name.starts_with("%p") {
362 Ok(PredicateRegister { name, span })
363 } else {
364 Err(invalid_literal(
365 span,
366 format!("expected predicate register starting with %p, found {name}"),
367 ))
368 }
369 }
370}
371
372impl PtxParser for Label {
373 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
374 let (name, span) = stream.expect_identifier()?;
375 Ok(Label { name, span })
376 }
377}
378
379impl PtxParser for SpecialRegister {
380 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
381 let (name, span) = parse_register_name(stream)?;
382 let name_str = name.as_str();
386 if let Some(rest) = name_str.strip_prefix("%cluster_ctaid") {
387 if rest.is_empty() {
388 return Ok(SpecialRegister::ClusterCtaid { axis: Axis::None { span: span.clone() }, span });
389 } else if rest == ".x" {
390 return Ok(SpecialRegister::ClusterCtaid { axis: Axis::X { span: span.clone() }, span });
391 } else if rest == ".y" {
392 return Ok(SpecialRegister::ClusterCtaid { axis: Axis::Y { span: span.clone() }, span });
393 } else if rest == ".z" {
394 return Ok(SpecialRegister::ClusterCtaid { axis: Axis::Z { span: span.clone() }, span });
395 }
396 }
397 if let Some(rest) = name_str.strip_prefix("%cluster_ctarank") {
398 if rest.is_empty() {
399 return Ok(SpecialRegister::ClusterCtarank { axis: Axis::None { span: span.clone() }, span });
400 } else if rest == ".x" {
401 return Ok(SpecialRegister::ClusterCtarank { axis: Axis::X { span: span.clone() }, span });
402 } else if rest == ".y" {
403 return Ok(SpecialRegister::ClusterCtarank { axis: Axis::Y { span: span.clone() }, span });
404 } else if rest == ".z" {
405 return Ok(SpecialRegister::ClusterCtarank { axis: Axis::Z { span: span.clone() }, span });
406 }
407 }
408 if let Some(rest) = name_str.strip_prefix("%nctaid") {
409 if rest.is_empty() {
410 return Ok(SpecialRegister::Nctaid { axis: Axis::None { span: span.clone() }, span });
411 } else if rest == ".x" {
412 return Ok(SpecialRegister::Nctaid { axis: Axis::X { span: span.clone() }, span });
413 } else if rest == ".y" {
414 return Ok(SpecialRegister::Nctaid { axis: Axis::Y { span: span.clone() }, span });
415 } else if rest == ".z" {
416 return Ok(SpecialRegister::Nctaid { axis: Axis::Z { span: span.clone() }, span });
417 }
418 }
419 if let Some(rest) = name_str.strip_prefix("%tid") {
420 if rest.is_empty() {
421 return Ok(SpecialRegister::Tid { axis: Axis::None { span: span.clone() }, span });
422 } else if rest == ".x" {
423 return Ok(SpecialRegister::Tid { axis: Axis::X { span: span.clone() }, span });
424 } else if rest == ".y" {
425 return Ok(SpecialRegister::Tid { axis: Axis::Y { span: span.clone() }, span });
426 } else if rest == ".z" {
427 return Ok(SpecialRegister::Tid { axis: Axis::Z { span: span.clone() }, span });
428 }
429 }
430 if let Some(rest) = name_str.strip_prefix("%cluster_nctaid") {
431 if rest.is_empty() {
432 return Ok(SpecialRegister::ClusterNctaid { axis: Axis::None { span: span.clone() }, span });
433 } else if rest == ".x" {
434 return Ok(SpecialRegister::ClusterNctaid { axis: Axis::X { span: span.clone() }, span });
435 } else if rest == ".y" {
436 return Ok(SpecialRegister::ClusterNctaid { axis: Axis::Y { span: span.clone() }, span });
437 } else if rest == ".z" {
438 return Ok(SpecialRegister::ClusterNctaid { axis: Axis::Z { span: span.clone() }, span });
439 }
440 }
441 if let Some(rest) = name_str.strip_prefix("%cluster_nctarank") {
442 if rest.is_empty() {
443 return Ok(SpecialRegister::ClusterNctarank { axis: Axis::None { span: span.clone() }, span });
444 } else if rest == ".x" {
445 return Ok(SpecialRegister::ClusterNctarank { axis: Axis::X { span: span.clone() }, span });
446 } else if rest == ".y" {
447 return Ok(SpecialRegister::ClusterNctarank { axis: Axis::Y { span: span.clone() }, span });
448 } else if rest == ".z" {
449 return Ok(SpecialRegister::ClusterNctarank { axis: Axis::Z { span: span.clone() }, span });
450 }
451 }
452 if let Some(rest) = name_str.strip_prefix("%ntid") {
453 if rest.is_empty() {
454 return Ok(SpecialRegister::Ntid { axis: Axis::None { span: span.clone() }, span });
455 } else if rest == ".x" {
456 return Ok(SpecialRegister::Ntid { axis: Axis::X { span: span.clone() }, span });
457 } else if rest == ".y" {
458 return Ok(SpecialRegister::Ntid { axis: Axis::Y { span: span.clone() }, span });
459 } else if rest == ".z" {
460 return Ok(SpecialRegister::Ntid { axis: Axis::Z { span: span.clone() }, span });
461 }
462 }
463 if let Some(rest) = name_str.strip_prefix("%ctaid") {
464 if rest.is_empty() {
465 return Ok(SpecialRegister::Ctaid { axis: Axis::None { span: span.clone() }, span });
466 } else if rest == ".x" {
467 return Ok(SpecialRegister::Ctaid { axis: Axis::X { span: span.clone() }, span });
468 } else if rest == ".y" {
469 return Ok(SpecialRegister::Ctaid { axis: Axis::Y { span: span.clone() }, span });
470 } else if rest == ".z" {
471 return Ok(SpecialRegister::Ctaid { axis: Axis::Z { span: span.clone() }, span });
472 }
473 }
474
475 match name.as_str() {
476 "%aggr_smem_size" => Ok(SpecialRegister::AggrSmemSize { span }),
477 "%dynamic_smem_size" => Ok(SpecialRegister::DynamicSmemSize { span }),
478 "%lanemask_gt" => Ok(SpecialRegister::LanemaskGt { span }),
479 "%reserved_smem_offset_begin" => Ok(SpecialRegister::ReservedSmemOffsetBegin { span }),
480 "%clock" => Ok(SpecialRegister::Clock { span }),
481 "%lanemask_le" => Ok(SpecialRegister::LanemaskLe { span }),
482 "%reserved_smem_offset_cap" => Ok(SpecialRegister::ReservedSmemOffsetCap { span }),
483 "%clock64" => Ok(SpecialRegister::Clock64 { span }),
484 "%globaltimer" => Ok(SpecialRegister::Globaltimer { span }),
485 "%lanemask_lt" => Ok(SpecialRegister::LanemaskLt { span }),
486 "%reserved_smem_offset_end" => Ok(SpecialRegister::ReservedSmemOffsetEnd { span }),
487 "%cluster_ctaid" | "%cluster_ctaid.x" | "%cluster_ctaid.y" | "%cluster_ctaid.z" => {
488 Ok(SpecialRegister::ClusterCtaid { axis: Axis::None { span: span.clone() }, span })
489 }
490 "%globaltimer_hi" => Ok(SpecialRegister::GlobaltimerHi { span }),
491 "%nclusterid" => Ok(SpecialRegister::Nclusterid { span }),
492 "%smid" => Ok(SpecialRegister::Smid { span }),
493 "%cluster_ctarank" | "%cluster_ctarank.x" | "%cluster_ctarank.y"
494 | "%cluster_ctarank.z" => Ok(SpecialRegister::ClusterCtarank { axis: Axis::None { span: span.clone() }, span }),
495 "%globaltimer_lo" => Ok(SpecialRegister::GlobaltimerLo { span }),
496 "%nctaid" | "%nctaid.x" | "%nctaid.y" | "%nctaid.z" => {
497 Ok(SpecialRegister::Nctaid { axis: Axis::None { span: span.clone() }, span })
498 }
499 "%tid" | "%tid.x" | "%tid.y" | "%tid.z" => Ok(SpecialRegister::Tid { axis: Axis::None { span: span.clone() }, span }),
500 "%cluster_nctaid" | "%cluster_nctaid.x" | "%cluster_nctaid.y" | "%cluster_nctaid.z" => {
501 Ok(SpecialRegister::ClusterNctaid { axis: Axis::None { span: span.clone() }, span })
502 }
503 "%gridid" => Ok(SpecialRegister::Gridid { span }),
504 "%nsmid" => Ok(SpecialRegister::Nsmid { span }),
505 "%total_smem_size" => Ok(SpecialRegister::TotalSmemSize { span }),
506 "%cluster_nctarank"
507 | "%cluster_nctarank.x"
508 | "%cluster_nctarank.y"
509 | "%cluster_nctarank.z" => Ok(SpecialRegister::ClusterNctarank { axis: Axis::None { span: span.clone() }, span }),
510 "%is_explicit_cluster" => Ok(SpecialRegister::IsExplicitCluster { span }),
511 "%ntid" | "%ntid.x" | "%ntid.y" | "%ntid.z" => Ok(SpecialRegister::Ntid { axis: Axis::None { span: span.clone() }, span }),
512 "%warpid" => Ok(SpecialRegister::Warpid { span }),
513 "%clusterid" => Ok(SpecialRegister::Clusterid { span }),
514 "%laneid" => Ok(SpecialRegister::Laneid { span }),
515 "%nwarpid" => Ok(SpecialRegister::Nwarpid { span }),
516 "%WARPSZ" => Ok(SpecialRegister::WARPSZ { span }),
517 "%ctaid" | "%ctaid.x" | "%ctaid.y" | "%ctaid.z" => {
518 Ok(SpecialRegister::Ctaid { axis: Axis::None { span: span.clone() }, span })
519 }
520 "%lanemask_eq" => Ok(SpecialRegister::LanemaskEq { span }),
521 "%current_graph_exec" => Ok(SpecialRegister::CurrentGraphExec { span }),
522 "%lanemask_ge" => Ok(SpecialRegister::LanemaskGe { span }),
523 other => {
524 if let Some(num) = other.strip_prefix("%envreg") {
525 let value = num
526 .parse::<u8>()
527 .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
528 if value <= 31 {
529 return Ok(SpecialRegister::Envreg { index: value, span });
530 }
531 return Err(invalid_literal(
532 span,
533 format!("envreg index out of range: {value}"),
534 ));
535 }
536
537 if let Some(num) = other.strip_prefix("%pm") {
538 if let Some(rest) = num.strip_suffix("_64") {
539 let value = rest
540 .parse::<u8>()
541 .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
542 if value <= 7 {
543 return Ok(SpecialRegister::Pm64 { index: value, span });
544 }
545 return Err(invalid_literal(
546 span,
547 format!("pm index out of range: {value}"),
548 ));
549 }
550
551 let value = num
552 .parse::<u8>()
553 .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
554 if value <= 7 {
555 return Ok(SpecialRegister::Pm { index: value, span });
556 }
557 return Err(invalid_literal(
558 span,
559 format!("pm index out of range: {value}"),
560 ));
561 }
562
563 if let Some(num) = other.strip_prefix("%reserved_smem_offset_") {
564 let value = num
565 .parse::<u8>()
566 .map_err(|_| invalid_literal(span.clone(), name.clone()))?;
567 if value <= 1 {
568 return Ok(SpecialRegister::ReservedSmemOffset { index: value, span });
569 }
570 return Err(invalid_literal(
571 span,
572 format!("reserved_smem_offset index out of range: {value}"),
573 ));
574 }
575
576 Err(invalid_literal(
577 span,
578 format!("unknown special register {name}"),
579 ))
580 }
581 }
582 }
583}
584
585impl PtxParser for Operand {
586 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
587 let saved_pos = stream.position();
588 if let Ok(immediate) = Immediate::parse(stream) {
589 let span = immediate.span.clone();
590 return Ok(Operand::Immediate { operand: immediate, span });
591 }
592 stream.set_position(saved_pos);
593
594 if stream.check(|token| matches!(token, PtxToken::Register(_))) {
595 let register = RegisterOperand::parse(stream)?;
596 let span = register.span.clone();
597 return Ok(Operand::Register { operand: register, span });
598 }
599
600 if stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
601 let (identifier, ident_span) = stream.expect_identifier()?;
602
603 let saved_pos_after_ident = stream.position();
605 if stream.expect(&PtxToken::Plus).is_ok() {
606 if let Ok(offset) = Immediate::parse(stream) {
607 let span = Span { start: ident_span.start, end: offset.span.end };
608 return Ok(Operand::SymbolOffset { symbol: identifier, offset, span });
609 }
610 stream.set_position(saved_pos_after_ident);
612 }
613
614 return Ok(Operand::Symbol { name: identifier, span: ident_span });
615 }
616
617 let (token, span) = stream.peek()?;
618 Err(unexpected_value(
619 span.clone(),
620 &["operand"],
621 format!("{token:?}"),
622 ))
623 }
624}
625
626impl PtxParser for VectorOperand {
627 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
628 let (_, brace_span) = stream.expect(&PtxToken::LBrace)?;
629 let mut operands = Vec::new();
630
631 loop {
632 operands.push(Operand::parse(stream)?);
633 if stream
634 .consume_if(|token| matches!(token, PtxToken::Comma))
635 .is_some()
636 {
637 continue;
638 }
639 break;
640 }
641
642 let (_, end_span) = stream.expect(&PtxToken::RBrace)?;
643 let span = Span { start: brace_span.start, end: end_span.end };
644
645 match operands.len() {
646 1 => Ok(VectorOperand::Vector1 { operand: operands.remove(0), span }),
647 2 => Ok(VectorOperand::Vector2 { operands: [
648 operands.remove(0),
649 operands.remove(0),
650 ], span }),
651 3 => Ok(VectorOperand::Vector3 { operands: [
652 operands.remove(0),
653 operands.remove(0),
654 operands.remove(0),
655 ], span }),
656 4 => Ok(VectorOperand::Vector4 { operands: [
657 operands.remove(0),
658 operands.remove(0),
659 operands.remove(0),
660 operands.remove(0),
661 ], span }),
662 8 => Ok(VectorOperand::Vector8 { operands: [
663 operands.remove(0),
664 operands.remove(0),
665 operands.remove(0),
666 operands.remove(0),
667 operands.remove(0),
668 operands.remove(0),
669 operands.remove(0),
670 operands.remove(0),
671 ], span }),
672 other => Err(invalid_literal(
673 brace_span.clone(),
674 format!("expected operand vector of length 1..=4 or 8, found {other}"),
675 )),
676 }
677 }
678}
679
680impl PtxParser for GeneralOperand {
681 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
682 if stream.check(|token| matches!(token, PtxToken::LBrace)) {
683 let vec_operand = VectorOperand::parse(stream)?;
684 let span = vec_operand.span();
685 Ok(GeneralOperand::Vec { operand: vec_operand, span })
686 } else {
687 let operand = Operand::parse(stream)?;
688 let span = operand.span();
689 Ok(GeneralOperand::Single { operand, span })
690 }
691 }
692}
693
694impl PtxParser for TexHandler2 {
695 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
696 let (_, start_span) = stream.expect(&PtxToken::LBracket)?;
697 let first = GeneralOperand::parse(stream)?;
698 stream.expect(&PtxToken::Comma)?;
699 let second = GeneralOperand::parse(stream)?;
700 let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
701 let span = Span { start: start_span.start, end: end_span.end };
702 Ok(TexHandler2 { operands: [first, second], span })
703 }
704}
705
706impl PtxParser for TexHandler3 {
707 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
708 let (_, start_span) = stream.expect(&PtxToken::LBracket)?;
709 let handle = GeneralOperand::parse(stream)?;
710 stream.expect(&PtxToken::Comma)?;
711 let sampler = GeneralOperand::parse(stream)?;
712 stream.expect(&PtxToken::Comma)?;
713 let coords = GeneralOperand::parse(stream)?;
714 let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
715 let span = Span { start: start_span.start, end: end_span.end };
716
717 Ok(TexHandler3 {
718 handle,
719 sampler,
720 coords,
721 span,
722 })
723 }
724}
725
726impl PtxParser for TexHandler3Optional {
727 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
728 let (_, start_span) = stream.expect(&PtxToken::LBracket)?;
729 let handle = GeneralOperand::parse(stream)?;
730 stream.expect(&PtxToken::Comma)?;
731 let second = GeneralOperand::parse(stream)?;
732
733 let (sampler, coords) = if stream
734 .consume_if(|token| matches!(token, PtxToken::Comma))
735 .is_some()
736 {
737 let coords = GeneralOperand::parse(stream)?;
738 (Some(second), coords)
739 } else {
740 (None, second)
741 };
742
743 let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
744 let span = Span { start: start_span.start, end: end_span.end };
745
746 Ok(TexHandler3Optional {
747 handle,
748 sampler,
749 coords,
750 span,
751 })
752 }
753}
754
755impl PtxParser for AddressBase {
756 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
757 if stream.check(|token| matches!(token, PtxToken::Register(_))) {
758 let register = RegisterOperand::parse(stream)?;
759 let span = register.span.clone();
760 Ok(AddressBase::Register { operand: register, span })
761 } else if stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
762 let variable = VariableSymbol::parse(stream)?;
763 let span = variable.span.clone();
764 Ok(AddressBase::Variable { symbol: variable, span })
765 } else {
766 let (token, span) = stream.peek()?;
767 Err(unexpected_value(
768 span.clone(),
769 &["register", "identifier"],
770 format!("{token:?}"),
771 ))
772 }
773 }
774}
775
776impl PtxParser for AddressOffset {
777 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
778 if let Some((_, plus_span)) = stream
779 .consume_if(|token| matches!(token, PtxToken::Plus))
780 {
781 if stream.check(|token| matches!(token, PtxToken::Register(_))) {
782 let register = RegisterOperand::parse(stream)?;
783 let span = Span { start: plus_span.start, end: register.span.end };
784 Ok(AddressOffset::Register { operand: register, span })
785 } else {
786 let sign = Sign::Positive { span: plus_span.clone() };
787 let value = Immediate::parse(stream)?;
788 let span = Span { start: plus_span.start, end: value.span.end };
789 Ok(AddressOffset::Immediate { sign, value, span })
790 }
791 } else if let Some((_, minus_span)) = stream
792 .consume_if(|token| matches!(token, PtxToken::Minus))
793 {
794 let sign = Sign::Negative { span: minus_span.clone() };
795 let value = Immediate::parse(stream)?;
796 let span = Span { start: minus_span.start, end: value.span.end };
797 Ok(AddressOffset::Immediate { sign, value, span })
798 } else {
799 let (token, span) = stream.peek()?;
800 Err(unexpected_value(
801 span.clone(),
802 &["+", "-"],
803 format!("{token:?}"),
804 ))
805 }
806 }
807}
808
809impl PtxParser for AddressOperand {
810 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
811 if stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
812 let saved = stream.position();
813 let (identifier, ident_span) = stream.expect_identifier()?;
814 if stream
815 .consume_if(|token| matches!(token, PtxToken::LBracket))
816 .is_some()
817 {
818 let immediate = Immediate::parse(stream)?;
819 let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
820 let span = Span { start: ident_span.start, end: end_span.end };
821 return Ok(AddressOperand::Array { base: VariableSymbol { name: identifier, span: ident_span }, index: immediate, span });
822 } else {
823 stream.set_position(saved);
824 }
825 }
826
827 let (_, start_span) = stream.expect(&PtxToken::LBracket)?;
828
829 if stream.check(|token| matches!(token, PtxToken::Minus)) {
830 let pos = stream.position();
831 stream.consume()?;
832 if stream.check(|token| is_numeric_token(token)) {
833 let mut immediate = Immediate::parse(stream)?;
834 immediate.value.insert(0, '-');
835 let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
836 let span = Span { start: start_span.start, end: end_span.end };
837 return Ok(AddressOperand::ImmediateAddress { addr: immediate, span });
838 } else {
839 stream.set_position(pos);
840 }
841 }
842
843 if stream.check(|token| is_numeric_token(token)) {
844 let immediate = Immediate::parse(stream)?;
845 let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
846 let span = Span { start: start_span.start, end: end_span.end };
847 return Ok(AddressOperand::ImmediateAddress { addr: immediate, span });
848 }
849
850 let base = AddressBase::parse(stream)?;
851 let offset = if stream.check(|token| matches!(token, PtxToken::Plus | PtxToken::Minus)) {
852 Some(AddressOffset::parse(stream)?)
853 } else {
854 None
855 };
856 let (_, end_span) = stream.expect(&PtxToken::RBracket)?;
857 let span = Span { start: start_span.start, end: end_span.end };
858
859 Ok(AddressOperand::Offset { base, offset, span })
860 }
861}
862
863impl PtxParser for FunctionSymbol {
864 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
865 let (name, span) = stream.expect_identifier()?;
866 Ok(FunctionSymbol { name, span })
867 }
868}
869
870impl PtxParser for VariableSymbol {
871 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
872 let (name, span) = stream.expect_identifier()?;
873 Ok(VariableSymbol { name, span })
874 }
875}
876
877pub(crate) fn try_parse_label(
881 stream: &mut PtxTokenStream,
882) -> Result<Option<String>, PtxParseError> {
883 if !stream.check(|token| matches!(token, PtxToken::Identifier(_))) {
884 return Ok(None);
885 }
886
887 let position = stream.position();
888 let (name, _) = stream.expect_identifier()?;
889 if stream
890 .consume_if(|token| matches!(token, PtxToken::Colon))
891 .is_some()
892 {
893 Ok(Some(name))
894 } else {
895 stream.set_position(position);
896 Ok(None)
897 }
898}
899
900impl PtxParser for Instruction {
901 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
903 let start_pos = stream.position();
904
905 let predicate = if stream.check(|t| matches!(t, PtxToken::At)) {
907 let (_, at_span) = stream.consume()?; let negated = stream
911 .consume_if(|t| matches!(t, PtxToken::Exclaim))
912 .is_some();
913
914 let operand = Operand::parse(stream)?;
916 let pred_span = Span { start: at_span.start, end: operand.span().end };
917
918 Some(Predicate { negated, operand, span: pred_span })
919 } else {
920 None
921 };
922
923 let inst = crate::parser::instruction::parse_instruction_inner(stream)?;
925
926 let end_pos = stream.position();
928 let span = if let Some(ref pred) = predicate {
929 Span { start: pred.span.start, end: end_pos.char_offset as usize }
930 } else {
931 Span { start: start_pos.char_offset as usize, end: end_pos.char_offset as usize }
932 };
933
934 Ok(Instruction { predicate, inst, span })
935 }
936}
937
938impl PtxParser for Inst {
940 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
941 Ok(Instruction::parse(stream)?.inst)
942 }
943}