1use std::sync::Arc;
2
3use super::PredReturn;
4use crate::{Config, heap::{
5 heap::{Cell, Heap, Tag},
6 query_heap::QueryHeap,
7 symbol_db::known_symbol_id,
8}, program::predicate_table::PredicateTable};
9use crate::program::hypothesis::Hypothesis;
10
11use fsize::fsize;
12
13type MathFn = fn(usize, &QueryHeap) -> Number;
14
15const MINUS_SYMBOL: usize = known_symbol_id(3);
17
18const FUNCTIONS: [(usize, MathFn); 17] = [
24 (known_symbol_id(2), add), (known_symbol_id(3), sub), (known_symbol_id(4), mul), (known_symbol_id(5), div), (known_symbol_id(6), pow), (known_symbol_id(7), cos), (known_symbol_id(8), sin), (known_symbol_id(9), tan), (known_symbol_id(10), acos), (known_symbol_id(11), asin), (known_symbol_id(12), atan), (known_symbol_id(13), log), (known_symbol_id(14), abs), (known_symbol_id(15), round), (known_symbol_id(16), sqrt), (known_symbol_id(17), to_degrees), (known_symbol_id(18), to_radians), ];
42
43#[derive(Debug, Clone, Copy)]
44enum Number {
45 Flt(fsize),
46 Int(isize),
47}
48
49impl Number {
50 fn float(&self) -> fsize {
51 match self {
52 Number::Flt(v) => *v,
53 Number::Int(v) => *v as fsize,
54 }
55 }
56
57 fn to_cell(&self) -> Cell {
58 match self {
59 Number::Flt(value) => (Tag::Flt, f64::to_bits(*value) as usize ),
60 Number::Int(value) => (Tag::Int, isize::cast_unsigned(*value) ),
61 }
62 }
63
64 pub fn power(self, rhs: Self) -> Number {
65 match (self, rhs) {
66 (Number::Int(v1), Number::Int(v2)) if v2 > 0 => {
67 Number::Int(v1.pow(v2.try_into().unwrap()))
68 }
69 (lhs, rhs) => Number::Flt(lhs.float().powf(rhs.float())),
70 }
71 }
72
73 pub fn abs(self) -> Number {
74 match self {
75 Number::Flt(value) => Number::Flt(value.abs()),
76 Number::Int(value) => Number::Int(value.abs()),
77 }
78 }
79
80 pub fn round(self) -> Number {
81 match self {
82 Number::Flt(value) => Number::Int(value.round() as isize),
83 Number::Int(value) => Number::Int(value),
84 }
85 }
86}
87
88impl std::ops::Add for Number {
89 type Output = Number;
90 fn add(self, rhs: Self) -> Self::Output {
91 match (self, rhs) {
92 (Number::Int(v1), Number::Int(v2)) => {
93 match v1.checked_add(v2) {
94 Some(result) => Number::Int(result),
95 None => Number::Flt(v1 as f64 + v2 as f64),
96 }
97 }
98 (lhs, rhs) => Number::Flt(lhs.float() + rhs.float()),
99 }
100 }
101}
102
103impl std::ops::Sub for Number {
104 type Output = Number;
105 fn sub(self, rhs: Self) -> Self::Output {
106 match (self, rhs) {
107 (Number::Int(v1), Number::Int(v2)) => {
108 match v1.checked_sub(v2) {
109 Some(result) => Number::Int(result),
110 None => Number::Flt(v1 as f64 - v2 as f64),
111 }
112 }
113 (lhs, rhs) => Number::Flt(lhs.float() - rhs.float()),
114 }
115 }
116}
117
118impl std::ops::Mul for Number {
119 type Output = Number;
120 fn mul(self, rhs: Self) -> Self::Output {
121 match (self, rhs) {
122 (Number::Int(v1), Number::Int(v2)) => {
123 match v1.checked_mul(v2) {
125 Some(result) => Number::Int(result),
126 None => Number::Flt(v1 as f64 * v2 as f64), }
128 }
129 (lhs, rhs) => Number::Flt(lhs.float() * rhs.float()),
130 }
131 }
132}
133
134impl std::ops::Div for Number {
135 type Output = Number;
136 fn div(self, rhs: Self) -> Self::Output {
137 match (self, rhs) {
138 (Number::Int(v1), Number::Int(v2)) => {
139 if v2 == 0 {
140 Number::Flt(f64::NAN) } else {
142 Number::Int(v1 / v2)
143 }
144 }
145 (lhs, rhs) => Number::Flt(lhs.float() / rhs.float()),
146 }
147 }
148}
149
150impl PartialEq for Number {
151 fn eq(&self, other: &Self) -> bool {
152 match (self, other) {
153 (Number::Int(v1), Number::Int(v2)) => v1 == v2,
154 (lhs, rhs) => lhs.float() == rhs.float(),
155 }
156 }
157}
158
159impl PartialOrd for Number {
160 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
161 match (self, other) {
162 (Number::Int(v1), Number::Int(v2)) => Some(v1.cmp(v2)),
163 _ => self.float().partial_cmp(&other.float()),
164 }
165 }
166}
167
168fn add(addr: usize, heap: &QueryHeap) -> Number {
170 evaluate_term(addr + 2, heap) + evaluate_term(addr + 3, heap)
171}
172
173fn sub(addr: usize, heap: &QueryHeap) -> Number {
174 evaluate_term(addr + 2, heap) - evaluate_term(addr + 3, heap)
175}
176
177fn mul(addr: usize, heap: &QueryHeap) -> Number {
178 evaluate_term(addr + 2, heap) * evaluate_term(addr + 3, heap)
179}
180
181fn div(addr: usize, heap: &QueryHeap) -> Number {
182 evaluate_term(addr + 2, heap) / evaluate_term(addr + 3, heap)
183}
184
185fn pow(addr: usize, heap: &QueryHeap) -> Number {
186 evaluate_term(addr + 2, heap).power(evaluate_term(addr + 3, heap))
187}
188
189fn cos(addr: usize, heap: &QueryHeap) -> Number {
190 Number::Flt(evaluate_term(addr + 2, heap).float().cos())
191}
192
193fn sin(addr: usize, heap: &QueryHeap) -> Number {
194 Number::Flt(evaluate_term(addr + 2, heap).float().sin())
195}
196
197fn tan(addr: usize, heap: &QueryHeap) -> Number {
198 Number::Flt(evaluate_term(addr + 2, heap).float().tan())
199}
200
201fn acos(addr: usize, heap: &QueryHeap) -> Number {
202 Number::Flt(evaluate_term(addr + 2, heap).float().acos())
203}
204
205fn asin(addr: usize, heap: &QueryHeap) -> Number {
206 Number::Flt(evaluate_term(addr + 2, heap).float().asin())
207}
208
209fn atan(addr: usize, heap: &QueryHeap) -> Number {
210 Number::Flt(evaluate_term(addr + 2, heap).float().atan())
211}
212
213fn log(addr: usize, heap: &QueryHeap) -> Number {
214 Number::Flt(
215 evaluate_term(addr + 2, heap)
216 .float()
217 .log(evaluate_term(addr + 3, heap).float()),
218 )
219}
220
221fn abs(addr: usize, heap: &QueryHeap) -> Number {
222 evaluate_term(addr + 2, heap).abs()
223}
224
225fn round(addr: usize, heap: &QueryHeap) -> Number {
226 evaluate_term(addr + 2, heap).round()
227}
228
229fn to_radians(addr: usize, heap: &QueryHeap) -> Number {
230 Number::Flt(evaluate_term(addr + 2, heap).float().to_radians())
231}
232
233fn to_degrees(addr: usize, heap: &QueryHeap) -> Number {
234 Number::Flt(evaluate_term(addr + 2, heap).float().to_degrees())
235}
236
237fn neg(addr: usize, heap: &QueryHeap) -> Number {
238 match evaluate_term(addr + 2, heap) {
239 Number::Int(v) => Number::Int(-v),
240 Number::Flt(v) => Number::Flt(-v),
241 }
242}
243
244fn sqrt(addr: usize, heap: &QueryHeap) -> Number {
245 Number::Flt(evaluate_term(addr + 2, heap).float().sqrt())
246}
247
248fn evaluate_str(addr: usize, heap: &QueryHeap) -> Number {
249 let symbol = heap[addr + 1].1;
250 let arity = heap[addr].1;
251
252 if symbol == MINUS_SYMBOL && arity == 2 {
254 return neg(addr, heap);
255 }
256
257 for (id, funct) in FUNCTIONS.iter() {
258 if *id == symbol {
259 return funct(addr, heap);
260 }
261 }
262 panic!("Unknown function {}", heap.term_string(addr));
263}
264
265fn evaluate_term(addr: usize, heap: &QueryHeap) -> Number {
266 let addr = heap.deref_addr(addr);
267 match heap[addr] {
268 (Tag::Func, _) => evaluate_str(addr, heap),
269 (Tag::Str, ptr) => evaluate_str(ptr, heap),
270 (Tag::Int, value) => Number::Int(usize::cast_signed(value)),
271 (Tag::Flt, value) => {
272 #[cfg(target_pointer_width = "32")]
273 let float_value = fsize::from_bits(value as u32);
274
275 #[cfg(target_pointer_width = "64")]
276 let float_value = fsize::from_bits(value as u64);
277
278 Number::Flt(float_value)
279 },
280 _ => panic!(
281 "{:?} : {} not a valid mathematical expression",
282 heap[addr],
283 heap.term_string(addr),
284 ),
285 }
286}
287
288pub fn is_pred(heap: &mut QueryHeap, _hypothesis: &mut Hypothesis, goal: usize, _pred_table: Arc<PredicateTable>, _config: Config) -> PredReturn {
290 let goal_addr = heap.deref_addr(goal);
292 let func_addr = match heap[goal_addr] {
293 (Tag::Str, ptr) => ptr,
294 (Tag::Func, _) => goal_addr,
295 _ => panic!("is/2: expected structure, got {:?}", heap[goal_addr]),
296 };
297
298 let rhs = evaluate_term(func_addr + 3, heap);
299 let lhs_addr = heap.deref_addr(func_addr + 2);
300
301 match heap[lhs_addr] {
302 (Tag::Ref, _) => {
303 let result_addr = heap.heap_push(rhs.to_cell());
305 PredReturn::Binding(vec![(lhs_addr, result_addr)])
306 }
307 _ => {
308 let lhs = evaluate_term(lhs_addr, heap);
310 PredReturn::bool(lhs == rhs)
311 }
312 }
313}