1use std::collections::BTreeMap;
11use std::sync::Arc;
12
13use num_bigint::BigInt;
14use num_traits::One;
15use parking_lot::Mutex;
16
17use crate::algebraic::AlgebraicNumber;
18use crate::rational::RnsRational;
19use crate::rns::Channels;
20
21pub trait Computable: Send + Sync {
23 fn evaluate(&self, precision: u64) -> RnsRational;
25}
26
27#[derive(Clone)]
29pub struct ComputableReal {
30 inner: Arc<dyn Computable>,
31 cache: Arc<Mutex<BTreeMap<u64, RnsRational>>>,
32 channels: Channels,
33}
34
35impl ComputableReal {
36 fn wrap(inner: Arc<dyn Computable>, channels: Channels) -> Self {
37 ComputableReal {
38 inner,
39 cache: Arc::new(Mutex::new(BTreeMap::new())),
40 channels,
41 }
42 }
43
44 pub fn evaluate(&self, precision: u64) -> RnsRational {
46 if let Some(r) = self.cache.lock().get(&precision) {
47 return r.clone();
48 }
49 let r = self.inner.evaluate(precision);
50 self.cache.lock().insert(precision, r.clone());
51 r
52 }
53
54 pub fn evaluate_f64(&self) -> f64 {
56 self.evaluate(20).to_f64()
57 }
58
59 pub fn channels(&self) -> Channels {
61 self.channels.clone()
62 }
63
64 pub fn from_rational(r: RnsRational) -> Self {
68 let channels = r.channels.clone();
69 Self::wrap(Arc::new(RationalC { r }), channels)
70 }
71
72 pub fn from_algebraic(a: AlgebraicNumber) -> Self {
74 let channels = a.channels.clone();
75 Self::wrap(Arc::new(AlgebraicC { a }), channels)
76 }
77
78 pub fn pi(channels: Channels) -> Self {
80 Self::wrap(Arc::new(PiC { channels: channels.clone() }), channels)
81 }
82
83 pub fn e(channels: Channels) -> Self {
85 Self::wrap(Arc::new(EulerC { channels: channels.clone() }), channels)
86 }
87
88 pub fn sqrt(r: RnsRational) -> Self {
90 let channels = r.channels.clone();
91 Self::wrap(Arc::new(SqrtC { r }), channels)
92 }
93
94 pub fn exp(r: RnsRational) -> Self {
96 let channels = r.channels.clone();
97 Self::wrap(Arc::new(ExpC { r }), channels)
98 }
99
100 pub fn ln(r: RnsRational) -> Self {
102 let channels = r.channels.clone();
103 Self::wrap(Arc::new(LnC { r }), channels)
104 }
105
106 pub fn add(&self, other: &Self) -> Self {
110 Self::wrap(
111 Arc::new(BinOp {
112 a: self.clone(),
113 b: other.clone(),
114 kind: BinKind::Add,
115 }),
116 self.channels.clone(),
117 )
118 }
119
120 pub fn sub(&self, other: &Self) -> Self {
122 self.add(&other.neg())
123 }
124
125 pub fn mul(&self, other: &Self) -> Self {
127 Self::wrap(
128 Arc::new(BinOp {
129 a: self.clone(),
130 b: other.clone(),
131 kind: BinKind::Mul,
132 }),
133 self.channels.clone(),
134 )
135 }
136
137 pub fn neg(&self) -> Self {
139 Self::wrap(Arc::new(NegC { a: self.clone() }), self.channels.clone())
140 }
141
142 pub fn recip(&self) -> Self {
144 Self::wrap(Arc::new(RecipC { a: self.clone() }), self.channels.clone())
145 }
146}
147
148fn eps(prec: u64, channels: &Channels) -> RnsRational {
152 RnsRational::new(BigInt::one(), pow10(prec), channels.clone())
153}
154
155fn pow10(p: u64) -> BigInt {
156 BigInt::from(10u8).pow(p as u32)
157}
158
159fn magnitude_digits(cr: &ComputableReal) -> u64 {
161 let v = cr.evaluate(4).to_f64().abs();
162 if v < 1.0 {
163 1
164 } else {
165 v.log10().floor() as u64 + 1
166 }
167}
168
169struct RationalC {
172 r: RnsRational,
173}
174impl Computable for RationalC {
175 fn evaluate(&self, _precision: u64) -> RnsRational {
176 self.r.clone()
177 }
178}
179
180struct AlgebraicC {
181 a: AlgebraicNumber,
182}
183impl Computable for AlgebraicC {
184 fn evaluate(&self, precision: u64) -> RnsRational {
185 let mut clone = self.a.clone();
186 let target = eps(precision + 1, &self.a.channels);
187 clone.refine_interval(&target);
188 clone.interval.0.midpoint(&clone.interval.1)
189 }
190}
191
192fn atan_inv(x: i64, target: &RnsRational, channels: &Channels) -> RnsRational {
194 let mut acc = RnsRational::zero(channels.clone());
195 let mut n: i64 = 0;
196 loop {
197 let exp = (2 * n + 1) as u32;
198 let denom = BigInt::from(2 * n + 1) * BigInt::from(x).pow(exp);
199 let sign = if n % 2 == 0 { 1 } else { -1 };
200 let term = RnsRational::new(BigInt::from(sign), denom, channels.clone());
201 acc = acc.add(&term);
202 if term.abs() < *target {
203 break;
204 }
205 n += 1;
206 }
207 acc
208}
209
210struct PiC {
211 channels: Channels,
212}
213impl Computable for PiC {
214 fn evaluate(&self, precision: u64) -> RnsRational {
215 let target = eps(precision + 5, &self.channels);
216 let a = atan_inv(5, &target, &self.channels)
217 .mul(&RnsRational::from_int(16, self.channels.clone()));
218 let b = atan_inv(239, &target, &self.channels)
219 .mul(&RnsRational::from_int(4, self.channels.clone()));
220 a.sub(&b)
221 }
222}
223
224struct EulerC {
225 channels: Channels,
226}
227impl Computable for EulerC {
228 fn evaluate(&self, precision: u64) -> RnsRational {
229 let target = eps(precision + 3, &self.channels);
230 let mut acc = RnsRational::zero(self.channels.clone());
231 let mut fact = BigInt::one();
232 let mut k: u64 = 0;
233 loop {
234 if k > 0 {
235 fact *= BigInt::from(k);
236 }
237 let term = RnsRational::new(BigInt::one(), fact.clone(), self.channels.clone());
238 acc = acc.add(&term);
239 if term < target {
240 break;
241 }
242 k += 1;
243 }
244 acc
245 }
246}
247
248struct SqrtC {
249 r: RnsRational,
250}
251impl Computable for SqrtC {
252 fn evaluate(&self, precision: u64) -> RnsRational {
253 let channels = self.r.channels.clone();
254 let target = eps(precision + 2, &channels);
255 let guess = self.r.to_f64().max(0.0).sqrt();
257 let mut x = if guess > 0.0 {
258 RnsRational::from_f64(guess, channels.clone())
259 } else {
260 RnsRational::from_int(1, channels.clone())
261 };
262 let two = RnsRational::from_int(2, channels.clone());
263 for _ in 0..200 {
265 if x.is_zero() {
266 break;
267 }
268 let next = x.add(&self.r.div(&x)).div(&two);
269 let err = next.mul(&next).sub(&self.r).abs();
270 x = next;
271 if err < target {
272 break;
273 }
274 }
275 x
276 }
277}
278
279struct ExpC {
280 r: RnsRational,
281}
282impl Computable for ExpC {
283 fn evaluate(&self, precision: u64) -> RnsRational {
284 let channels = self.r.channels.clone();
285 let target = eps(precision + 3, &channels);
286 let mut acc = RnsRational::zero(channels.clone());
287 let mut term = RnsRational::from_int(1, channels.clone()); let mut k: u64 = 0;
289 loop {
290 acc = acc.add(&term);
291 if k > 0 && term.abs() < target {
292 break;
293 }
294 k += 1;
295 term = term.mul(&self.r).div(&RnsRational::from_int(k as i64, channels.clone()));
297 if k > 5000 {
298 break;
299 }
300 }
301 acc
302 }
303}
304
305struct LnC {
306 r: RnsRational,
307}
308impl Computable for LnC {
309 fn evaluate(&self, precision: u64) -> RnsRational {
310 let channels = self.r.channels.clone();
311 let target = eps(precision + 3, &channels);
312 let one = RnsRational::from_int(1, channels.clone());
314 let t = self.r.sub(&one).div(&self.r.add(&one));
315 let t2 = t.mul(&t);
316 let mut acc = RnsRational::zero(channels.clone());
317 let mut power = t.clone();
318 let mut k: u64 = 0;
319 loop {
320 let term = power.div(&RnsRational::from_int((2 * k + 1) as i64, channels.clone()));
321 acc = acc.add(&term);
322 if term.abs() < target {
323 break;
324 }
325 power = power.mul(&t2);
326 k += 1;
327 if k > 100_000 {
328 break;
329 }
330 }
331 acc.mul(&RnsRational::from_int(2, channels.clone()))
332 }
333}
334
335enum BinKind {
336 Add,
337 Mul,
338}
339
340struct BinOp {
341 a: ComputableReal,
342 b: ComputableReal,
343 kind: BinKind,
344}
345impl Computable for BinOp {
346 fn evaluate(&self, precision: u64) -> RnsRational {
347 match self.kind {
348 BinKind::Add => {
349 let pa = self.a.evaluate(precision + 1);
350 let pb = self.b.evaluate(precision + 1);
351 pa.add(&pb)
352 }
353 BinKind::Mul => {
354 let guard = magnitude_digits(&self.a) + magnitude_digits(&self.b) + 2;
355 let pa = self.a.evaluate(precision + guard);
356 let pb = self.b.evaluate(precision + guard);
357 pa.mul(&pb)
358 }
359 }
360 }
361}
362
363struct NegC {
364 a: ComputableReal,
365}
366impl Computable for NegC {
367 fn evaluate(&self, precision: u64) -> RnsRational {
368 self.a.evaluate(precision).neg()
369 }
370}
371
372struct RecipC {
373 a: ComputableReal,
374}
375impl Computable for RecipC {
376 fn evaluate(&self, precision: u64) -> RnsRational {
377 let v = self.a.evaluate(4).to_f64().abs();
379 let extra = if v > 0.0 && v < 1.0 {
380 (-v.log10()).ceil() as u64 * 2 + 2
381 } else {
382 2
383 };
384 self.a.evaluate(precision + extra).recip()
385 }
386}
387
388impl AlgebraicNumber {
390 pub fn to_computable(&self) -> ComputableReal {
392 ComputableReal::from_algebraic(self.clone())
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 fn ch() -> Channels {
401 Channels::standard(32)
402 }
403
404 #[test]
405 fn pi_to_ten_places() {
406 let pi = ComputableReal::pi(ch());
407 assert!((pi.evaluate(10).to_f64() - std::f64::consts::PI).abs() < 1e-10);
408 }
409
410 #[test]
411 fn e_to_fifteen_places() {
412 let e = ComputableReal::e(ch());
413 assert!((e.evaluate(15).to_f64() - std::f64::consts::E).abs() < 1e-14);
414 }
415
416 #[test]
417 fn sqrt_two() {
418 let r2 = RnsRational::from_int(2, ch());
419 let s = ComputableReal::sqrt(r2);
420 assert!((s.evaluate(20).to_f64() - 2f64.sqrt()).abs() < 1e-12);
421 }
422
423 #[test]
424 fn rational_passes_through() {
425 let r = RnsRational::from_fraction(1, 3, ch());
426 let cr = ComputableReal::from_rational(r.clone());
427 assert_eq!(cr.evaluate(100), r);
428 }
429
430 #[test]
431 fn precision_contract() {
432 let pi = ComputableReal::pi(ch());
433 let lo = pi.evaluate(5).to_f64();
434 let hi = pi.evaluate(50).to_f64();
435 assert!((lo - hi).abs() < 1e-5);
436 }
437
438 #[test]
439 fn lazy_sum_of_pi_and_one() {
440 let pi = ComputableReal::pi(ch());
441 let one = ComputableReal::from_rational(RnsRational::from_int(1, ch()));
442 let sum = pi.add(&one);
443 assert!((sum.evaluate(20).to_f64() - (std::f64::consts::PI + 1.0)).abs() < 1e-12);
444 }
445
446 #[test]
447 fn exp_and_ln() {
448 let e = ComputableReal::exp(RnsRational::from_int(1, ch()));
449 assert!((e.evaluate(15).to_f64() - std::f64::consts::E).abs() < 1e-13);
450 let l = ComputableReal::ln(RnsRational::from_int(2, ch()));
451 assert!((l.evaluate(15).to_f64() - 2f64.ln()).abs() < 1e-13);
452 }
453
454 #[test]
455 fn algebraic_to_computable() {
456 let s2 = AlgebraicNumber::sqrt(2, ch()).to_computable();
457 assert!((s2.evaluate(15).to_f64() - 2f64.sqrt()).abs() < 1e-13);
458 }
459}