arithmetic_typing/
defs.rs

1//! Type definitions for the standard types from the [`arithmetic-eval`] crate.
2//!
3//! [`arithmetic-eval`]: https://docs.rs/arithmetic-eval/
4
5use crate::{arith::WithBoolean, Function, PrimitiveType, Type, UnknownLen};
6
7/// Map containing type definitions for all variables from `Prelude` in the eval crate,
8/// except for `loop` function.
9///
10/// # Contents
11///
12/// - `true` and `false` Boolean constants
13/// - `if`, `while`, `map`, `filter`, `fold`, `push` and `merge` functions
14///
15/// The `merge` function has somewhat imprecise typing; its return value is
16/// a dynamically-sized slice.
17///
18/// The `array` function is available separately via [`Self::array()`].
19///
20/// # Examples
21///
22/// Function counting number of zeros in a slice:
23///
24/// ```
25/// use arithmetic_parser::grammars::{F32Grammar, Parse};
26/// use arithmetic_typing::{defs::Prelude, Annotated, TypeEnvironment, Type};
27///
28/// # fn main() -> anyhow::Result<()> {
29/// let code = "|xs| xs.fold(0, |acc, x| if(x == 0, acc + 1, acc))";
30/// let ast = Annotated::<F32Grammar>::parse_statements(code)?;
31///
32/// let mut env: TypeEnvironment = Prelude::iter().collect();
33/// let count_zeros_fn = env.process_statements(&ast)?;
34/// assert_eq!(count_zeros_fn.to_string(), "([Num; N]) -> Num");
35/// # Ok(())
36/// # }
37/// ```
38///
39/// Limitations of `merge`:
40///
41/// ```
42/// # use arithmetic_parser::grammars::{F32Grammar, Parse};
43/// # use arithmetic_typing::{defs::Prelude, error::ErrorKind, Annotated, TypeEnvironment, Type};
44/// # use assert_matches::assert_matches;
45/// # fn main() -> anyhow::Result<()> {
46/// let code = r#"
47///     len = |xs| xs.fold(0, |acc, _| acc + 1);
48///     slice = (1, 2).merge((3, 4));
49///     slice.len(); // methods working on slices are applicable
50///     (_, _, _, z) = slice; // but destructring is not
51/// "#;
52/// let ast = Annotated::<F32Grammar>::parse_statements(code)?;
53///
54/// let mut env: TypeEnvironment = Prelude::iter().collect();
55/// let errors = env.process_statements(&ast).unwrap_err();
56/// assert_eq!(errors.len(), 1);
57/// let err = errors.iter().next().unwrap();
58/// assert_eq!(*err.main_span().fragment(), "(_, _, _, z)");
59/// # assert_matches!(err.kind(), ErrorKind::TupleLenMismatch { .. });
60/// # Ok(())
61/// # }
62/// ```
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64#[non_exhaustive]
65pub enum Prelude {
66    /// `false` type (Boolean).
67    False,
68    /// `true` type (Boolean).
69    True,
70    /// Type of the `if` function.
71    If,
72    /// Type of the `while` function.
73    While,
74    /// Type of the `map` function.
75    Map,
76    /// Type of the `filter` function.
77    Filter,
78    /// Type of the `fold` function.
79    Fold,
80    /// Type of the `push` function.
81    Push,
82    /// Type of the `merge` function.
83    Merge,
84}
85
86impl<Prim: WithBoolean> From<Prelude> for Type<Prim> {
87    fn from(value: Prelude) -> Self {
88        match value {
89            Prelude::True | Prelude::False => Type::BOOL,
90
91            Prelude::If => Function::builder()
92                .with_arg(Type::BOOL)
93                .with_arg(Type::param(0))
94                .with_arg(Type::param(0))
95                .returning(Type::param(0))
96                .into(),
97
98            Prelude::While => {
99                let condition_fn = Function::builder()
100                    .with_arg(Type::param(0))
101                    .returning(Type::BOOL);
102                let iter_fn = Function::builder()
103                    .with_arg(Type::param(0))
104                    .returning(Type::param(0));
105
106                Function::builder()
107                    .with_arg(Type::param(0)) // state
108                    .with_arg(condition_fn)
109                    .with_arg(iter_fn)
110                    .returning(Type::param(0))
111                    .into()
112            }
113
114            Prelude::Map => {
115                let map_arg = Function::builder()
116                    .with_arg(Type::param(0))
117                    .returning(Type::param(1));
118
119                Function::builder()
120                    .with_arg(Type::param(0).repeat(UnknownLen::param(0)))
121                    .with_arg(map_arg)
122                    .returning(Type::param(1).repeat(UnknownLen::param(0)))
123                    .into()
124            }
125
126            Prelude::Filter => {
127                let predicate_arg = Function::builder()
128                    .with_arg(Type::param(0))
129                    .returning(Type::BOOL);
130
131                Function::builder()
132                    .with_arg(Type::param(0).repeat(UnknownLen::Dynamic))
133                    .with_arg(predicate_arg)
134                    .returning(Type::param(0).repeat(UnknownLen::Dynamic))
135                    .into()
136            }
137
138            Prelude::Fold => {
139                // 0th type param is slice element, 1st is accumulator
140                let fold_arg = Function::builder()
141                    .with_arg(Type::param(1))
142                    .with_arg(Type::param(0))
143                    .returning(Type::param(1));
144
145                Function::builder()
146                    .with_arg(Type::param(0).repeat(UnknownLen::Dynamic))
147                    .with_arg(Type::param(1))
148                    .with_arg(fold_arg)
149                    .returning(Type::param(1))
150                    .into()
151            }
152
153            Prelude::Push => Function::builder()
154                .with_arg(Type::param(0).repeat(UnknownLen::param(0)))
155                .with_arg(Type::param(0))
156                .returning(Type::param(0).repeat(UnknownLen::param(0) + 1))
157                .into(),
158
159            Prelude::Merge => Function::builder()
160                .with_arg(Type::param(0).repeat(UnknownLen::Dynamic))
161                .with_arg(Type::param(0).repeat(UnknownLen::Dynamic))
162                .returning(Type::param(0).repeat(UnknownLen::Dynamic))
163                .into(),
164        }
165    }
166}
167
168impl Prelude {
169    const VALUES: &'static [Self] = &[
170        Self::True,
171        Self::False,
172        Self::If,
173        Self::While,
174        Self::Map,
175        Self::Filter,
176        Self::Fold,
177        Self::Push,
178        Self::Merge,
179    ];
180
181    fn as_str(self) -> &'static str {
182        match self {
183            Self::True => "true",
184            Self::False => "false",
185            Self::If => "if",
186            Self::While => "while",
187            Self::Map => "map",
188            Self::Filter => "filter",
189            Self::Fold => "fold",
190            Self::Push => "push",
191            Self::Merge => "merge",
192        }
193    }
194
195    /// Returns the type of the `array` generation function from the eval crate.
196    ///
197    /// The `array` function is not included into [`Self::iter()`] because in the general case
198    /// we don't know the type of indexes.
199    pub fn array<T: PrimitiveType>(index_type: T) -> Function<T> {
200        Function::builder()
201            .with_arg(Type::Prim(index_type.clone()))
202            .with_arg(
203                Function::builder()
204                    .with_arg(Type::Prim(index_type))
205                    .returning(Type::param(0)),
206            )
207            .returning(Type::param(0).repeat(UnknownLen::Dynamic))
208    }
209
210    /// Returns an iterator over all type definitions in the `Prelude`.
211    pub fn iter<Prim: WithBoolean>() -> impl Iterator<Item = (&'static str, Type<Prim>)> {
212        Self::VALUES
213            .iter()
214            .map(|&value| (value.as_str(), value.into()))
215    }
216}
217
218/// Definitions for `assert` and `assert_eq` functions.
219#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
220#[non_exhaustive]
221pub enum Assertions {
222    /// Type of the `assert` function.
223    Assert,
224    /// Type of the `assert_eq` function.
225    AssertEq,
226}
227
228impl<Prim: WithBoolean> From<Assertions> for Type<Prim> {
229    fn from(value: Assertions) -> Self {
230        match value {
231            Assertions::Assert => Function::builder()
232                .with_arg(Type::BOOL)
233                .returning(Type::void())
234                .into(),
235            Assertions::AssertEq => Function::builder()
236                .with_arg(Type::param(0))
237                .with_arg(Type::param(0))
238                .returning(Type::void())
239                .into(),
240        }
241    }
242}
243
244impl Assertions {
245    const VALUES: &'static [Self] = &[Self::Assert, Self::AssertEq];
246
247    fn as_str(self) -> &'static str {
248        match self {
249            Self::Assert => "assert",
250            Self::AssertEq => "assert_eq",
251        }
252    }
253
254    /// Returns an iterator over all type definitions in `Assertions`.
255    pub fn iter<Prim: WithBoolean>() -> impl Iterator<Item = (&'static str, Type<Prim>)> {
256        Self::VALUES.iter().map(|&val| (val.as_str(), val.into()))
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::arith::Num;
264
265    use std::collections::{HashMap, HashSet};
266
267    const EXPECTED_PRELUDE_TYPES: &[(&str, &str)] = &[
268        ("false", "Bool"),
269        ("true", "Bool"),
270        ("if", "(Bool, 'T, 'T) -> 'T"),
271        ("while", "('T, ('T) -> Bool, ('T) -> 'T) -> 'T"),
272        ("map", "(['T; N], ('T) -> 'U) -> ['U; N]"),
273        ("filter", "(['T], ('T) -> Bool) -> ['T]"),
274        ("fold", "(['T], 'U, ('U, 'T) -> 'U) -> 'U"),
275        ("push", "(['T; N], 'T) -> ['T; N + 1]"),
276        ("merge", "(['T], ['T]) -> ['T]"),
277    ];
278
279    #[test]
280    fn string_presentations_of_prelude_types() {
281        let expected_types: HashMap<_, _> = EXPECTED_PRELUDE_TYPES.iter().copied().collect();
282
283        for (name, ty) in Prelude::iter::<Num>() {
284            assert_eq!(ty.to_string(), expected_types[name]);
285        }
286
287        assert_eq!(
288            Prelude::iter::<Num>()
289                .map(|(name, _)| name)
290                .collect::<HashSet<_>>(),
291            expected_types.keys().copied().collect::<HashSet<_>>()
292        );
293    }
294
295    #[test]
296    fn string_presentation_of_array_type() {
297        let array_fn = Prelude::array(Num::Num);
298        assert_eq!(array_fn.to_string(), "(Num, (Num) -> 'T) -> ['T]");
299    }
300}