Skip to main content

prolog2/predicate_modules/
maths.rs

1use std::sync::Arc;
2
3use super::PredReturn;
4use crate::{Config, heap::{
5    heap::{Cell, Heap, Tag},
6    query_heap::QueryHeap,
7    symbol_db::known_symbol_id,
8}, program::predicate_table::PredicateTable};
9use crate::program::hypothesis::Hypothesis;
10
11use fsize::fsize;
12
13type MathFn = fn(usize, &QueryHeap) -> Number;
14
15// Minus symbol ID for distinguishing unary negation from binary subtraction
16const MINUS_SYMBOL: usize = known_symbol_id(3);
17
18// Math functions array using compile-time known symbol IDs
19// Indices match KNOWN_SYMBOLS in symbol_db.rs:
20// 0: false, 1: true, 2: +, 3: -, 4: *, 5: /, 6: **
21// 7: cos, 8: sin, 9: tan, 10: acos, 11: asin, 12: atan
22// 13: log, 14: abs, 15: round, 16: sqrt, 17: to_degrees, 18: to_radians
23const FUNCTIONS: [(usize, MathFn); 17] = [
24    (known_symbol_id(2), add),    // +
25    (known_symbol_id(3), sub),    // -
26    (known_symbol_id(4), mul),    // *
27    (known_symbol_id(5), div),    // /
28    (known_symbol_id(6), pow),    // **
29    (known_symbol_id(7), cos),    // cos
30    (known_symbol_id(8), sin),    // sin
31    (known_symbol_id(9), tan),    // tan
32    (known_symbol_id(10), acos),  // acos
33    (known_symbol_id(11), asin),  // asin
34    (known_symbol_id(12), atan),  // atan
35    (known_symbol_id(13), log),   // log
36    (known_symbol_id(14), abs),   // abs
37    (known_symbol_id(15), round), // round
38    (known_symbol_id(16), sqrt),  // sqrt
39    (known_symbol_id(17), to_degrees),  // to_degrees
40    (known_symbol_id(18), to_radians),  // to_radians
41];
42
43#[derive(Debug, Clone, Copy)]
44enum Number {
45    Flt(fsize),
46    Int(isize),
47}
48
49impl Number {
50    fn float(&self) -> fsize {
51        match self {
52            Number::Flt(v) => *v,
53            Number::Int(v) => *v as fsize,
54        }
55    }
56
57    fn to_cell(&self) -> Cell {
58        match self {
59            Number::Flt(value) => (Tag::Flt, f64::to_bits(*value) as usize ),
60            Number::Int(value) => (Tag::Int, isize::cast_unsigned(*value) ),
61        }
62    }
63
64    pub fn power(self, rhs: Self) -> Number {
65        match (self, rhs) {
66            (Number::Int(v1), Number::Int(v2)) if v2 > 0 => {
67                Number::Int(v1.pow(v2.try_into().unwrap()))
68            }
69            (lhs, rhs) => Number::Flt(lhs.float().powf(rhs.float())),
70        }
71    }
72
73    pub fn abs(self) -> Number {
74        match self {
75            Number::Flt(value) => Number::Flt(value.abs()),
76            Number::Int(value) => Number::Int(value.abs()),
77        }
78    }
79
80    pub fn round(self) -> Number {
81        match self {
82            Number::Flt(value) => Number::Int(value.round() as isize),
83            Number::Int(value) => Number::Int(value),
84        }
85    }
86}
87
88impl std::ops::Add for Number {
89    type Output = Number;
90    fn add(self, rhs: Self) -> Self::Output {
91        match (self, rhs) {
92            (Number::Int(v1), Number::Int(v2)) => {
93                match v1.checked_add(v2) {
94                    Some(result) => Number::Int(result),
95                    None => Number::Flt(v1 as f64 + v2 as f64),
96                }
97            }
98            (lhs, rhs) => Number::Flt(lhs.float() + rhs.float()),
99        }
100    }
101}
102
103impl std::ops::Sub for Number {
104    type Output = Number;
105    fn sub(self, rhs: Self) -> Self::Output {
106        match (self, rhs) {
107            (Number::Int(v1), Number::Int(v2)) => {
108                match v1.checked_sub(v2) {
109                    Some(result) => Number::Int(result),
110                    None => Number::Flt(v1 as f64 - v2 as f64),
111                }
112            }
113            (lhs, rhs) => Number::Flt(lhs.float() - rhs.float()),
114        }
115    }
116}
117
118impl std::ops::Mul for Number {
119    type Output = Number;
120    fn mul(self, rhs: Self) -> Self::Output {
121        match (self, rhs) {
122            (Number::Int(v1), Number::Int(v2)) => {
123                // Use checked multiplication to avoid overflow panic
124                match v1.checked_mul(v2) {
125                    Some(result) => Number::Int(result),
126                    None => Number::Flt(v1 as f64 * v2 as f64), // Fallback to float on overflow
127                }
128            }
129            (lhs, rhs) => Number::Flt(lhs.float() * rhs.float()),
130        }
131    }
132}
133
134impl std::ops::Div for Number {
135    type Output = Number;
136    fn div(self, rhs: Self) -> Self::Output {
137        match (self, rhs) {
138            (Number::Int(v1), Number::Int(v2)) => {
139                if v2 == 0 {
140                    Number::Flt(f64::NAN) // Return NaN for division by zero
141                } else {
142                    Number::Int(v1 / v2)
143                }
144            }
145            (lhs, rhs) => Number::Flt(lhs.float() / rhs.float()),
146        }
147    }
148}
149
150impl PartialEq for Number {
151    fn eq(&self, other: &Self) -> bool {
152        match (self, other) {
153            (Number::Int(v1), Number::Int(v2)) => v1 == v2,
154            (lhs, rhs) => lhs.float() == rhs.float(),
155        }
156    }
157}
158
159impl PartialOrd for Number {
160    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
161        match (self, other) {
162            (Number::Int(v1), Number::Int(v2)) => Some(v1.cmp(v2)),
163            _ => self.float().partial_cmp(&other.float()),
164        }
165    }
166}
167
168// Math operation functions
169fn add(addr: usize, heap: &QueryHeap) -> Number {
170    evaluate_term(addr + 2, heap) + evaluate_term(addr + 3, heap)
171}
172
173fn sub(addr: usize, heap: &QueryHeap) -> Number {
174    evaluate_term(addr + 2, heap) - evaluate_term(addr + 3, heap)
175}
176
177fn mul(addr: usize, heap: &QueryHeap) -> Number {
178    evaluate_term(addr + 2, heap) * evaluate_term(addr + 3, heap)
179}
180
181fn div(addr: usize, heap: &QueryHeap) -> Number {
182    evaluate_term(addr + 2, heap) / evaluate_term(addr + 3, heap)
183}
184
185fn pow(addr: usize, heap: &QueryHeap) -> Number {
186    evaluate_term(addr + 2, heap).power(evaluate_term(addr + 3, heap))
187}
188
189fn cos(addr: usize, heap: &QueryHeap) -> Number {
190    Number::Flt(evaluate_term(addr + 2, heap).float().cos())
191}
192
193fn sin(addr: usize, heap: &QueryHeap) -> Number {
194    Number::Flt(evaluate_term(addr + 2, heap).float().sin())
195}
196
197fn tan(addr: usize, heap: &QueryHeap) -> Number {
198    Number::Flt(evaluate_term(addr + 2, heap).float().tan())
199}
200
201fn acos(addr: usize, heap: &QueryHeap) -> Number {
202    Number::Flt(evaluate_term(addr + 2, heap).float().acos())
203}
204
205fn asin(addr: usize, heap: &QueryHeap) -> Number {
206    Number::Flt(evaluate_term(addr + 2, heap).float().asin())
207}
208
209fn atan(addr: usize, heap: &QueryHeap) -> Number {
210    Number::Flt(evaluate_term(addr + 2, heap).float().atan())
211}
212
213fn log(addr: usize, heap: &QueryHeap) -> Number {
214    Number::Flt(
215        evaluate_term(addr + 2, heap)
216            .float()
217            .log(evaluate_term(addr + 3, heap).float()),
218    )
219}
220
221fn abs(addr: usize, heap: &QueryHeap) -> Number {
222    evaluate_term(addr + 2, heap).abs()
223}
224
225fn round(addr: usize, heap: &QueryHeap) -> Number {
226    evaluate_term(addr + 2, heap).round()
227}
228
229fn to_radians(addr: usize, heap: &QueryHeap) -> Number {
230    Number::Flt(evaluate_term(addr + 2, heap).float().to_radians())
231}
232
233fn to_degrees(addr: usize, heap: &QueryHeap) -> Number {
234    Number::Flt(evaluate_term(addr + 2, heap).float().to_degrees())
235}
236
237fn neg(addr: usize, heap: &QueryHeap) -> Number {
238    match evaluate_term(addr + 2, heap) {
239        Number::Int(v) => Number::Int(-v),
240        Number::Flt(v) => Number::Flt(-v),
241    }
242}
243
244fn sqrt(addr: usize, heap: &QueryHeap) -> Number {
245    Number::Flt(evaluate_term(addr + 2, heap).float().sqrt())
246}
247
248fn evaluate_str(addr: usize, heap: &QueryHeap) -> Number {
249    let symbol = heap[addr + 1].1;
250    let arity = heap[addr].1;
251    
252    // Handle unary minus: -(X) has arity 2 (functor + 1 arg)
253    if symbol == MINUS_SYMBOL && arity == 2 {
254        return neg(addr, heap);
255    }
256    
257    for (id, funct) in FUNCTIONS.iter() {
258        if *id == symbol {
259            return funct(addr, heap);
260        }
261    }
262    panic!("Unknown function {}", heap.term_string(addr));
263}
264
265fn evaluate_term(addr: usize, heap: &QueryHeap) -> Number {
266    let addr = heap.deref_addr(addr);
267    match heap[addr] {
268        (Tag::Func, _) => evaluate_str(addr, heap),
269        (Tag::Str, ptr) => evaluate_str(ptr, heap),
270        (Tag::Int, value) => Number::Int(usize::cast_signed(value)),
271        (Tag::Flt, value) => {
272            #[cfg(target_pointer_width = "32")]
273            let float_value = fsize::from_bits(value as u32);
274
275            #[cfg(target_pointer_width = "64")]
276            let float_value = fsize::from_bits(value as u64);
277
278            Number::Flt(float_value)
279        },
280        _ => panic!(
281            "{:?} : {} not a valid mathematical expression",
282            heap[addr],
283            heap.term_string(addr),
284        ),
285    }
286}
287
288/// is/2 predicate: evaluates RHS and unifies with LHS
289pub fn is_pred(heap: &mut QueryHeap, _hypothesis: &mut Hypothesis, goal: usize, _pred_table: Arc<PredicateTable>, _config: Config) -> PredReturn {
290    // Goal structure: Func(3) | Con("is") | LHS | RHS
291    let goal_addr = heap.deref_addr(goal);
292    let func_addr = match heap[goal_addr] {
293        (Tag::Str, ptr) => ptr,
294        (Tag::Func, _) => goal_addr,
295        _ => panic!("is/2: expected structure, got {:?}", heap[goal_addr]),
296    };
297
298    let rhs = evaluate_term(func_addr + 3, heap);
299    let lhs_addr = heap.deref_addr(func_addr + 2);
300
301    match heap[lhs_addr] {
302        (Tag::Ref, _) => {
303            // LHS is unbound - create binding
304            let result_addr = heap.heap_push(rhs.to_cell());
305            PredReturn::Binding(vec![(lhs_addr, result_addr)])
306        }
307        _ => {
308            // LHS is bound - check equality
309            let lhs = evaluate_term(lhs_addr, heap);
310            PredReturn::bool(lhs == rhs)
311        }
312    }
313}