computation_types/
named_args.rs

1use core::fmt;
2use std::{any::type_name, collections::BTreeMap};
3
4use downcast_rs::{impl_downcast, Downcast};
5use paste::paste;
6
7use crate::{Name, Names};
8
9pub trait AnyArg: Downcast + fmt::Debug {
10    fn boxed_clone(&self) -> Box<dyn AnyArg>;
11}
12impl_downcast!(AnyArg);
13impl<T> AnyArg for T
14where
15    T: 'static + Clone + fmt::Debug,
16{
17    fn boxed_clone(&self) -> Box<dyn AnyArg> {
18        Box::new(self.clone())
19    }
20}
21impl Clone for Box<dyn AnyArg> {
22    fn clone(&self) -> Self {
23        // Calling `boxed_clone` without `as_ref`
24        // will result in a stack overflow.
25        self.as_ref().boxed_clone()
26    }
27}
28
29#[derive(Clone, Debug)]
30pub struct NamedArgs(BTreeMap<Name, Box<dyn AnyArg>>);
31
32impl Default for NamedArgs {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38#[macro_export]
39macro_rules! named_args {
40( ) => {
41    $crate::NamedArgs::new()
42};
43( ($name:expr, $arg:expr) ) => {
44    $crate::NamedArgs::singleton($name, $arg)
45};
46( ($name:expr, $arg:expr), $( $rest:tt ),* ) => {
47    $crate::NamedArgs::singleton($name, $arg).union(named_args![$( $rest ),*])
48};
49}
50
51impl NamedArgs {
52    pub fn new() -> Self {
53        Self(BTreeMap::new())
54    }
55
56    pub fn singleton<T>(name: Name, arg: T) -> Self
57    where
58        T: AnyArg,
59    {
60        let arg = Box::new(arg) as Box<dyn AnyArg>;
61        Self(std::iter::once((name, arg)).collect())
62    }
63
64    pub fn union(mut self, mut other: Self) -> Self {
65        self.0.append(&mut other.0);
66        self
67    }
68
69    pub fn contains_args(&self, args: &Names) -> bool {
70        args.iter().all(|arg| self.0.contains_key(arg))
71    }
72
73    /// Return a (map with `fst_names`, map with `snd_names`) tuple if all arguments are present.
74    pub fn partition(
75        self,
76        fst_names: &Names,
77        snd_names: &Names,
78    ) -> Result<(Self, Self), PartitionErr> {
79        // We want to avoid making a new map
80        // if we can simply return this map
81        // as the partition with all arguments.
82        if snd_names.is_empty() && fst_names.len() == self.0.len() {
83            if self.contains_args(fst_names) {
84                Ok((self, Self::new()))
85            } else {
86                Err(PartitionErr::Missing(
87                    fst_names
88                        .iter()
89                        .find(|arg| !self.0.contains_key(arg))
90                        .unwrap(),
91                ))
92            }
93        } else if fst_names.is_empty() && snd_names.len() == self.0.len() {
94            if self.contains_args(snd_names) {
95                Ok((Self::new(), self))
96            } else {
97                Err(PartitionErr::Missing(
98                    snd_names
99                        .iter()
100                        .find(|arg| !self.0.contains_key(arg))
101                        .unwrap(),
102                ))
103            }
104        } else {
105            let mut fst_map = Self::new();
106            let mut snd_map = Self::new();
107
108            for (name, arg) in self.0.into_iter() {
109                match (fst_names.contains(name), snd_names.contains(name)) {
110                    (true, true) => {
111                        fst_map.0.insert(name, arg.clone());
112                        snd_map.0.insert(name, arg);
113                    }
114                    (true, false) => {
115                        fst_map.0.insert(name, arg);
116                    }
117                    (false, true) => {
118                        snd_map.0.insert(name, arg);
119                    }
120                    (false, false) => {}
121                }
122            }
123
124            if fst_map.contains_args(fst_names) && snd_map.contains_args(snd_names) {
125                Ok((fst_map, snd_map))
126            } else {
127                Err(PartitionErr::Missing(
128                    fst_names
129                        .iter()
130                        .find(|arg| !fst_map.0.contains_key(arg))
131                        .unwrap_or_else(|| {
132                            snd_names
133                                .iter()
134                                .find(|arg| !snd_map.0.contains_key(arg))
135                                .unwrap()
136                        }),
137                ))
138            }
139        }
140    }
141
142    /// Return the given argument
143    /// if it is present
144    /// and has the right type.
145    ///
146    /// Note,
147    /// if argument is present
148    /// but has the wrong type,
149    /// it will still be removed.
150    pub fn pop<T>(&mut self, name: Name) -> Result<T, PopErr>
151    where
152        T: 'static + AnyArg,
153    {
154        self.0
155            .remove(name)
156            .map_or_else(
157                || Err(PopErr::Missing(name)),
158                |x| {
159                    x.downcast().map_err(|x| PopErr::WrongType {
160                        name,
161                        arg: x,
162                        ty: type_name::<T>(),
163                    })
164                },
165            )
166            .map(|x: Box<T>| *x)
167    }
168
169    /// This function is unstable and likely to change.
170    ///
171    /// Use at your own risk.
172    pub fn insert_raw(&mut self, name: Name, arg: Box<dyn AnyArg>) {
173        self.0.insert(name, arg);
174    }
175
176    #[cfg(test)]
177    fn into_map<T>(self) -> BTreeMap<Name, T>
178    where
179        T: 'static + AnyArg,
180    {
181        self.0
182            .into_iter()
183            .map(|(k, v)| (k, *v.downcast().unwrap_or_else(|_| panic!())))
184            .collect()
185    }
186}
187
188macro_rules! impl_partition_n {
189    ( $n:expr, $( $i:expr ),* ) => {
190        paste! {
191            impl NamedArgs {
192                /// Return a tuple with each requested set of arguments
193                /// if all arguments are present.
194                #[allow(clippy::too_many_arguments)]
195                pub fn [<partition $n>](mut self, $( [<names_ $i>]: &Names ),* ) -> Result<( $( impl_partition_n!(@as_self $i) ),* ), PartitionErr> {
196                    let all_names = [ $( [<names_ $i>] ),* ];
197
198                    let mut partitions = Vec::new();
199                    for i in 0..($n - 1) {
200                        let (next, rest) = self.partition(all_names[i], &Names::union_many(all_names.into_iter().skip(i + 1)))?;
201                        partitions.push(next);
202                        self = rest;
203                    }
204                    partitions.push(self);
205
206                    // My hope is
207                    // the Rust compiler will understand
208                    // it does not actually need to clone.
209                    Ok(( $( partitions[$i].clone() ),* ))
210                }
211            }
212        }
213    };
214    ( @as_self $i:expr ) => {
215        Self
216    };
217}
218
219impl_partition_n!(3, 0, 1, 2);
220impl_partition_n!(4, 0, 1, 2, 3);
221impl_partition_n!(5, 0, 1, 2, 3, 4);
222impl_partition_n!(6, 0, 1, 2, 3, 4, 5);
223impl_partition_n!(7, 0, 1, 2, 3, 4, 5, 6);
224impl_partition_n!(8, 0, 1, 2, 3, 4, 5, 6, 7);
225impl_partition_n!(9, 0, 1, 2, 3, 4, 5, 6, 7, 8);
226impl_partition_n!(10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
227impl_partition_n!(11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
228impl_partition_n!(12, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
229impl_partition_n!(13, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
230impl_partition_n!(14, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13);
231impl_partition_n!(15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14);
232impl_partition_n!(16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
233
234#[derive(Clone, Debug, thiserror::Error)]
235pub enum PartitionErr {
236    #[error("`NamedArgs` is missing `{0}`.")]
237    Missing(&'static str),
238}
239
240#[derive(Clone, Debug, thiserror::Error)]
241pub enum PopErr {
242    #[error("`NamedArgs` is missing `{0}`.")]
243    Missing(&'static str),
244    #[error("Expected type `{ty}` for `{name}`, found arg `{arg:?}`.")]
245    WrongType {
246        name: &'static str,
247        arg: Box<dyn AnyArg>,
248        ty: &'static str,
249    },
250}
251
252impl<Names, Args> From<(Names, Args)> for NamedArgs
253where
254    NamedArgs: FromNamesArgs<Names, Args>,
255{
256    fn from(value: (Names, Args)) -> Self {
257        Self::from_names_args(value.0, value.1)
258    }
259}
260
261trait FromNamesArgs<Names, Args> {
262    fn from_names_args(names: Names, args: Args) -> Self;
263}
264
265impl<T> FromNamesArgs<Name, T> for NamedArgs
266where
267    T: AnyArg,
268{
269    default fn from_names_args(names: Name, args: T) -> Self {
270        Self::singleton(names, args)
271    }
272}
273
274impl<Names0, T0> FromNamesArgs<(Names0,), (T0,)> for NamedArgs
275where
276    NamedArgs: FromNamesArgs<Names0, T0>,
277{
278    fn from_names_args(names: (Names0,), args: (T0,)) -> Self {
279        Self::from_names_args(names.0, args.0)
280    }
281}
282
283macro_rules! impl_from_names_args {
284    ( $( $i:expr ),* ) => {
285        paste! {
286            impl< $( [<Names $i>] ),* , $( [<T $i>] ),* > FromNamesArgs<( $( [<Names $i>] ),* ), ( $( [<T $i>] ),* )> for NamedArgs
287            where
288                $( NamedArgs: FromNamesArgs<[<Names $i>], [<T $i>]> ),*
289            {
290                fn from_names_args(names: ( $( [<Names $i>] ),* ), args: ( $( [<T $i>] ),* )) -> Self {
291                    let mut out = NamedArgs::new();
292                    $(
293                        out = out.union(Self::from_names_args(names.$i, args.$i));
294                    )*
295                    out
296                }
297            }
298        }
299    };
300    ( @as_name $i:expr ) => {
301        Name
302    };
303}
304
305impl_from_names_args!(0, 1);
306impl_from_names_args!(0, 1, 2);
307impl_from_names_args!(0, 1, 2, 3);
308impl_from_names_args!(0, 1, 2, 3, 4);
309impl_from_names_args!(0, 1, 2, 3, 4, 5);
310impl_from_names_args!(0, 1, 2, 3, 4, 5, 6);
311impl_from_names_args!(0, 1, 2, 3, 4, 5, 6, 7);
312impl_from_names_args!(0, 1, 2, 3, 4, 5, 6, 7, 8);
313impl_from_names_args!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
314impl_from_names_args!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
315impl_from_names_args!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
316impl_from_names_args!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
317impl_from_names_args!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13);
318impl_from_names_args!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14);
319impl_from_names_args!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
320
321#[cfg(test)]
322mod tests {
323    use crate::names;
324
325    #[test]
326    fn partition_should_return_full_map_for_fst_when_all_args() {
327        let named_args = named_args![("foo", 1), ("bar", 2)];
328        assert_eq!(
329            named_args
330                .clone()
331                .partition(&names!["foo", "bar"], &names![])
332                .unwrap()
333                .0
334                .into_map::<i32>(),
335            named_args.into_map()
336        );
337    }
338
339    #[test]
340    fn partition_should_return_full_map_for_snd_when_all_args() {
341        let named_args = named_args![("foo", 1), ("bar", 2)];
342        assert_eq!(
343            named_args
344                .clone()
345                .partition(&names![], &names!["foo", "bar"])
346                .unwrap()
347                .1
348                .into_map::<i32>(),
349            named_args.into_map()
350        );
351    }
352
353    #[test]
354    fn partition_should_return_maps_with_given_args_when_some_args() {
355        let named_args = named_args![("foo", 1), ("bar", 2), ("baz", 3)];
356        let (left, right) = named_args
357            .partition(&names!["baz"], &names!["foo"])
358            .unwrap();
359        assert_eq!(left.into_map::<i32>(), named_args![("baz", 3)].into_map(),);
360        assert_eq!(right.into_map::<i32>(), named_args![("foo", 1)].into_map(),);
361    }
362
363    #[test]
364    fn partition_should_duplicate_args_required_by_both() {
365        let named_args = named_args![("foo", 1), ("bar", 2), ("baz", 3), ("biz", 4)];
366        let (left, right) = named_args
367            .partition(&names!["foo", "bar", "baz"], &names!["foo", "bar", "biz"])
368            .unwrap();
369        assert_eq!(
370            left.into_map::<i32>(),
371            named_args![("foo", 1), ("bar", 2), ("baz", 3)].into_map(),
372        );
373        assert_eq!(
374            right.into_map::<i32>(),
375            named_args![("foo", 1), ("bar", 2), ("biz", 4)].into_map(),
376        );
377    }
378
379    #[test]
380    fn partition_should_return_none_if_arg_missing() {
381        let named_args = named_args![("foo", 1), ("bar", 2)];
382        assert!(named_args
383            .clone()
384            .partition(&names![], &names!["foo", "baz"])
385            .is_err());
386        assert!(named_args
387            .partition(&names!["foo", "baz"], &names![])
388            .is_err());
389    }
390
391    #[test]
392    fn partition3_should_return_full_map_for_fst_when_all_args() {
393        let named_args = named_args![("foo", 1), ("bar", 2)];
394        assert_eq!(
395            named_args
396                .clone()
397                .partition3(&names!["foo", "bar"], &names![], &names![])
398                .unwrap()
399                .0
400                .into_map::<i32>(),
401            named_args.into_map()
402        );
403    }
404
405    #[test]
406    fn partition3_should_return_full_map_for_snd_when_all_args() {
407        let named_args = named_args![("foo", 1), ("bar", 2)];
408        assert_eq!(
409            named_args
410                .clone()
411                .partition3(&names![], &names!["foo", "bar"], &names![])
412                .unwrap()
413                .1
414                .into_map::<i32>(),
415            named_args.into_map()
416        );
417    }
418
419    #[test]
420    fn partition3_should_return_full_map_for_third_when_all_args() {
421        let named_args = named_args![("foo", 1), ("bar", 2)];
422        assert_eq!(
423            named_args
424                .clone()
425                .partition3(&names![], &names![], &names!["foo", "bar"])
426                .unwrap()
427                .2
428                .into_map::<i32>(),
429            named_args.into_map()
430        );
431    }
432
433    #[test]
434    fn partition3_should_return_maps_with_given_args_when_some_args() {
435        let named_args = named_args![("foo", 1), ("bar", 2), ("baz", 3), ("bin", 4)];
436        let (fst, snd, third) = named_args
437            .partition3(&names!["baz"], &names!["foo"], &names!["bar"])
438            .unwrap();
439        assert_eq!(fst.into_map::<i32>(), named_args![("baz", 3)].into_map(),);
440        assert_eq!(snd.into_map::<i32>(), named_args![("foo", 1)].into_map(),);
441        assert_eq!(third.into_map::<i32>(), named_args![("bar", 2)].into_map(),);
442    }
443
444    #[test]
445    fn partition3_should_duplicate_args_required_by_more_than_one() {
446        let named_args = named_args![("foo", 1), ("bar", 2), ("baz", 3), ("biz", 4), ("bop", 5)];
447        let (fst, snd, third) = named_args
448            .partition3(
449                &names!["foo", "bar", "baz"],
450                &names!["foo", "bar", "biz"],
451                &names!["bar", "biz", "bop"],
452            )
453            .unwrap();
454        assert_eq!(
455            fst.into_map::<i32>(),
456            named_args![("foo", 1), ("bar", 2), ("baz", 3)].into_map(),
457        );
458        assert_eq!(
459            snd.into_map::<i32>(),
460            named_args![("foo", 1), ("bar", 2), ("biz", 4)].into_map(),
461        );
462        assert_eq!(
463            third.into_map::<i32>(),
464            named_args![("bar", 2), ("biz", 4), ("bop", 5)].into_map(),
465        );
466    }
467
468    #[test]
469    fn partition3_should_return_none_if_arg_missing() {
470        let named_args = named_args![("foo", 1), ("bar", 2)];
471        assert!(named_args
472            .clone()
473            .partition3(&names!["foo", "baz"], &names![], &names![])
474            .is_err());
475        assert!(named_args
476            .clone()
477            .partition3(&names![], &names!["foo", "baz"], &names![])
478            .is_err());
479        assert!(named_args
480            .partition3(&names![], &names![], &names!["foo", "baz"])
481            .is_err());
482    }
483
484    #[test]
485    fn pop_should_return_arg_if_available_and_correct_type() {
486        let mut named_args = named_args![("foo", 1), ("bar", 2.0)];
487        assert_eq!(named_args.pop::<i32>("foo").ok(), Some(1));
488        assert_eq!(named_args.pop::<f64>("bar").ok(), Some(2.0));
489    }
490
491    #[test]
492    fn pop_should_return_none_if_not_available() {
493        let mut named_args = named_args![("foo", 1), ("bar", 2.0)];
494        assert_eq!(named_args.pop::<i32>("baz").ok(), None);
495    }
496
497    #[test]
498    fn pop_should_return_none_if_wrong_type() {
499        let mut named_args = named_args![("foo", 1), ("bar", 2.0)];
500        assert_eq!(named_args.pop::<f64>("foo").ok(), None);
501    }
502}