Skip to main content

prolog2/predicate_modules/
maths.rs

1use super::PredReturn;
2use crate::program::hypothesis::Hypothesis;
3use crate::{
4    heap::{
5        heap::{Cell, Heap, Tag},
6        query_heap::QueryHeap,
7        symbol_db::known_symbol_id,
8    },
9    program::predicate_table::PredicateTable,
10    Config,
11};
12
13use fsize::fsize;
14
15type MathFn = fn(usize, &QueryHeap) -> Number;
16
17// Minus symbol ID for distinguishing unary negation from binary subtraction
18const MINUS_SYMBOL: usize = known_symbol_id(3);
19
20// Math functions array using compile-time known symbol IDs
21// Indices match KNOWN_SYMBOLS in symbol_db.rs:
22// 0: false, 1: true, 2: +, 3: -, 4: *, 5: /, 6: **
23// 7: cos, 8: sin, 9: tan, 10: acos, 11: asin, 12: atan
24// 13: log, 14: abs, 15: round, 16: sqrt, 17: to_degrees, 18: to_radians
25const FUNCTIONS: [(usize, MathFn); 17] = [
26    (known_symbol_id(2), add),         // +
27    (known_symbol_id(3), sub),         // -
28    (known_symbol_id(4), mul),         // *
29    (known_symbol_id(5), div),         // /
30    (known_symbol_id(6), pow),         // **
31    (known_symbol_id(7), cos),         // cos
32    (known_symbol_id(8), sin),         // sin
33    (known_symbol_id(9), tan),         // tan
34    (known_symbol_id(10), acos),       // acos
35    (known_symbol_id(11), asin),       // asin
36    (known_symbol_id(12), atan),       // atan
37    (known_symbol_id(13), log),        // log
38    (known_symbol_id(14), abs),        // abs
39    (known_symbol_id(15), round),      // round
40    (known_symbol_id(16), sqrt),       // sqrt
41    (known_symbol_id(17), to_degrees), // to_degrees
42    (known_symbol_id(18), to_radians), // to_radians
43];
44
45#[derive(Debug, Clone, Copy)]
46enum Number {
47    Flt(fsize),
48    Int(isize),
49}
50
51impl Number {
52    fn float(&self) -> fsize {
53        match self {
54            Number::Flt(v) => *v,
55            Number::Int(v) => *v as fsize,
56        }
57    }
58
59    fn to_cell(&self) -> Cell {
60        match self {
61            Number::Flt(value) => (Tag::Flt, f64::to_bits(*value) as usize),
62            Number::Int(value) => (Tag::Int, isize::cast_unsigned(*value)),
63        }
64    }
65
66    pub fn power(self, rhs: Self) -> Number {
67        match (self, rhs) {
68            (Number::Int(v1), Number::Int(v2)) if v2 > 0 => {
69                Number::Int(v1.pow(v2.try_into().unwrap()))
70            }
71            (lhs, rhs) => Number::Flt(lhs.float().powf(rhs.float())),
72        }
73    }
74
75    pub fn abs(self) -> Number {
76        match self {
77            Number::Flt(value) => Number::Flt(value.abs()),
78            Number::Int(value) => Number::Int(value.abs()),
79        }
80    }
81
82    pub fn round(self) -> Number {
83        match self {
84            Number::Flt(value) => Number::Int(value.round() as isize),
85            Number::Int(value) => Number::Int(value),
86        }
87    }
88}
89
90impl std::ops::Add for Number {
91    type Output = Number;
92    fn add(self, rhs: Self) -> Self::Output {
93        match (self, rhs) {
94            (Number::Int(v1), Number::Int(v2)) => match v1.checked_add(v2) {
95                Some(result) => Number::Int(result),
96                None => Number::Flt(v1 as f64 + v2 as f64),
97            },
98            (lhs, rhs) => Number::Flt(lhs.float() + rhs.float()),
99        }
100    }
101}
102
103impl std::ops::Sub for Number {
104    type Output = Number;
105    fn sub(self, rhs: Self) -> Self::Output {
106        match (self, rhs) {
107            (Number::Int(v1), Number::Int(v2)) => match v1.checked_sub(v2) {
108                Some(result) => Number::Int(result),
109                None => Number::Flt(v1 as f64 - v2 as f64),
110            },
111            (lhs, rhs) => Number::Flt(lhs.float() - rhs.float()),
112        }
113    }
114}
115
116impl std::ops::Mul for Number {
117    type Output = Number;
118    fn mul(self, rhs: Self) -> Self::Output {
119        match (self, rhs) {
120            (Number::Int(v1), Number::Int(v2)) => {
121                // Use checked multiplication to avoid overflow panic
122                match v1.checked_mul(v2) {
123                    Some(result) => Number::Int(result),
124                    None => Number::Flt(v1 as f64 * v2 as f64), // Fallback to float on overflow
125                }
126            }
127            (lhs, rhs) => Number::Flt(lhs.float() * rhs.float()),
128        }
129    }
130}
131
132impl std::ops::Div for Number {
133    type Output = Number;
134    fn div(self, rhs: Self) -> Self::Output {
135        match (self, rhs) {
136            (Number::Int(v1), Number::Int(v2)) => {
137                if v2 == 0 {
138                    Number::Flt(f64::NAN) // Return NaN for division by zero
139                } else {
140                    Number::Int(v1 / v2)
141                }
142            }
143            (lhs, rhs) => Number::Flt(lhs.float() / rhs.float()),
144        }
145    }
146}
147
148impl PartialEq for Number {
149    fn eq(&self, other: &Self) -> bool {
150        match (self, other) {
151            (Number::Int(v1), Number::Int(v2)) => v1 == v2,
152            (lhs, rhs) => lhs.float() == rhs.float(),
153        }
154    }
155}
156
157impl PartialOrd for Number {
158    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
159        match (self, other) {
160            (Number::Int(v1), Number::Int(v2)) => Some(v1.cmp(v2)),
161            _ => self.float().partial_cmp(&other.float()),
162        }
163    }
164}
165
166// Math operation functions
167fn add(addr: usize, heap: &QueryHeap) -> Number {
168    evaluate_term(addr + 2, heap) + evaluate_term(addr + 3, heap)
169}
170
171fn sub(addr: usize, heap: &QueryHeap) -> Number {
172    evaluate_term(addr + 2, heap) - evaluate_term(addr + 3, heap)
173}
174
175fn mul(addr: usize, heap: &QueryHeap) -> Number {
176    evaluate_term(addr + 2, heap) * evaluate_term(addr + 3, heap)
177}
178
179fn div(addr: usize, heap: &QueryHeap) -> Number {
180    evaluate_term(addr + 2, heap) / evaluate_term(addr + 3, heap)
181}
182
183fn pow(addr: usize, heap: &QueryHeap) -> Number {
184    evaluate_term(addr + 2, heap).power(evaluate_term(addr + 3, heap))
185}
186
187fn cos(addr: usize, heap: &QueryHeap) -> Number {
188    Number::Flt(evaluate_term(addr + 2, heap).float().cos())
189}
190
191fn sin(addr: usize, heap: &QueryHeap) -> Number {
192    Number::Flt(evaluate_term(addr + 2, heap).float().sin())
193}
194
195fn tan(addr: usize, heap: &QueryHeap) -> Number {
196    Number::Flt(evaluate_term(addr + 2, heap).float().tan())
197}
198
199fn acos(addr: usize, heap: &QueryHeap) -> Number {
200    Number::Flt(evaluate_term(addr + 2, heap).float().acos())
201}
202
203fn asin(addr: usize, heap: &QueryHeap) -> Number {
204    Number::Flt(evaluate_term(addr + 2, heap).float().asin())
205}
206
207fn atan(addr: usize, heap: &QueryHeap) -> Number {
208    Number::Flt(evaluate_term(addr + 2, heap).float().atan())
209}
210
211fn log(addr: usize, heap: &QueryHeap) -> Number {
212    Number::Flt(
213        evaluate_term(addr + 2, heap)
214            .float()
215            .log(evaluate_term(addr + 3, heap).float()),
216    )
217}
218
219fn abs(addr: usize, heap: &QueryHeap) -> Number {
220    evaluate_term(addr + 2, heap).abs()
221}
222
223fn round(addr: usize, heap: &QueryHeap) -> Number {
224    evaluate_term(addr + 2, heap).round()
225}
226
227fn to_radians(addr: usize, heap: &QueryHeap) -> Number {
228    Number::Flt(evaluate_term(addr + 2, heap).float().to_radians())
229}
230
231fn to_degrees(addr: usize, heap: &QueryHeap) -> Number {
232    Number::Flt(evaluate_term(addr + 2, heap).float().to_degrees())
233}
234
235fn neg(addr: usize, heap: &QueryHeap) -> Number {
236    match evaluate_term(addr + 2, heap) {
237        Number::Int(v) => Number::Int(-v),
238        Number::Flt(v) => Number::Flt(-v),
239    }
240}
241
242fn sqrt(addr: usize, heap: &QueryHeap) -> Number {
243    Number::Flt(evaluate_term(addr + 2, heap).float().sqrt())
244}
245
246fn evaluate_str(addr: usize, heap: &QueryHeap) -> Number {
247    let symbol = heap[addr + 1].1;
248    let arity = heap[addr].1;
249
250    // Handle unary minus: -(X) has arity 2 (functor + 1 arg)
251    if symbol == MINUS_SYMBOL && arity == 2 {
252        return neg(addr, heap);
253    }
254
255    for (id, funct) in FUNCTIONS.iter() {
256        if *id == symbol {
257            return funct(addr, heap);
258        }
259    }
260    panic!("Unknown function {}", heap.term_string(addr));
261}
262
263fn evaluate_term(addr: usize, heap: &QueryHeap) -> Number {
264    let addr = heap.deref_addr(addr);
265    match heap[addr] {
266        (Tag::Func, _) => evaluate_str(addr, heap),
267        (Tag::Str, ptr) => evaluate_str(ptr, heap),
268        (Tag::Int, value) => Number::Int(usize::cast_signed(value)),
269        (Tag::Flt, value) => {
270            #[cfg(target_pointer_width = "32")]
271            let float_value = fsize::from_bits(value as u32);
272
273            #[cfg(target_pointer_width = "64")]
274            let float_value = fsize::from_bits(value as u64);
275
276            Number::Flt(float_value)
277        }
278        _ => panic!(
279            "{:?} : {} not a valid mathematical expression",
280            heap[addr],
281            heap.term_string(addr),
282        ),
283    }
284}
285
286/// is/2 predicate: evaluates RHS and unifies with LHS
287pub fn is_pred(
288    heap: &mut QueryHeap,
289    _hypothesis: &mut Hypothesis,
290    goal: usize,
291    _pred_table: &PredicateTable,
292    _config: Config,
293) -> PredReturn {
294    // Goal structure: Func(3) | Con("is") | LHS | RHS
295    let goal_addr = heap.deref_addr(goal);
296    let func_addr = match heap[goal_addr] {
297        (Tag::Str, ptr) => ptr,
298        (Tag::Func, _) => goal_addr,
299        _ => panic!("is/2: expected structure, got {:?}", heap[goal_addr]),
300    };
301
302    let rhs = evaluate_term(func_addr + 3, heap);
303    let lhs_addr = heap.deref_addr(func_addr + 2);
304
305    match heap[lhs_addr] {
306        (Tag::Ref, _) => {
307            // LHS is unbound - create binding
308            let result_addr = heap.heap_push(rhs.to_cell());
309            PredReturn::Binding(vec![(lhs_addr, result_addr)])
310        }
311        _ => {
312            // LHS is bound - check equality
313            let lhs = evaluate_term(lhs_addr, heap);
314            PredReturn::bool(lhs == rhs)
315        }
316    }
317}