1use super::PredReturn;
2use crate::program::hypothesis::Hypothesis;
3use crate::{
4 heap::{
5 heap::{Cell, Heap, Tag},
6 query_heap::QueryHeap,
7 symbol_db::known_symbol_id,
8 },
9 program::predicate_table::PredicateTable,
10 Config,
11};
12
13use fsize::fsize;
14
15type MathFn = fn(usize, &QueryHeap) -> Number;
16
17const MINUS_SYMBOL: usize = known_symbol_id(3);
19
20const FUNCTIONS: [(usize, MathFn); 17] = [
26 (known_symbol_id(2), add), (known_symbol_id(3), sub), (known_symbol_id(4), mul), (known_symbol_id(5), div), (known_symbol_id(6), pow), (known_symbol_id(7), cos), (known_symbol_id(8), sin), (known_symbol_id(9), tan), (known_symbol_id(10), acos), (known_symbol_id(11), asin), (known_symbol_id(12), atan), (known_symbol_id(13), log), (known_symbol_id(14), abs), (known_symbol_id(15), round), (known_symbol_id(16), sqrt), (known_symbol_id(17), to_degrees), (known_symbol_id(18), to_radians), ];
44
45#[derive(Debug, Clone, Copy)]
46enum Number {
47 Flt(fsize),
48 Int(isize),
49}
50
51impl Number {
52 fn float(&self) -> fsize {
53 match self {
54 Number::Flt(v) => *v,
55 Number::Int(v) => *v as fsize,
56 }
57 }
58
59 fn to_cell(&self) -> Cell {
60 match self {
61 Number::Flt(value) => (Tag::Flt, f64::to_bits(*value) as usize),
62 Number::Int(value) => (Tag::Int, isize::cast_unsigned(*value)),
63 }
64 }
65
66 pub fn power(self, rhs: Self) -> Number {
67 match (self, rhs) {
68 (Number::Int(v1), Number::Int(v2)) if v2 > 0 => {
69 Number::Int(v1.pow(v2.try_into().unwrap()))
70 }
71 (lhs, rhs) => Number::Flt(lhs.float().powf(rhs.float())),
72 }
73 }
74
75 pub fn abs(self) -> Number {
76 match self {
77 Number::Flt(value) => Number::Flt(value.abs()),
78 Number::Int(value) => Number::Int(value.abs()),
79 }
80 }
81
82 pub fn round(self) -> Number {
83 match self {
84 Number::Flt(value) => Number::Int(value.round() as isize),
85 Number::Int(value) => Number::Int(value),
86 }
87 }
88}
89
90impl std::ops::Add for Number {
91 type Output = Number;
92 fn add(self, rhs: Self) -> Self::Output {
93 match (self, rhs) {
94 (Number::Int(v1), Number::Int(v2)) => match v1.checked_add(v2) {
95 Some(result) => Number::Int(result),
96 None => Number::Flt(v1 as f64 + v2 as f64),
97 },
98 (lhs, rhs) => Number::Flt(lhs.float() + rhs.float()),
99 }
100 }
101}
102
103impl std::ops::Sub for Number {
104 type Output = Number;
105 fn sub(self, rhs: Self) -> Self::Output {
106 match (self, rhs) {
107 (Number::Int(v1), Number::Int(v2)) => match v1.checked_sub(v2) {
108 Some(result) => Number::Int(result),
109 None => Number::Flt(v1 as f64 - v2 as f64),
110 },
111 (lhs, rhs) => Number::Flt(lhs.float() - rhs.float()),
112 }
113 }
114}
115
116impl std::ops::Mul for Number {
117 type Output = Number;
118 fn mul(self, rhs: Self) -> Self::Output {
119 match (self, rhs) {
120 (Number::Int(v1), Number::Int(v2)) => {
121 match v1.checked_mul(v2) {
123 Some(result) => Number::Int(result),
124 None => Number::Flt(v1 as f64 * v2 as f64), }
126 }
127 (lhs, rhs) => Number::Flt(lhs.float() * rhs.float()),
128 }
129 }
130}
131
132impl std::ops::Div for Number {
133 type Output = Number;
134 fn div(self, rhs: Self) -> Self::Output {
135 match (self, rhs) {
136 (Number::Int(v1), Number::Int(v2)) => {
137 if v2 == 0 {
138 Number::Flt(f64::NAN) } else {
140 Number::Int(v1 / v2)
141 }
142 }
143 (lhs, rhs) => Number::Flt(lhs.float() / rhs.float()),
144 }
145 }
146}
147
148impl PartialEq for Number {
149 fn eq(&self, other: &Self) -> bool {
150 match (self, other) {
151 (Number::Int(v1), Number::Int(v2)) => v1 == v2,
152 (lhs, rhs) => lhs.float() == rhs.float(),
153 }
154 }
155}
156
157impl PartialOrd for Number {
158 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
159 match (self, other) {
160 (Number::Int(v1), Number::Int(v2)) => Some(v1.cmp(v2)),
161 _ => self.float().partial_cmp(&other.float()),
162 }
163 }
164}
165
166fn add(addr: usize, heap: &QueryHeap) -> Number {
168 evaluate_term(addr + 2, heap) + evaluate_term(addr + 3, heap)
169}
170
171fn sub(addr: usize, heap: &QueryHeap) -> Number {
172 evaluate_term(addr + 2, heap) - evaluate_term(addr + 3, heap)
173}
174
175fn mul(addr: usize, heap: &QueryHeap) -> Number {
176 evaluate_term(addr + 2, heap) * evaluate_term(addr + 3, heap)
177}
178
179fn div(addr: usize, heap: &QueryHeap) -> Number {
180 evaluate_term(addr + 2, heap) / evaluate_term(addr + 3, heap)
181}
182
183fn pow(addr: usize, heap: &QueryHeap) -> Number {
184 evaluate_term(addr + 2, heap).power(evaluate_term(addr + 3, heap))
185}
186
187fn cos(addr: usize, heap: &QueryHeap) -> Number {
188 Number::Flt(evaluate_term(addr + 2, heap).float().cos())
189}
190
191fn sin(addr: usize, heap: &QueryHeap) -> Number {
192 Number::Flt(evaluate_term(addr + 2, heap).float().sin())
193}
194
195fn tan(addr: usize, heap: &QueryHeap) -> Number {
196 Number::Flt(evaluate_term(addr + 2, heap).float().tan())
197}
198
199fn acos(addr: usize, heap: &QueryHeap) -> Number {
200 Number::Flt(evaluate_term(addr + 2, heap).float().acos())
201}
202
203fn asin(addr: usize, heap: &QueryHeap) -> Number {
204 Number::Flt(evaluate_term(addr + 2, heap).float().asin())
205}
206
207fn atan(addr: usize, heap: &QueryHeap) -> Number {
208 Number::Flt(evaluate_term(addr + 2, heap).float().atan())
209}
210
211fn log(addr: usize, heap: &QueryHeap) -> Number {
212 Number::Flt(
213 evaluate_term(addr + 2, heap)
214 .float()
215 .log(evaluate_term(addr + 3, heap).float()),
216 )
217}
218
219fn abs(addr: usize, heap: &QueryHeap) -> Number {
220 evaluate_term(addr + 2, heap).abs()
221}
222
223fn round(addr: usize, heap: &QueryHeap) -> Number {
224 evaluate_term(addr + 2, heap).round()
225}
226
227fn to_radians(addr: usize, heap: &QueryHeap) -> Number {
228 Number::Flt(evaluate_term(addr + 2, heap).float().to_radians())
229}
230
231fn to_degrees(addr: usize, heap: &QueryHeap) -> Number {
232 Number::Flt(evaluate_term(addr + 2, heap).float().to_degrees())
233}
234
235fn neg(addr: usize, heap: &QueryHeap) -> Number {
236 match evaluate_term(addr + 2, heap) {
237 Number::Int(v) => Number::Int(-v),
238 Number::Flt(v) => Number::Flt(-v),
239 }
240}
241
242fn sqrt(addr: usize, heap: &QueryHeap) -> Number {
243 Number::Flt(evaluate_term(addr + 2, heap).float().sqrt())
244}
245
246fn evaluate_str(addr: usize, heap: &QueryHeap) -> Number {
247 let symbol = heap[addr + 1].1;
248 let arity = heap[addr].1;
249
250 if symbol == MINUS_SYMBOL && arity == 2 {
252 return neg(addr, heap);
253 }
254
255 for (id, funct) in FUNCTIONS.iter() {
256 if *id == symbol {
257 return funct(addr, heap);
258 }
259 }
260 panic!("Unknown function {}", heap.term_string(addr));
261}
262
263fn evaluate_term(addr: usize, heap: &QueryHeap) -> Number {
264 let addr = heap.deref_addr(addr);
265 match heap[addr] {
266 (Tag::Func, _) => evaluate_str(addr, heap),
267 (Tag::Str, ptr) => evaluate_str(ptr, heap),
268 (Tag::Int, value) => Number::Int(usize::cast_signed(value)),
269 (Tag::Flt, value) => {
270 #[cfg(target_pointer_width = "32")]
271 let float_value = fsize::from_bits(value as u32);
272
273 #[cfg(target_pointer_width = "64")]
274 let float_value = fsize::from_bits(value as u64);
275
276 Number::Flt(float_value)
277 }
278 _ => panic!(
279 "{:?} : {} not a valid mathematical expression",
280 heap[addr],
281 heap.term_string(addr),
282 ),
283 }
284}
285
286pub fn is_pred(
288 heap: &mut QueryHeap,
289 _hypothesis: &mut Hypothesis,
290 goal: usize,
291 _pred_table: &PredicateTable,
292 _config: Config,
293) -> PredReturn {
294 let goal_addr = heap.deref_addr(goal);
296 let func_addr = match heap[goal_addr] {
297 (Tag::Str, ptr) => ptr,
298 (Tag::Func, _) => goal_addr,
299 _ => panic!("is/2: expected structure, got {:?}", heap[goal_addr]),
300 };
301
302 let rhs = evaluate_term(func_addr + 3, heap);
303 let lhs_addr = heap.deref_addr(func_addr + 2);
304
305 match heap[lhs_addr] {
306 (Tag::Ref, _) => {
307 let result_addr = heap.heap_push(rhs.to_cell());
309 PredReturn::Binding(vec![(lhs_addr, result_addr)])
310 }
311 _ => {
312 let lhs = evaluate_term(lhs_addr, heap);
314 PredReturn::bool(lhs == rhs)
315 }
316 }
317}