1use std::collections::HashSet;
19use std::ops::Range;
20
21use blueprint_allocative::Allocative;
22use blueprint_dupe::Dupe;
23
24use crate::codemap::CodeMap;
25use crate::codemap::Span;
26use crate::codemap::Spanned;
27use crate::eval_exception::EvalException;
28use crate::syntax::ast::AstAssignIdentP;
29use crate::syntax::ast::AstExprP;
30use crate::syntax::ast::AstParameterP;
31use crate::syntax::ast::AstPayload;
32use crate::syntax::ast::AstTypeExprP;
33use crate::syntax::ast::ParameterP;
34
35#[derive(Debug, Clone, Copy, Dupe, PartialEq, Eq)]
36pub enum DefRegularParamMode {
37 PosOnly,
38 PosOrName,
39 NameOnly,
40}
41
42pub enum DefParamKind<'a, P: AstPayload> {
43 Regular(
44 DefRegularParamMode,
45 Option<&'a AstExprP<P>>,
47 ),
48 Args,
49 Kwargs,
50}
51
52pub struct DefParam<'a, P: AstPayload> {
54 pub ident: &'a AstAssignIdentP<P>,
56 pub kind: DefParamKind<'a, P>,
59 pub ty: Option<&'a AstTypeExprP<P>>,
61}
62
63#[derive(
67 Copy, Clone, Dupe, Debug, Eq, PartialEq, Hash, Ord, PartialOrd, Allocative
68)]
69pub struct DefParamIndices {
70 pub num_positional: u32,
73 pub num_positional_only: u32,
76 pub args: Option<u32>,
79 pub kwargs: Option<u32>,
82}
83
84impl DefParamIndices {
85 pub fn pos_only(&self) -> Range<usize> {
86 0..self.num_positional_only as usize
87 }
88
89 pub fn pos_or_named(&self) -> Range<usize> {
90 self.num_positional_only as usize..self.num_positional as usize
91 }
92
93 pub fn named_only(&self, param_count: usize) -> Range<usize> {
94 self.args
95 .map(|a| a as usize + 1)
96 .unwrap_or(self.num_positional as usize)
97 ..self.kwargs.unwrap_or(param_count as u32) as usize
98 }
99}
100
101pub struct DefParams<'a, P: AstPayload> {
106 pub params: Vec<Spanned<DefParam<'a, P>>>,
107 pub indices: DefParamIndices,
108}
109
110fn check_param_name<'a, P: AstPayload, T>(
111 argset: &mut HashSet<&'a str>,
112 n: &'a AstAssignIdentP<P>,
113 arg: &Spanned<T>,
114 codemap: &CodeMap,
115) -> Result<(), EvalException> {
116 if !argset.insert(n.node.ident.as_str()) {
117 return Err(EvalException::parser_error(
118 "duplicated parameter name",
119 arg.span,
120 codemap,
121 ));
122 }
123 Ok(())
124}
125
126impl<'a, P: AstPayload> DefParams<'a, P> {
127 pub fn unpack(
128 ast_params: &'a [AstParameterP<P>],
129 codemap: &CodeMap,
130 ) -> Result<DefParams<'a, P>, EvalException> {
131 #[derive(Ord, PartialOrd, Eq, PartialEq)]
132 enum State {
133 Normal,
134 SeenSlash,
136 SeenStar,
138 SeenStarStar,
140 }
141
142 let mut argset = HashSet::new();
144 let mut seen_optional = false;
148
149 let mut params = Vec::with_capacity(ast_params.len());
150 let mut num_positional = 0;
151 let mut args = None;
152 let mut kwargs = None;
153
154 let mut index_of_star = None;
156
157 let num_positional_only = match ast_params
158 .iter()
159 .position(|p| matches!(p.node, ParameterP::Slash))
160 {
161 None => 0,
162 Some(0) => {
163 return Err(EvalException::parser_error(
164 "`/` cannot be first parameter",
165 ast_params[0].span,
166 codemap,
167 ));
168 }
169 Some(n) => match n.try_into() {
170 Ok(n) => n,
171 Err(_) => {
172 return Err(EvalException::parser_error(
173 format_args!("Too many parameters: {}", ast_params.len()),
174 Span::merge_all(ast_params.iter().map(|p| p.span)),
175 codemap,
176 ));
177 }
178 },
179 };
180
181 let mut state = if num_positional_only == 0 {
182 State::SeenSlash
183 } else {
184 State::Normal
185 };
186
187 for (i, param) in ast_params.iter().enumerate() {
188 let span = param.span;
189
190 if let Some(name) = param.ident() {
191 check_param_name(&mut argset, name, param, codemap)?;
192 }
193
194 match ¶m.node {
195 ParameterP::Normal(n, ty, default_value) => {
196 if state >= State::SeenStarStar {
197 return Err(EvalException::parser_error(
198 "Parameter after kwargs",
199 param.span,
200 codemap,
201 ));
202 }
203 match default_value {
204 None => {
205 if seen_optional && state < State::SeenStar {
206 return Err(EvalException::parser_error(
207 "positional parameter after non positional",
208 param.span,
209 codemap,
210 ));
211 }
212 }
213 Some(_default_value) => {
214 seen_optional = true;
215 }
216 }
217 if state < State::SeenStar {
218 num_positional += 1;
219 }
220 let mode = if state < State::SeenSlash {
221 DefRegularParamMode::PosOnly
222 } else if state < State::SeenStar {
223 DefRegularParamMode::PosOrName
224 } else {
225 DefRegularParamMode::NameOnly
226 };
227 params.push(Spanned {
228 span,
229 node: DefParam {
230 ident: n,
231 kind: DefParamKind::Regular(mode, default_value.as_deref()),
232 ty: ty.as_deref(),
233 },
234 });
235 }
236 ParameterP::NoArgs => {
237 if state >= State::SeenStar {
238 return Err(EvalException::parser_error(
239 "Args parameter after another args or kwargs parameter",
240 param.span,
241 codemap,
242 ));
243 }
244 state = State::SeenStar;
245 if index_of_star.is_some() {
246 return Err(EvalException::internal_error(
247 "Multiple `*` in parameters, must have been caught earlier",
248 param.span,
249 codemap,
250 ));
251 }
252 index_of_star = Some(i);
253 }
254 ParameterP::Slash => {
255 if state >= State::SeenSlash {
256 return Err(EvalException::parser_error(
257 "Multiple `/` in parameters",
258 param.span,
259 codemap,
260 ));
261 }
262 state = State::SeenSlash;
263 }
264 ParameterP::Args(n, ty) => {
265 if state >= State::SeenStar {
266 return Err(EvalException::parser_error(
267 "Args parameter after another args or kwargs parameter",
268 param.span,
269 codemap,
270 ));
271 }
272 state = State::SeenStar;
273 if args.is_some() {
274 return Err(EvalException::internal_error(
275 "Multiple *args",
276 param.span,
277 codemap,
278 ));
279 }
280 args = Some(params.len().try_into().unwrap());
281 params.push(Spanned {
282 span,
283 node: DefParam {
284 ident: n,
285 kind: DefParamKind::Args,
286 ty: ty.as_deref(),
287 },
288 });
289 }
290 ParameterP::KwArgs(n, ty) => {
291 if state >= State::SeenStarStar {
292 return Err(EvalException::parser_error(
293 "Multiple kwargs dictionary in parameters",
294 param.span,
295 codemap,
296 ));
297 }
298 if kwargs.is_some() {
299 return Err(EvalException::internal_error(
300 "Multiple **kwargs",
301 param.span,
302 codemap,
303 ));
304 }
305 kwargs = Some(params.len().try_into().unwrap());
306 state = State::SeenStarStar;
307 params.push(Spanned {
308 span,
309 node: DefParam {
310 ident: n,
311 kind: DefParamKind::Kwargs,
312 ty: ty.as_deref(),
313 },
314 });
315 }
316 }
317 }
318
319 if let Some(index_of_star) = index_of_star {
320 let Some(next) = ast_params.get(index_of_star + 1) else {
321 return Err(EvalException::parser_error(
322 "`*` parameter must not be last",
323 ast_params[index_of_star].span,
324 codemap,
325 ));
326 };
327 match &next.node {
328 ParameterP::Normal(..) => {}
329 ParameterP::KwArgs(_, _)
330 | ParameterP::Args(_, _)
331 | ParameterP::NoArgs
332 | ParameterP::Slash => {
333 return Err(EvalException::parser_error(
335 "`*` must be followed by named parameter",
336 next.span,
337 codemap,
338 ));
339 }
340 }
341 }
342
343 Ok(DefParams {
344 params,
345 indices: DefParamIndices {
346 num_positional: u32::try_from(num_positional).unwrap(),
347 num_positional_only,
348 args,
349 kwargs,
350 },
351 })
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use crate::golden_test_template::golden_test_template;
358 use crate::syntax::AstModule;
359 use crate::syntax::Dialect;
360
361 fn fails_dialect(test_name: &str, program: &str, dialect: &Dialect) {
362 let e = AstModule::parse("test.star", program.to_owned(), dialect).unwrap_err();
363 let text = format!("Program:\n{program}\n\nError: {e}\n");
364 golden_test_template(&format!("src/syntax/def_tests/{test_name}.golden"), &text);
365 }
366
367 fn fails(test_name: &str, program: &str) {
368 fails_dialect(test_name, program, &Dialect::AllOptionsInternal);
369 }
370
371 fn passes(program: &str) {
372 AstModule::parse(
373 "test.star",
374 program.to_owned(),
375 &Dialect::AllOptionsInternal,
376 )
377 .unwrap();
378 }
379
380 #[test]
381 fn test_params_unpack() {
382 fails("dup_name", "def test(x, y, x): pass");
383 fails("pos_after_default", "def test(x=1, y): pass");
384 fails("default_after_kwargs", "def test(**kwargs, y=1): pass");
385 fails("args_args", "def test(*x, *y): pass");
386 fails("kwargs_args", "def test(**x, *y): pass");
387 fails("kwargs_kwargs", "def test(**x, **y): pass");
388
389 passes("def test(x, y, z=1, *args, **kwargs): pass");
390 }
391
392 #[test]
393 fn test_params_noargs() {
394 fails("star_star", "def test(*, *): pass");
395 fails("normal_after_default", "def test(x, y=1, z): pass");
396
397 passes("def test(*args, x): pass");
398 passes("def test(*args, x=1): pass");
399 passes("def test(*args, x, y=1): pass");
400 passes("def test(x=1, *args, y): pass");
401 passes("def test(*args, x, y=1, z): pass");
402 passes("def test(*, x, y=1, z): pass");
403 }
404
405 #[test]
406 fn test_star_cannot_be_last() {
407 fails("star_cannot_be_last", "def test(x, *): pass");
408 }
409
410 #[test]
411 fn test_star_then_args() {
412 fails("star_then_args", "def test(x, *, *args): pass");
413 }
414
415 #[test]
416 fn test_star_then_kwargs() {
417 fails("star_then_kwargs", "def test(x, *, **kwargs): pass");
418 }
419
420 #[test]
421 fn test_positional_only() {
422 passes("def test(x, /): pass");
423 }
424
425 #[test]
426 fn test_positional_only_cannot_be_first() {
427 fails("positional_only_cannot_be_first", "def test(/, x): pass");
428 }
429
430 #[test]
431 fn test_slash_slash() {
432 fails("slash_slash", "def test(x, /, y, /): pass");
433 }
434
435 #[test]
436 fn test_named_only_in_standard_dialect_def() {
437 fails_dialect(
438 "named_only_in_standard_dialect_def",
439 "def test(*, x): pass",
440 &Dialect::Standard,
441 );
442 }
443
444 #[test]
445 fn test_named_only_in_standard_dialect_lambda() {
446 fails_dialect(
447 "named_only_in_standard_dialect_lambda",
448 "lambda *, x: 17",
449 &Dialect::Standard,
450 );
451 }
452
453 #[test]
454 fn test_positional_only_in_standard_dialect_def() {
455 fails_dialect(
456 "positional_only_in_standard_dialect_def",
457 "def test(/, x): pass",
458 &Dialect::Standard,
459 );
460 }
461
462 #[test]
463 fn test_positional_only_in_standard_dialect_lambda() {
464 fails_dialect(
465 "positional_only_in_standard_dialect_lambda",
466 "lambda /, x: 17",
467 &Dialect::Standard,
468 );
469 }
470}