blueprint_starlark_syntax/syntax/
def.rs

1/*
2 * Copyright 2018 The Starlark in Rust Authors.
3 * Copyright (c) Facebook, Inc. and its affiliates.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18use std::collections::HashSet;
19use std::ops::Range;
20
21use blueprint_allocative::Allocative;
22use blueprint_dupe::Dupe;
23
24use crate::codemap::CodeMap;
25use crate::codemap::Span;
26use crate::codemap::Spanned;
27use crate::eval_exception::EvalException;
28use crate::syntax::ast::AstAssignIdentP;
29use crate::syntax::ast::AstExprP;
30use crate::syntax::ast::AstParameterP;
31use crate::syntax::ast::AstPayload;
32use crate::syntax::ast::AstTypeExprP;
33use crate::syntax::ast::ParameterP;
34
35#[derive(Debug, Clone, Copy, Dupe, PartialEq, Eq)]
36pub enum DefRegularParamMode {
37    PosOnly,
38    PosOrName,
39    NameOnly,
40}
41
42pub enum DefParamKind<'a, P: AstPayload> {
43    Regular(
44        DefRegularParamMode,
45        /// Default value.
46        Option<&'a AstExprP<P>>,
47    ),
48    Args,
49    Kwargs,
50}
51
52/// One function parameter.
53pub struct DefParam<'a, P: AstPayload> {
54    /// Name of the parameter.
55    pub ident: &'a AstAssignIdentP<P>,
56    /// Whether this is a regular parameter (with optional default) or a varargs construct (*args,
57    /// **kwargs).
58    pub kind: DefParamKind<'a, P>,
59    /// Type of the parameter. This is None when a type is not specified.
60    pub ty: Option<&'a AstTypeExprP<P>>,
61}
62
63/// Parameters internally in starlark-rust are commonly represented as a flat list of parameters,
64/// with markers `/` and `*` omitted.
65/// This struct contains sizes and indices to split the list into parts.
66#[derive(
67    Copy, Clone, Dupe, Debug, Eq, PartialEq, Hash, Ord, PartialOrd, Allocative
68)]
69pub struct DefParamIndices {
70    /// Number of parameters which can be filled positionally.
71    /// That is, number of parameters before first `*`, `*args` or `**kwargs`.
72    pub num_positional: u32,
73    /// Number of parameters which can only be filled positionally.
74    /// Always less or equal to `num_positional`.
75    pub num_positional_only: u32,
76    /// Index of `*args` parameter, if any.
77    /// If present, equal to `num_positional`.
78    pub args: Option<u32>,
79    /// Index of `**kwargs` parameter, if any.
80    /// If present, equal to the number of parameters minus 1.
81    pub kwargs: Option<u32>,
82}
83
84impl DefParamIndices {
85    pub fn pos_only(&self) -> Range<usize> {
86        0..self.num_positional_only as usize
87    }
88
89    pub fn pos_or_named(&self) -> Range<usize> {
90        self.num_positional_only as usize..self.num_positional as usize
91    }
92
93    pub fn named_only(&self, param_count: usize) -> Range<usize> {
94        self.args
95            .map(|a| a as usize + 1)
96            .unwrap_or(self.num_positional as usize)
97            ..self.kwargs.unwrap_or(param_count as u32) as usize
98    }
99}
100
101/// Post-processed AST for function parameters.
102///
103/// * Validated
104/// * `*` parameter replaced with `num_positional` field
105pub struct DefParams<'a, P: AstPayload> {
106    pub params: Vec<Spanned<DefParam<'a, P>>>,
107    pub indices: DefParamIndices,
108}
109
110fn check_param_name<'a, P: AstPayload, T>(
111    argset: &mut HashSet<&'a str>,
112    n: &'a AstAssignIdentP<P>,
113    arg: &Spanned<T>,
114    codemap: &CodeMap,
115) -> Result<(), EvalException> {
116    if !argset.insert(n.node.ident.as_str()) {
117        return Err(EvalException::parser_error(
118            "duplicated parameter name",
119            arg.span,
120            codemap,
121        ));
122    }
123    Ok(())
124}
125
126impl<'a, P: AstPayload> DefParams<'a, P> {
127    pub fn unpack(
128        ast_params: &'a [AstParameterP<P>],
129        codemap: &CodeMap,
130    ) -> Result<DefParams<'a, P>, EvalException> {
131        #[derive(Ord, PartialOrd, Eq, PartialEq)]
132        enum State {
133            Normal,
134            /// After `/`.
135            SeenSlash,
136            /// After `*` or `*args`.
137            SeenStar,
138            /// After `**kwargs`.
139            SeenStarStar,
140        }
141
142        // you can't repeat argument names
143        let mut argset = HashSet::new();
144        // You can't have more than one *args/*, **kwargs
145        // **kwargs must be last
146        // You can't have a required `x` after an optional `y=1`
147        let mut seen_optional = false;
148
149        let mut params = Vec::with_capacity(ast_params.len());
150        let mut num_positional = 0;
151        let mut args = None;
152        let mut kwargs = None;
153
154        // Index of `*` parameter, if any.
155        let mut index_of_star = None;
156
157        let num_positional_only = match ast_params
158            .iter()
159            .position(|p| matches!(p.node, ParameterP::Slash))
160        {
161            None => 0,
162            Some(0) => {
163                return Err(EvalException::parser_error(
164                    "`/` cannot be first parameter",
165                    ast_params[0].span,
166                    codemap,
167                ));
168            }
169            Some(n) => match n.try_into() {
170                Ok(n) => n,
171                Err(_) => {
172                    return Err(EvalException::parser_error(
173                        format_args!("Too many parameters: {}", ast_params.len()),
174                        Span::merge_all(ast_params.iter().map(|p| p.span)),
175                        codemap,
176                    ));
177                }
178            },
179        };
180
181        let mut state = if num_positional_only == 0 {
182            State::SeenSlash
183        } else {
184            State::Normal
185        };
186
187        for (i, param) in ast_params.iter().enumerate() {
188            let span = param.span;
189
190            if let Some(name) = param.ident() {
191                check_param_name(&mut argset, name, param, codemap)?;
192            }
193
194            match &param.node {
195                ParameterP::Normal(n, ty, default_value) => {
196                    if state >= State::SeenStarStar {
197                        return Err(EvalException::parser_error(
198                            "Parameter after kwargs",
199                            param.span,
200                            codemap,
201                        ));
202                    }
203                    match default_value {
204                        None => {
205                            if seen_optional && state < State::SeenStar {
206                                return Err(EvalException::parser_error(
207                                    "positional parameter after non positional",
208                                    param.span,
209                                    codemap,
210                                ));
211                            }
212                        }
213                        Some(_default_value) => {
214                            seen_optional = true;
215                        }
216                    }
217                    if state < State::SeenStar {
218                        num_positional += 1;
219                    }
220                    let mode = if state < State::SeenSlash {
221                        DefRegularParamMode::PosOnly
222                    } else if state < State::SeenStar {
223                        DefRegularParamMode::PosOrName
224                    } else {
225                        DefRegularParamMode::NameOnly
226                    };
227                    params.push(Spanned {
228                        span,
229                        node: DefParam {
230                            ident: n,
231                            kind: DefParamKind::Regular(mode, default_value.as_deref()),
232                            ty: ty.as_deref(),
233                        },
234                    });
235                }
236                ParameterP::NoArgs => {
237                    if state >= State::SeenStar {
238                        return Err(EvalException::parser_error(
239                            "Args parameter after another args or kwargs parameter",
240                            param.span,
241                            codemap,
242                        ));
243                    }
244                    state = State::SeenStar;
245                    if index_of_star.is_some() {
246                        return Err(EvalException::internal_error(
247                            "Multiple `*` in parameters, must have been caught earlier",
248                            param.span,
249                            codemap,
250                        ));
251                    }
252                    index_of_star = Some(i);
253                }
254                ParameterP::Slash => {
255                    if state >= State::SeenSlash {
256                        return Err(EvalException::parser_error(
257                            "Multiple `/` in parameters",
258                            param.span,
259                            codemap,
260                        ));
261                    }
262                    state = State::SeenSlash;
263                }
264                ParameterP::Args(n, ty) => {
265                    if state >= State::SeenStar {
266                        return Err(EvalException::parser_error(
267                            "Args parameter after another args or kwargs parameter",
268                            param.span,
269                            codemap,
270                        ));
271                    }
272                    state = State::SeenStar;
273                    if args.is_some() {
274                        return Err(EvalException::internal_error(
275                            "Multiple *args",
276                            param.span,
277                            codemap,
278                        ));
279                    }
280                    args = Some(params.len().try_into().unwrap());
281                    params.push(Spanned {
282                        span,
283                        node: DefParam {
284                            ident: n,
285                            kind: DefParamKind::Args,
286                            ty: ty.as_deref(),
287                        },
288                    });
289                }
290                ParameterP::KwArgs(n, ty) => {
291                    if state >= State::SeenStarStar {
292                        return Err(EvalException::parser_error(
293                            "Multiple kwargs dictionary in parameters",
294                            param.span,
295                            codemap,
296                        ));
297                    }
298                    if kwargs.is_some() {
299                        return Err(EvalException::internal_error(
300                            "Multiple **kwargs",
301                            param.span,
302                            codemap,
303                        ));
304                    }
305                    kwargs = Some(params.len().try_into().unwrap());
306                    state = State::SeenStarStar;
307                    params.push(Spanned {
308                        span,
309                        node: DefParam {
310                            ident: n,
311                            kind: DefParamKind::Kwargs,
312                            ty: ty.as_deref(),
313                        },
314                    });
315                }
316            }
317        }
318
319        if let Some(index_of_star) = index_of_star {
320            let Some(next) = ast_params.get(index_of_star + 1) else {
321                return Err(EvalException::parser_error(
322                    "`*` parameter must not be last",
323                    ast_params[index_of_star].span,
324                    codemap,
325                ));
326            };
327            match &next.node {
328                ParameterP::Normal(..) => {}
329                ParameterP::KwArgs(_, _)
330                | ParameterP::Args(_, _)
331                | ParameterP::NoArgs
332                | ParameterP::Slash => {
333                    // We get here only for `**kwargs`, the rest is handled above.
334                    return Err(EvalException::parser_error(
335                        "`*` must be followed by named parameter",
336                        next.span,
337                        codemap,
338                    ));
339                }
340            }
341        }
342
343        Ok(DefParams {
344            params,
345            indices: DefParamIndices {
346                num_positional: u32::try_from(num_positional).unwrap(),
347                num_positional_only,
348                args,
349                kwargs,
350            },
351        })
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use crate::golden_test_template::golden_test_template;
358    use crate::syntax::AstModule;
359    use crate::syntax::Dialect;
360
361    fn fails_dialect(test_name: &str, program: &str, dialect: &Dialect) {
362        let e = AstModule::parse("test.star", program.to_owned(), dialect).unwrap_err();
363        let text = format!("Program:\n{program}\n\nError: {e}\n");
364        golden_test_template(&format!("src/syntax/def_tests/{test_name}.golden"), &text);
365    }
366
367    fn fails(test_name: &str, program: &str) {
368        fails_dialect(test_name, program, &Dialect::AllOptionsInternal);
369    }
370
371    fn passes(program: &str) {
372        AstModule::parse(
373            "test.star",
374            program.to_owned(),
375            &Dialect::AllOptionsInternal,
376        )
377        .unwrap();
378    }
379
380    #[test]
381    fn test_params_unpack() {
382        fails("dup_name", "def test(x, y, x): pass");
383        fails("pos_after_default", "def test(x=1, y): pass");
384        fails("default_after_kwargs", "def test(**kwargs, y=1): pass");
385        fails("args_args", "def test(*x, *y): pass");
386        fails("kwargs_args", "def test(**x, *y): pass");
387        fails("kwargs_kwargs", "def test(**x, **y): pass");
388
389        passes("def test(x, y, z=1, *args, **kwargs): pass");
390    }
391
392    #[test]
393    fn test_params_noargs() {
394        fails("star_star", "def test(*, *): pass");
395        fails("normal_after_default", "def test(x, y=1, z): pass");
396
397        passes("def test(*args, x): pass");
398        passes("def test(*args, x=1): pass");
399        passes("def test(*args, x, y=1): pass");
400        passes("def test(x=1, *args, y): pass");
401        passes("def test(*args, x, y=1, z): pass");
402        passes("def test(*, x, y=1, z): pass");
403    }
404
405    #[test]
406    fn test_star_cannot_be_last() {
407        fails("star_cannot_be_last", "def test(x, *): pass");
408    }
409
410    #[test]
411    fn test_star_then_args() {
412        fails("star_then_args", "def test(x, *, *args): pass");
413    }
414
415    #[test]
416    fn test_star_then_kwargs() {
417        fails("star_then_kwargs", "def test(x, *, **kwargs): pass");
418    }
419
420    #[test]
421    fn test_positional_only() {
422        passes("def test(x, /): pass");
423    }
424
425    #[test]
426    fn test_positional_only_cannot_be_first() {
427        fails("positional_only_cannot_be_first", "def test(/, x): pass");
428    }
429
430    #[test]
431    fn test_slash_slash() {
432        fails("slash_slash", "def test(x, /, y, /): pass");
433    }
434
435    #[test]
436    fn test_named_only_in_standard_dialect_def() {
437        fails_dialect(
438            "named_only_in_standard_dialect_def",
439            "def test(*, x): pass",
440            &Dialect::Standard,
441        );
442    }
443
444    #[test]
445    fn test_named_only_in_standard_dialect_lambda() {
446        fails_dialect(
447            "named_only_in_standard_dialect_lambda",
448            "lambda *, x: 17",
449            &Dialect::Standard,
450        );
451    }
452
453    #[test]
454    fn test_positional_only_in_standard_dialect_def() {
455        fails_dialect(
456            "positional_only_in_standard_dialect_def",
457            "def test(/, x): pass",
458            &Dialect::Standard,
459        );
460    }
461
462    #[test]
463    fn test_positional_only_in_standard_dialect_lambda() {
464        fails_dialect(
465            "positional_only_in_standard_dialect_lambda",
466            "lambda /, x: 17",
467            &Dialect::Standard,
468        );
469    }
470}