arithmetic_typing/types/
fn_type.rs

1//! Functional type (`Function`) and closely related types.
2
3use std::{
4    collections::{HashMap, HashSet},
5    fmt,
6    sync::Arc,
7};
8
9use crate::{
10    arith::{CompleteConstraints, Constraint, ConstraintSet, Num},
11    types::ParamQuantifier,
12    LengthVar, PrimitiveType, Tuple, TupleLen, Type, TypeVar,
13};
14
15#[derive(Debug, Clone)]
16pub(crate) struct ParamConstraints<Prim: PrimitiveType> {
17    pub type_params: HashMap<usize, CompleteConstraints<Prim>>,
18    pub static_lengths: HashSet<usize>,
19}
20
21impl<Prim: PrimitiveType> Default for ParamConstraints<Prim> {
22    fn default() -> Self {
23        Self {
24            type_params: HashMap::new(),
25            static_lengths: HashSet::new(),
26        }
27    }
28}
29
30impl<Prim: PrimitiveType> fmt::Display for ParamConstraints<Prim> {
31    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
32        if !self.static_lengths.is_empty() {
33            formatter.write_str("len! ")?;
34            for (i, len) in self.static_lengths.iter().enumerate() {
35                write!(formatter, "{}", LengthVar::param_str(*len))?;
36                if i + 1 < self.static_lengths.len() {
37                    formatter.write_str(", ")?;
38                }
39            }
40
41            if !self.type_params.is_empty() {
42                formatter.write_str("; ")?;
43            }
44        }
45
46        let type_param_count = self.type_params.len();
47        for (i, (idx, constraints)) in self.type_params().enumerate() {
48            write!(formatter, "'{}: {}", TypeVar::param_str(idx), constraints)?;
49            if i + 1 < type_param_count {
50                formatter.write_str(", ")?;
51            }
52        }
53
54        Ok(())
55    }
56}
57
58impl<Prim: PrimitiveType> ParamConstraints<Prim> {
59    fn is_empty(&self) -> bool {
60        self.type_params.is_empty() && self.static_lengths.is_empty()
61    }
62
63    fn type_params(&self) -> impl Iterator<Item = (usize, &CompleteConstraints<Prim>)> + '_ {
64        let mut type_params: Vec<_> = self.type_params.iter().map(|(&idx, c)| (idx, c)).collect();
65        type_params.sort_unstable_by_key(|(idx, _)| *idx);
66        type_params.into_iter()
67    }
68}
69
70#[derive(Debug)]
71pub(crate) struct FnParams<Prim: PrimitiveType> {
72    /// Type params associated with this function. Filled in by `FnQuantifier`.
73    pub type_params: Vec<(usize, CompleteConstraints<Prim>)>,
74    /// Length params associated with this function. Filled in by `FnQuantifier`.
75    pub len_params: Vec<(usize, bool)>,
76    /// Constraints for params of this function and child functions.
77    pub constraints: Option<ParamConstraints<Prim>>,
78}
79
80impl<Prim: PrimitiveType> Default for FnParams<Prim> {
81    fn default() -> Self {
82        Self {
83            type_params: vec![],
84            len_params: vec![],
85            constraints: None,
86        }
87    }
88}
89
90impl<Prim: PrimitiveType> PartialEq for FnParams<Prim> {
91    fn eq(&self, other: &Self) -> bool {
92        self.type_params == other.type_params && self.len_params == other.len_params
93    }
94}
95
96impl<Prim: PrimitiveType> FnParams<Prim> {
97    fn is_empty(&self) -> bool {
98        self.len_params.is_empty() && self.type_params.is_empty()
99    }
100}
101
102/// Functional type.
103///
104/// # Notation
105///
106/// Functional types are denoted as follows:
107///
108/// ```text
109/// for<len! M; 'T: Lin> (['T; N], 'T) -> ['T; M]
110/// ```
111///
112/// Here:
113///
114/// - `len! M` and `'T: Lin` are constraints on [length params] and [type params], respectively.
115///   Length and/or type params constraints may be empty. Unconstrained type / length params
116///   (such as length `N` in the example) do not need to be mentioned.
117/// - `len! M` means that `M` is a [static length](TupleLen#static-lengths).
118/// - `Lin` is a [constraint] on the type param.
119/// - `N`, `M` and `'T` are parameter names. The args and the return type may reference these
120///   parameters.
121/// - `['T; N]` and `'T` are types of the function arguments.
122/// - `['T; M]` is the return type.
123///
124/// The `for` constraints can only be present on top-level functions, but not in functions
125/// mentioned in args / return types of other functions.
126///
127/// The `-> _` part is mandatory, even if the function returns [`Type::void()`].
128///
129/// A function may accept variable number of arguments of the same type along
130/// with other args. (This construction is known as *varargs*.) This is denoted similarly
131/// to middles in [`Tuple`]s. For example, `(...[Num; N]) -> Num` denotes a function
132/// that accepts any number of `Num` args and returns a `Num` value.
133///
134/// [length params]: crate::LengthVar
135/// [type params]: crate::TypeVar
136/// [constraint]: crate::arith::Constraint
137/// [dynamic length]: crate::TupleLen#static-lengths
138///
139/// # Construction
140///
141/// Functional types can be constructed via [`Self::builder()`] or parsed from a string.
142///
143/// With [`Self::builder()`], type / length params are *implicit*; they are computed automatically
144/// when a function or [`FnWithConstraints`] is supplied to a [`TypeEnvironment`]. Computations
145/// include both the function itself, and any child functions.
146///
147/// [`TypeEnvironment`]: crate::TypeEnvironment
148///
149/// # Examples
150///
151/// ```
152/// # use arithmetic_typing::{ast::FunctionAst, Function, Slice, Type};
153/// # use std::convert::TryFrom;
154/// # use assert_matches::assert_matches;
155/// # fn main() -> anyhow::Result<()> {
156/// let fn_type: Function = FunctionAst::try_from("([Num; N]) -> Num")?
157///     .try_convert()?;
158/// assert_eq!(*fn_type.return_type(), Type::NUM);
159/// assert_matches!(
160///     fn_type.args().parts(),
161///     ([Type::Tuple(t)], None, [])
162///         if t.as_slice().map(Slice::element) == Some(&Type::NUM)
163/// );
164/// # Ok(())
165/// # }
166/// ```
167#[derive(Debug, Clone, PartialEq)]
168pub struct Function<Prim: PrimitiveType = Num> {
169    /// Type of function arguments.
170    pub(crate) args: Tuple<Prim>,
171    /// Type of the value returned by the function.
172    pub(crate) return_type: Type<Prim>,
173    /// Cache for function params.
174    pub(crate) params: Option<Arc<FnParams<Prim>>>,
175}
176
177impl<Prim: PrimitiveType> fmt::Display for Function<Prim> {
178    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
179        let constraints = self
180            .params
181            .as_ref()
182            .and_then(|params| params.constraints.as_ref());
183        if let Some(constraints) = constraints {
184            if !constraints.is_empty() {
185                write!(formatter, "for<{}> ", constraints)?;
186            }
187        }
188
189        self.args.format_as_tuple(formatter)?;
190        write!(formatter, " -> {}", self.return_type)?;
191        Ok(())
192    }
193}
194
195impl<Prim: PrimitiveType> Function<Prim> {
196    pub(crate) fn new(args: Tuple<Prim>, return_type: Type<Prim>) -> Self {
197        Self {
198            args,
199            return_type,
200            params: None,
201        }
202    }
203
204    /// Returns a builder for `Function`s.
205    pub fn builder() -> FunctionBuilder<Prim> {
206        FunctionBuilder::default()
207    }
208
209    /// Gets the argument types of this function.
210    pub fn args(&self) -> &Tuple<Prim> {
211        &self.args
212    }
213
214    /// Gets the return type of this function.
215    pub fn return_type(&self) -> &Type<Prim> {
216        &self.return_type
217    }
218
219    pub(crate) fn set_params(&mut self, params: FnParams<Prim>) {
220        self.params = Some(Arc::new(params));
221    }
222
223    pub(crate) fn is_parametric(&self) -> bool {
224        self.params
225            .as_ref()
226            .map_or(false, |params| !params.is_empty())
227    }
228
229    /// Returns `true` iff this type does not contain type / length variables.
230    ///
231    /// See [`TypeEnvironment`](crate::TypeEnvironment) for caveats of dealing with
232    /// non-concrete types.
233    pub fn is_concrete(&self) -> bool {
234        self.args.is_concrete() && self.return_type.is_concrete()
235    }
236
237    /// Marks type params with the specified `indexes` to have `constraints`.
238    ///
239    /// # Panics
240    ///
241    /// - Panics if parameters were already computed for the function.
242    pub fn with_constraints<C: Constraint<Prim>>(
243        self,
244        indexes: &[usize],
245        constraint: C,
246    ) -> FnWithConstraints<Prim> {
247        assert!(
248            self.params.is_none(),
249            "Cannot attach constraints to a function with computed params: `{}`",
250            self
251        );
252
253        let constraints = CompleteConstraints::from(ConstraintSet::just(constraint));
254        let type_params = indexes
255            .iter()
256            .map(|&idx| (idx, constraints.clone()))
257            .collect();
258
259        FnWithConstraints {
260            function: self,
261            constraints: ParamConstraints {
262                type_params,
263                static_lengths: HashSet::new(),
264            },
265        }
266    }
267
268    /// Marks lengths with the specified `indexes` as static.
269    ///
270    /// # Panics
271    ///
272    /// - Panics if parameters were already computed for the function.
273    pub fn with_static_lengths(self, indexes: &[usize]) -> FnWithConstraints<Prim> {
274        assert!(
275            self.params.is_none(),
276            "Cannot attach constraints to a function with computed params: `{}`",
277            self
278        );
279
280        FnWithConstraints {
281            function: self,
282            constraints: ParamConstraints {
283                type_params: HashMap::new(),
284                static_lengths: indexes.iter().copied().collect(),
285            },
286        }
287    }
288}
289
290/// Function together with constraints on type variables contained either in the function itself
291/// or any of the child functions.
292///
293/// Constructed via [`Function::with_constraints()`].
294#[derive(Debug)]
295pub struct FnWithConstraints<Prim: PrimitiveType> {
296    function: Function<Prim>,
297    constraints: ParamConstraints<Prim>,
298}
299
300impl<Prim: PrimitiveType> fmt::Display for FnWithConstraints<Prim> {
301    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
302        if self.constraints.is_empty() {
303            fmt::Display::fmt(&self.function, formatter)
304        } else {
305            write!(formatter, "for<{}> {}", self.constraints, self.function)
306        }
307    }
308}
309
310impl<Prim: PrimitiveType> FnWithConstraints<Prim> {
311    /// Marks type params with the specified `indexes` to have `constraints`. If some constraints
312    /// are already present for some of the types, they are overwritten.
313    pub fn with_constraint<C>(mut self, indexes: &[usize], constraint: &C) -> Self
314    where
315        C: Constraint<Prim> + Clone,
316    {
317        for &i in indexes {
318            let constraints = self.constraints.type_params.entry(i).or_default();
319            constraints.simple.insert(constraint.clone());
320        }
321        self
322    }
323
324    /// Marks lengths with the specified `indexes` as static.
325    pub fn with_static_lengths(mut self, indexes: &[usize]) -> FnWithConstraints<Prim> {
326        let indexes = indexes.iter().copied();
327        self.constraints.static_lengths.extend(indexes);
328        self
329    }
330}
331
332impl<Prim: PrimitiveType> From<FnWithConstraints<Prim>> for Function<Prim> {
333    fn from(value: FnWithConstraints<Prim>) -> Self {
334        let mut function = value.function;
335        ParamQuantifier::set_params(&mut function, value.constraints);
336        function
337    }
338}
339
340impl<Prim: PrimitiveType> From<FnWithConstraints<Prim>> for Type<Prim> {
341    fn from(value: FnWithConstraints<Prim>) -> Self {
342        Function::from(value).into()
343    }
344}
345
346/// Builder for functional types.
347///
348/// **Tip.** You may also use [`FromStr`](core::str::FromStr) implementation to parse
349/// functional types.
350///
351/// # Examples
352///
353/// Signature for a function summing a slice of numbers:
354///
355/// ```
356/// # use arithmetic_typing::{Function, UnknownLen, Type, TypeEnvironment};
357/// let sum_fn_type = Function::builder()
358///     .with_arg(Type::NUM.repeat(UnknownLen::param(0)))
359///     .returning(Type::NUM);
360/// assert_eq!(sum_fn_type.to_string(), "([Num; N]) -> Num");
361/// ```
362///
363/// Signature for a slice mapping function:
364///
365/// ```
366/// # use arithmetic_typing::{arith::Linearity, Function, UnknownLen, Type};
367/// // Definition of the mapping arg.
368/// let map_fn_arg = <Function>::builder()
369///     .with_arg(Type::param(0))
370///     .returning(Type::param(1));
371///
372/// let map_fn_type = <Function>::builder()
373///     .with_arg(Type::param(0).repeat(UnknownLen::param(0)))
374///     .with_arg(map_fn_arg)
375///     .returning(Type::param(1).repeat(UnknownLen::Dynamic))
376///     .with_constraints(&[1], Linearity);
377/// assert_eq!(
378///     map_fn_type.to_string(),
379///     "for<'U: Lin> (['T; N], ('T) -> 'U) -> ['U]"
380/// );
381/// ```
382///
383/// Signature of a function with varargs:
384///
385/// ```
386/// # use arithmetic_typing::{Function, UnknownLen, Type};
387/// let fn_type = <Function>::builder()
388///     .with_varargs(Type::param(0), UnknownLen::param(0))
389///     .with_arg(Type::BOOL)
390///     .returning(Type::param(0));
391/// assert_eq!(fn_type.to_string(), "(...['T; N], Bool) -> 'T");
392/// ```
393#[derive(Debug, Clone)]
394pub struct FunctionBuilder<Prim: PrimitiveType = Num> {
395    args: Tuple<Prim>,
396}
397
398impl<Prim: PrimitiveType> Default for FunctionBuilder<Prim> {
399    fn default() -> Self {
400        Self {
401            args: Tuple::empty(),
402        }
403    }
404}
405
406impl<Prim: PrimitiveType> FunctionBuilder<Prim> {
407    /// Adds a new argument to the function definition.
408    pub fn with_arg(mut self, arg: impl Into<Type<Prim>>) -> Self {
409        self.args.push(arg.into());
410        self
411    }
412
413    /// Adds or sets varargs in the function definition.
414    pub fn with_varargs(
415        mut self,
416        element: impl Into<Type<Prim>>,
417        len: impl Into<TupleLen>,
418    ) -> Self {
419        self.args.set_middle(element.into(), len.into());
420        self
421    }
422
423    /// Declares the return type of the function and builds it.
424    pub fn returning(self, return_type: impl Into<Type<Prim>>) -> Function<Prim> {
425        Function::new(self.args, return_type.into())
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use crate::{arith::Linearity, UnknownLen};
433
434    #[test]
435    fn constraints_display() {
436        let type_constraints = ConstraintSet::<Num>::just(Linearity);
437        let type_constraints = CompleteConstraints::from(type_constraints);
438
439        let constraints = ParamConstraints {
440            type_params: vec![(0, type_constraints.clone())].into_iter().collect(),
441            static_lengths: HashSet::new(),
442        };
443        assert_eq!(constraints.to_string(), "'T: Lin");
444
445        let constraints: ParamConstraints<Num> = ParamConstraints {
446            type_params: vec![(0, type_constraints)].into_iter().collect(),
447            static_lengths: vec![0].into_iter().collect(),
448        };
449        assert_eq!(constraints.to_string(), "len! N; 'T: Lin");
450    }
451
452    #[test]
453    fn fn_with_constraints_display() {
454        let sum_fn = <Function>::builder()
455            .with_arg(Type::param(0).repeat(UnknownLen::param(0)))
456            .returning(Type::param(0))
457            .with_constraints(&[0], Linearity);
458        assert_eq!(sum_fn.to_string(), "for<'T: Lin> (['T; N]) -> 'T");
459    }
460
461    #[test]
462    fn fn_builder_with_quantified_arg() {
463        let sum_fn: Function = Function::builder()
464            .with_arg(Type::NUM.repeat(UnknownLen::param(0)))
465            .returning(Type::NUM)
466            .with_constraints(&[], Linearity)
467            .into();
468        assert_eq!(sum_fn.to_string(), "([Num; N]) -> Num");
469
470        let complex_fn: Function = Function::builder()
471            .with_arg(Type::NUM)
472            .with_arg(sum_fn.clone())
473            .returning(Type::NUM)
474            .with_constraints(&[], Linearity)
475            .into();
476        assert_eq!(complex_fn.to_string(), "(Num, ([Num; N]) -> Num) -> Num");
477
478        let other_complex_fn: Function = Function::builder()
479            .with_varargs(Type::NUM, UnknownLen::param(0))
480            .with_arg(sum_fn)
481            .returning(Type::NUM)
482            .with_constraints(&[], Linearity)
483            .into();
484        assert_eq!(
485            other_complex_fn.to_string(),
486            "(...[Num; N], ([Num; N]) -> Num) -> Num"
487        );
488    }
489}