lambdas/domains/
simple.rs

1/// This is an example domain, heavily commented to explain how to implement your own!
2
3use crate::*;
4use std::collections::HashMap;
5
6/// A simple domain with ints and polymorphic lists (allows nested lists).
7/// Generally it's good to be able to imagine the hindley milner type system
8/// for your domain so that it's compatible when we add that later. In this case the types
9/// would look like `T := (T -> T) | Int | List(T)` where functions are handled
10/// by dreamegg::domain::Val so they don't appear here.
11#[derive(Clone,Debug, PartialEq, Eq, Hash)]
12pub enum SimpleVal {
13    Int(i32),
14    List(Vec<Val>),
15}
16
17#[derive(Clone,Debug, PartialEq, Eq, Hash)]
18pub enum SimpleType {
19    TInt,
20    TList
21}
22
23// aliases of various typed specialized to our SimpleVal
24type Val = crate::eval::Val<SimpleVal>;
25type LazyVal = crate::eval::LazyVal<SimpleVal>;
26type Evaluator<'a> = crate::eval::Evaluator<'a,SimpleVal>;
27type VResult = crate::eval::VResult<SimpleVal>;
28type DSLFn = crate::dsl::DSLFn<SimpleVal>;
29
30// to more concisely refer to the variants
31use SimpleVal::*;
32
33use crate::eval::Val::*;
34// use domain::Type::*;
35
36// this macro generates two global lazy_static constants: PRIM and FUNCS
37// which get used by `val_of_prim` and `fn_of_prim` below. In short they simply
38// associate the strings on the left with the rust function and arity on the right.
39define_semantics! {
40    SimpleVal;
41    "+" = (add, "int -> int -> int"),
42    "*" = (mul, "int -> int -> int"),
43    "map" = (map, "(t0 -> t1) -> (list t0) -> (list t1)"),
44    "sum" = (sum, "list int -> int"),
45    "0" = "int",
46    "1" = "int",
47    "2" = "int",
48    "[]" = "(list t0)",
49    //const "0" = Dom(Int(0)) //todo add support for constants
50}
51
52
53// From<Val> impls are needed for unwrapping values. We can assume the program
54// has been type checked so it's okay to panic if the type is wrong. Each val variant
55// must map to exactly one unwrapped type (though it doesnt need to be one to one in the
56// other direction)
57impl FromVal<SimpleVal> for i32 {
58    fn from_val(v: Val) -> Result<Self, VError> {
59        match v {
60            Dom(Int(i)) => Ok(i),
61            _ => Err("from_val_to_i32: not an int".into())
62        }
63    }
64}
65impl<T: FromVal<SimpleVal>> FromVal<SimpleVal> for Vec<T> {
66    fn from_val(v: Val) -> Result<Self, VError> {
67        match v {
68            Dom(List(v)) => v.into_iter().map(|v| T::from_val(v)).collect(),
69            _ => Err("from_val_to_vec: not a list".into())
70        }
71    }
72}
73
74// These Into<Val>s are convenience functions. It's okay if theres not a one to one mapping
75// like this in all domains - it just makes .into() save us a lot of work if there is.
76impl From<i32> for Val {
77    fn from(i: i32) -> Val {
78        Dom(Int(i))
79    }
80}
81impl<T: Into<Val>> From<Vec<T>> for Val {
82    fn from(vec: Vec<T>) -> Val {
83        Dom(List(vec.into_iter().map(|v| v.into()).collect()))
84    }
85}
86
87// here we actually implement Domain for our domain. 
88impl Domain for SimpleVal {
89    // we dont use Data here
90    type Data = ();
91    // type Type = SimpleType;
92
93    // val_of_prim takes a symbol like "+" or "0" and returns the corresponding Val.
94    // Note that it can largely just be a call to the global hashmap PRIMS that define_semantics generated
95    // however you're also free to do any sort of generic parsing you want, allowing for domains with
96    // infinite sets of values or dynamically generated values. For example here we support all integers
97    // and all integer lists.
98    fn val_of_prim_fallback(p: Symbol) -> Option<Val> {
99        // starts with digit -> Int
100        if p.as_str().chars().next().unwrap().is_ascii_digit() {
101            let i: i32 = p.as_str().parse().ok()?;
102            Some(Int(i).into())
103        }
104        // starts with `[` -> List (must be all ints)
105        else if p.as_str().starts_with('[') {
106            let intvec: Vec<i32> = serde_json::from_str(p.as_str()).ok()?;
107            let valvec: Vec<Val> = intvec.into_iter().map(|v|Dom(Int(v))).collect();
108            Some(List(valvec).into())
109        } else {
110            None
111        }
112    }
113
114    dsl_entries_lookup_gen!();
115
116    fn type_of_dom_val(&self) -> Type {
117        match self {
118            Int(_) => Type::base("int".into()),
119            List(xs) => {
120                let elem_tp = if xs.is_empty() {
121                    Type::Var(0) // (list t0)
122                } else {
123                    // todo here we just use the type of the first entry as the type
124                    Self::type_of_dom_val(&xs.first().unwrap().clone().dom().unwrap())
125                    // assert!(xs.iter().all(|v| Self::type_of_dom_val(v.clone().dom().unwrap())))
126                };
127                Type::Term("list".into(),vec![elem_tp])
128            },
129        }
130    }
131
132
133}
134
135
136// *** DSL FUNCTIONS ***
137// See comments throughout pointing out useful aspects
138
139fn add(mut args: Vec<LazyVal>, handle: &Evaluator) -> VResult {
140    // load_args! macro is used to extract the arguments from the args vector. This uses
141    // .into() to convert the Val into the appropriate type. For example an int list, which is written
142    // as  Dom(List(Vec<Dom(Int)>)), can be .into()'d into a Vec<i32> or a Vec<Val> or a Val.
143    load_args!(handle, args, x:i32, y:i32); 
144    // ok() is a convenience function that does Ok(v.into()) for you. It relies on your internal primitive types having a one
145    // to one mapping to Val variants like `Int <-> i32`. For any domain, the forward mapping `Int -> i32` is guaranteed, however
146    // depending on your implementation the reverse mapping `i32 -> Int` may not be. If that's the case you can manually construct
147    // the Val from the primitive type like Ok(Dom(Int(v))) for example. Alternatively you can get the .into() property by wrapping
148    // all your primitive types eg Int1 = struct(i32) and Int2 = struct(i32) etc for the various types that are i32 under the hood.
149    ok(x+y)
150}
151
152fn mul(mut args: Vec<LazyVal>, handle: &Evaluator) -> VResult {
153    load_args!(handle, args, x:i32, y:i32);
154    ok(x*y)
155}
156
157fn map(mut args: Vec<LazyVal>, handle: &Evaluator) -> VResult {
158    load_args!(handle, args, fn_val: Val, xs: Vec<Val>);
159    ok(xs.into_iter()
160        // sometimes you might want to apply a value that you know is a function to something else. In that
161        // case handle.apply(f: &Val, x: Val) is the way to go. `handle` mainly exists to allow for this, as well
162        // as to access handle.data (generic global data) which may be needed for implementation details of certain very complex domains
163        // but should largely be avoided.
164        .map(|x| handle.apply(&fn_val, x))  
165        // here we just turn a Vec<Result> into a Result<Vec> via .collect()'s casting - a handy trick that collapses
166        // all the results together into one (which is an Err if any of them was an Err).
167        .collect::<Result<Vec<Val>,_>>()?)
168}
169
170fn sum(mut args: Vec<LazyVal>, handle: &Evaluator) -> VResult {
171    load_args!(handle, args, xs: Vec<i32>);
172    ok(xs.iter().sum::<i32>())
173}
174
175
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn test_types_simple() {
183        use domains::simple::SimpleVal;
184
185        fn assert_unify(t1: &str, t2: &str, expected: UnifyResult) {
186            let mut ctx = Context::empty();
187            let res = ctx.unify(&t1.parse::<Type>().unwrap(),
188                        &t2.parse::<Type>().unwrap());
189            assert_eq!(res, expected);
190
191            let mut typeset = TypeSet::empty();
192            let t1 = typeset.add_tp(&t1.parse::<Type>().unwrap()).instantiate(&mut typeset);
193            let t2 = typeset.add_tp(&t2.parse::<Type>().unwrap()).instantiate(&mut typeset);
194            let res = typeset.unify(&t1,&t2);
195            assert_eq!(res, expected);
196        }
197
198        fn assert_infer(p: &str, expected: Result<&str, UnifyErr>) {
199            let res = p.parse::<Expr>().unwrap().infer::<SimpleVal>(None, &mut Context::empty(), &mut Default::default());
200            assert_eq!(res, expected.map(|ty| ty.parse::<Type>().unwrap()));
201        }
202
203        assert_unify("int", "int", Ok(()));
204        assert_unify("int", "t0", Ok(()));
205        assert_unify("int", "t1", Ok(()));
206        assert_unify("(list int)", "(list t1)", Ok(()));
207        assert_unify("(int -> bool)", "(int -> t0)", Ok(()));
208        assert_unify("t0", "t1", Ok(()));
209
210        assert_infer("3", Ok("int"));
211        assert_infer("[1,2,3]", Ok("list int"));
212        assert_infer("(+ 2 3)", Ok("int"));
213        assert_infer("(lam $0)", Ok("t0 -> t0"));
214        assert_infer("(lam (+ $0 1))", Ok("int -> int"));
215        assert_infer("map", Ok("((t0 -> t1) -> (list t0) -> (list t1))"));
216        assert_infer("(map (lam (+ $0 1)))", Ok("list int -> list int"));
217
218    }
219
220    #[test]
221    fn test_eval_simple() {
222
223        assert_execution::<domains::simple::SimpleVal, i32>("(+ 1 2)", &[], 3);
224
225        assert_execution::<domains::simple::SimpleVal, i32>("(sum (map (lam $0) []))", &[], 0);
226        
227
228        let arg = SimpleVal::val_of_prim("[1,2,3]".into()).unwrap();
229        assert_execution("(map (lam (+ 1 $0)) $0)", &[arg], vec![2,3,4]);
230
231        let arg = SimpleVal::val_of_prim("[1,2,3]".into()).unwrap();
232        assert_execution("(sum (map (lam (+ 1 $0)) $0))", &[arg], 9);
233
234        let arg = SimpleVal::val_of_prim("[1,2,3]".into()).unwrap();
235        assert_execution("(map (lam (* $0 $0)) (map (lam (+ 1 $0)) $0))", &[arg], vec![4,9,16]);
236
237        let arg = SimpleVal::val_of_prim("[1,2,3]".into()).unwrap();
238        assert_execution("(map (lam (* $0 $0)) (map (lam (+ (sum $1) $0)) $0))", &[arg], vec![49,64,81]);
239
240    }
241}