fetish_lib/
func_impl.rs

1extern crate ndarray;
2extern crate ndarray_linalg;
3
4use crate::type_id::*;
5use crate::interpreter_state::*;
6use crate::term_reference::*;
7use crate::term_application::*;
8
9use ndarray::*;
10use noisy_float::prelude::*;
11
12use std::cmp::*;
13use std::fmt::*;
14use std::hash::*;
15use crate::params::*;
16use crate::newly_evaluated_terms::*;
17
18///Trait which gives a "signature" for
19///functions to be included in a [`crate::primitive_directory::PrimitiveDirectory`].
20///This consists of a name, a collection of required argument types,
21///and a return type. Primitive functions are assumed to be 
22///uniquely identifiable from their particular implementation of this trait.
23pub trait HasFuncSignature {
24    ///Gets the name for the implemented function.
25    fn get_name(&self) -> String;
26    ///Gets the return type for the implemented function.
27    fn ret_type(&self) -> TypeId;
28    ///Gets the list of required argument types for the implemented function.
29    fn required_arg_types(&self) -> Vec::<TypeId>;
30
31    ///Given a collection of [`TermReference`]s, determines whether
32    ///we have sufficiently-many arguments to fully evaluate this function.
33    fn ready_to_evaluate(&self, args : &Vec::<TermReference>) -> bool {
34        let expected_num : usize =  self.required_arg_types().len();
35        expected_num == args.len()
36    }
37
38    ///Given a [`TypeInfoDirectory`], obtains the [`TypeId`] which
39    ///corresponds to the type of this implemented function.
40    fn func_type(&self, type_info_directory : &TypeInfoDirectory) -> TypeId {
41        let mut reverse_arg_types : Vec<TypeId> = self.required_arg_types();
42        reverse_arg_types.reverse();
43
44        let mut result : TypeId = self.ret_type();
45        for arg_type_id in reverse_arg_types.drain(..) {
46            result = type_info_directory.get_func_type_id(arg_type_id, result);
47        }
48        result
49    }
50}
51
52///Trait for primitive function implementations.
53pub trait FuncImpl : HasFuncSignature {
54    ///Given a handle on the current [`InterpreterState`] (primarily useful if additional terms
55    ///need to be evaluated / looked up) and the collection of [`TermReference`] arguments to
56    ///apply this function implementation to, yields a [`TermReference`] to the result, along
57    ///with a collection of any `NewlyEvaluatedTerms` which may have arisen as part of the
58    ///implementation of this method. See `func_impl.rs` in the source for sample implementations.
59    fn evaluate(&self, state : &mut InterpreterState, args : Vec::<TermReference>)
60                -> (TermReference, NewlyEvaluatedTerms);
61}
62
63impl PartialEq for dyn FuncImpl + '_ {
64    fn eq(&self, other : &Self) -> bool {
65        self.get_name() == other.get_name() &&
66        self.required_arg_types() == other.required_arg_types() &&
67        self.ret_type() == other.ret_type()
68    }
69}
70
71impl Eq for dyn FuncImpl + '_ {}
72
73impl Hash for dyn FuncImpl + '_ {
74    fn hash<H : Hasher>(&self, state : &mut H) {
75        self.get_name().hash(state);
76        self.required_arg_types().hash(state);
77        self.ret_type().hash(state);
78    }
79}
80
81///Trait to ease implementation of primitive binary operators which have identical argument types
82///and return type. To be used in tandem with [`BinaryFuncImpl`].
83pub trait BinaryArrayOperator {
84    ///Given two arrays of equal dimension, act to yield an array of the same number of dimensions.
85    fn act(&self, arg_one : ArrayView1::<R32>, arg_two : ArrayView1::<R32>) -> Array1::<R32>;
86    ///Gets the name of this binary operator
87    fn get_name(&self) -> String;
88}
89
90impl PartialEq for dyn BinaryArrayOperator + '_ {
91    fn eq(&self, other : &Self) -> bool {
92        self.get_name() == other.get_name()
93    }
94}
95
96impl Eq for dyn BinaryArrayOperator + '_ {}
97
98impl Hash for dyn BinaryArrayOperator + '_ {
99    fn hash<H : Hasher>(&self, state : &mut H) {
100        self.get_name().hash(state);
101    }
102}
103
104///[`BinaryArrayOperator`] for vector addition.
105pub struct AddOperator {
106}
107
108impl BinaryArrayOperator for AddOperator {
109    fn act(&self, arg_one : ArrayView1::<R32>, arg_two : ArrayView1::<R32>) -> Array1::<R32> {
110        &arg_one + &arg_two
111    }
112    fn get_name(&self) -> String {
113        String::from("+")
114    }
115}
116
117///[`BinaryArrayOperator`] for vector subtraction.
118pub struct SubOperator {
119}
120
121impl BinaryArrayOperator for SubOperator {
122    fn act(&self, arg_one : ArrayView1::<R32>, arg_two : ArrayView1::<R32>) -> Array1::<R32> {
123        &arg_one - &arg_two
124    }
125    fn get_name(&self) -> String {
126        String::from("-")
127    }
128}
129
130///[`BinaryArrayOperator`] for elementwise vector multiplication.
131pub struct MulOperator {
132}
133
134impl BinaryArrayOperator for MulOperator {
135    fn act(&self, arg_one : ArrayView1::<R32>, arg_two : ArrayView1::<R32>) -> Array1::<R32> {
136        &arg_one * &arg_two
137    }
138    fn get_name(&self) -> String {
139        String::from("*")
140    }
141}
142
143///Wrapper around a [`BinaryArrayOperator`] to conveniently lift it to a [`FuncImpl`]
144///given the [`TypeId`] of the argument/return type.
145pub struct BinaryFuncImpl {
146    pub elem_type : TypeId,
147    pub f : Box<dyn BinaryArrayOperator>
148}
149
150impl HasFuncSignature for BinaryFuncImpl {
151    fn get_name(&self) -> String {
152        self.f.get_name()
153    }
154    fn required_arg_types(&self) -> Vec<TypeId> {
155        vec![self.elem_type, self.elem_type]
156    }
157    fn ret_type(&self) -> TypeId {
158        self.elem_type
159    }
160}
161
162impl FuncImpl for BinaryFuncImpl {
163    fn evaluate(&self, _state : &mut InterpreterState, args : Vec::<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
164        if let TermReference::VecRef(_, arg_one_vec) = &args[0] {
165            if let TermReference::VecRef(_, arg_two_vec) = &args[1] {
166                let result_vec = self.f.act(arg_one_vec.view(), arg_two_vec.view());
167                let result_ref = TermReference::VecRef(self.elem_type, result_vec);
168                (result_ref, NewlyEvaluatedTerms::new())
169            } else {
170                panic!();
171            }
172        } else {
173            panic!();
174        }
175    }
176}
177
178///Implementation of a "rotate left one index" [`FuncImpl`] for a given vector [`TypeId`].
179///(That is, given `[x_1, x_2, ...]`, rotates to `[x_2, x_2, ... x_1]`.
180#[derive(Clone)]
181pub struct RotateImpl {
182    pub vector_type : TypeId
183}
184
185impl HasFuncSignature for RotateImpl {
186    fn get_name(&self) -> String {
187        String::from("rotate")
188    }
189    fn required_arg_types(&self) -> Vec<TypeId> {
190        vec![self.vector_type]
191    }
192    fn ret_type(&self) -> TypeId {
193        self.vector_type
194    }
195}
196
197impl FuncImpl for RotateImpl {
198    fn evaluate(&self, _state : &mut InterpreterState, args : Vec<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
199        if let TermReference::VecRef(vector_type, arg_vec) = &args[0] {
200            let n = arg_vec.len();
201            let arg_vec_head : R32 = arg_vec[[0,]];
202            let mut result_vec : Array1::<R32> = Array::from_elem((n,), arg_vec_head);
203            for i in 1..n {
204                result_vec[[i-1,]] = arg_vec[[i,]];
205            }
206            let result : TermReference = TermReference::VecRef(*vector_type, result_vec);
207            (result, NewlyEvaluatedTerms::new())
208        } else {
209            panic!();
210        }
211    }
212}
213
214///Implementation of a "set the first element of a vector to the given one" [`FuncImpl`] for the given
215///vector and scalar types. 
216#[derive(Clone)]
217pub struct SetHeadImpl {
218    pub vector_type : TypeId,
219    pub scalar_type : TypeId
220}
221
222impl HasFuncSignature for SetHeadImpl {
223    fn get_name(&self) -> String {
224        String::from("setHead")
225    }
226    fn required_arg_types(&self) -> Vec<TypeId> {
227        vec![self.vector_type, self.scalar_type]
228    }
229    fn ret_type(&self) -> TypeId {
230        self.vector_type
231    }
232}
233impl FuncImpl for SetHeadImpl {
234    fn evaluate(&self, _state : &mut InterpreterState, args : Vec<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
235        if let TermReference::VecRef(vector_type, arg_vec) = &args[0] {
236            if let TermReference::VecRef(_, val_vec) = &args[1] {
237                let val : R32 = val_vec[[0,]];
238                let mut result_vec : Array1<R32> = arg_vec.clone();
239                result_vec[[0,]] = val;
240                let result = TermReference::VecRef(*vector_type, result_vec);
241                (result, NewlyEvaluatedTerms::new())
242            } else {
243                panic!();
244            }
245        } else {
246            panic!();
247        }
248    }
249}
250
251///Implementation of a "get the first element of a vector" [`FuncImpl`] for the given vector
252///and scalar types.
253#[derive(Clone)]
254pub struct HeadImpl {
255    pub vector_type : TypeId,
256    pub scalar_type : TypeId
257}
258
259impl HasFuncSignature for HeadImpl {
260    fn get_name(&self) -> String {
261        String::from("head")
262    }
263    fn required_arg_types(&self) -> Vec<TypeId> {
264        vec![self.vector_type]
265    }
266    fn ret_type(&self) -> TypeId {
267        self.scalar_type
268    }
269}
270impl FuncImpl for HeadImpl {
271    fn evaluate(&self, _state : &mut InterpreterState, args : Vec<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
272        if let TermReference::VecRef(_, arg_vec) = &args[0] {
273            let ret_val : R32 = arg_vec[[0,]];
274            let result_array : Array1::<R32> = Array::from_elem((1,), ret_val);
275
276            let result = TermReference::VecRef(self.scalar_type, result_array);
277            (result, NewlyEvaluatedTerms::new())
278        } else {
279            panic!();
280        }
281    }
282}
283
284///Implementation of a "function composition" [`FuncImpl`]
285#[derive(Clone)]
286pub struct ComposeImpl {
287    in_type : TypeId,
288    middle_type : TypeId,
289    func_one : TypeId,
290    func_two : TypeId,
291    ret_type : TypeId
292}
293
294impl ComposeImpl {
295    ///Given a [`TypeInfoDirectory`], the input type, a middle type, and a return type,
296    ///yields a [`ComposeImpl`] of type `(in -> middle) -> (middle -> return) -> (in -> return)`.
297    pub fn new(type_info_directory : &TypeInfoDirectory, 
298               in_type : TypeId, middle_type : TypeId, ret_type : TypeId) -> ComposeImpl {
299        let func_one : TypeId = type_info_directory.get_func_type_id(middle_type, ret_type);
300        let func_two : TypeId = type_info_directory.get_func_type_id(in_type, middle_type);
301        ComposeImpl {
302            in_type,
303            middle_type,
304            func_one,
305            func_two,
306            ret_type
307        }
308    }
309}
310
311impl HasFuncSignature for ComposeImpl {
312    fn get_name(&self) -> String {
313        String::from("compose")
314    }
315    fn required_arg_types(&self) -> Vec<TypeId> {
316        vec![self.func_one, self.func_two, self.in_type]
317    }
318    fn ret_type(&self) -> TypeId {
319        self.ret_type
320    }
321}
322
323impl FuncImpl for ComposeImpl {
324    fn evaluate(&self, state : &mut InterpreterState, args : Vec<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
325        if let TermReference::FuncRef(func_one) = &args[0] {
326            if let TermReference::FuncRef(func_two) = &args[1] {
327                let arg : TermReference = args[2].clone();
328                let application_one = TermApplication {
329                    func_ptr : func_two.clone(),
330                    arg_ref : arg
331                };
332                let (middle_ref, mut newly_evaluated_terms) = state.evaluate(&application_one);
333                let application_two = TermApplication {
334                    func_ptr : func_one.clone(),
335                    arg_ref : middle_ref
336                };
337                let (final_ref, more_evaluated_terms) = state.evaluate(&application_two);
338                newly_evaluated_terms.merge(more_evaluated_terms);
339                (final_ref, newly_evaluated_terms)
340            } else {
341                panic!();
342            }
343        } else {
344            panic!();
345        }
346    }
347}
348
349///Implementation of a "Fill a vector with the given scalar" [`FuncImpl`] for the given
350///scalar and vector [`TypeId`]s.
351#[derive(Clone)]
352pub struct FillImpl {
353    pub scalar_type : TypeId,
354    pub vector_type : TypeId
355}
356
357impl HasFuncSignature for FillImpl {
358    fn get_name(&self) -> String {
359        String::from("fill")
360    }
361    fn required_arg_types(&self) -> Vec<TypeId> {
362        vec![self.scalar_type]
363    }
364    fn ret_type(&self) -> TypeId {
365        self.vector_type
366    }
367}
368impl FuncImpl for FillImpl {
369    fn evaluate(&self, state : &mut InterpreterState, args : Vec<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
370        let dim = state.get_context().get_dimension(self.vector_type);
371
372        if let TermReference::VecRef(_, arg_vec) = &args[0] {
373            let arg_val : R32 = arg_vec[[0,]];
374            let ret_val : Array1::<R32> = Array::from_elem((dim,), arg_val);
375
376            let result = TermReference::VecRef(self.vector_type, ret_val);
377            (result, NewlyEvaluatedTerms::new())
378        } else {
379            panic!();
380        }
381    }
382}
383
384///Implementation of the constant function for the given "return" and "ignored" types.
385///The result is of type `return -> ignored -> return`.
386#[derive(Clone)]
387pub struct ConstImpl {
388    pub ret_type : TypeId,
389    pub ignored_type : TypeId
390}
391
392impl HasFuncSignature for ConstImpl {
393    fn get_name(&self) -> String {
394        String::from("const")
395    }
396    fn required_arg_types(&self) -> Vec<TypeId> {
397        vec![self.ret_type.clone(), self.ignored_type.clone()]
398    }
399    fn ret_type(&self) -> TypeId {
400        self.ret_type.clone()
401    }
402}
403impl FuncImpl for ConstImpl {
404    fn evaluate(&self, _state : &mut InterpreterState, args : Vec::<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
405        let result_ptr : TermReference = args[0].clone();
406        (result_ptr, NewlyEvaluatedTerms::new())
407    }
408}
409
410///Implementation of a "reduce this vector by this binary operator to yield a scalar"
411///[`FuncImpl`] for the given [`TypeId`]s of the binary scalar operator, the scalar type,
412///and the vector type.
413#[derive(Clone)]
414pub struct ReduceImpl {
415    pub binary_scalar_func_type : TypeId,
416    pub scalar_type : TypeId,
417    pub vector_type : TypeId
418}
419
420impl HasFuncSignature for ReduceImpl {
421    fn get_name(&self) -> String {
422        String::from("reduce")
423    }
424    fn required_arg_types(&self) -> Vec<TypeId> {
425        vec![self.binary_scalar_func_type, self.scalar_type, self.vector_type]
426    }
427    fn ret_type(&self) -> TypeId {
428        self.scalar_type
429    }
430}
431
432impl FuncImpl for ReduceImpl {
433    fn evaluate(&self, state : &mut InterpreterState, args : Vec<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
434        let dim = state.get_context().get_dimension(self.vector_type);
435
436        let mut newly_evaluated_terms = NewlyEvaluatedTerms::new();
437        let mut accum_ref : TermReference = args[1].clone();
438        if let TermReference::FuncRef(func_ptr) = &args[0] {
439            if let TermReference::VecRef(_, vec) = &args[2] {
440                for i in 0..dim {
441                    //First, put the scalar term at this position into a term ref
442                    let val : R32 = vec[[i,]];
443                    let val_vec : Array1::<R32> = Array::from_elem((1,), val);
444                    let val_ref = TermReference::VecRef(self.scalar_type, val_vec);
445                     
446                    let term_app_one = TermApplication {
447                        func_ptr : func_ptr.clone(),
448                        arg_ref : val_ref
449                    };
450                    let (curry_ref, more_evaluated_terms) = state.evaluate(&term_app_one);
451                    newly_evaluated_terms.merge(more_evaluated_terms);
452
453                    if let TermReference::FuncRef(curry_ptr) = curry_ref {
454                        let term_app_two = TermApplication {
455                            func_ptr : curry_ptr,
456                            arg_ref : accum_ref
457                        };
458                        let (result_ref, more_evaluated_terms) = state.evaluate(&term_app_two);
459                        newly_evaluated_terms.merge(more_evaluated_terms);
460                        accum_ref = result_ref;
461                    } else {
462                        panic!();
463                    }
464                }
465
466                (accum_ref, newly_evaluated_terms)
467            } else {
468                panic!();
469            }
470        } else {
471            panic!();
472        }
473    }
474}
475
476///Implementation of a "map this scalar function over every element of a vector"
477///[`FuncImpl`] for the given scalar type, scalar function type, and vector type.
478#[derive(Clone)]
479pub struct MapImpl {
480    pub scalar_type : TypeId,
481    pub unary_scalar_func_type : TypeId,
482    pub vector_type : TypeId
483}
484
485impl HasFuncSignature for MapImpl {
486    fn get_name(&self) -> String {
487        String::from("map")
488    }
489    fn required_arg_types(&self) -> Vec<TypeId> {
490        vec![self.unary_scalar_func_type, self.vector_type]
491    }
492    fn ret_type(&self) -> TypeId {
493        self.vector_type
494    }
495}
496
497impl FuncImpl for MapImpl {
498    fn evaluate(&self, state : &mut InterpreterState, args : Vec::<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
499        if let TermReference::FuncRef(func_ptr) = &args[0] {
500            if let TermReference::VecRef(_, arg_vec) = &args[1] {
501                let n = arg_vec.len();
502                let mut newly_evaluated_terms = NewlyEvaluatedTerms::new();
503                let mut result : Array1<R32> = Array::from_elem((n,), R32::new(0.0)); 
504                for i in 0..n {
505                    let boxed_scalar : Array1<R32> = Array::from_elem((1,), arg_vec[i]);
506                    let arg_ref = TermReference::VecRef(self.scalar_type, boxed_scalar);
507
508                    let term_app = TermApplication {
509                        func_ptr : func_ptr.clone(),
510                        arg_ref : arg_ref
511                    };
512                    let (result_ref, more_evaluated_terms) = state.evaluate(&term_app);
513                    newly_evaluated_terms.merge(more_evaluated_terms);
514                    if let TermReference::VecRef(_, result_scalar_vec) = result_ref {
515                        result[[i,]] = result_scalar_vec[[0,]];
516                    }
517                }
518                let result_ref = TermReference::VecRef(self.vector_type, result);
519                (result_ref, newly_evaluated_terms)
520            } else {
521                panic!();
522            }
523        } else {
524            panic!();
525        }
526
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use crate::test_utils::*;
534    use crate::array_utils::*;
535
536    fn term_ref(in_array : Array1<f32>) -> TermReference {
537        let noisy_array = to_noisy(in_array.view());
538        if (in_array.shape()[0] == 1) {
539            TermReference::VecRef(TEST_SCALAR_T, noisy_array)
540        } else {
541            TermReference::VecRef(TEST_VECTOR_T, noisy_array)
542        }
543    }
544
545    #[test]
546    fn test_addition() {
547        let ctxt = get_test_vector_only_context();
548        let mut state = InterpreterState::new(&ctxt);
549        let args = vec![term_ref(array![1.0f32, 2.0f32]), term_ref(array![3.0f32, 4.0f32])];
550
551        let addition_func = BinaryFuncImpl {
552            elem_type : TEST_VECTOR_T,
553            f : Box::new(AddOperator {})
554        };
555
556        let (result, _) = addition_func.evaluate(&mut state, args);
557        assert_equal_vector_term(result, array![4.0f32, 6.0f32].view());
558    }
559    #[test]
560    fn test_rotate() {
561        let ctxt = get_test_vector_only_context();
562        let mut state = InterpreterState::new(&ctxt);
563        let args = vec![term_ref(array![5.0f32, 10.0f32])];
564
565        let rotate_func = RotateImpl {
566            vector_type : TEST_VECTOR_T
567        };
568
569        let (result, _) = rotate_func.evaluate(&mut state, args);
570        assert_equal_vector_term(result, array![10.0f32, 5.0f32].view());
571    }
572
573    #[test]
574    fn test_set_head() {
575        let ctxt = get_test_vector_only_context();
576        let mut state = InterpreterState::new(&ctxt);
577        let args = vec![term_ref(array![1.0f32, 2.0f32]), term_ref(array![9.0f32])];
578
579        let set_head_func = SetHeadImpl {
580            vector_type : TEST_VECTOR_T,
581            scalar_type : TEST_SCALAR_T
582        };
583
584        let (result, _) = set_head_func.evaluate(&mut state, args);
585        assert_equal_vector_term(result, array![9.0f32, 2.0f32].view());
586    }
587
588    #[test]
589    fn test_head() {
590        let ctxt = get_test_vector_only_context();
591        let mut state = InterpreterState::new(&ctxt);
592
593        let args = vec![term_ref(array![1.0f32, 2.0f32])];
594        
595        let head_func = HeadImpl {
596            vector_type : TEST_VECTOR_T,
597            scalar_type : TEST_SCALAR_T
598        };
599        
600        let (result, _) = head_func.evaluate(&mut state, args);
601        assert_equal_vector_term(result, array![1.0f32].view());
602    }
603
604    #[test]
605    fn test_fill() {
606        let ctxt = get_test_vector_only_context();
607        let mut state = InterpreterState::new(&ctxt);
608        let args = vec![term_ref(array![3.0f32])];
609
610        let fill_func = FillImpl {
611            vector_type : TEST_VECTOR_T,
612            scalar_type : TEST_SCALAR_T
613        };
614
615        let (result, _) = fill_func.evaluate(&mut state, args);
616        assert_equal_vector_term(result, array![3.0f32, 3.0f32].view());
617    }
618    
619    #[test]
620    fn test_const() {
621        let ctxt = get_test_vector_only_context();
622        let mut state = InterpreterState::new(&ctxt);
623
624        let args = vec![term_ref(array![1.0f32, 2.0f32]), term_ref(array![3.0f32])];
625
626        let const_func = ConstImpl {
627            ret_type : TEST_VECTOR_T,
628            ignored_type : TEST_SCALAR_T
629        };
630
631        let (result, _) = const_func.evaluate(&mut state, args);
632        assert_equal_vector_term(result, array![1.0f32, 2.0f32].view());
633    }
634
635}