Skip to main content

shape_ast/parser/
functions.rs

1//! Function and annotation parsing for Shape
2//!
3//! This module handles parsing of:
4//! - Function definitions with parameters and return types
5//! - Function parameters with default values
6//! - Annotations (@warmup, @strategy, etc.)
7
8use crate::ast::{
9    Annotation, BuiltinFunctionDecl, ForeignFunctionDef, FunctionDef, FunctionParameter,
10    NativeAbiBinding,
11};
12use crate::error::Result;
13use pest::iterators::Pair;
14
15use super::expressions;
16use super::statements;
17use super::string_literals::parse_string_literal;
18use super::types;
19use super::types::parse_type_annotation;
20use super::{Rule, pair_span};
21
22/// Parse annotations
23pub fn parse_annotations(pair: Pair<Rule>) -> Result<Vec<Annotation>> {
24    let mut annotations = vec![];
25
26    for annotation_pair in pair.into_inner() {
27        if annotation_pair.as_rule() == Rule::annotation {
28            annotations.push(parse_annotation(annotation_pair)?);
29        }
30    }
31
32    Ok(annotations)
33}
34
35/// Parse a single annotation
36pub fn parse_annotation(pair: Pair<Rule>) -> Result<Annotation> {
37    let span = pair_span(&pair);
38    let mut name = String::new();
39    let mut args = Vec::new();
40
41    for inner_pair in pair.into_inner() {
42        match inner_pair.as_rule() {
43            Rule::annotation_name | Rule::ident => {
44                name = inner_pair.as_str().to_string();
45            }
46            Rule::annotation_args => {
47                for arg_pair in inner_pair.into_inner() {
48                    if arg_pair.as_rule() == Rule::expression {
49                        args.push(expressions::parse_expression(arg_pair)?);
50                    }
51                }
52            }
53            Rule::expression => {
54                args.push(expressions::parse_expression(inner_pair)?);
55            }
56            _ => {}
57        }
58    }
59
60    Ok(Annotation { name, args, span })
61}
62
63/// Parse a function parameter
64pub fn parse_function_param(pair: Pair<Rule>) -> Result<FunctionParameter> {
65    let mut pattern = None;
66    let mut is_const = false;
67    let mut is_reference = false;
68    let mut is_mut_reference = false;
69    let mut type_annotation = None;
70    let mut default_value = None;
71
72    for inner_pair in pair.into_inner() {
73        match inner_pair.as_rule() {
74            Rule::param_const_keyword => {
75                is_const = true;
76            }
77            Rule::param_ref_keyword => {
78                is_reference = true;
79                // Check for &mut: param_ref_keyword contains optional param_mut_keyword
80                for child in inner_pair.into_inner() {
81                    if child.as_rule() == Rule::param_mut_keyword {
82                        is_mut_reference = true;
83                    }
84                }
85            }
86            Rule::destructure_pattern => {
87                pattern = Some(super::items::parse_pattern(inner_pair)?);
88            }
89            Rule::type_annotation => {
90                type_annotation = Some(parse_type_annotation(inner_pair)?);
91            }
92            Rule::expression => {
93                default_value = Some(expressions::parse_expression(inner_pair)?);
94            }
95            _ => {}
96        }
97    }
98
99    let pattern = pattern.ok_or_else(|| crate::error::ShapeError::ParseError {
100        message: "expected pattern in function parameter".to_string(),
101        location: None,
102    })?;
103
104    Ok(FunctionParameter {
105        pattern,
106        is_const,
107        is_reference,
108        is_mut_reference,
109        type_annotation,
110        default_value,
111    })
112}
113
114/// Parse a function definition
115pub fn parse_function_def(pair: Pair<Rule>) -> Result<FunctionDef> {
116    let mut name = String::new();
117    let mut name_span = crate::ast::Span::DUMMY;
118    let mut type_params = None;
119    let mut params = vec![];
120    let mut return_type = None;
121    let mut where_clause = None;
122    let mut body = vec![];
123    let mut annotations = vec![];
124    let mut is_async = false;
125    let mut is_comptime = false;
126
127    // Parse all parts sequentially (can't use find() as it consumes the iterator)
128    for inner_pair in pair.into_inner() {
129        match inner_pair.as_rule() {
130            Rule::annotations => {
131                annotations = parse_annotations(inner_pair)?;
132            }
133            Rule::async_keyword => {
134                is_async = true;
135            }
136            Rule::comptime_keyword => {
137                is_comptime = true;
138            }
139            Rule::ident => {
140                if name.is_empty() {
141                    name = inner_pair.as_str().to_string();
142                    name_span = pair_span(&inner_pair);
143                }
144            }
145            Rule::type_params => {
146                type_params = Some(types::parse_type_params(inner_pair)?);
147            }
148            Rule::function_params => {
149                for param_pair in inner_pair.into_inner() {
150                    if param_pair.as_rule() == Rule::function_param {
151                        params.push(parse_function_param(param_pair)?);
152                    }
153                }
154            }
155            Rule::return_type => {
156                // Skip the "->" and get the type annotation
157                if let Some(type_pair) = inner_pair.into_inner().next() {
158                    return_type = Some(parse_type_annotation(type_pair)?);
159                }
160            }
161            Rule::where_clause => {
162                where_clause = Some(parse_where_clause(inner_pair)?);
163            }
164            Rule::function_body => {
165                // Parse all statements in the function body
166                body = statements::parse_statements(inner_pair.into_inner())?;
167            }
168            _ => {}
169        }
170    }
171
172    Ok(FunctionDef {
173        name,
174        name_span,
175        type_params,
176        params,
177        return_type,
178        where_clause,
179        body,
180        annotations,
181        is_async,
182        is_comptime,
183    })
184}
185
186/// Parse a declaration-only builtin function definition.
187///
188/// Grammar:
189/// `builtin fn name<T>(params...) -> ReturnType;`
190pub fn parse_builtin_function_decl(pair: Pair<Rule>) -> Result<BuiltinFunctionDecl> {
191    let mut name = String::new();
192    let mut name_span = crate::ast::Span::DUMMY;
193    let mut type_params = None;
194    let mut params = vec![];
195    let mut return_type = None;
196
197    for inner_pair in pair.into_inner() {
198        match inner_pair.as_rule() {
199            Rule::ident => {
200                if name.is_empty() {
201                    name = inner_pair.as_str().to_string();
202                    name_span = pair_span(&inner_pair);
203                }
204            }
205            Rule::type_params => {
206                type_params = Some(types::parse_type_params(inner_pair)?);
207            }
208            Rule::function_params => {
209                for param_pair in inner_pair.into_inner() {
210                    if param_pair.as_rule() == Rule::function_param {
211                        params.push(parse_function_param(param_pair)?);
212                    }
213                }
214            }
215            Rule::return_type => {
216                if let Some(type_pair) = inner_pair.into_inner().next() {
217                    return_type = Some(parse_type_annotation(type_pair)?);
218                }
219            }
220            _ => {}
221        }
222    }
223
224    let return_type = return_type.ok_or_else(|| crate::error::ShapeError::ParseError {
225        message: "builtin function declaration requires an explicit return type".to_string(),
226        location: None,
227    })?;
228
229    Ok(BuiltinFunctionDecl {
230        name,
231        name_span,
232        type_params,
233        params,
234        return_type,
235    })
236}
237
238/// Parse a foreign function definition: `fn python analyze(data: DataTable) -> number { ... }`
239pub fn parse_foreign_function_def(pair: Pair<Rule>) -> Result<ForeignFunctionDef> {
240    let mut language = String::new();
241    let mut language_span = crate::ast::Span::DUMMY;
242    let mut name = String::new();
243    let mut name_span = crate::ast::Span::DUMMY;
244    let mut type_params = None;
245    let mut params = vec![];
246    let mut return_type = None;
247    let mut body_text = String::new();
248    let mut body_span = crate::ast::Span::DUMMY;
249    let mut annotations = vec![];
250    let mut is_async = false;
251
252    for inner_pair in pair.into_inner() {
253        match inner_pair.as_rule() {
254            Rule::annotations => {
255                annotations = parse_annotations(inner_pair)?;
256            }
257            Rule::async_keyword => {
258                is_async = true;
259            }
260            Rule::function_keyword => {}
261            Rule::foreign_language_id => {
262                language = inner_pair.as_str().to_string();
263                language_span = pair_span(&inner_pair);
264            }
265            Rule::ident => {
266                if name.is_empty() {
267                    name = inner_pair.as_str().to_string();
268                    name_span = pair_span(&inner_pair);
269                }
270            }
271            Rule::type_params => {
272                type_params = Some(types::parse_type_params(inner_pair)?);
273            }
274            Rule::function_params => {
275                for param_pair in inner_pair.into_inner() {
276                    if param_pair.as_rule() == Rule::function_param {
277                        params.push(parse_function_param(param_pair)?);
278                    }
279                }
280            }
281            Rule::return_type => {
282                if let Some(type_pair) = inner_pair.into_inner().next() {
283                    return_type = Some(parse_type_annotation(type_pair)?);
284                }
285            }
286            Rule::foreign_body => {
287                body_span = pair_span(&inner_pair);
288                body_text = dedent_foreign_body(inner_pair.as_str());
289            }
290            _ => {}
291        }
292    }
293
294    Ok(ForeignFunctionDef {
295        language,
296        language_span,
297        name,
298        name_span,
299        type_params,
300        params,
301        return_type,
302        body_text,
303        body_span,
304        annotations,
305        is_async,
306        native_abi: None,
307    })
308}
309
310/// Parse a native ABI declaration:
311/// `extern "C" fn name(args...) -> Ret from "library" [as "symbol"];`
312pub fn parse_extern_native_function_def(pair: Pair<Rule>) -> Result<ForeignFunctionDef> {
313    let mut abi = String::new();
314    let mut abi_span = crate::ast::Span::DUMMY;
315    let mut name = String::new();
316    let mut name_span = crate::ast::Span::DUMMY;
317    let mut type_params = None;
318    let mut params = Vec::new();
319    let mut return_type = None;
320    let mut library: Option<String> = None;
321    let mut symbol: Option<String> = None;
322    let mut annotations = Vec::new();
323    let mut is_async = false;
324
325    for inner_pair in pair.into_inner() {
326        match inner_pair.as_rule() {
327            Rule::annotations => {
328                annotations = parse_annotations(inner_pair)?;
329            }
330            Rule::async_keyword => {
331                is_async = true;
332            }
333            Rule::extern_abi => {
334                abi_span = pair_span(&inner_pair);
335                abi = parse_extern_abi(inner_pair)?;
336            }
337            Rule::function_keyword => {}
338            Rule::ident => {
339                if name.is_empty() {
340                    name = inner_pair.as_str().to_string();
341                    name_span = pair_span(&inner_pair);
342                }
343            }
344            Rule::type_params => {
345                type_params = Some(types::parse_type_params(inner_pair)?);
346            }
347            Rule::function_params => {
348                for param_pair in inner_pair.into_inner() {
349                    if param_pair.as_rule() == Rule::function_param {
350                        params.push(parse_function_param(param_pair)?);
351                    }
352                }
353            }
354            Rule::return_type => {
355                if let Some(type_pair) = inner_pair.into_inner().next() {
356                    return_type = Some(parse_type_annotation(type_pair)?);
357                }
358            }
359            Rule::extern_native_link => {
360                for link_part in inner_pair.into_inner() {
361                    match link_part.as_rule() {
362                        Rule::extern_native_library => {
363                            library = Some(parse_string_literal(link_part.as_str())?);
364                        }
365                        Rule::extern_native_symbol => {
366                            symbol = Some(parse_string_literal(link_part.as_str())?);
367                        }
368                        _ => {}
369                    }
370                }
371            }
372            _ => {}
373        }
374    }
375
376    let library = library.ok_or_else(|| crate::error::ShapeError::ParseError {
377        message: "extern native declaration requires `from \"library\"`".to_string(),
378        location: None,
379    })?;
380
381    if abi.trim() != "C" {
382        return Err(crate::error::ShapeError::ParseError {
383            message: format!(
384                "unsupported extern ABI '{}': only \"C\" is currently supported",
385                abi
386            ),
387            location: None,
388        });
389    }
390
391    let symbol = symbol.unwrap_or_else(|| name.clone());
392
393    Ok(ForeignFunctionDef {
394        // Keep foreign-language compatibility for downstream compilation/runtime
395        // while carrying explicit native ABI metadata.
396        language: "native".to_string(),
397        language_span: abi_span,
398        name,
399        name_span,
400        type_params,
401        params,
402        return_type,
403        body_text: String::new(),
404        body_span: crate::ast::Span::DUMMY,
405        annotations,
406        is_async,
407        native_abi: Some(NativeAbiBinding {
408            abi,
409            library,
410            symbol,
411        }),
412    })
413}
414
415pub(crate) fn parse_extern_abi(pair: Pair<Rule>) -> Result<String> {
416    let inner = pair
417        .into_inner()
418        .next()
419        .ok_or_else(|| crate::error::ShapeError::ParseError {
420            message: "extern declaration is missing ABI name".to_string(),
421            location: None,
422        })?;
423
424    match inner.as_rule() {
425        Rule::string => parse_string_literal(inner.as_str()),
426        Rule::ident => Ok(inner.as_str().to_string()),
427        _ => Err(crate::error::ShapeError::ParseError {
428            message: format!("unsupported extern ABI token: {:?}", inner.as_rule()),
429            location: None,
430        }),
431    }
432}
433
434/// Strip common leading whitespace from foreign body text.
435///
436/// Similar to Python's `textwrap.dedent`. This is critical for Python blocks
437/// since the body is indented inside Shape code but needs to be dedented
438/// for the foreign language runtime.
439///
440/// Note: The Pest parser's implicit WHITESPACE rule consumes the newline and
441/// leading whitespace between `{` and the first token of `foreign_body`. This
442/// means the first line has its leading whitespace eaten by the parser, while
443/// subsequent lines retain their original indentation. We compute `min_indent`
444/// from lines after the first, then strip that amount only from those lines.
445/// The first line is kept as-is.
446fn dedent_foreign_body(text: &str) -> String {
447    let lines: Vec<&str> = text.lines().collect();
448    if lines.is_empty() {
449        return String::new();
450    }
451    if lines.len() == 1 {
452        return lines[0].trim_start().to_string();
453    }
454
455    // Compute min_indent from lines after the first, since the parser already
456    // consumed the first line's leading whitespace.
457    let min_indent = lines
458        .iter()
459        .skip(1)
460        .filter(|line| !line.trim().is_empty())
461        .map(|line| line.len() - line.trim_start().len())
462        .min()
463        .unwrap_or(0);
464
465    // First line: keep as-is (parser already stripped its whitespace).
466    // Subsequent lines: strip min_indent characters.
467    let mut result = Vec::with_capacity(lines.len());
468    result.push(lines[0]);
469    for line in &lines[1..] {
470        if line.len() >= min_indent {
471            result.push(&line[min_indent..]);
472        } else {
473            result.push(line.trim());
474        }
475    }
476    result.join("\n")
477}
478
479/// Parse a where clause: `where T: Bound1 + Bound2, U: Bound3`
480pub fn parse_where_clause(pair: Pair<Rule>) -> Result<Vec<crate::ast::types::WherePredicate>> {
481    let mut predicates = Vec::new();
482    for child in pair.into_inner() {
483        if child.as_rule() == Rule::where_predicate {
484            predicates.push(parse_where_predicate(child)?);
485        }
486    }
487    Ok(predicates)
488}
489
490fn parse_where_predicate(pair: Pair<Rule>) -> Result<crate::ast::types::WherePredicate> {
491    let mut inner = pair.into_inner();
492
493    let name_pair = inner
494        .next()
495        .ok_or_else(|| crate::error::ShapeError::ParseError {
496            message: "expected type parameter name in where predicate".to_string(),
497            location: None,
498        })?;
499    let type_name = name_pair.as_str().to_string();
500
501    let mut bounds = Vec::new();
502    for remaining in inner {
503        if remaining.as_rule() == Rule::trait_bound_list {
504            for bound_ident in remaining.into_inner() {
505                if bound_ident.as_rule() == Rule::ident {
506                    bounds.push(bound_ident.as_str().to_string());
507                }
508            }
509        }
510    }
511
512    Ok(crate::ast::types::WherePredicate { type_name, bounds })
513}