1use crate::{
2 lexer::PtxToken,
3 parser::{PtxParseError, PtxParser, PtxTokenStream, unexpected_value},
4 r#type::{
5 common::CodeOrDataLinkage,
6 function::FunctionKernelDirective,
7 module::{
8 AddressSizeDirective, FileDirective, LinkingDirective, Module, ModuleDebugDirective,
9 ModuleDirective, ModuleInfoDirectiveKind, SectionDirective, TargetDirective,
10 VersionDirective,
11 },
12 variable::ModuleVariableDirective,
13 },
14};
15
16fn is_module_directive_start(token: &PtxToken) -> bool {
17 matches!(token, PtxToken::Dot)
18}
19
20fn parse_decimal_u32(
21 stream: &mut PtxTokenStream,
22) -> Result<(u32, std::ops::Range<usize>), PtxParseError> {
23 let (token, span) = stream.consume()?;
24 match token {
25 PtxToken::DecimalInteger(text) => text
26 .parse::<u32>()
27 .map(|value| (value, span.clone()))
28 .map_err(|_| unexpected_value(span.clone(), &["decimal literal"], text.clone())),
29 _ => Err(unexpected_value(
30 span.clone(),
31 &["decimal literal"],
32 format!("{token:?}"),
33 )),
34 }
35}
36
37fn token_to_string(token: &PtxToken) -> String {
38 match token {
39 PtxToken::Dot => ".".into(),
40 PtxToken::Identifier(name) => name.clone(),
41 PtxToken::DecimalInteger(value) => value.clone(),
42 PtxToken::StringLiteral(value) => format!("\"{value}\""),
43 PtxToken::LBrace => "{".into(),
44 PtxToken::RBrace => "}".into(),
45 PtxToken::Comma => ",".into(),
46 PtxToken::Colon => ":".into(),
47 PtxToken::Star => "*".into(),
48 PtxToken::Plus => "+".into(),
49 PtxToken::Minus => "-".into(),
50 PtxToken::Slash => "/".into(),
51 PtxToken::Percent => "%".into(),
52 PtxToken::Equals => "=".into(),
53 other => format!("{other:?}"),
54 }
55}
56
57impl PtxParser for VersionDirective {
58 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
59 let (token, span) = stream.consume()?;
60 match token {
61 PtxToken::DecimalInteger(text) => {
62 let major = text.parse::<u32>().map_err(|_| {
63 unexpected_value(span.clone(), &["decimal literal"], text.clone())
64 })?;
65 stream.expect(&PtxToken::Dot)?;
66 let (minor_token, minor_span) = stream.consume()?;
67 let minor = match minor_token {
68 PtxToken::DecimalInteger(value) => value.parse::<u32>().map_err(|_| {
69 unexpected_value(minor_span.clone(), &["decimal literal"], value.clone())
70 })?,
71 _ => {
72 return Err(unexpected_value(
73 minor_span.clone(),
74 &["decimal literal"],
75 format!("{minor_token:?}"),
76 ));
77 }
78 };
79 Ok(VersionDirective { major, minor })
80 }
81 PtxToken::Float(value) | PtxToken::FloatExponent(value) => {
82 let mut parts = value.split('.');
83 let major_str = parts.next().unwrap_or("");
84 let minor_part = parts.next().unwrap_or("");
85 if parts.next().is_some() || major_str.is_empty() || minor_part.is_empty() {
86 return Err(unexpected_value(
87 span.clone(),
88 &["major.minor"],
89 value.clone(),
90 ));
91 }
92 let major = major_str.parse::<u32>().map_err(|_| {
93 unexpected_value(span.clone(), &["decimal literal"], value.clone())
94 })?;
95 let minor = minor_part.parse::<u32>().map_err(|_| {
96 unexpected_value(span.clone(), &["decimal literal"], value.clone())
97 })?;
98 Ok(VersionDirective { major, minor })
99 }
100 _ => Err(unexpected_value(
101 span.clone(),
102 &["decimal literal"],
103 format!("{token:?}"),
104 )),
105 }
106 }
107}
108
109impl PtxParser for TargetDirective {
110 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
111 let mut entries = Vec::new();
112 loop {
113 let next = stream.peek();
114 let Ok((token, _span)) = next else {
115 break;
116 };
117 match token {
118 PtxToken::Identifier(name) => {
119 entries.push(name.clone());
120 stream.consume()?;
121 }
122 PtxToken::Dot => {
123 stream.consume()?;
124 let (name, _) = stream.expect_identifier()?;
125 entries.push(format!(".{name}"));
126 }
127 _ => break,
128 }
129 if stream
130 .consume_if(|token| matches!(token, PtxToken::Comma))
131 .is_none()
132 {
133 break;
134 }
135 }
136 if entries.is_empty() {
137 let span = stream.peek().map(|(_, span)| span.clone()).unwrap_or(0..0);
138 return Err(unexpected_value(
139 span,
140 &["sm arch or target modifier"],
141 "".to_string(),
142 ));
143 }
144 Ok(TargetDirective {
145 entries: entries.clone(),
146 raw: entries.join(", "),
147 })
148 }
149}
150
151impl PtxParser for AddressSizeDirective {
152 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
153 let (size, _) = parse_decimal_u32(stream)?;
154 Ok(AddressSizeDirective { size })
155 }
156}
157
158impl PtxParser for ModuleInfoDirectiveKind {
159 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
160 let (directive, span) = stream.expect_directive()?;
161 match directive.as_str() {
162 "version" => Ok(ModuleInfoDirectiveKind::Version(VersionDirective::parse(
163 stream,
164 )?)),
165 "target" => Ok(ModuleInfoDirectiveKind::Target(TargetDirective::parse(
166 stream,
167 )?)),
168 "address_size" => Ok(ModuleInfoDirectiveKind::AddressSize(
169 AddressSizeDirective::parse(stream)?,
170 )),
171 other => Err(unexpected_value(
172 span,
173 &[".version", ".target", ".address_size"],
174 format!(".{other}"),
175 )),
176 }
177 }
178}
179
180impl PtxParser for FileDirective {
181 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
182 let (index, _) = parse_decimal_u32(stream)?;
183 let (token, span) = stream.consume()?;
184 let path = match token {
185 PtxToken::StringLiteral(content) => content.clone(),
186 _ => {
187 return Err(unexpected_value(
188 span.clone(),
189 &["string literal"],
190 format!("{token:?}"),
191 ));
192 }
193 };
194 Ok(FileDirective { index, path })
195 }
196}
197
198impl PtxParser for SectionDirective {
199 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
200 let (token, span) = stream.consume()?;
201 let name = match token {
202 PtxToken::Identifier(value) => value.clone(),
203 PtxToken::Dot => {
204 let (value, _) = stream.expect_identifier()?;
205 format!(".{value}")
206 }
207 _ => {
208 return Err(unexpected_value(
209 span.clone(),
210 &["section name"],
211 format!("{token:?}"),
212 ));
213 }
214 };
215
216 let mut attributes = Vec::new();
217 loop {
218 let next = stream.peek();
219 let Ok((token, _)) = next else { break };
220 if is_module_directive_start(token) || matches!(token, PtxToken::Semicolon) {
221 break;
222 }
223 let (tok, _) = stream.consume()?;
224 attributes.push(token_to_string(tok));
225 }
226
227 Ok(SectionDirective { name, attributes })
228 }
229}
230
231impl PtxParser for ModuleDebugDirective {
232 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
233 let (directive, span) = stream.expect_directive()?;
234 match directive.as_str() {
235 "file" => Ok(ModuleDebugDirective::File(FileDirective::parse(stream)?)),
236 "section" => Ok(ModuleDebugDirective::Section(SectionDirective::parse(
237 stream,
238 )?)),
239 other => Err(unexpected_value(
240 span,
241 &[".file", ".section"],
242 format!(".{other}"),
243 )),
244 }
245 }
246}
247
248impl PtxParser for LinkingDirective {
249 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
250 let linkage = CodeOrDataLinkage::parse(stream)?;
251 let mut prototype = String::new();
252 loop {
253 let next = stream.peek();
254 let Ok((token, _span)) = next else { break };
255 if is_module_directive_start(token) {
256 break;
257 }
258 match token {
259 PtxToken::Semicolon => {
260 stream.consume()?;
261 break;
262 }
263 _ => {
264 let (tok, _) = stream.consume()?;
265 if !prototype.is_empty() {
266 prototype.push(' ');
267 }
268 prototype.push_str(&token_to_string(tok));
269 }
270 }
271 }
272 Ok(LinkingDirective {
273 kind: linkage,
274 prototype: prototype.clone(),
275 raw: prototype,
276 })
277 }
278}
279
280impl PtxParser for ModuleDirective {
281 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
282 let position = stream.position();
283
284 if let Ok(info) = ModuleInfoDirectiveKind::parse(stream) {
285 return Ok(ModuleDirective::ModuleInfo(info));
286 }
287 stream.set_position(position);
288
289 if let Ok(debug) = ModuleDebugDirective::parse(stream) {
290 return Ok(ModuleDirective::Debug(debug));
291 }
292 stream.set_position(position);
293
294 if let Ok(function) = FunctionKernelDirective::parse(stream) {
295 return Ok(ModuleDirective::FunctionKernel(function));
296 }
297 stream.set_position(position);
298
299 if let Ok(variable) = ModuleVariableDirective::parse(stream) {
300 return Ok(ModuleDirective::ModuleVariable(variable));
301 }
302 stream.set_position(position);
303
304 if let Ok(linking) = LinkingDirective::parse(stream) {
305 return Ok(ModuleDirective::Linking(linking));
306 }
307 stream.set_position(position);
308
309 let span = stream
310 .peek()
311 .map(|(_, span)| span.clone())
312 .unwrap_or(position.index..position.index);
313 Err(unexpected_value(
314 span,
315 &["module directive"],
316 "unrecognised directive".to_string(),
317 ))
318 }
319}
320
321impl PtxParser for Module {
322 fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
323 let mut directives = Vec::new();
324 while !stream.is_at_end() {
325 if stream.is_at_end() {
326 break;
327 }
328 let directive = ModuleDirective::parse(stream)?;
329 directives.push(directive);
330 }
331 Ok(Module { directives })
332 }
333}