1use crate::{
2 lexer::PtxToken,
3 parser::{PtxParseError, PtxParser, PtxTokenStream, unexpected_value},
4 r#type::{
5 common::CodeOrDataLinkage,
6 function::FunctionKernelDirective,
7 module::{
8 AddressSizeDirective, FileDirective, LinkingDirective, Module, ModuleDebugDirective,
9 ModuleDirective, ModuleInfoDirectiveKind, SectionDirective, TargetDirective,
10 VersionDirective,
11 },
12 variable::ModuleVariableDirective,
13 },
14};
15
16fn is_module_directive_start(token: &PtxToken) -> bool {
17 matches!(token, PtxToken::Dot)
18}
19
20fn parse_decimal_u32(
21 stream: &mut PtxTokenStream,
22) -> Result<(u32, std::ops::Range<usize>), PtxParseError> {
23 let (token, span) = stream.consume()?;
24 match token {
25 PtxToken::DecimalInteger(text) => text
26 .parse::<u32>()
27 .map(|value| (value, span.clone()))
28 .map_err(|_| unexpected_value(span.clone(), &["decimal literal"], text.clone())),
29 _ => Err(unexpected_value(
30 span.clone(),
31 &["decimal literal"],
32 format!("{token:?}"),
33 )),
34 }
35}
36
37fn token_to_string(token: &PtxToken) -> String {
38 match token {
39 PtxToken::Dot => ".".into(),
40 PtxToken::Identifier(name) => name.clone(),
41 PtxToken::DecimalInteger(value) => value.clone(),
42 PtxToken::StringLiteral(value) => format!("\"{value}\""),
43 PtxToken::LBrace => "{".into(),
44 PtxToken::RBrace => "}".into(),
45 PtxToken::Comma => ",".into(),
46 PtxToken::Colon => ":".into(),
47 PtxToken::Star => "*".into(),
48 PtxToken::Plus => "+".into(),
49 PtxToken::Minus => "-".into(),
50 PtxToken::Slash => "/".into(),
51 PtxToken::Percent => "%".into(),
52 PtxToken::Equals => "=".into(),
53 other => format!("{other:?}"),
54 }
55}
56
57fn classify_next_directive(stream: &PtxTokenStream) -> Option<String> {
58 let mut iter = stream.remaining().iter();
59 while let Some((token, _)) = iter.next() {
60 match token {
61 PtxToken::Dot => {
62 if let Some((PtxToken::Identifier(name), _)) = iter.next() {
64 match name.as_str() {
65 "visible" | "extern" | "weak" | "common" => continue,
66 other => return Some(other.to_string()),
67 }
68 } else {
69 return None;
70 }
71 }
72 _ => return None,
73 }
74 }
75 None
76}
77
78impl PtxParser for VersionDirective {
79 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
80 let (token, span) = stream.consume()?;
81 match token {
82 PtxToken::DecimalInteger(text) => {
83 let major = text.parse::<u32>().map_err(|_| {
84 unexpected_value(span.clone(), &["decimal literal"], text.clone())
85 })?;
86 stream.expect(&PtxToken::Dot)?;
87 let (minor_token, minor_span) = stream.consume()?;
88 let minor = match minor_token {
89 PtxToken::DecimalInteger(value) => value.parse::<u32>().map_err(|_| {
90 unexpected_value(minor_span.clone(), &["decimal literal"], value.clone())
91 })?,
92 _ => {
93 return Err(unexpected_value(
94 minor_span.clone(),
95 &["decimal literal"],
96 format!("{minor_token:?}"),
97 ));
98 }
99 };
100 Ok(VersionDirective { major, minor })
101 }
102 PtxToken::Float(value) | PtxToken::FloatExponent(value) => {
103 let mut parts = value.split('.');
104 let major_str = parts.next().unwrap_or("");
105 let minor_part = parts.next().unwrap_or("");
106 if parts.next().is_some() || major_str.is_empty() || minor_part.is_empty() {
107 return Err(unexpected_value(
108 span.clone(),
109 &["major.minor"],
110 value.clone(),
111 ));
112 }
113 let major = major_str.parse::<u32>().map_err(|_| {
114 unexpected_value(span.clone(), &["decimal literal"], value.clone())
115 })?;
116 let minor = minor_part.parse::<u32>().map_err(|_| {
117 unexpected_value(span.clone(), &["decimal literal"], value.clone())
118 })?;
119 Ok(VersionDirective { major, minor })
120 }
121 _ => Err(unexpected_value(
122 span.clone(),
123 &["decimal literal"],
124 format!("{token:?}"),
125 )),
126 }
127 }
128}
129
130impl PtxParser for TargetDirective {
131 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
132 let mut entries = Vec::new();
133 loop {
134 let next = stream.peek();
135 let Ok((token, _span)) = next else {
136 break;
137 };
138 match token {
139 PtxToken::Identifier(name) => {
140 entries.push(name.clone());
141 stream.consume()?;
142 }
143 PtxToken::Dot => {
144 stream.consume()?;
145 let (name, _) = stream.expect_identifier()?;
146 entries.push(format!(".{name}"));
147 }
148 _ => break,
149 }
150 if stream
151 .consume_if(|token| matches!(token, PtxToken::Comma))
152 .is_none()
153 {
154 break;
155 }
156 }
157 if entries.is_empty() {
158 let span = stream.peek().map(|(_, span)| span.clone()).unwrap_or(0..0);
159 return Err(unexpected_value(
160 span,
161 &["sm arch or target modifier"],
162 "".to_string(),
163 ));
164 }
165 Ok(TargetDirective {
166 entries: entries.clone(),
167 raw: entries.join(", "),
168 })
169 }
170}
171
172impl PtxParser for AddressSizeDirective {
173 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
174 let (size, _) = parse_decimal_u32(stream)?;
175 Ok(AddressSizeDirective { size })
176 }
177}
178
179impl PtxParser for ModuleInfoDirectiveKind {
180 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
181 let (directive, span) = stream.expect_directive()?;
182 match directive.as_str() {
183 "version" => Ok(ModuleInfoDirectiveKind::Version(VersionDirective::parse(
184 stream,
185 )?)),
186 "target" => Ok(ModuleInfoDirectiveKind::Target(TargetDirective::parse(
187 stream,
188 )?)),
189 "address_size" => Ok(ModuleInfoDirectiveKind::AddressSize(
190 AddressSizeDirective::parse(stream)?,
191 )),
192 other => Err(unexpected_value(
193 span,
194 &[".version", ".target", ".address_size"],
195 format!(".{other}"),
196 )),
197 }
198 }
199}
200
201impl PtxParser for FileDirective {
202 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
203 let (index, _) = parse_decimal_u32(stream)?;
204 let (token, span) = stream.consume()?;
205 let path = match token {
206 PtxToken::StringLiteral(content) => content.clone(),
207 _ => {
208 return Err(unexpected_value(
209 span.clone(),
210 &["string literal"],
211 format!("{token:?}"),
212 ));
213 }
214 };
215 Ok(FileDirective { index, path })
216 }
217}
218
219impl PtxParser for SectionDirective {
220 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
221 let (token, span) = stream.consume()?;
222 let name = match token {
223 PtxToken::Identifier(value) => value.clone(),
224 PtxToken::Dot => {
225 let (value, _) = stream.expect_identifier()?;
226 format!(".{value}")
227 }
228 _ => {
229 return Err(unexpected_value(
230 span.clone(),
231 &["section name"],
232 format!("{token:?}"),
233 ));
234 }
235 };
236
237 let mut attributes = Vec::new();
238 loop {
239 let next = stream.peek();
240 let Ok((token, _)) = next else { break };
241 if is_module_directive_start(token) || matches!(token, PtxToken::Semicolon) {
242 break;
243 }
244 let (tok, _) = stream.consume()?;
245 attributes.push(token_to_string(tok));
246 }
247
248 Ok(SectionDirective { name, attributes })
249 }
250}
251
252impl PtxParser for ModuleDebugDirective {
253 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
254 let (directive, span) = stream.expect_directive()?;
255 match directive.as_str() {
256 "file" => Ok(ModuleDebugDirective::File(FileDirective::parse(stream)?)),
257 "section" => Ok(ModuleDebugDirective::Section(SectionDirective::parse(
258 stream,
259 )?)),
260 other => Err(unexpected_value(
261 span,
262 &[".file", ".section"],
263 format!(".{other}"),
264 )),
265 }
266 }
267}
268
269impl PtxParser for LinkingDirective {
270 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
271 let (linkage, _) = parse_code_or_data_linkage(stream)?;
272 let mut prototype = String::new();
273 loop {
274 let next = stream.peek();
275 let Ok((token, _span)) = next else { break };
276 if is_module_directive_start(token) {
277 break;
278 }
279 match token {
280 PtxToken::Semicolon => {
281 stream.consume()?;
282 break;
283 }
284 _ => {
285 let (tok, _) = stream.consume()?;
286 if !prototype.is_empty() {
287 prototype.push(' ');
288 }
289 prototype.push_str(&token_to_string(tok));
290 }
291 }
292 }
293 Ok(LinkingDirective {
294 kind: linkage,
295 prototype: prototype.clone(),
296 raw: prototype,
297 })
298 }
299}
300
301fn parse_code_or_data_linkage(
302 stream: &mut PtxTokenStream,
303) -> Result<(CodeOrDataLinkage, std::ops::Range<usize>), PtxParseError> {
304 let (directive, span) = stream.expect_directive()?;
305 match directive.as_str() {
306 "visible" => Ok((CodeOrDataLinkage::Visible, span.clone())),
307 "extern" => Ok((CodeOrDataLinkage::Extern, span.clone())),
308 "weak" => Ok((CodeOrDataLinkage::Weak, span.clone())),
309 "common" => Ok((CodeOrDataLinkage::Common, span.clone())),
310 _ => Err(unexpected_value(
311 span.clone(),
312 &[".visible", ".extern", ".weak", ".common"],
313 format!(".{directive}"),
314 )),
315 }
316}
317
318impl PtxParser for ModuleDirective {
319 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
320 let Some(name) = classify_next_directive(stream) else {
321 let span = stream.peek()?.1.clone();
322 return Err(unexpected_value(span, &["directive"], "".to_string()));
323 };
324
325 match name.as_str() {
326 "version" | "target" | "address_size" => {
327 let info = ModuleInfoDirectiveKind::parse(stream)?;
328 Ok(ModuleDirective::ModuleInfo(info))
329 }
330 "file" | "section" => {
331 let debug = ModuleDebugDirective::parse(stream)?;
332 Ok(ModuleDirective::Debug(debug))
333 }
334 "entry" | "func" | "alias" => {
335 let function = FunctionKernelDirective::parse(stream)?;
336 Ok(ModuleDirective::FunctionKernel(function))
337 }
338 "global" | "const" | "shared" | "tex" => {
339 let variable = ModuleVariableDirective::parse(stream)?;
340 Ok(ModuleDirective::ModuleVariable(variable))
341 }
342 _ => {
343 let position = stream.position();
344 if let Ok(variable) = ModuleVariableDirective::parse(stream) {
345 return Ok(ModuleDirective::ModuleVariable(variable));
346 }
347 stream.set_position(position);
348 if let Ok(function) = FunctionKernelDirective::parse(stream) {
349 return Ok(ModuleDirective::FunctionKernel(function));
350 }
351 stream.set_position(position);
352 let span = stream
353 .peek()
354 .map(|(_, span)| span.clone())
355 .unwrap_or(position.index..position.index);
356 Err(unexpected_value(
357 span,
358 &["module directive"],
359 "unrecognised directive".to_string(),
360 ))
361 }
362 }
363 }
364}
365
366impl PtxParser for Module {
367 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
368 let mut directives = Vec::new();
369 while !stream.is_at_end() {
370 if stream.is_at_end() {
371 break;
372 }
373 let directive = ModuleDirective::parse(stream)?;
374 directives.push(directive);
375 }
376 Ok(Module { directives })
377 }
378}