1use crate::{
2 lexer::PtxToken,
3 parser::{
4 PtxParseError, PtxParser, PtxTokenStream, Span, common::parse_u64_literal, invalid_literal,
5 peek_directive, unexpected_value,
6 },
7 r#type::{
8 common::{AddressSpace, AttributeDirective, DataLinkage, DataType},
9 variable::{
10 GlobalInitializer, InitializerValue, ModuleVariableDirective, NumericLiteral,
11 VariableDirective, VariableModifier,
12 },
13 },
14};
15
16const DATA_TYPE_NAMES: &[&str] = &[
17 "u8", "u16", "u32", "u64", "s8", "s16", "s32", "s64", "f16", "f16x2", "f32", "f64", "b8",
18 "b16", "b32", "b64", "b128", "pred",
19];
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
22enum VariableDirectiveKind {
23 Tex,
24 Shared,
25 Global,
26 Const,
27 Other,
28}
29
30struct ParsedVariableDirective {
31 directive: VariableDirective,
32 kind: VariableDirectiveKind,
33 leading_span: Option<Span>,
34}
35
36fn is_data_type_directive(name: &str) -> bool {
37 DATA_TYPE_NAMES.iter().any(|candidate| candidate == &name)
38}
39
40fn is_vector_modifier(name: &str) -> bool {
41 let mut chars = name.chars();
42 match (chars.next(), chars.next()) {
43 (Some('v'), Some(digit)) if digit.is_ascii_digit() => chars.all(|ch| ch.is_ascii_digit()),
44 _ => false,
45 }
46}
47
48fn parse_alignment_value(stream: &mut PtxTokenStream) -> Result<u32, PtxParseError> {
49 let (value, value_span) = parse_u64_literal(stream)?;
50 if value > u32::MAX as u64 {
51 return Err(invalid_literal(
52 value_span,
53 "alignment value exceeds u32 range",
54 ));
55 }
56 Ok(value as u32)
57}
58
59fn parse_numeric_string(text: &str, span: Span) -> Result<u128, PtxParseError> {
60 text.parse::<u128>()
61 .map_err(|_| invalid_literal(span, "invalid integer literal"))
62}
63
64impl PtxParser for NumericLiteral {
65 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
66 let negative = stream
67 .consume_if(|token| matches!(token, PtxToken::Minus))
68 .is_some();
69 let positive = stream
70 .consume_if(|token| matches!(token, PtxToken::Plus))
71 .is_some();
72
73 if negative && positive {
74 let (_, span) = stream.peek()?;
75 return Err(invalid_literal(
76 span.clone(),
77 "cannot have both '+' and '-' signs",
78 ));
79 }
80
81 let (token, span_ref) = stream.consume()?;
82 let span = span_ref.clone();
83 match token {
84 PtxToken::DecimalInteger(text) => {
85 let value = parse_numeric_string(text.as_str(), span.clone())?;
86 if negative {
87 if value > (i64::MAX as u128) + 1 {
88 return Err(invalid_literal(span.clone(), "signed integer underflow"));
89 }
90 let signed = -(value as i128);
91 Ok(NumericLiteral::Signed(signed as i64))
92 } else {
93 if value > u64::MAX as u128 {
94 return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
95 }
96 Ok(NumericLiteral::Unsigned(value as u64))
97 }
98 }
99 PtxToken::HexInteger(text) => {
100 let stripped = text
101 .strip_prefix("0x")
102 .or_else(|| text.strip_prefix("0X"))
103 .unwrap_or(text.as_str());
104 let value = u128::from_str_radix(stripped, 16)
105 .map_err(|_| invalid_literal(span.clone(), "invalid hex literal"))?;
106 if negative {
107 if value > (i64::MAX as u128) + 1 {
108 return Err(invalid_literal(span.clone(), "signed integer underflow"));
109 }
110 let signed = -(value as i128);
111 Ok(NumericLiteral::Signed(signed as i64))
112 } else {
113 if value > u64::MAX as u128 {
114 return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
115 }
116 Ok(NumericLiteral::Unsigned(value as u64))
117 }
118 }
119 PtxToken::BinaryInteger(text) => {
120 let stripped = text
121 .strip_prefix("0b")
122 .or_else(|| text.strip_prefix("0B"))
123 .unwrap_or(text.as_str());
124 let value = u128::from_str_radix(stripped, 2)
125 .map_err(|_| invalid_literal(span.clone(), "invalid binary literal"))?;
126 if negative {
127 if value > (i64::MAX as u128) + 1 {
128 return Err(invalid_literal(span.clone(), "signed integer underflow"));
129 }
130 let signed = -(value as i128);
131 Ok(NumericLiteral::Signed(signed as i64))
132 } else {
133 if value > u64::MAX as u128 {
134 return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
135 }
136 Ok(NumericLiteral::Unsigned(value as u64))
137 }
138 }
139 PtxToken::OctalInteger(text) => {
140 let stripped = &text.as_str()[1..];
141 let value = u128::from_str_radix(stripped, 8)
142 .map_err(|_| invalid_literal(span.clone(), "invalid octal literal"))?;
143 if negative {
144 if value > (i64::MAX as u128) + 1 {
145 return Err(invalid_literal(span.clone(), "signed integer underflow"));
146 }
147 let signed = -(value as i128);
148 Ok(NumericLiteral::Signed(signed as i64))
149 } else {
150 if value > u64::MAX as u128 {
151 return Err(invalid_literal(span.clone(), "unsigned integer overflow"));
152 }
153 Ok(NumericLiteral::Unsigned(value as u64))
154 }
155 }
156 PtxToken::Float(text) | PtxToken::FloatExponent(text) => {
157 let mut value = text
158 .parse::<f64>()
159 .map_err(|_| invalid_literal(span.clone(), "invalid floating-point literal"))?;
160 if negative {
161 value = -value;
162 }
163 Ok(NumericLiteral::Float64(value.to_bits()))
164 }
165 PtxToken::HexFloat(text) => {
166 if text.len() < 3 {
167 return Err(invalid_literal(
168 span.clone(),
169 "invalid hexadecimal float literal",
170 ));
171 }
172 let (prefix, digits) = text.split_at(2);
173 match prefix.to_ascii_lowercase().as_str() {
174 "0f" => {
175 let mut bits = u32::from_str_radix(digits, 16)
176 .map_err(|_| invalid_literal(span.clone(), "invalid float literal"))?;
177 if negative {
178 bits ^= 0x8000_0000;
179 }
180 Ok(NumericLiteral::Float32(bits))
181 }
182 "0d" => {
183 let mut bits = u64::from_str_radix(digits, 16)
184 .map_err(|_| invalid_literal(span.clone(), "invalid float literal"))?;
185 if negative {
186 bits ^= 0x8000_0000_0000_0000;
187 }
188 Ok(NumericLiteral::Float64(bits))
189 }
190 _ => Err(invalid_literal(
191 span.clone(),
192 "hexadecimal float must start with 0f or 0d",
193 )),
194 }
195 }
196 _ => Err(unexpected_value(
197 span.clone(),
198 &["numeric literal"],
199 format!("{token:?}"),
200 )),
201 }
202 }
203}
204
205impl PtxParser for InitializerValue {
206 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
207 if let Some((token, span)) = stream.peek().ok() {
208 match token {
209 PtxToken::StringLiteral(value) => {
210 let value = value.clone();
211 stream.consume()?;
212 return Ok(InitializerValue::StringLiteral(value));
213 }
214 PtxToken::Identifier(_) => {
215 let (symbol, _) = stream.expect_identifier()?;
216 return Ok(InitializerValue::Symbol(symbol));
217 }
218 PtxToken::Plus | PtxToken::Minus => {
219 let literal = NumericLiteral::parse(stream)?;
220 return Ok(InitializerValue::Numeric(literal));
221 }
222 PtxToken::DecimalInteger(_)
223 | PtxToken::HexInteger(_)
224 | PtxToken::BinaryInteger(_)
225 | PtxToken::OctalInteger(_)
226 | PtxToken::Float(_)
227 | PtxToken::FloatExponent(_)
228 | PtxToken::HexFloat(_) => {
229 let literal = NumericLiteral::parse(stream)?;
230 return Ok(InitializerValue::Numeric(literal));
231 }
232 _ => {
233 return Err(unexpected_value(
234 span.clone(),
235 &["numeric literal", "symbol", "string literal"],
236 format!("{token:?}"),
237 ));
238 }
239 }
240 }
241 let span = stream.peek()?.1.clone();
242 Err(unexpected_value(
243 span,
244 &["numeric literal", "symbol", "string literal"],
245 "end of input".to_string(),
246 ))
247 }
248}
249
250impl PtxParser for GlobalInitializer {
251 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
252 if stream
253 .consume_if(|token| matches!(token, PtxToken::LBrace))
254 .is_some()
255 {
256 let mut children = Vec::new();
257 if !stream.check(|token| matches!(token, PtxToken::RBrace)) {
258 loop {
259 let initializer = GlobalInitializer::parse(stream)?;
260 children.push(initializer);
261 if !(stream
262 .consume_if(|token| matches!(token, PtxToken::Comma))
263 .is_some())
264 {
265 break;
266 }
267 }
268 }
269 stream.expect(&PtxToken::RBrace)?;
270 Ok(GlobalInitializer::Aggregate(children))
271 } else {
272 let value = InitializerValue::parse(stream)?;
273 Ok(GlobalInitializer::Scalar(value))
274 }
275 }
276}
277
278impl PtxParser for VariableModifier {
279 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
280 let (directive, span_ref) = stream.expect_directive()?;
281 let span = span_ref.clone();
282 match directive.as_str() {
283 "align" => {
284 let value = parse_alignment_value(stream)?;
285 Ok(VariableModifier::Alignment(value))
286 }
287 "ptr" => Ok(VariableModifier::Ptr),
288 "visible" => Ok(VariableModifier::Linkage(DataLinkage::Visible)),
289 "extern" => Ok(VariableModifier::Linkage(DataLinkage::Extern)),
290 "weak" => Ok(VariableModifier::Linkage(DataLinkage::Weak)),
291 "common" => Ok(VariableModifier::Linkage(DataLinkage::Common)),
292 other if is_vector_modifier(other) => {
293 let digits = &other[1..];
294 let value = digits
295 .parse::<u32>()
296 .map_err(|_| invalid_literal(span.clone(), "invalid vector width"))?;
297 Ok(VariableModifier::Vector(value))
298 }
299 other => Err(unexpected_value(
300 span.clone(),
301 &[
302 ".align", ".ptr", ".visible", ".extern", ".weak", ".common", ".vN",
303 ],
304 format!(".{other}"),
305 )),
306 }
307 }
308}
309
310fn parse_variable_directive_internal(
311 stream: &mut PtxTokenStream,
312) -> Result<ParsedVariableDirective, PtxParseError> {
313 let first_span = stream.peek().ok().map(|(_, span)| span.clone());
314
315 let mut address_space: Option<AddressSpace> = None;
316 let mut attributes = Vec::new();
317 let mut modifiers = Vec::new();
318 let mut ty: Option<DataType> = None;
319 let mut array = Vec::new();
320 let mut initializer = None;
321 let mut seen_tex = false;
322 let mut kind = VariableDirectiveKind::Other;
323 let mut kind_span = None;
324
325 loop {
326 let Some((directive, directive_span)) = peek_directive(stream)? else {
327 break;
328 };
329 match directive.as_str() {
330 "tex" => {
331 stream.expect_directive()?;
332 if !seen_tex {
333 seen_tex = true;
334 kind = VariableDirectiveKind::Tex;
335 kind_span = Some(directive_span);
336 }
337 }
338 "global" | "const" | "shared" | "local" | "param" | "reg" => {
339 if address_space.is_some() {
340 return Err(unexpected_value(
341 directive_span.clone(),
342 &["single address space qualifier"],
343 format!(".{directive}"),
344 ));
345 }
346 let space = AddressSpace::parse(stream)?;
347 address_space = Some(space);
348 match space {
349 AddressSpace::Global => {
350 kind = VariableDirectiveKind::Global;
351 kind_span = Some(directive_span);
352 }
353 AddressSpace::Const => {
354 kind = VariableDirectiveKind::Const;
355 kind_span = Some(directive_span);
356 }
357 AddressSpace::Shared => {
358 kind = VariableDirectiveKind::Shared;
359 kind_span = Some(directive_span);
360 }
361 _ => {}
362 }
363 }
364 "managed" | "unified" => {
365 attributes.push(AttributeDirective::parse(stream)?);
366 }
367 "align" | "ptr" | "visible" | "extern" | "weak" | "common" => {
368 modifiers.push(VariableModifier::parse(stream)?);
369 }
370 other if is_vector_modifier(other) => {
371 modifiers.push(VariableModifier::parse(stream)?);
372 }
373 other if is_data_type_directive(other) => {
374 if ty.is_some() {
375 return Err(unexpected_value(
376 directive_span.clone(),
377 &["single data type qualifier"],
378 format!(".{other}"),
379 ));
380 }
381 ty = Some(DataType::parse(stream)?);
382 }
383 _ => break,
384 }
385 }
386
387 let (name, _) = stream.expect_identifier()?;
388
389 loop {
390 if stream
391 .consume_if(|token| matches!(token, PtxToken::LBracket))
392 .is_none()
393 {
394 break;
395 }
396
397 if stream
398 .consume_if(|token| matches!(token, PtxToken::RBracket))
399 .is_some()
400 {
401 array.push(None);
402 continue;
403 }
404
405 let size_span = stream.peek()?.1.clone();
406 let literal = NumericLiteral::parse(stream)?;
407 let size = match literal {
408 NumericLiteral::Unsigned(value) => value,
409 NumericLiteral::Signed(value) if value >= 0 => value as u64,
410 _ => {
411 return Err(invalid_literal(
412 size_span.clone(),
413 "array size must be a non-negative integer",
414 ));
415 }
416 };
417
418 stream.expect(&PtxToken::RBracket)?;
419 array.push(Some(size));
420 }
421
422 if stream
423 .consume_if(|token| matches!(token, PtxToken::Equals))
424 .is_some()
425 {
426 initializer = Some(GlobalInitializer::parse(stream)?);
427 }
428
429 stream.expect(&PtxToken::Semicolon)?;
430
431 let mut final_kind = kind;
432 if seen_tex {
433 final_kind = VariableDirectiveKind::Tex;
434 } else if matches!(final_kind, VariableDirectiveKind::Other) {
435 final_kind = match address_space {
436 Some(AddressSpace::Shared) => VariableDirectiveKind::Shared,
437 Some(AddressSpace::Global) => VariableDirectiveKind::Global,
438 Some(AddressSpace::Const) => VariableDirectiveKind::Const,
439 _ => VariableDirectiveKind::Other,
440 };
441 }
442
443 let directive = VariableDirective {
444 address_space,
445 attributes,
446 ty,
447 modifiers,
448 name,
449 array,
450 initializer,
451 raw: String::new(),
452 };
453
454 Ok(ParsedVariableDirective {
455 directive,
456 kind: final_kind,
457 leading_span: kind_span.or(first_span),
458 })
459}
460
461impl VariableDirective {
462 fn parse_with_kind(
463 stream: &mut PtxTokenStream,
464 ) -> Result<(VariableDirective, VariableDirectiveKind, Option<Span>), PtxParseError> {
465 let parsed = parse_variable_directive_internal(stream)?;
466 Ok((parsed.directive, parsed.kind, parsed.leading_span))
467 }
468}
469
470impl PtxParser for VariableDirective {
471 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
472 let parsed = parse_variable_directive_internal(stream)?;
473 Ok(parsed.directive)
474 }
475}
476
477impl PtxParser for ModuleVariableDirective {
478 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
479 let (directive, kind, span) = VariableDirective::parse_with_kind(stream)?;
480 match kind {
481 VariableDirectiveKind::Tex => Ok(ModuleVariableDirective::Tex(directive)),
482 VariableDirectiveKind::Shared => Ok(ModuleVariableDirective::Shared(directive)),
483 VariableDirectiveKind::Global => Ok(ModuleVariableDirective::Global(directive)),
484 VariableDirectiveKind::Const => Ok(ModuleVariableDirective::Const(directive)),
485 VariableDirectiveKind::Other => Err(unexpected_value(
486 span.unwrap_or(0..0),
487 &[".tex", ".shared", ".global", ".const"],
488 "variable directive".to_string(),
489 )),
490 }
491 }
492}