1extern crate ndarray;
2extern crate ndarray_linalg;
3
4use crate::type_id::*;
5use crate::interpreter_state::*;
6use crate::term_reference::*;
7use crate::term_application::*;
8
9use ndarray::*;
10use noisy_float::prelude::*;
11
12use std::cmp::*;
13use std::fmt::*;
14use std::hash::*;
15use crate::params::*;
16use crate::newly_evaluated_terms::*;
17
18pub trait HasFuncSignature {
24 fn get_name(&self) -> String;
26 fn ret_type(&self) -> TypeId;
28 fn required_arg_types(&self) -> Vec::<TypeId>;
30
31 fn ready_to_evaluate(&self, args : &Vec::<TermReference>) -> bool {
34 let expected_num : usize = self.required_arg_types().len();
35 expected_num == args.len()
36 }
37
38 fn func_type(&self, type_info_directory : &TypeInfoDirectory) -> TypeId {
41 let mut reverse_arg_types : Vec<TypeId> = self.required_arg_types();
42 reverse_arg_types.reverse();
43
44 let mut result : TypeId = self.ret_type();
45 for arg_type_id in reverse_arg_types.drain(..) {
46 result = type_info_directory.get_func_type_id(arg_type_id, result);
47 }
48 result
49 }
50}
51
52pub trait FuncImpl : HasFuncSignature {
54 fn evaluate(&self, state : &mut InterpreterState, args : Vec::<TermReference>)
60 -> (TermReference, NewlyEvaluatedTerms);
61}
62
63impl PartialEq for dyn FuncImpl + '_ {
64 fn eq(&self, other : &Self) -> bool {
65 self.get_name() == other.get_name() &&
66 self.required_arg_types() == other.required_arg_types() &&
67 self.ret_type() == other.ret_type()
68 }
69}
70
71impl Eq for dyn FuncImpl + '_ {}
72
73impl Hash for dyn FuncImpl + '_ {
74 fn hash<H : Hasher>(&self, state : &mut H) {
75 self.get_name().hash(state);
76 self.required_arg_types().hash(state);
77 self.ret_type().hash(state);
78 }
79}
80
81pub trait BinaryArrayOperator {
84 fn act(&self, arg_one : ArrayView1::<R32>, arg_two : ArrayView1::<R32>) -> Array1::<R32>;
86 fn get_name(&self) -> String;
88}
89
90impl PartialEq for dyn BinaryArrayOperator + '_ {
91 fn eq(&self, other : &Self) -> bool {
92 self.get_name() == other.get_name()
93 }
94}
95
96impl Eq for dyn BinaryArrayOperator + '_ {}
97
98impl Hash for dyn BinaryArrayOperator + '_ {
99 fn hash<H : Hasher>(&self, state : &mut H) {
100 self.get_name().hash(state);
101 }
102}
103
104pub struct AddOperator {
106}
107
108impl BinaryArrayOperator for AddOperator {
109 fn act(&self, arg_one : ArrayView1::<R32>, arg_two : ArrayView1::<R32>) -> Array1::<R32> {
110 &arg_one + &arg_two
111 }
112 fn get_name(&self) -> String {
113 String::from("+")
114 }
115}
116
117pub struct SubOperator {
119}
120
121impl BinaryArrayOperator for SubOperator {
122 fn act(&self, arg_one : ArrayView1::<R32>, arg_two : ArrayView1::<R32>) -> Array1::<R32> {
123 &arg_one - &arg_two
124 }
125 fn get_name(&self) -> String {
126 String::from("-")
127 }
128}
129
130pub struct MulOperator {
132}
133
134impl BinaryArrayOperator for MulOperator {
135 fn act(&self, arg_one : ArrayView1::<R32>, arg_two : ArrayView1::<R32>) -> Array1::<R32> {
136 &arg_one * &arg_two
137 }
138 fn get_name(&self) -> String {
139 String::from("*")
140 }
141}
142
143pub struct BinaryFuncImpl {
146 pub elem_type : TypeId,
147 pub f : Box<dyn BinaryArrayOperator>
148}
149
150impl HasFuncSignature for BinaryFuncImpl {
151 fn get_name(&self) -> String {
152 self.f.get_name()
153 }
154 fn required_arg_types(&self) -> Vec<TypeId> {
155 vec![self.elem_type, self.elem_type]
156 }
157 fn ret_type(&self) -> TypeId {
158 self.elem_type
159 }
160}
161
162impl FuncImpl for BinaryFuncImpl {
163 fn evaluate(&self, _state : &mut InterpreterState, args : Vec::<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
164 if let TermReference::VecRef(_, arg_one_vec) = &args[0] {
165 if let TermReference::VecRef(_, arg_two_vec) = &args[1] {
166 let result_vec = self.f.act(arg_one_vec.view(), arg_two_vec.view());
167 let result_ref = TermReference::VecRef(self.elem_type, result_vec);
168 (result_ref, NewlyEvaluatedTerms::new())
169 } else {
170 panic!();
171 }
172 } else {
173 panic!();
174 }
175 }
176}
177
178#[derive(Clone)]
181pub struct RotateImpl {
182 pub vector_type : TypeId
183}
184
185impl HasFuncSignature for RotateImpl {
186 fn get_name(&self) -> String {
187 String::from("rotate")
188 }
189 fn required_arg_types(&self) -> Vec<TypeId> {
190 vec![self.vector_type]
191 }
192 fn ret_type(&self) -> TypeId {
193 self.vector_type
194 }
195}
196
197impl FuncImpl for RotateImpl {
198 fn evaluate(&self, _state : &mut InterpreterState, args : Vec<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
199 if let TermReference::VecRef(vector_type, arg_vec) = &args[0] {
200 let n = arg_vec.len();
201 let arg_vec_head : R32 = arg_vec[[0,]];
202 let mut result_vec : Array1::<R32> = Array::from_elem((n,), arg_vec_head);
203 for i in 1..n {
204 result_vec[[i-1,]] = arg_vec[[i,]];
205 }
206 let result : TermReference = TermReference::VecRef(*vector_type, result_vec);
207 (result, NewlyEvaluatedTerms::new())
208 } else {
209 panic!();
210 }
211 }
212}
213
214#[derive(Clone)]
217pub struct SetHeadImpl {
218 pub vector_type : TypeId,
219 pub scalar_type : TypeId
220}
221
222impl HasFuncSignature for SetHeadImpl {
223 fn get_name(&self) -> String {
224 String::from("setHead")
225 }
226 fn required_arg_types(&self) -> Vec<TypeId> {
227 vec![self.vector_type, self.scalar_type]
228 }
229 fn ret_type(&self) -> TypeId {
230 self.vector_type
231 }
232}
233impl FuncImpl for SetHeadImpl {
234 fn evaluate(&self, _state : &mut InterpreterState, args : Vec<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
235 if let TermReference::VecRef(vector_type, arg_vec) = &args[0] {
236 if let TermReference::VecRef(_, val_vec) = &args[1] {
237 let val : R32 = val_vec[[0,]];
238 let mut result_vec : Array1<R32> = arg_vec.clone();
239 result_vec[[0,]] = val;
240 let result = TermReference::VecRef(*vector_type, result_vec);
241 (result, NewlyEvaluatedTerms::new())
242 } else {
243 panic!();
244 }
245 } else {
246 panic!();
247 }
248 }
249}
250
251#[derive(Clone)]
254pub struct HeadImpl {
255 pub vector_type : TypeId,
256 pub scalar_type : TypeId
257}
258
259impl HasFuncSignature for HeadImpl {
260 fn get_name(&self) -> String {
261 String::from("head")
262 }
263 fn required_arg_types(&self) -> Vec<TypeId> {
264 vec![self.vector_type]
265 }
266 fn ret_type(&self) -> TypeId {
267 self.scalar_type
268 }
269}
270impl FuncImpl for HeadImpl {
271 fn evaluate(&self, _state : &mut InterpreterState, args : Vec<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
272 if let TermReference::VecRef(_, arg_vec) = &args[0] {
273 let ret_val : R32 = arg_vec[[0,]];
274 let result_array : Array1::<R32> = Array::from_elem((1,), ret_val);
275
276 let result = TermReference::VecRef(self.scalar_type, result_array);
277 (result, NewlyEvaluatedTerms::new())
278 } else {
279 panic!();
280 }
281 }
282}
283
284#[derive(Clone)]
286pub struct ComposeImpl {
287 in_type : TypeId,
288 middle_type : TypeId,
289 func_one : TypeId,
290 func_two : TypeId,
291 ret_type : TypeId
292}
293
294impl ComposeImpl {
295 pub fn new(type_info_directory : &TypeInfoDirectory,
298 in_type : TypeId, middle_type : TypeId, ret_type : TypeId) -> ComposeImpl {
299 let func_one : TypeId = type_info_directory.get_func_type_id(middle_type, ret_type);
300 let func_two : TypeId = type_info_directory.get_func_type_id(in_type, middle_type);
301 ComposeImpl {
302 in_type,
303 middle_type,
304 func_one,
305 func_two,
306 ret_type
307 }
308 }
309}
310
311impl HasFuncSignature for ComposeImpl {
312 fn get_name(&self) -> String {
313 String::from("compose")
314 }
315 fn required_arg_types(&self) -> Vec<TypeId> {
316 vec![self.func_one, self.func_two, self.in_type]
317 }
318 fn ret_type(&self) -> TypeId {
319 self.ret_type
320 }
321}
322
323impl FuncImpl for ComposeImpl {
324 fn evaluate(&self, state : &mut InterpreterState, args : Vec<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
325 if let TermReference::FuncRef(func_one) = &args[0] {
326 if let TermReference::FuncRef(func_two) = &args[1] {
327 let arg : TermReference = args[2].clone();
328 let application_one = TermApplication {
329 func_ptr : func_two.clone(),
330 arg_ref : arg
331 };
332 let (middle_ref, mut newly_evaluated_terms) = state.evaluate(&application_one);
333 let application_two = TermApplication {
334 func_ptr : func_one.clone(),
335 arg_ref : middle_ref
336 };
337 let (final_ref, more_evaluated_terms) = state.evaluate(&application_two);
338 newly_evaluated_terms.merge(more_evaluated_terms);
339 (final_ref, newly_evaluated_terms)
340 } else {
341 panic!();
342 }
343 } else {
344 panic!();
345 }
346 }
347}
348
349#[derive(Clone)]
352pub struct FillImpl {
353 pub scalar_type : TypeId,
354 pub vector_type : TypeId
355}
356
357impl HasFuncSignature for FillImpl {
358 fn get_name(&self) -> String {
359 String::from("fill")
360 }
361 fn required_arg_types(&self) -> Vec<TypeId> {
362 vec![self.scalar_type]
363 }
364 fn ret_type(&self) -> TypeId {
365 self.vector_type
366 }
367}
368impl FuncImpl for FillImpl {
369 fn evaluate(&self, state : &mut InterpreterState, args : Vec<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
370 let dim = state.get_context().get_dimension(self.vector_type);
371
372 if let TermReference::VecRef(_, arg_vec) = &args[0] {
373 let arg_val : R32 = arg_vec[[0,]];
374 let ret_val : Array1::<R32> = Array::from_elem((dim,), arg_val);
375
376 let result = TermReference::VecRef(self.vector_type, ret_val);
377 (result, NewlyEvaluatedTerms::new())
378 } else {
379 panic!();
380 }
381 }
382}
383
384#[derive(Clone)]
387pub struct ConstImpl {
388 pub ret_type : TypeId,
389 pub ignored_type : TypeId
390}
391
392impl HasFuncSignature for ConstImpl {
393 fn get_name(&self) -> String {
394 String::from("const")
395 }
396 fn required_arg_types(&self) -> Vec<TypeId> {
397 vec![self.ret_type.clone(), self.ignored_type.clone()]
398 }
399 fn ret_type(&self) -> TypeId {
400 self.ret_type.clone()
401 }
402}
403impl FuncImpl for ConstImpl {
404 fn evaluate(&self, _state : &mut InterpreterState, args : Vec::<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
405 let result_ptr : TermReference = args[0].clone();
406 (result_ptr, NewlyEvaluatedTerms::new())
407 }
408}
409
410#[derive(Clone)]
414pub struct ReduceImpl {
415 pub binary_scalar_func_type : TypeId,
416 pub scalar_type : TypeId,
417 pub vector_type : TypeId
418}
419
420impl HasFuncSignature for ReduceImpl {
421 fn get_name(&self) -> String {
422 String::from("reduce")
423 }
424 fn required_arg_types(&self) -> Vec<TypeId> {
425 vec![self.binary_scalar_func_type, self.scalar_type, self.vector_type]
426 }
427 fn ret_type(&self) -> TypeId {
428 self.scalar_type
429 }
430}
431
432impl FuncImpl for ReduceImpl {
433 fn evaluate(&self, state : &mut InterpreterState, args : Vec<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
434 let dim = state.get_context().get_dimension(self.vector_type);
435
436 let mut newly_evaluated_terms = NewlyEvaluatedTerms::new();
437 let mut accum_ref : TermReference = args[1].clone();
438 if let TermReference::FuncRef(func_ptr) = &args[0] {
439 if let TermReference::VecRef(_, vec) = &args[2] {
440 for i in 0..dim {
441 let val : R32 = vec[[i,]];
443 let val_vec : Array1::<R32> = Array::from_elem((1,), val);
444 let val_ref = TermReference::VecRef(self.scalar_type, val_vec);
445
446 let term_app_one = TermApplication {
447 func_ptr : func_ptr.clone(),
448 arg_ref : val_ref
449 };
450 let (curry_ref, more_evaluated_terms) = state.evaluate(&term_app_one);
451 newly_evaluated_terms.merge(more_evaluated_terms);
452
453 if let TermReference::FuncRef(curry_ptr) = curry_ref {
454 let term_app_two = TermApplication {
455 func_ptr : curry_ptr,
456 arg_ref : accum_ref
457 };
458 let (result_ref, more_evaluated_terms) = state.evaluate(&term_app_two);
459 newly_evaluated_terms.merge(more_evaluated_terms);
460 accum_ref = result_ref;
461 } else {
462 panic!();
463 }
464 }
465
466 (accum_ref, newly_evaluated_terms)
467 } else {
468 panic!();
469 }
470 } else {
471 panic!();
472 }
473 }
474}
475
476#[derive(Clone)]
479pub struct MapImpl {
480 pub scalar_type : TypeId,
481 pub unary_scalar_func_type : TypeId,
482 pub vector_type : TypeId
483}
484
485impl HasFuncSignature for MapImpl {
486 fn get_name(&self) -> String {
487 String::from("map")
488 }
489 fn required_arg_types(&self) -> Vec<TypeId> {
490 vec![self.unary_scalar_func_type, self.vector_type]
491 }
492 fn ret_type(&self) -> TypeId {
493 self.vector_type
494 }
495}
496
497impl FuncImpl for MapImpl {
498 fn evaluate(&self, state : &mut InterpreterState, args : Vec::<TermReference>) -> (TermReference, NewlyEvaluatedTerms) {
499 if let TermReference::FuncRef(func_ptr) = &args[0] {
500 if let TermReference::VecRef(_, arg_vec) = &args[1] {
501 let n = arg_vec.len();
502 let mut newly_evaluated_terms = NewlyEvaluatedTerms::new();
503 let mut result : Array1<R32> = Array::from_elem((n,), R32::new(0.0));
504 for i in 0..n {
505 let boxed_scalar : Array1<R32> = Array::from_elem((1,), arg_vec[i]);
506 let arg_ref = TermReference::VecRef(self.scalar_type, boxed_scalar);
507
508 let term_app = TermApplication {
509 func_ptr : func_ptr.clone(),
510 arg_ref : arg_ref
511 };
512 let (result_ref, more_evaluated_terms) = state.evaluate(&term_app);
513 newly_evaluated_terms.merge(more_evaluated_terms);
514 if let TermReference::VecRef(_, result_scalar_vec) = result_ref {
515 result[[i,]] = result_scalar_vec[[0,]];
516 }
517 }
518 let result_ref = TermReference::VecRef(self.vector_type, result);
519 (result_ref, newly_evaluated_terms)
520 } else {
521 panic!();
522 }
523 } else {
524 panic!();
525 }
526
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use crate::test_utils::*;
534 use crate::array_utils::*;
535
536 fn term_ref(in_array : Array1<f32>) -> TermReference {
537 let noisy_array = to_noisy(in_array.view());
538 if (in_array.shape()[0] == 1) {
539 TermReference::VecRef(TEST_SCALAR_T, noisy_array)
540 } else {
541 TermReference::VecRef(TEST_VECTOR_T, noisy_array)
542 }
543 }
544
545 #[test]
546 fn test_addition() {
547 let ctxt = get_test_vector_only_context();
548 let mut state = InterpreterState::new(&ctxt);
549 let args = vec![term_ref(array![1.0f32, 2.0f32]), term_ref(array![3.0f32, 4.0f32])];
550
551 let addition_func = BinaryFuncImpl {
552 elem_type : TEST_VECTOR_T,
553 f : Box::new(AddOperator {})
554 };
555
556 let (result, _) = addition_func.evaluate(&mut state, args);
557 assert_equal_vector_term(result, array![4.0f32, 6.0f32].view());
558 }
559 #[test]
560 fn test_rotate() {
561 let ctxt = get_test_vector_only_context();
562 let mut state = InterpreterState::new(&ctxt);
563 let args = vec![term_ref(array![5.0f32, 10.0f32])];
564
565 let rotate_func = RotateImpl {
566 vector_type : TEST_VECTOR_T
567 };
568
569 let (result, _) = rotate_func.evaluate(&mut state, args);
570 assert_equal_vector_term(result, array![10.0f32, 5.0f32].view());
571 }
572
573 #[test]
574 fn test_set_head() {
575 let ctxt = get_test_vector_only_context();
576 let mut state = InterpreterState::new(&ctxt);
577 let args = vec![term_ref(array![1.0f32, 2.0f32]), term_ref(array![9.0f32])];
578
579 let set_head_func = SetHeadImpl {
580 vector_type : TEST_VECTOR_T,
581 scalar_type : TEST_SCALAR_T
582 };
583
584 let (result, _) = set_head_func.evaluate(&mut state, args);
585 assert_equal_vector_term(result, array![9.0f32, 2.0f32].view());
586 }
587
588 #[test]
589 fn test_head() {
590 let ctxt = get_test_vector_only_context();
591 let mut state = InterpreterState::new(&ctxt);
592
593 let args = vec![term_ref(array![1.0f32, 2.0f32])];
594
595 let head_func = HeadImpl {
596 vector_type : TEST_VECTOR_T,
597 scalar_type : TEST_SCALAR_T
598 };
599
600 let (result, _) = head_func.evaluate(&mut state, args);
601 assert_equal_vector_term(result, array![1.0f32].view());
602 }
603
604 #[test]
605 fn test_fill() {
606 let ctxt = get_test_vector_only_context();
607 let mut state = InterpreterState::new(&ctxt);
608 let args = vec![term_ref(array![3.0f32])];
609
610 let fill_func = FillImpl {
611 vector_type : TEST_VECTOR_T,
612 scalar_type : TEST_SCALAR_T
613 };
614
615 let (result, _) = fill_func.evaluate(&mut state, args);
616 assert_equal_vector_term(result, array![3.0f32, 3.0f32].view());
617 }
618
619 #[test]
620 fn test_const() {
621 let ctxt = get_test_vector_only_context();
622 let mut state = InterpreterState::new(&ctxt);
623
624 let args = vec![term_ref(array![1.0f32, 2.0f32]), term_ref(array![3.0f32])];
625
626 let const_func = ConstImpl {
627 ret_type : TEST_VECTOR_T,
628 ignored_type : TEST_SCALAR_T
629 };
630
631 let (result, _) = const_func.evaluate(&mut state, args);
632 assert_equal_vector_term(result, array![1.0f32, 2.0f32].view());
633 }
634
635}