ptx_parser/parser/
module.rs

1use crate::{
2    lexer::PtxToken,
3    parser::{PtxParseError, PtxParser, PtxTokenStream, unexpected_value},
4    r#type::{
5        common::CodeOrDataLinkage,
6        function::FunctionKernelDirective,
7        module::{
8            AddressSizeDirective, FileDirective, LinkingDirective, Module, ModuleDebugDirective,
9            ModuleDirective, ModuleInfoDirectiveKind, SectionDirective, TargetDirective,
10            VersionDirective,
11        },
12        variable::ModuleVariableDirective,
13    },
14};
15
16fn is_module_directive_start(token: &PtxToken) -> bool {
17    matches!(token, PtxToken::Dot)
18}
19
20fn parse_decimal_u32(
21    stream: &mut PtxTokenStream,
22) -> Result<(u32, std::ops::Range<usize>), PtxParseError> {
23    let (token, span) = stream.consume()?;
24    match token {
25        PtxToken::DecimalInteger(text) => text
26            .parse::<u32>()
27            .map(|value| (value, span.clone()))
28            .map_err(|_| unexpected_value(span.clone(), &["decimal literal"], text.clone())),
29        _ => Err(unexpected_value(
30            span.clone(),
31            &["decimal literal"],
32            format!("{token:?}"),
33        )),
34    }
35}
36
37fn token_to_string(token: &PtxToken) -> String {
38    match token {
39        PtxToken::Dot => ".".into(),
40        PtxToken::Identifier(name) => name.clone(),
41        PtxToken::DecimalInteger(value) => value.clone(),
42        PtxToken::StringLiteral(value) => format!("\"{value}\""),
43        PtxToken::LBrace => "{".into(),
44        PtxToken::RBrace => "}".into(),
45        PtxToken::Comma => ",".into(),
46        PtxToken::Colon => ":".into(),
47        PtxToken::Star => "*".into(),
48        PtxToken::Plus => "+".into(),
49        PtxToken::Minus => "-".into(),
50        PtxToken::Slash => "/".into(),
51        PtxToken::Percent => "%".into(),
52        PtxToken::Equals => "=".into(),
53        other => format!("{other:?}"),
54    }
55}
56
57fn classify_next_directive(stream: &PtxTokenStream) -> Option<String> {
58    let mut iter = stream.remaining().iter();
59    while let Some((token, _)) = iter.next() {
60        match token {
61            PtxToken::Dot => {
62                // Check if next token is an identifier
63                if let Some((PtxToken::Identifier(name), _)) = iter.next() {
64                    match name.as_str() {
65                        "visible" | "extern" | "weak" | "common" => continue,
66                        other => return Some(other.to_string()),
67                    }
68                } else {
69                    return None;
70                }
71            }
72            _ => return None,
73        }
74    }
75    None
76}
77
78impl PtxParser for VersionDirective {
79    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
80        let (token, span) = stream.consume()?;
81        match token {
82            PtxToken::DecimalInteger(text) => {
83                let major = text.parse::<u32>().map_err(|_| {
84                    unexpected_value(span.clone(), &["decimal literal"], text.clone())
85                })?;
86                stream.expect(&PtxToken::Dot)?;
87                let (minor_token, minor_span) = stream.consume()?;
88                let minor = match minor_token {
89                    PtxToken::DecimalInteger(value) => value.parse::<u32>().map_err(|_| {
90                        unexpected_value(minor_span.clone(), &["decimal literal"], value.clone())
91                    })?,
92                    _ => {
93                        return Err(unexpected_value(
94                            minor_span.clone(),
95                            &["decimal literal"],
96                            format!("{minor_token:?}"),
97                        ));
98                    }
99                };
100                Ok(VersionDirective { major, minor })
101            }
102            PtxToken::Float(value) | PtxToken::FloatExponent(value) => {
103                let mut parts = value.split('.');
104                let major_str = parts.next().unwrap_or("");
105                let minor_part = parts.next().unwrap_or("");
106                if parts.next().is_some() || major_str.is_empty() || minor_part.is_empty() {
107                    return Err(unexpected_value(
108                        span.clone(),
109                        &["major.minor"],
110                        value.clone(),
111                    ));
112                }
113                let major = major_str.parse::<u32>().map_err(|_| {
114                    unexpected_value(span.clone(), &["decimal literal"], value.clone())
115                })?;
116                let minor = minor_part.parse::<u32>().map_err(|_| {
117                    unexpected_value(span.clone(), &["decimal literal"], value.clone())
118                })?;
119                Ok(VersionDirective { major, minor })
120            }
121            _ => Err(unexpected_value(
122                span.clone(),
123                &["decimal literal"],
124                format!("{token:?}"),
125            )),
126        }
127    }
128}
129
130impl PtxParser for TargetDirective {
131    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
132        let mut entries = Vec::new();
133        loop {
134            let next = stream.peek();
135            let Ok((token, _span)) = next else {
136                break;
137            };
138            match token {
139                PtxToken::Identifier(name) => {
140                    entries.push(name.clone());
141                    stream.consume()?;
142                }
143                PtxToken::Dot => {
144                    stream.consume()?;
145                    let (name, _) = stream.expect_identifier()?;
146                    entries.push(format!(".{name}"));
147                }
148                _ => break,
149            }
150            if stream
151                .consume_if(|token| matches!(token, PtxToken::Comma))
152                .is_none()
153            {
154                break;
155            }
156        }
157        if entries.is_empty() {
158            let span = stream.peek().map(|(_, span)| span.clone()).unwrap_or(0..0);
159            return Err(unexpected_value(
160                span,
161                &["sm arch or target modifier"],
162                "".to_string(),
163            ));
164        }
165        Ok(TargetDirective {
166            entries: entries.clone(),
167            raw: entries.join(", "),
168        })
169    }
170}
171
172impl PtxParser for AddressSizeDirective {
173    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
174        let (size, _) = parse_decimal_u32(stream)?;
175        Ok(AddressSizeDirective { size })
176    }
177}
178
179impl PtxParser for ModuleInfoDirectiveKind {
180    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
181        let (directive, span) = stream.expect_directive()?;
182        match directive.as_str() {
183            "version" => Ok(ModuleInfoDirectiveKind::Version(VersionDirective::parse(
184                stream,
185            )?)),
186            "target" => Ok(ModuleInfoDirectiveKind::Target(TargetDirective::parse(
187                stream,
188            )?)),
189            "address_size" => Ok(ModuleInfoDirectiveKind::AddressSize(
190                AddressSizeDirective::parse(stream)?,
191            )),
192            other => Err(unexpected_value(
193                span,
194                &[".version", ".target", ".address_size"],
195                format!(".{other}"),
196            )),
197        }
198    }
199}
200
201impl PtxParser for FileDirective {
202    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
203        let (index, _) = parse_decimal_u32(stream)?;
204        let (token, span) = stream.consume()?;
205        let path = match token {
206            PtxToken::StringLiteral(content) => content.clone(),
207            _ => {
208                return Err(unexpected_value(
209                    span.clone(),
210                    &["string literal"],
211                    format!("{token:?}"),
212                ));
213            }
214        };
215        Ok(FileDirective { index, path })
216    }
217}
218
219impl PtxParser for SectionDirective {
220    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
221        let (token, span) = stream.consume()?;
222        let name = match token {
223            PtxToken::Identifier(value) => value.clone(),
224            PtxToken::Dot => {
225                let (value, _) = stream.expect_identifier()?;
226                format!(".{value}")
227            }
228            _ => {
229                return Err(unexpected_value(
230                    span.clone(),
231                    &["section name"],
232                    format!("{token:?}"),
233                ));
234            }
235        };
236
237        let mut attributes = Vec::new();
238        loop {
239            let next = stream.peek();
240            let Ok((token, _)) = next else { break };
241            if is_module_directive_start(token) || matches!(token, PtxToken::Semicolon) {
242                break;
243            }
244            let (tok, _) = stream.consume()?;
245            attributes.push(token_to_string(tok));
246        }
247
248        Ok(SectionDirective { name, attributes })
249    }
250}
251
252impl PtxParser for ModuleDebugDirective {
253    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
254        let (directive, span) = stream.expect_directive()?;
255        match directive.as_str() {
256            "file" => Ok(ModuleDebugDirective::File(FileDirective::parse(stream)?)),
257            "section" => Ok(ModuleDebugDirective::Section(SectionDirective::parse(
258                stream,
259            )?)),
260            other => Err(unexpected_value(
261                span,
262                &[".file", ".section"],
263                format!(".{other}"),
264            )),
265        }
266    }
267}
268
269impl PtxParser for LinkingDirective {
270    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
271        let (linkage, _) = parse_code_or_data_linkage(stream)?;
272        let mut prototype = String::new();
273        loop {
274            let next = stream.peek();
275            let Ok((token, _span)) = next else { break };
276            if is_module_directive_start(token) {
277                break;
278            }
279            match token {
280                PtxToken::Semicolon => {
281                    stream.consume()?;
282                    break;
283                }
284                _ => {
285                    let (tok, _) = stream.consume()?;
286                    if !prototype.is_empty() {
287                        prototype.push(' ');
288                    }
289                    prototype.push_str(&token_to_string(tok));
290                }
291            }
292        }
293        Ok(LinkingDirective {
294            kind: linkage,
295            prototype: prototype.clone(),
296            raw: prototype,
297        })
298    }
299}
300
301fn parse_code_or_data_linkage(
302    stream: &mut PtxTokenStream,
303) -> Result<(CodeOrDataLinkage, std::ops::Range<usize>), PtxParseError> {
304    let (directive, span) = stream.expect_directive()?;
305    match directive.as_str() {
306        "visible" => Ok((CodeOrDataLinkage::Visible, span.clone())),
307        "extern" => Ok((CodeOrDataLinkage::Extern, span.clone())),
308        "weak" => Ok((CodeOrDataLinkage::Weak, span.clone())),
309        "common" => Ok((CodeOrDataLinkage::Common, span.clone())),
310        _ => Err(unexpected_value(
311            span.clone(),
312            &[".visible", ".extern", ".weak", ".common"],
313            format!(".{directive}"),
314        )),
315    }
316}
317
318impl PtxParser for ModuleDirective {
319    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
320        let Some(name) = classify_next_directive(stream) else {
321            let span = stream.peek()?.1.clone();
322            return Err(unexpected_value(span, &["directive"], "".to_string()));
323        };
324
325        match name.as_str() {
326            "version" | "target" | "address_size" => {
327                let info = ModuleInfoDirectiveKind::parse(stream)?;
328                Ok(ModuleDirective::ModuleInfo(info))
329            }
330            "file" | "section" => {
331                let debug = ModuleDebugDirective::parse(stream)?;
332                Ok(ModuleDirective::Debug(debug))
333            }
334            "entry" | "func" | "alias" => {
335                let function = FunctionKernelDirective::parse(stream)?;
336                Ok(ModuleDirective::FunctionKernel(function))
337            }
338            "global" | "const" | "shared" | "tex" => {
339                let variable = ModuleVariableDirective::parse(stream)?;
340                Ok(ModuleDirective::ModuleVariable(variable))
341            }
342            _ => {
343                let position = stream.position();
344                if let Ok(variable) = ModuleVariableDirective::parse(stream) {
345                    return Ok(ModuleDirective::ModuleVariable(variable));
346                }
347                stream.set_position(position);
348                if let Ok(function) = FunctionKernelDirective::parse(stream) {
349                    return Ok(ModuleDirective::FunctionKernel(function));
350                }
351                stream.set_position(position);
352                let span = stream
353                    .peek()
354                    .map(|(_, span)| span.clone())
355                    .unwrap_or(position.index..position.index);
356                Err(unexpected_value(
357                    span,
358                    &["module directive"],
359                    "unrecognised directive".to_string(),
360                ))
361            }
362        }
363    }
364}
365
366impl PtxParser for Module {
367    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
368        let mut directives = Vec::new();
369        while !stream.is_at_end() {
370            if stream.is_at_end() {
371                break;
372            }
373            let directive = ModuleDirective::parse(stream)?;
374            directives.push(directive);
375        }
376        Ok(Module { directives })
377    }
378}