1use crate::{
2 lexer::PtxToken,
3 parser::{
4 PtxParseError, PtxParser, PtxTokenStream, Span, common::parse_u64_literal, invalid_literal,
5 peek_directive, unexpected_value,
6 },
7 r#type::{
8 common::{AddressSpace, AttributeDirective, DataLinkage, DataType},
9 variable::{
10 GlobalInitializer, InitializerValue, ModuleVariableDirective, NumericLiteral,
11 VariableDirective, VariableModifier,
12 },
13 },
14};
15
16const DATA_TYPE_NAMES: &[&str] = &[
17 "u8", "u16", "u32", "u64", "s8", "s16", "s32", "s64", "f16", "f16x2", "f32", "f64", "b8",
18 "b16", "b32", "b64", "b128", "pred",
19];
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
22enum VariableDirectiveKind {
23 Tex,
24 Shared,
25 Global,
26 Const,
27 Other,
28}
29
30fn is_data_type_directive(name: &str) -> bool {
31 DATA_TYPE_NAMES.iter().any(|candidate| candidate == &name)
32}
33
34fn is_vector_modifier(name: &str) -> bool {
35 let mut chars = name.chars();
36 match (chars.next(), chars.next()) {
37 (Some('v'), Some(digit)) if digit.is_ascii_digit() => chars.all(|ch| ch.is_ascii_digit()),
38 _ => false,
39 }
40}
41
42fn parse_alignment_value(stream: &mut PtxTokenStream) -> Result<u32, PtxParseError> {
43 let (value, value_span) = parse_u64_literal(stream)?;
44 if value > u32::MAX as u64 {
45 return Err(invalid_literal(
46 value_span,
47 "alignment value exceeds u32 range",
48 ));
49 }
50 Ok(value as u32)
51}
52
53fn parse_numeric_string(text: &str, span: Span) -> Result<u128, PtxParseError> {
54 text.parse::<u128>()
55 .map_err(|_| invalid_literal(span, "invalid integer literal"))
56}
57
58impl PtxParser for NumericLiteral {
59 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
60 let negative = stream
61 .consume_if(|token| matches!(token, PtxToken::Minus))
62 .is_some();
63 let positive = stream
64 .consume_if(|token| matches!(token, PtxToken::Plus))
65 .is_some();
66
67 if negative && positive {
68 let (_, span) = stream.peek()?;
69 return Err(invalid_literal(
70 span.clone(),
71 "cannot have both '+' and '-' signs",
72 ));
73 }
74
75 let (token, span_ref) = stream.consume()?;
76 let span = span_ref.clone();
77 match token {
78 PtxToken::DecimalInteger(text) => {
79 let value = parse_numeric_string(text.as_str(), span.clone())?;
80 if negative {
81 if value > (i64::MAX as u128) + 1 {
82 return Err(invalid_literal(span.clone(), "signed integer underflow"));
83 }
84 let signed = -(value as i128);
85 Ok(NumericLiteral::Signed { value: signed as i64, span })
86 } else {
87 if value > u64::MAX as u128 {
88 return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
89 }
90 Ok(NumericLiteral::Unsigned { value: value as u64, span })
91 }
92 }
93 PtxToken::HexInteger(text) => {
94 let stripped = text
95 .strip_prefix("0x")
96 .or_else(|| text.strip_prefix("0X"))
97 .unwrap_or(text.as_str());
98 let value = u128::from_str_radix(stripped, 16)
99 .map_err(|_| invalid_literal(span.clone(), "invalid hex literal"))?;
100 if negative {
101 if value > (i64::MAX as u128) + 1 {
102 return Err(invalid_literal(span.clone(), "signed integer underflow"));
103 }
104 let signed = -(value as i128);
105 Ok(NumericLiteral::Signed { value: signed as i64, span })
106 } else {
107 if value > u64::MAX as u128 {
108 return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
109 }
110 Ok(NumericLiteral::Unsigned { value: value as u64, span })
111 }
112 }
113 PtxToken::BinaryInteger(text) => {
114 let stripped = text
115 .strip_prefix("0b")
116 .or_else(|| text.strip_prefix("0B"))
117 .unwrap_or(text.as_str());
118 let value = u128::from_str_radix(stripped, 2)
119 .map_err(|_| invalid_literal(span.clone(), "invalid binary literal"))?;
120 if negative {
121 if value > (i64::MAX as u128) + 1 {
122 return Err(invalid_literal(span.clone(), "signed integer underflow"));
123 }
124 let signed = -(value as i128);
125 Ok(NumericLiteral::Signed { value: signed as i64, span })
126 } else {
127 if value > u64::MAX as u128 {
128 return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
129 }
130 Ok(NumericLiteral::Unsigned { value: value as u64, span })
131 }
132 }
133 PtxToken::OctalInteger(text) => {
134 let stripped = &text.as_str()[1..];
135 let value = u128::from_str_radix(stripped, 8)
136 .map_err(|_| invalid_literal(span.clone(), "invalid octal literal"))?;
137 if negative {
138 if value > (i64::MAX as u128) + 1 {
139 return Err(invalid_literal(span.clone(), "signed integer underflow"));
140 }
141 let signed = -(value as i128);
142 Ok(NumericLiteral::Signed { value: signed as i64, span })
143 } else {
144 if value > u64::MAX as u128 {
145 return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
146 }
147 Ok(NumericLiteral::Unsigned { value: value as u64, span })
148 }
149 }
150 PtxToken::Float(text) | PtxToken::FloatExponent(text) => {
151 let mut value = text
152 .parse::<f64>()
153 .map_err(|_| invalid_literal(span.clone(), "invalid floating-point literal"))?;
154 if negative {
155 value = -value;
156 }
157 Ok(NumericLiteral::Float64 { value: value.to_bits(), span })
158 }
159 PtxToken::HexFloat(text) => {
160 if text.len() < 3 {
161 return Err(invalid_literal(
162 span.clone(),
163 "invalid hexadecimal float literal",
164 ));
165 }
166 let (prefix, digits) = text.split_at(2);
167 match prefix.to_ascii_lowercase().as_str() {
168 "0f" => {
169 let mut bits = u32::from_str_radix(digits, 16)
170 .map_err(|_| invalid_literal(span.clone(), "invalid float literal"))?;
171 if negative {
172 bits ^= 0x8000_0000;
173 }
174 Ok(NumericLiteral::Float32 { value: bits, span })
175 }
176 "0d" => {
177 let mut bits = u64::from_str_radix(digits, 16)
178 .map_err(|_| invalid_literal(span.clone(), "invalid float literal"))?;
179 if negative {
180 bits ^= 0x8000_0000_0000_0000;
181 }
182 Ok(NumericLiteral::Float64 { value: bits, span })
183 }
184 _ => Err(invalid_literal(
185 span.clone(),
186 "hexadecimal float must start with 0f or 0d",
187 )),
188 }
189 }
190 _ => Err(unexpected_value(
191 span.clone(),
192 &["numeric literal"],
193 format!("{token:?}"),
194 )),
195 }
196 }
197}
198
199impl PtxParser for InitializerValue {
200 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
201 if let Some((token, span)) = stream.peek().ok() {
202 let span = span.clone();
203 match token {
204 PtxToken::StringLiteral(value) => {
205 let value = value.clone();
206 stream.consume()?;
207 return Ok(InitializerValue::StringLiteral { value, span });
208 }
209 PtxToken::Identifier(_) => {
210 let (name, span) = stream.expect_identifier()?;
211 return Ok(InitializerValue::Symbol { name, span: span.clone() });
212 }
213 PtxToken::Plus | PtxToken::Minus => {
214 let literal = NumericLiteral::parse(stream)?;
215 let span = literal.span();
216 return Ok(InitializerValue::Numeric { value: literal, span });
217 }
218 PtxToken::DecimalInteger(_)
219 | PtxToken::HexInteger(_)
220 | PtxToken::BinaryInteger(_)
221 | PtxToken::OctalInteger(_)
222 | PtxToken::Float(_)
223 | PtxToken::FloatExponent(_)
224 | PtxToken::HexFloat(_) => {
225 let literal = NumericLiteral::parse(stream)?;
226 let span = literal.span();
227 return Ok(InitializerValue::Numeric { value: literal, span });
228 }
229 _ => {
230 return Err(unexpected_value(
231 span.clone(),
232 &["numeric literal", "symbol", "string literal"],
233 format!("{token:?}"),
234 ));
235 }
236 }
237 }
238 let span = stream.peek()?.1.clone();
239 Err(unexpected_value(
240 span,
241 &["numeric literal", "symbol", "string literal"],
242 "end of input".to_string(),
243 ))
244 }
245}
246
247impl PtxParser for GlobalInitializer {
248 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
249 let start_span = stream.peek()?.1.clone();
250 if stream
251 .consume_if(|token| matches!(token, PtxToken::LBrace))
252 .is_some()
253 {
254 let mut children = Vec::new();
255 if !stream.check(|token| matches!(token, PtxToken::RBrace)) {
256 loop {
257 let initializer = GlobalInitializer::parse(stream)?;
258 children.push(initializer);
259 if !(stream
260 .consume_if(|token| matches!(token, PtxToken::Comma))
261 .is_some())
262 {
263 break;
264 }
265 }
266 }
267 let (_, end_span) = stream.expect(&PtxToken::RBrace)?;
268 let span = start_span.start..end_span.end;
269 Ok(GlobalInitializer::Aggregate { values: children, span })
270 } else {
271 let value = InitializerValue::parse(stream)?;
272 let span = value.span();
273 Ok(GlobalInitializer::Scalar { value, span })
274 }
275 }
276}
277
278impl PtxParser for VariableModifier {
279 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
280 let (directive, span_ref) = stream.expect_directive()?;
281 let span = span_ref.clone();
282 match directive.as_str() {
283 "align" => {
284 let value = parse_alignment_value(stream)?;
285 Ok(VariableModifier::Alignment { value, span })
286 }
287 "ptr" => Ok(VariableModifier::Ptr { span }),
288 "visible" => Ok(VariableModifier::Linkage {
289 linkage: DataLinkage::Visible { span: span.clone() },
290 span
291 }),
292 "extern" => Ok(VariableModifier::Linkage {
293 linkage: DataLinkage::Extern { span: span.clone() },
294 span
295 }),
296 "weak" => Ok(VariableModifier::Linkage {
297 linkage: DataLinkage::Weak { span: span.clone() },
298 span
299 }),
300 "common" => Ok(VariableModifier::Linkage {
301 linkage: DataLinkage::Common { span: span.clone() },
302 span
303 }),
304 other if is_vector_modifier(other) => {
305 let digits = &other[1..];
306 let value = digits
307 .parse::<u32>()
308 .map_err(|_| invalid_literal(span.clone(), "invalid vector width"))?;
309 Ok(VariableModifier::Vector { value, span })
310 }
311 other => Err(unexpected_value(
312 span.clone(),
313 &[
314 ".align", ".ptr", ".visible", ".extern", ".weak", ".common", ".vN",
315 ],
316 format!(".{other}"),
317 )),
318 }
319 }
320}
321
322impl VariableDirective {
323 fn parse_with_kind(
324 stream: &mut PtxTokenStream,
325 ) -> Result<(VariableDirective, VariableDirectiveKind, Option<Span>), PtxParseError> {
326 let first_span = stream.peek().ok().map(|(_, span)| span.clone());
327
328 let mut address_space: Option<AddressSpace> = None;
329 let mut attributes = Vec::new();
330 let mut modifiers = Vec::new();
331 let mut ty: Option<DataType> = None;
332 let mut array = Vec::new();
333 let mut initializer = None;
334 let mut seen_tex = false;
335 let mut kind = VariableDirectiveKind::Other;
336 let mut kind_span = None;
337
338 loop {
339 let Some((directive, directive_span)) = peek_directive(stream)? else {
340 break;
341 };
342 match directive.as_str() {
343 "tex" => {
344 stream.expect_directive()?;
345 if !seen_tex {
346 seen_tex = true;
347 kind = VariableDirectiveKind::Tex;
348 kind_span = Some(directive_span);
349 }
350 }
351 "global" | "const" | "shared" | "local" | "param" | "reg" => {
352 if address_space.is_some() {
353 return Err(unexpected_value(
354 directive_span.clone(),
355 &["single address space qualifier"],
356 format!(".{directive}"),
357 ));
358 }
359 let space = AddressSpace::parse(stream)?;
360 match space {
361 AddressSpace::Global { .. } => {
362 kind = VariableDirectiveKind::Global;
363 kind_span = Some(directive_span.clone());
364 }
365 AddressSpace::Const { .. } => {
366 kind = VariableDirectiveKind::Const;
367 kind_span = Some(directive_span.clone());
368 }
369 AddressSpace::Shared { .. } => {
370 kind = VariableDirectiveKind::Shared;
371 kind_span = Some(directive_span.clone());
372 }
373 _ => {}
374 }
375 address_space = Some(space);
376 }
377 "managed" | "unified" => {
378 attributes.push(AttributeDirective::parse(stream)?);
379 }
380 "align" | "ptr" | "visible" | "extern" | "weak" | "common" => {
381 modifiers.push(VariableModifier::parse(stream)?);
382 }
383 other if is_vector_modifier(other) => {
384 modifiers.push(VariableModifier::parse(stream)?);
385 }
386 other if is_data_type_directive(other) => {
387 if ty.is_some() {
388 return Err(unexpected_value(
389 directive_span.clone(),
390 &["single data type qualifier"],
391 format!(".{other}"),
392 ));
393 }
394 ty = Some(DataType::parse(stream)?);
395 }
396 _ => break,
397 }
398 }
399
400 let (name, _) = stream.expect_identifier()?;
401
402 loop {
403 if stream
404 .consume_if(|token| matches!(token, PtxToken::LBracket))
405 .is_none()
406 {
407 break;
408 }
409
410 if stream
411 .consume_if(|token| matches!(token, PtxToken::RBracket))
412 .is_some()
413 {
414 array.push(None);
415 continue;
416 }
417
418 let size_span = stream.peek()?.1.clone();
419 let literal = NumericLiteral::parse(stream)?;
420 let size = match literal {
421 NumericLiteral::Unsigned { value, .. } => value,
422 NumericLiteral::Signed { value, .. } if value >= 0 => value as u64,
423 _ => {
424 return Err(invalid_literal(
425 size_span.clone(),
426 "array size must be a non-negative integer",
427 ));
428 }
429 };
430
431 stream.expect(&PtxToken::RBracket)?;
432 array.push(Some(size));
433 }
434
435 if stream
436 .consume_if(|token| matches!(token, PtxToken::Equals))
437 .is_some()
438 {
439 initializer = Some(GlobalInitializer::parse(stream)?);
440 }
441
442 stream.expect(&PtxToken::Semicolon)?;
443
444 let mut final_kind = kind;
445 if seen_tex {
446 final_kind = VariableDirectiveKind::Tex;
447 } else if matches!(final_kind, VariableDirectiveKind::Other) {
448 final_kind = match address_space {
449 Some(AddressSpace::Shared { .. }) => VariableDirectiveKind::Shared,
450 Some(AddressSpace::Global { .. }) => VariableDirectiveKind::Global,
451 Some(AddressSpace::Const { .. }) => VariableDirectiveKind::Const,
452 _ => VariableDirectiveKind::Other,
453 };
454 }
455
456 let end_span = stream.peek().ok().map(|(_, s)| s.clone()).unwrap_or(0..0);
457 let span = first_span.map(|s| s.start..end_span.end).unwrap_or(0..0);
458
459 let directive = VariableDirective {
460 address_space,
461 attributes,
462 ty,
463 modifiers,
464 name,
465 array,
466 initializer,
467 span: span.clone(),
468 };
469
470 Ok((directive, final_kind, kind_span.or(Some(span))))
471 }
472}
473
474impl PtxParser for VariableDirective {
475 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
476 let (directive, _, _) = VariableDirective::parse_with_kind(stream)?;
477 Ok(directive)
478 }
479}
480
481impl PtxParser for ModuleVariableDirective {
482 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
483 let (directive, kind, span_opt) = VariableDirective::parse_with_kind(stream)?;
484 let span = span_opt.unwrap_or(0..0);
485 match kind {
486 VariableDirectiveKind::Tex => Ok(ModuleVariableDirective::Tex {
487 directive,
488 span
489 }),
490 VariableDirectiveKind::Shared => Ok(ModuleVariableDirective::Shared {
491 directive,
492 span
493 }),
494 VariableDirectiveKind::Global => Ok(ModuleVariableDirective::Global {
495 directive,
496 span
497 }),
498 VariableDirectiveKind::Const => Ok(ModuleVariableDirective::Const {
499 directive,
500 span: span.clone()
501 }),
502 VariableDirectiveKind::Other => Err(unexpected_value(
503 span,
504 &[".tex", ".shared", ".global", ".const"],
505 "variable directive".to_string(),
506 )),
507 }
508 }
509}