1use std::{
4 collections::{HashMap, HashSet},
5 fmt,
6 sync::Arc,
7};
8
9use crate::{
10 arith::{CompleteConstraints, Constraint, ConstraintSet, Num},
11 types::ParamQuantifier,
12 LengthVar, PrimitiveType, Tuple, TupleLen, Type, TypeVar,
13};
14
15#[derive(Debug, Clone)]
16pub(crate) struct ParamConstraints<Prim: PrimitiveType> {
17 pub type_params: HashMap<usize, CompleteConstraints<Prim>>,
18 pub static_lengths: HashSet<usize>,
19}
20
21impl<Prim: PrimitiveType> Default for ParamConstraints<Prim> {
22 fn default() -> Self {
23 Self {
24 type_params: HashMap::new(),
25 static_lengths: HashSet::new(),
26 }
27 }
28}
29
30impl<Prim: PrimitiveType> fmt::Display for ParamConstraints<Prim> {
31 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
32 if !self.static_lengths.is_empty() {
33 formatter.write_str("len! ")?;
34 for (i, len) in self.static_lengths.iter().enumerate() {
35 write!(formatter, "{}", LengthVar::param_str(*len))?;
36 if i + 1 < self.static_lengths.len() {
37 formatter.write_str(", ")?;
38 }
39 }
40
41 if !self.type_params.is_empty() {
42 formatter.write_str("; ")?;
43 }
44 }
45
46 let type_param_count = self.type_params.len();
47 for (i, (idx, constraints)) in self.type_params().enumerate() {
48 write!(formatter, "'{}: {}", TypeVar::param_str(idx), constraints)?;
49 if i + 1 < type_param_count {
50 formatter.write_str(", ")?;
51 }
52 }
53
54 Ok(())
55 }
56}
57
58impl<Prim: PrimitiveType> ParamConstraints<Prim> {
59 fn is_empty(&self) -> bool {
60 self.type_params.is_empty() && self.static_lengths.is_empty()
61 }
62
63 fn type_params(&self) -> impl Iterator<Item = (usize, &CompleteConstraints<Prim>)> + '_ {
64 let mut type_params: Vec<_> = self.type_params.iter().map(|(&idx, c)| (idx, c)).collect();
65 type_params.sort_unstable_by_key(|(idx, _)| *idx);
66 type_params.into_iter()
67 }
68}
69
70#[derive(Debug)]
71pub(crate) struct FnParams<Prim: PrimitiveType> {
72 pub type_params: Vec<(usize, CompleteConstraints<Prim>)>,
74 pub len_params: Vec<(usize, bool)>,
76 pub constraints: Option<ParamConstraints<Prim>>,
78}
79
80impl<Prim: PrimitiveType> Default for FnParams<Prim> {
81 fn default() -> Self {
82 Self {
83 type_params: vec![],
84 len_params: vec![],
85 constraints: None,
86 }
87 }
88}
89
90impl<Prim: PrimitiveType> PartialEq for FnParams<Prim> {
91 fn eq(&self, other: &Self) -> bool {
92 self.type_params == other.type_params && self.len_params == other.len_params
93 }
94}
95
96impl<Prim: PrimitiveType> FnParams<Prim> {
97 fn is_empty(&self) -> bool {
98 self.len_params.is_empty() && self.type_params.is_empty()
99 }
100}
101
102#[derive(Debug, Clone, PartialEq)]
168pub struct Function<Prim: PrimitiveType = Num> {
169 pub(crate) args: Tuple<Prim>,
171 pub(crate) return_type: Type<Prim>,
173 pub(crate) params: Option<Arc<FnParams<Prim>>>,
175}
176
177impl<Prim: PrimitiveType> fmt::Display for Function<Prim> {
178 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
179 let constraints = self
180 .params
181 .as_ref()
182 .and_then(|params| params.constraints.as_ref());
183 if let Some(constraints) = constraints {
184 if !constraints.is_empty() {
185 write!(formatter, "for<{}> ", constraints)?;
186 }
187 }
188
189 self.args.format_as_tuple(formatter)?;
190 write!(formatter, " -> {}", self.return_type)?;
191 Ok(())
192 }
193}
194
195impl<Prim: PrimitiveType> Function<Prim> {
196 pub(crate) fn new(args: Tuple<Prim>, return_type: Type<Prim>) -> Self {
197 Self {
198 args,
199 return_type,
200 params: None,
201 }
202 }
203
204 pub fn builder() -> FunctionBuilder<Prim> {
206 FunctionBuilder::default()
207 }
208
209 pub fn args(&self) -> &Tuple<Prim> {
211 &self.args
212 }
213
214 pub fn return_type(&self) -> &Type<Prim> {
216 &self.return_type
217 }
218
219 pub(crate) fn set_params(&mut self, params: FnParams<Prim>) {
220 self.params = Some(Arc::new(params));
221 }
222
223 pub(crate) fn is_parametric(&self) -> bool {
224 self.params
225 .as_ref()
226 .map_or(false, |params| !params.is_empty())
227 }
228
229 pub fn is_concrete(&self) -> bool {
234 self.args.is_concrete() && self.return_type.is_concrete()
235 }
236
237 pub fn with_constraints<C: Constraint<Prim>>(
243 self,
244 indexes: &[usize],
245 constraint: C,
246 ) -> FnWithConstraints<Prim> {
247 assert!(
248 self.params.is_none(),
249 "Cannot attach constraints to a function with computed params: `{}`",
250 self
251 );
252
253 let constraints = CompleteConstraints::from(ConstraintSet::just(constraint));
254 let type_params = indexes
255 .iter()
256 .map(|&idx| (idx, constraints.clone()))
257 .collect();
258
259 FnWithConstraints {
260 function: self,
261 constraints: ParamConstraints {
262 type_params,
263 static_lengths: HashSet::new(),
264 },
265 }
266 }
267
268 pub fn with_static_lengths(self, indexes: &[usize]) -> FnWithConstraints<Prim> {
274 assert!(
275 self.params.is_none(),
276 "Cannot attach constraints to a function with computed params: `{}`",
277 self
278 );
279
280 FnWithConstraints {
281 function: self,
282 constraints: ParamConstraints {
283 type_params: HashMap::new(),
284 static_lengths: indexes.iter().copied().collect(),
285 },
286 }
287 }
288}
289
290#[derive(Debug)]
295pub struct FnWithConstraints<Prim: PrimitiveType> {
296 function: Function<Prim>,
297 constraints: ParamConstraints<Prim>,
298}
299
300impl<Prim: PrimitiveType> fmt::Display for FnWithConstraints<Prim> {
301 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
302 if self.constraints.is_empty() {
303 fmt::Display::fmt(&self.function, formatter)
304 } else {
305 write!(formatter, "for<{}> {}", self.constraints, self.function)
306 }
307 }
308}
309
310impl<Prim: PrimitiveType> FnWithConstraints<Prim> {
311 pub fn with_constraint<C>(mut self, indexes: &[usize], constraint: &C) -> Self
314 where
315 C: Constraint<Prim> + Clone,
316 {
317 for &i in indexes {
318 let constraints = self.constraints.type_params.entry(i).or_default();
319 constraints.simple.insert(constraint.clone());
320 }
321 self
322 }
323
324 pub fn with_static_lengths(mut self, indexes: &[usize]) -> FnWithConstraints<Prim> {
326 let indexes = indexes.iter().copied();
327 self.constraints.static_lengths.extend(indexes);
328 self
329 }
330}
331
332impl<Prim: PrimitiveType> From<FnWithConstraints<Prim>> for Function<Prim> {
333 fn from(value: FnWithConstraints<Prim>) -> Self {
334 let mut function = value.function;
335 ParamQuantifier::set_params(&mut function, value.constraints);
336 function
337 }
338}
339
340impl<Prim: PrimitiveType> From<FnWithConstraints<Prim>> for Type<Prim> {
341 fn from(value: FnWithConstraints<Prim>) -> Self {
342 Function::from(value).into()
343 }
344}
345
346#[derive(Debug, Clone)]
394pub struct FunctionBuilder<Prim: PrimitiveType = Num> {
395 args: Tuple<Prim>,
396}
397
398impl<Prim: PrimitiveType> Default for FunctionBuilder<Prim> {
399 fn default() -> Self {
400 Self {
401 args: Tuple::empty(),
402 }
403 }
404}
405
406impl<Prim: PrimitiveType> FunctionBuilder<Prim> {
407 pub fn with_arg(mut self, arg: impl Into<Type<Prim>>) -> Self {
409 self.args.push(arg.into());
410 self
411 }
412
413 pub fn with_varargs(
415 mut self,
416 element: impl Into<Type<Prim>>,
417 len: impl Into<TupleLen>,
418 ) -> Self {
419 self.args.set_middle(element.into(), len.into());
420 self
421 }
422
423 pub fn returning(self, return_type: impl Into<Type<Prim>>) -> Function<Prim> {
425 Function::new(self.args, return_type.into())
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use crate::{arith::Linearity, UnknownLen};
433
434 #[test]
435 fn constraints_display() {
436 let type_constraints = ConstraintSet::<Num>::just(Linearity);
437 let type_constraints = CompleteConstraints::from(type_constraints);
438
439 let constraints = ParamConstraints {
440 type_params: vec![(0, type_constraints.clone())].into_iter().collect(),
441 static_lengths: HashSet::new(),
442 };
443 assert_eq!(constraints.to_string(), "'T: Lin");
444
445 let constraints: ParamConstraints<Num> = ParamConstraints {
446 type_params: vec![(0, type_constraints)].into_iter().collect(),
447 static_lengths: vec![0].into_iter().collect(),
448 };
449 assert_eq!(constraints.to_string(), "len! N; 'T: Lin");
450 }
451
452 #[test]
453 fn fn_with_constraints_display() {
454 let sum_fn = <Function>::builder()
455 .with_arg(Type::param(0).repeat(UnknownLen::param(0)))
456 .returning(Type::param(0))
457 .with_constraints(&[0], Linearity);
458 assert_eq!(sum_fn.to_string(), "for<'T: Lin> (['T; N]) -> 'T");
459 }
460
461 #[test]
462 fn fn_builder_with_quantified_arg() {
463 let sum_fn: Function = Function::builder()
464 .with_arg(Type::NUM.repeat(UnknownLen::param(0)))
465 .returning(Type::NUM)
466 .with_constraints(&[], Linearity)
467 .into();
468 assert_eq!(sum_fn.to_string(), "([Num; N]) -> Num");
469
470 let complex_fn: Function = Function::builder()
471 .with_arg(Type::NUM)
472 .with_arg(sum_fn.clone())
473 .returning(Type::NUM)
474 .with_constraints(&[], Linearity)
475 .into();
476 assert_eq!(complex_fn.to_string(), "(Num, ([Num; N]) -> Num) -> Num");
477
478 let other_complex_fn: Function = Function::builder()
479 .with_varargs(Type::NUM, UnknownLen::param(0))
480 .with_arg(sum_fn)
481 .returning(Type::NUM)
482 .with_constraints(&[], Linearity)
483 .into();
484 assert_eq!(
485 other_complex_fn.to_string(),
486 "(...[Num; N], ([Num; N]) -> Num) -> Num"
487 );
488 }
489}