ptx_parser/parser/
module.rs

1use crate::{
2    lexer::PtxToken,
3    parser::{PtxParseError, PtxParser, PtxTokenStream, unexpected_value},
4    r#type::{
5        common::CodeOrDataLinkage,
6        function::FunctionKernelDirective,
7        module::{
8            AddressSizeDirective, FileDirective, LinkingDirective, Module, ModuleDebugDirective,
9            ModuleDirective, ModuleInfoDirectiveKind, SectionDirective, TargetDirective,
10            VersionDirective,
11        },
12        variable::ModuleVariableDirective,
13    },
14};
15
16fn is_module_directive_start(token: &PtxToken) -> bool {
17    matches!(token, PtxToken::Dot)
18}
19
20fn parse_decimal_u32(
21    stream: &mut PtxTokenStream,
22) -> Result<(u32, std::ops::Range<usize>), PtxParseError> {
23    let (token, span) = stream.consume()?;
24    match token {
25        PtxToken::DecimalInteger(text) => text
26            .parse::<u32>()
27            .map(|value| (value, span.clone()))
28            .map_err(|_| unexpected_value(span.clone(), &["decimal literal"], text.clone())),
29        _ => Err(unexpected_value(
30            span.clone(),
31            &["decimal literal"],
32            format!("{token:?}"),
33        )),
34    }
35}
36
37fn token_to_string(token: &PtxToken) -> String {
38    match token {
39        PtxToken::Dot => ".".into(),
40        PtxToken::Identifier(name) => name.clone(),
41        PtxToken::DecimalInteger(value) => value.clone(),
42        PtxToken::StringLiteral(value) => format!("\"{value}\""),
43        PtxToken::LBrace => "{".into(),
44        PtxToken::RBrace => "}".into(),
45        PtxToken::Comma => ",".into(),
46        PtxToken::Colon => ":".into(),
47        PtxToken::Star => "*".into(),
48        PtxToken::Plus => "+".into(),
49        PtxToken::Minus => "-".into(),
50        PtxToken::Slash => "/".into(),
51        PtxToken::Percent => "%".into(),
52        PtxToken::Equals => "=".into(),
53        other => format!("{other:?}"),
54    }
55}
56
57impl PtxParser for VersionDirective {
58    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
59        let start_pos = stream.position().index;
60        let (token, span) = stream.consume()?;
61        match token {
62            PtxToken::DecimalInteger(text) => {
63                let major = text.parse::<u32>().map_err(|_| {
64                    unexpected_value(span.clone(), &["decimal literal"], text.clone())
65                })?;
66                stream.expect(&PtxToken::Dot)?;
67                let (minor_token, minor_span) = stream.consume()?;
68                let minor = match minor_token {
69                    PtxToken::DecimalInteger(value) => value.parse::<u32>().map_err(|_| {
70                        unexpected_value(minor_span.clone(), &["decimal literal"], value.clone())
71                    })?,
72                    _ => {
73                        return Err(unexpected_value(
74                            minor_span.clone(),
75                            &["decimal literal"],
76                            format!("{minor_token:?}"),
77                        ));
78                    }
79                };
80                let end_pos = stream.position().index;
81                Ok(VersionDirective { major, minor, span: start_pos..end_pos })
82            }
83            PtxToken::Float(value) | PtxToken::FloatExponent(value) => {
84                let mut parts = value.split('.');
85                let major_str = parts.next().unwrap_or("");
86                let minor_part = parts.next().unwrap_or("");
87                if parts.next().is_some() || major_str.is_empty() || minor_part.is_empty() {
88                    return Err(unexpected_value(
89                        span.clone(),
90                        &["major.minor"],
91                        value.clone(),
92                    ));
93                }
94                let major = major_str.parse::<u32>().map_err(|_| {
95                    unexpected_value(span.clone(), &["decimal literal"], value.clone())
96                })?;
97                let minor = minor_part.parse::<u32>().map_err(|_| {
98                    unexpected_value(span.clone(), &["decimal literal"], value.clone())
99                })?;
100                let end_pos = stream.position().index;
101                Ok(VersionDirective { major, minor, span: start_pos..end_pos })
102            }
103            _ => Err(unexpected_value(
104                span.clone(),
105                &["decimal literal"],
106                format!("{token:?}"),
107            )),
108        }
109    }
110}
111
112impl PtxParser for TargetDirective {
113    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
114        let start_pos = stream.position().index;
115        let mut entries = Vec::new();
116        loop {
117            let next = stream.peek();
118            let Ok((token, _span)) = next else {
119                break;
120            };
121            match token {
122                PtxToken::Identifier(name) => {
123                    entries.push(name.clone());
124                    stream.consume()?;
125                }
126                PtxToken::Dot => {
127                    stream.consume()?;
128                    let (name, _) = stream.expect_identifier()?;
129                    entries.push(format!(".{name}"));
130                }
131                _ => break,
132            }
133            if stream
134                .consume_if(|token| matches!(token, PtxToken::Comma))
135                .is_none()
136            {
137                break;
138            }
139        }
140        if entries.is_empty() {
141            let span = stream.peek().map(|(_, span)| span.clone()).unwrap_or(0..0);
142            return Err(unexpected_value(
143                span,
144                &["sm arch or target modifier"],
145                "".to_string(),
146            ));
147        }
148        let end_pos = stream.position().index;
149        Ok(TargetDirective {
150            entries,
151            span: start_pos..end_pos,
152        })
153    }
154}
155
156impl PtxParser for AddressSizeDirective {
157    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
158        let start_pos = stream.position().index;
159        let (size, _) = parse_decimal_u32(stream)?;
160        let end_pos = stream.position().index;
161        Ok(AddressSizeDirective { size, span: start_pos..end_pos })
162    }
163}
164
165impl PtxParser for ModuleInfoDirectiveKind {
166    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
167        let start_pos = stream.position().index;
168        let (directive, span) = stream.expect_directive()?;
169        match directive.as_str() {
170            "version" => {
171                let directive = VersionDirective::parse(stream)?;
172                let end_pos = stream.position().index;
173                Ok(ModuleInfoDirectiveKind::Version { directive, span: start_pos..end_pos })
174            }
175            "target" => {
176                let directive = TargetDirective::parse(stream)?;
177                let end_pos = stream.position().index;
178                Ok(ModuleInfoDirectiveKind::Target { directive, span: start_pos..end_pos })
179            }
180            "address_size" => {
181                let directive = AddressSizeDirective::parse(stream)?;
182                let end_pos = stream.position().index;
183                Ok(ModuleInfoDirectiveKind::AddressSize { directive, span: start_pos..end_pos })
184            }
185            other => Err(unexpected_value(
186                span,
187                &[".version", ".target", ".address_size"],
188                format!(".{other}"),
189            )),
190        }
191    }
192}
193
194impl PtxParser for FileDirective {
195    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
196        let start_pos = stream.position().index;
197        let (index, _) = parse_decimal_u32(stream)?;
198        let (token, span) = stream.consume()?;
199        let path = match token {
200            PtxToken::StringLiteral(content) => content.clone(),
201            _ => {
202                return Err(unexpected_value(
203                    span.clone(),
204                    &["string literal"],
205                    format!("{token:?}"),
206                ));
207            }
208        };
209        let end_pos = stream.position().index;
210        Ok(FileDirective { index, path, span: start_pos..end_pos })
211    }
212}
213
214impl PtxParser for SectionDirective {
215    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
216        let start_pos = stream.position().index;
217        let (token, span) = stream.consume()?;
218        let name = match token {
219            PtxToken::Identifier(value) => value.clone(),
220            PtxToken::Dot => {
221                let (value, _) = stream.expect_identifier()?;
222                format!(".{value}")
223            }
224            _ => {
225                return Err(unexpected_value(
226                    span.clone(),
227                    &["section name"],
228                    format!("{token:?}"),
229                ));
230            }
231        };
232
233        let mut attributes = Vec::new();
234        loop {
235            let next = stream.peek();
236            let Ok((token, _)) = next else { break };
237            if is_module_directive_start(token) || matches!(token, PtxToken::Semicolon) {
238                break;
239            }
240            let (tok, _) = stream.consume()?;
241            attributes.push(token_to_string(tok));
242        }
243
244        let end_pos = stream.position().index;
245        Ok(SectionDirective { name, attributes, span: start_pos..end_pos })
246    }
247}
248
249impl PtxParser for ModuleDebugDirective {
250    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
251        let start_pos = stream.position().index;
252        let (directive, span) = stream.expect_directive()?;
253        match directive.as_str() {
254            "file" => {
255                let directive = FileDirective::parse(stream)?;
256                let end_pos = stream.position().index;
257                Ok(ModuleDebugDirective::File { directive, span: start_pos..end_pos })
258            }
259            "section" => {
260                let directive = SectionDirective::parse(stream)?;
261                let end_pos = stream.position().index;
262                Ok(ModuleDebugDirective::Section { directive, span: start_pos..end_pos })
263            }
264            other => Err(unexpected_value(
265                span,
266                &[".file", ".section"],
267                format!(".{other}"),
268            )),
269        }
270    }
271}
272
273impl PtxParser for LinkingDirective {
274    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
275        let start_pos = stream.position().index;
276        let linkage = CodeOrDataLinkage::parse(stream)?;
277        let mut prototype = String::new();
278        loop {
279            let next = stream.peek();
280            let Ok((token, _span)) = next else { break };
281            if is_module_directive_start(token) {
282                break;
283            }
284            match token {
285                PtxToken::Semicolon => {
286                    stream.consume()?;
287                    break;
288                }
289                _ => {
290                    let (tok, _) = stream.consume()?;
291                    if !prototype.is_empty() {
292                        prototype.push(' ');
293                    }
294                    prototype.push_str(&token_to_string(tok));
295                }
296            }
297        }
298        let end_pos = stream.position().index;
299        Ok(LinkingDirective {
300            kind: linkage,
301            prototype,
302            span: start_pos..end_pos,
303        })
304    }
305}
306
307impl PtxParser for ModuleDirective {
308    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
309        let position = stream.position();
310        let start_pos = position.index;
311
312        if let Ok(info) = ModuleInfoDirectiveKind::parse(stream) {
313            let end_pos = stream.position().index;
314            return Ok(ModuleDirective::ModuleInfo { directive: info, span: start_pos..end_pos });
315        }
316        stream.set_position(position);
317
318        if let Ok(debug) = ModuleDebugDirective::parse(stream) {
319            let end_pos = stream.position().index;
320            return Ok(ModuleDirective::Debug { directive: debug, span: start_pos..end_pos });
321        }
322        stream.set_position(position);
323
324        if let Ok(function) = FunctionKernelDirective::parse(stream) {
325            let end_pos = stream.position().index;
326            return Ok(ModuleDirective::FunctionKernel { directive: function, span: start_pos..end_pos });
327        }
328        stream.set_position(position);
329
330        if let Ok(variable) = ModuleVariableDirective::parse(stream) {
331            let end_pos = stream.position().index;
332            return Ok(ModuleDirective::ModuleVariable { directive: variable, span: start_pos..end_pos });
333        }
334        stream.set_position(position);
335
336        if let Ok(linking) = LinkingDirective::parse(stream) {
337            let end_pos = stream.position().index;
338            return Ok(ModuleDirective::Linking { directive: linking, span: start_pos..end_pos });
339        }
340        stream.set_position(position);
341
342        let span = stream
343            .peek()
344            .map(|(_, span)| span.clone())
345            .unwrap_or(position.index..position.index);
346        Err(unexpected_value(
347            span,
348            &["module directive"],
349            "unrecognised directive".to_string(),
350        ))
351    }
352}
353
354impl PtxParser for Module {
355    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
356        let start_pos = stream.position().index;
357        let mut directives = Vec::new();
358        while !stream.is_at_end() {
359            if stream.is_at_end() {
360                break;
361            }
362            let directive = ModuleDirective::parse(stream)?;
363            directives.push(directive);
364        }
365        let end_pos = stream.position().index;
366        Ok(Module { directives, span: start_pos..end_pos })
367    }
368}