ptx_parser/parser/
module.rs

1use crate::{
2    lexer::PtxToken,
3    parser::{PtxParseError, PtxParser, PtxTokenStream, unexpected_value},
4    r#type::{
5        common::CodeOrDataLinkage,
6        function::FunctionKernelDirective,
7        module::{
8            AddressSizeDirective, FileDirective, LinkingDirective, Module, ModuleDebugDirective,
9            ModuleDirective, ModuleInfoDirectiveKind, SectionDirective, TargetDirective,
10            VersionDirective,
11        },
12        variable::ModuleVariableDirective,
13    },
14};
15
16fn is_module_directive_start(token: &PtxToken) -> bool {
17    matches!(token, PtxToken::Dot)
18}
19
20fn parse_decimal_u32(
21    stream: &mut PtxTokenStream,
22) -> Result<(u32, std::ops::Range<usize>), PtxParseError> {
23    let (token, span) = stream.consume()?;
24    match token {
25        PtxToken::DecimalInteger(text) => text
26            .parse::<u32>()
27            .map(|value| (value, span.clone()))
28            .map_err(|_| unexpected_value(span.clone(), &["decimal literal"], text.clone())),
29        _ => Err(unexpected_value(
30            span.clone(),
31            &["decimal literal"],
32            format!("{token:?}"),
33        )),
34    }
35}
36
37fn token_to_string(token: &PtxToken) -> String {
38    match token {
39        PtxToken::Dot => ".".into(),
40        PtxToken::Identifier(name) => name.clone(),
41        PtxToken::DecimalInteger(value) => value.clone(),
42        PtxToken::StringLiteral(value) => format!("\"{value}\""),
43        PtxToken::LBrace => "{".into(),
44        PtxToken::RBrace => "}".into(),
45        PtxToken::Comma => ",".into(),
46        PtxToken::Colon => ":".into(),
47        PtxToken::Star => "*".into(),
48        PtxToken::Plus => "+".into(),
49        PtxToken::Minus => "-".into(),
50        PtxToken::Slash => "/".into(),
51        PtxToken::Percent => "%".into(),
52        PtxToken::Equals => "=".into(),
53        other => format!("{other:?}"),
54    }
55}
56
57impl PtxParser for VersionDirective {
58    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
59        let (token, span) = stream.consume()?;
60        match token {
61            PtxToken::DecimalInteger(text) => {
62                let major = text.parse::<u32>().map_err(|_| {
63                    unexpected_value(span.clone(), &["decimal literal"], text.clone())
64                })?;
65                stream.expect(&PtxToken::Dot)?;
66                let (minor_token, minor_span) = stream.consume()?;
67                let minor = match minor_token {
68                    PtxToken::DecimalInteger(value) => value.parse::<u32>().map_err(|_| {
69                        unexpected_value(minor_span.clone(), &["decimal literal"], value.clone())
70                    })?,
71                    _ => {
72                        return Err(unexpected_value(
73                            minor_span.clone(),
74                            &["decimal literal"],
75                            format!("{minor_token:?}"),
76                        ));
77                    }
78                };
79                Ok(VersionDirective { major, minor })
80            }
81            PtxToken::Float(value) | PtxToken::FloatExponent(value) => {
82                let mut parts = value.split('.');
83                let major_str = parts.next().unwrap_or("");
84                let minor_part = parts.next().unwrap_or("");
85                if parts.next().is_some() || major_str.is_empty() || minor_part.is_empty() {
86                    return Err(unexpected_value(
87                        span.clone(),
88                        &["major.minor"],
89                        value.clone(),
90                    ));
91                }
92                let major = major_str.parse::<u32>().map_err(|_| {
93                    unexpected_value(span.clone(), &["decimal literal"], value.clone())
94                })?;
95                let minor = minor_part.parse::<u32>().map_err(|_| {
96                    unexpected_value(span.clone(), &["decimal literal"], value.clone())
97                })?;
98                Ok(VersionDirective { major, minor })
99            }
100            _ => Err(unexpected_value(
101                span.clone(),
102                &["decimal literal"],
103                format!("{token:?}"),
104            )),
105        }
106    }
107}
108
109impl PtxParser for TargetDirective {
110    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
111        let mut entries = Vec::new();
112        loop {
113            let next = stream.peek();
114            let Ok((token, _span)) = next else {
115                break;
116            };
117            match token {
118                PtxToken::Identifier(name) => {
119                    entries.push(name.clone());
120                    stream.consume()?;
121                }
122                PtxToken::Dot => {
123                    stream.consume()?;
124                    let (name, _) = stream.expect_identifier()?;
125                    entries.push(format!(".{name}"));
126                }
127                _ => break,
128            }
129            if stream
130                .consume_if(|token| matches!(token, PtxToken::Comma))
131                .is_none()
132            {
133                break;
134            }
135        }
136        if entries.is_empty() {
137            let span = stream.peek().map(|(_, span)| span.clone()).unwrap_or(0..0);
138            return Err(unexpected_value(
139                span,
140                &["sm arch or target modifier"],
141                "".to_string(),
142            ));
143        }
144        Ok(TargetDirective {
145            entries: entries.clone(),
146            raw: entries.join(", "),
147        })
148    }
149}
150
151impl PtxParser for AddressSizeDirective {
152    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
153        let (size, _) = parse_decimal_u32(stream)?;
154        Ok(AddressSizeDirective { size })
155    }
156}
157
158impl PtxParser for ModuleInfoDirectiveKind {
159    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
160        let (directive, span) = stream.expect_directive()?;
161        match directive.as_str() {
162            "version" => Ok(ModuleInfoDirectiveKind::Version(VersionDirective::parse(
163                stream,
164            )?)),
165            "target" => Ok(ModuleInfoDirectiveKind::Target(TargetDirective::parse(
166                stream,
167            )?)),
168            "address_size" => Ok(ModuleInfoDirectiveKind::AddressSize(
169                AddressSizeDirective::parse(stream)?,
170            )),
171            other => Err(unexpected_value(
172                span,
173                &[".version", ".target", ".address_size"],
174                format!(".{other}"),
175            )),
176        }
177    }
178}
179
180impl PtxParser for FileDirective {
181    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
182        let (index, _) = parse_decimal_u32(stream)?;
183        let (token, span) = stream.consume()?;
184        let path = match token {
185            PtxToken::StringLiteral(content) => content.clone(),
186            _ => {
187                return Err(unexpected_value(
188                    span.clone(),
189                    &["string literal"],
190                    format!("{token:?}"),
191                ));
192            }
193        };
194        Ok(FileDirective { index, path })
195    }
196}
197
198impl PtxParser for SectionDirective {
199    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
200        let (token, span) = stream.consume()?;
201        let name = match token {
202            PtxToken::Identifier(value) => value.clone(),
203            PtxToken::Dot => {
204                let (value, _) = stream.expect_identifier()?;
205                format!(".{value}")
206            }
207            _ => {
208                return Err(unexpected_value(
209                    span.clone(),
210                    &["section name"],
211                    format!("{token:?}"),
212                ));
213            }
214        };
215
216        let mut attributes = Vec::new();
217        loop {
218            let next = stream.peek();
219            let Ok((token, _)) = next else { break };
220            if is_module_directive_start(token) || matches!(token, PtxToken::Semicolon) {
221                break;
222            }
223            let (tok, _) = stream.consume()?;
224            attributes.push(token_to_string(tok));
225        }
226
227        Ok(SectionDirective { name, attributes })
228    }
229}
230
231impl PtxParser for ModuleDebugDirective {
232    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
233        let (directive, span) = stream.expect_directive()?;
234        match directive.as_str() {
235            "file" => Ok(ModuleDebugDirective::File(FileDirective::parse(stream)?)),
236            "section" => Ok(ModuleDebugDirective::Section(SectionDirective::parse(
237                stream,
238            )?)),
239            other => Err(unexpected_value(
240                span,
241                &[".file", ".section"],
242                format!(".{other}"),
243            )),
244        }
245    }
246}
247
248impl PtxParser for LinkingDirective {
249    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
250        let linkage = CodeOrDataLinkage::parse(stream)?;
251        let mut prototype = String::new();
252        loop {
253            let next = stream.peek();
254            let Ok((token, _span)) = next else { break };
255            if is_module_directive_start(token) {
256                break;
257            }
258            match token {
259                PtxToken::Semicolon => {
260                    stream.consume()?;
261                    break;
262                }
263                _ => {
264                    let (tok, _) = stream.consume()?;
265                    if !prototype.is_empty() {
266                        prototype.push(' ');
267                    }
268                    prototype.push_str(&token_to_string(tok));
269                }
270            }
271        }
272        Ok(LinkingDirective {
273            kind: linkage,
274            prototype: prototype.clone(),
275            raw: prototype,
276        })
277    }
278}
279
280impl PtxParser for ModuleDirective {
281    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
282        let position = stream.position();
283
284        if let Ok(info) = ModuleInfoDirectiveKind::parse(stream) {
285            return Ok(ModuleDirective::ModuleInfo(info));
286        }
287        stream.set_position(position);
288
289        if let Ok(debug) = ModuleDebugDirective::parse(stream) {
290            return Ok(ModuleDirective::Debug(debug));
291        }
292        stream.set_position(position);
293
294        if let Ok(function) = FunctionKernelDirective::parse(stream) {
295            return Ok(ModuleDirective::FunctionKernel(function));
296        }
297        stream.set_position(position);
298
299        if let Ok(variable) = ModuleVariableDirective::parse(stream) {
300            return Ok(ModuleDirective::ModuleVariable(variable));
301        }
302        stream.set_position(position);
303
304        if let Ok(linking) = LinkingDirective::parse(stream) {
305            return Ok(ModuleDirective::Linking(linking));
306        }
307        stream.set_position(position);
308
309        let span = stream
310            .peek()
311            .map(|(_, span)| span.clone())
312            .unwrap_or(position.index..position.index);
313        Err(unexpected_value(
314            span,
315            &["module directive"],
316            "unrecognised directive".to_string(),
317        ))
318    }
319}
320
321impl PtxParser for Module {
322    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
323        let mut directives = Vec::new();
324        while !stream.is_at_end() {
325            if stream.is_at_end() {
326                break;
327            }
328            let directive = ModuleDirective::parse(stream)?;
329            directives.push(directive);
330        }
331        Ok(Module { directives })
332    }
333}