1use crate::ball::ArbBall;
41use crate::kernel::{ExprId, ExprPool};
42use std::collections::HashMap;
43use std::fmt;
44
45bitflags::bitflags! {
50 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52 pub struct Capabilities: u32 {
53 const SIMPLIFY = 1 << 0;
54 const DIFF_FORWARD = 1 << 1;
55 const DIFF_REVERSE = 1 << 2;
56 const NUMERIC_F64 = 1 << 3;
57 const NUMERIC_BALL = 1 << 4;
58 const LOWER_LLVM = 1 << 5;
59 const LEAN_THEOREM = 1 << 6;
60 }
61}
62
63impl fmt::Display for Capabilities {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 let names = [
66 (Capabilities::SIMPLIFY, "simplify"),
67 (Capabilities::DIFF_FORWARD, "diff_fwd"),
68 (Capabilities::DIFF_REVERSE, "diff_rev"),
69 (Capabilities::NUMERIC_F64, "numeric_f64"),
70 (Capabilities::NUMERIC_BALL, "numeric_ball"),
71 (Capabilities::LOWER_LLVM, "lower_llvm"),
72 (Capabilities::LEAN_THEOREM, "lean"),
73 ];
74 let present: Vec<&str> = names
75 .iter()
76 .filter(|(flag, _)| self.contains(*flag))
77 .map(|(_, name)| *name)
78 .collect();
79 write!(f, "[{}]", present.join(", "))
80 }
81}
82
83pub trait Primitive: 'static + Send + Sync {
92 fn name(&self) -> &'static str;
94
95 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String;
97
98 fn simplify(&self, _args: &[ExprId], _pool: &ExprPool) -> Option<ExprId> {
102 None
103 }
104
105 fn diff_forward(&self, _args: &[ExprId], _wrt: ExprId, _pool: &ExprPool) -> Option<ExprId> {
108 None
109 }
110
111 fn diff_reverse(
114 &self,
115 _args: &[ExprId],
116 _cotan: ExprId,
117 _pool: &ExprPool,
118 ) -> Option<Vec<ExprId>> {
119 None
120 }
121
122 fn numeric_f64(&self, _args: &[f64]) -> Option<f64> {
124 None
125 }
126
127 fn numeric_ball(&self, _args: &[ArbBall]) -> Option<ArbBall> {
129 None
130 }
131
132 fn lean_theorem(&self) -> Option<&'static str> {
134 None
135 }
136}
137
138#[derive(Debug, Clone)]
144pub struct CoverageRow {
145 pub name: String,
146 pub caps: Capabilities,
147}
148
149#[derive(Debug, Clone)]
152pub struct CoverageReport {
153 pub rows: Vec<CoverageRow>,
154}
155
156impl CoverageReport {
157 pub fn to_markdown(&self) -> String {
159 let header = "| Primitive | simplify | diff_fwd | diff_rev | numeric_f64 | numeric_ball | lower_llvm | lean |\n\
160 |---|---|---|---|---|---|---|---|";
161 let rows: Vec<String> = self
162 .rows
163 .iter()
164 .map(|r| {
165 let tick = |flag: Capabilities| {
166 if r.caps.contains(flag) {
167 "✓"
168 } else {
169 "✗"
170 }
171 };
172 format!(
173 "| {} | {} | {} | {} | {} | {} | {} | {} |",
174 r.name,
175 tick(Capabilities::SIMPLIFY),
176 tick(Capabilities::DIFF_FORWARD),
177 tick(Capabilities::DIFF_REVERSE),
178 tick(Capabilities::NUMERIC_F64),
179 tick(Capabilities::NUMERIC_BALL),
180 tick(Capabilities::LOWER_LLVM),
181 tick(Capabilities::LEAN_THEOREM),
182 )
183 })
184 .collect();
185 format!("{}\n{}", header, rows.join("\n"))
186 }
187}
188
189struct Entry {
194 primitive: Box<dyn Primitive>,
195 caps: Capabilities,
196}
197
198pub struct PrimitiveRegistry {
208 map: HashMap<&'static str, Entry>,
209}
210
211impl PrimitiveRegistry {
212 pub fn new() -> Self {
214 PrimitiveRegistry {
215 map: HashMap::new(),
216 }
217 }
218
219 pub fn register(&mut self, p: Box<dyn Primitive>) {
223 let caps = probe_caps(&*p);
224 let name = p.name();
225 self.map.insert(name, Entry { primitive: p, caps });
226 }
227
228 pub fn get(&self, name: &str) -> Option<&dyn Primitive> {
230 self.map.get(name).map(|e| &*e.primitive)
231 }
232
233 pub fn capabilities(&self, name: &str) -> Capabilities {
236 self.map
237 .get(name)
238 .map(|e| e.caps)
239 .unwrap_or(Capabilities::empty())
240 }
241
242 pub fn coverage_report(&self) -> CoverageReport {
245 let mut rows: Vec<CoverageRow> = self
246 .map
247 .iter()
248 .map(|(name, e)| CoverageRow {
249 name: name.to_string(),
250 caps: e.caps,
251 })
252 .collect();
253 rows.sort_by(|a, b| a.name.cmp(&b.name));
254 CoverageReport { rows }
255 }
256
257 pub fn diff_forward(
260 &self,
261 name: &str,
262 args: &[ExprId],
263 wrt: ExprId,
264 pool: &ExprPool,
265 ) -> Option<ExprId> {
266 let entry = self.map.get(name)?;
267 entry.primitive.diff_forward(args, wrt, pool)
268 }
269
270 pub fn diff_reverse(
272 &self,
273 name: &str,
274 args: &[ExprId],
275 cotan: ExprId,
276 pool: &ExprPool,
277 ) -> Option<Vec<ExprId>> {
278 let entry = self.map.get(name)?;
279 entry.primitive.diff_reverse(args, cotan, pool)
280 }
281
282 pub fn numeric_f64(&self, name: &str, args: &[f64]) -> Option<f64> {
284 let entry = self.map.get(name)?;
285 entry.primitive.numeric_f64(args)
286 }
287
288 pub fn numeric_ball(&self, name: &str, args: &[ArbBall]) -> Option<ArbBall> {
290 let entry = self.map.get(name)?;
291 entry.primitive.numeric_ball(args)
292 }
293
294 pub fn default_registry() -> Self {
296 let mut reg = Self::new();
297 reg.register(Box::new(builtins::SinPrimitive));
298 reg.register(Box::new(builtins::CosPrimitive));
299 reg.register(Box::new(builtins::ExpPrimitive));
300 reg.register(Box::new(builtins::LogPrimitive));
301 reg.register(Box::new(builtins::SqrtPrimitive));
302 reg.register(Box::new(builtins::TanPrimitive));
304 reg.register(Box::new(builtins::SinhPrimitive));
305 reg.register(Box::new(builtins::CoshPrimitive));
306 reg.register(Box::new(builtins::TanhPrimitive));
307 reg.register(Box::new(builtins::AsinPrimitive));
308 reg.register(Box::new(builtins::AcosPrimitive));
309 reg.register(Box::new(builtins::AtanPrimitive));
310 reg.register(Box::new(builtins::ErfPrimitive));
311 reg.register(Box::new(builtins::ErfcPrimitive));
312 reg.register(Box::new(builtins::AbsPrimitive));
313 reg.register(Box::new(builtins::SignPrimitive));
314 reg.register(Box::new(builtins::FloorPrimitive));
315 reg.register(Box::new(builtins::CeilPrimitive));
316 reg.register(Box::new(builtins::RoundPrimitive));
317 reg.register(Box::new(builtins::Atan2Primitive));
318 reg.register(Box::new(builtins::GammaPrimitive));
319 reg.register(Box::new(builtins::MinPrimitive));
320 reg.register(Box::new(builtins::MaxPrimitive));
321 reg
322 }
323
324 pub fn is_registered(&self, name: &str) -> bool {
326 self.map.contains_key(name)
327 }
328
329 pub fn iter(&self) -> impl Iterator<Item = (&str, Capabilities)> {
331 self.map.iter().map(|(k, e)| (*k, e.caps))
332 }
333}
334
335impl Default for PrimitiveRegistry {
336 fn default() -> Self {
337 Self::default_registry()
338 }
339}
340
341fn probe_caps(p: &dyn Primitive) -> Capabilities {
349 let mut caps = Capabilities::empty();
350
351 let probe_f64_sets: [&[f64]; 2] = [&[1.0], &[1.0, 2.0]];
354 for args in probe_f64_sets {
355 if p.numeric_f64(args).is_some() {
356 caps |= Capabilities::NUMERIC_F64;
357 break;
358 }
359 }
360
361 let ball1 = [ArbBall::from_f64(1.0, 128)];
362 let ball2 = [ArbBall::from_f64(1.0, 128), ArbBall::from_f64(2.0, 128)];
363 if p.numeric_ball(&ball1).is_some() || p.numeric_ball(&ball2).is_some() {
364 caps |= Capabilities::NUMERIC_BALL;
365 }
366
367 let pool = ExprPool::new();
369 let x = pool.symbol("_probe", crate::kernel::Domain::Real);
370 let y = pool.symbol("_probe_y", crate::kernel::Domain::Real);
371 let probe_id_sets: [Vec<ExprId>; 2] = [vec![x], vec![x, y]];
372
373 for args in &probe_id_sets {
374 if p.diff_forward(args, x, &pool).is_some() {
375 caps |= Capabilities::DIFF_FORWARD;
376 break;
377 }
378 }
379 for args in &probe_id_sets {
380 if p.diff_reverse(args, x, &pool).is_some() {
381 caps |= Capabilities::DIFF_REVERSE;
382 break;
383 }
384 }
385 for args in &probe_id_sets {
386 if p.simplify(args, &pool).is_some() {
387 caps |= Capabilities::SIMPLIFY;
388 break;
389 }
390 }
391 if p.lean_theorem().is_some() {
392 caps |= Capabilities::LEAN_THEOREM;
393 }
394 caps
395}
396
397pub mod builtins {
402 use super::Primitive;
403 use crate::ball::ArbBall;
404 use crate::kernel::{ExprId, ExprPool};
405
406 pub struct SinPrimitive;
409
410 impl Primitive for SinPrimitive {
411 fn name(&self) -> &'static str {
412 "sin"
413 }
414
415 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
416 format!("sin({})", pool.display(args[0]))
417 }
418
419 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
420 let x = args[0];
421 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
422 let cos_x = pool.func("cos", vec![x]);
423 Some(pool.mul(vec![cos_x, dx]))
424 }
425
426 fn diff_reverse(
427 &self,
428 args: &[ExprId],
429 cotan: ExprId,
430 pool: &ExprPool,
431 ) -> Option<Vec<ExprId>> {
432 let x = args[0];
433 let cos_x = pool.func("cos", vec![x]);
434 Some(vec![pool.mul(vec![cotan, cos_x])])
435 }
436
437 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
438 Some(args[0].sin())
439 }
440
441 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
442 Some(args[0].sin())
443 }
444
445 fn lean_theorem(&self) -> Option<&'static str> {
446 Some("Real.sin_deriv")
447 }
448 }
449
450 pub struct CosPrimitive;
453
454 impl Primitive for CosPrimitive {
455 fn name(&self) -> &'static str {
456 "cos"
457 }
458
459 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
460 format!("cos({})", pool.display(args[0]))
461 }
462
463 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
464 let x = args[0];
465 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
466 let neg_one = pool.integer(-1_i32);
467 let sin_x = pool.func("sin", vec![x]);
468 Some(pool.mul(vec![neg_one, sin_x, dx]))
469 }
470
471 fn diff_reverse(
472 &self,
473 args: &[ExprId],
474 cotan: ExprId,
475 pool: &ExprPool,
476 ) -> Option<Vec<ExprId>> {
477 let x = args[0];
478 let neg_one = pool.integer(-1_i32);
479 let sin_x = pool.func("sin", vec![x]);
480 Some(vec![pool.mul(vec![cotan, neg_one, sin_x])])
481 }
482
483 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
484 Some(args[0].cos())
485 }
486
487 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
488 Some(args[0].cos())
489 }
490
491 fn lean_theorem(&self) -> Option<&'static str> {
492 Some("Real.cos_deriv")
493 }
494 }
495
496 pub struct ExpPrimitive;
499
500 impl Primitive for ExpPrimitive {
501 fn name(&self) -> &'static str {
502 "exp"
503 }
504
505 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
506 format!("exp({})", pool.display(args[0]))
507 }
508
509 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
510 let x = args[0];
511 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
512 let exp_x = pool.func("exp", vec![x]);
513 Some(pool.mul(vec![exp_x, dx]))
514 }
515
516 fn diff_reverse(
517 &self,
518 args: &[ExprId],
519 cotan: ExprId,
520 pool: &ExprPool,
521 ) -> Option<Vec<ExprId>> {
522 let x = args[0];
523 let exp_x = pool.func("exp", vec![x]);
524 Some(vec![pool.mul(vec![cotan, exp_x])])
525 }
526
527 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
528 Some(args[0].exp())
529 }
530
531 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
532 Some(args[0].exp())
533 }
534
535 fn lean_theorem(&self) -> Option<&'static str> {
536 Some("Real.exp_deriv")
537 }
538 }
539
540 pub struct LogPrimitive;
543
544 impl Primitive for LogPrimitive {
545 fn name(&self) -> &'static str {
546 "log"
547 }
548
549 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
550 format!("log({})", pool.display(args[0]))
551 }
552
553 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
554 let x = args[0];
555 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
556 let x_inv = pool.pow(x, pool.integer(-1_i32));
558 Some(pool.mul(vec![x_inv, dx]))
559 }
560
561 fn diff_reverse(
562 &self,
563 args: &[ExprId],
564 cotan: ExprId,
565 pool: &ExprPool,
566 ) -> Option<Vec<ExprId>> {
567 let x = args[0];
568 let x_inv = pool.pow(x, pool.integer(-1_i32));
569 Some(vec![pool.mul(vec![cotan, x_inv])])
570 }
571
572 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
573 Some(args[0].ln())
574 }
575
576 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
577 args[0].log()
578 }
579
580 fn lean_theorem(&self) -> Option<&'static str> {
581 Some("Real.log_deriv")
582 }
583 }
584
585 pub struct SqrtPrimitive;
588
589 impl Primitive for SqrtPrimitive {
590 fn name(&self) -> &'static str {
591 "sqrt"
592 }
593
594 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
595 format!("sqrt({})", pool.display(args[0]))
596 }
597
598 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
599 let x = args[0];
600 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
601 let sqrt_x = pool.func("sqrt", vec![x]);
603 let two = pool.integer(2_i32);
604 let denom = pool.mul(vec![two, sqrt_x]);
605 let denom_inv = pool.pow(denom, pool.integer(-1_i32));
606 Some(pool.mul(vec![dx, denom_inv]))
607 }
608
609 fn diff_reverse(
610 &self,
611 args: &[ExprId],
612 cotan: ExprId,
613 pool: &ExprPool,
614 ) -> Option<Vec<ExprId>> {
615 let x = args[0];
616 let sqrt_x = pool.func("sqrt", vec![x]);
617 let two = pool.integer(2_i32);
618 let denom = pool.mul(vec![two, sqrt_x]);
619 let denom_inv = pool.pow(denom, pool.integer(-1_i32));
620 Some(vec![pool.mul(vec![cotan, denom_inv])])
621 }
622
623 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
624 Some(args[0].sqrt())
625 }
626
627 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
628 args[0].sqrt()
629 }
630
631 fn lean_theorem(&self) -> Option<&'static str> {
632 Some("Real.sqrt_deriv")
633 }
634 }
635
636 pub struct TanPrimitive;
639
640 impl Primitive for TanPrimitive {
641 fn name(&self) -> &'static str {
642 "tan"
643 }
644
645 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
646 format!("tan({})", pool.display(args[0]))
647 }
648
649 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
650 let x = args[0];
652 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
653 let tan_x = pool.func("tan", vec![x]);
654 let tan2 = pool.pow(tan_x, pool.integer(2_i32));
655 let one = pool.integer(1_i32);
656 let sec2 = pool.add(vec![one, tan2]);
657 Some(pool.mul(vec![sec2, dx]))
658 }
659
660 fn diff_reverse(
661 &self,
662 args: &[ExprId],
663 cotan: ExprId,
664 pool: &ExprPool,
665 ) -> Option<Vec<ExprId>> {
666 let x = args[0];
667 let tan_x = pool.func("tan", vec![x]);
668 let tan2 = pool.pow(tan_x, pool.integer(2_i32));
669 let one = pool.integer(1_i32);
670 let sec2 = pool.add(vec![one, tan2]);
671 Some(vec![pool.mul(vec![cotan, sec2])])
672 }
673
674 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
675 Some(args[0].tan())
676 }
677
678 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
679 args[0].tan()
680 }
681
682 fn lean_theorem(&self) -> Option<&'static str> {
683 Some("Real.tan_deriv")
684 }
685 }
686
687 pub struct SinhPrimitive;
690
691 impl Primitive for SinhPrimitive {
692 fn name(&self) -> &'static str {
693 "sinh"
694 }
695
696 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
697 format!("sinh({})", pool.display(args[0]))
698 }
699
700 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
701 let x = args[0];
702 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
703 let cosh_x = pool.func("cosh", vec![x]);
704 Some(pool.mul(vec![cosh_x, dx]))
705 }
706
707 fn diff_reverse(
708 &self,
709 args: &[ExprId],
710 cotan: ExprId,
711 pool: &ExprPool,
712 ) -> Option<Vec<ExprId>> {
713 let x = args[0];
714 let cosh_x = pool.func("cosh", vec![x]);
715 Some(vec![pool.mul(vec![cotan, cosh_x])])
716 }
717
718 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
719 Some(args[0].sinh())
720 }
721
722 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
723 Some(args[0].sinh())
724 }
725
726 fn lean_theorem(&self) -> Option<&'static str> {
727 Some("Real.sinh_deriv")
728 }
729 }
730
731 pub struct CoshPrimitive;
734
735 impl Primitive for CoshPrimitive {
736 fn name(&self) -> &'static str {
737 "cosh"
738 }
739
740 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
741 format!("cosh({})", pool.display(args[0]))
742 }
743
744 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
745 let x = args[0];
746 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
747 let sinh_x = pool.func("sinh", vec![x]);
748 Some(pool.mul(vec![sinh_x, dx]))
749 }
750
751 fn diff_reverse(
752 &self,
753 args: &[ExprId],
754 cotan: ExprId,
755 pool: &ExprPool,
756 ) -> Option<Vec<ExprId>> {
757 let x = args[0];
758 let sinh_x = pool.func("sinh", vec![x]);
759 Some(vec![pool.mul(vec![cotan, sinh_x])])
760 }
761
762 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
763 Some(args[0].cosh())
764 }
765
766 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
767 Some(args[0].cosh())
768 }
769
770 fn lean_theorem(&self) -> Option<&'static str> {
771 Some("Real.cosh_deriv")
772 }
773 }
774
775 pub struct TanhPrimitive;
778
779 impl Primitive for TanhPrimitive {
780 fn name(&self) -> &'static str {
781 "tanh"
782 }
783
784 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
785 format!("tanh({})", pool.display(args[0]))
786 }
787
788 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
789 let x = args[0];
791 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
792 let tanh_x = pool.func("tanh", vec![x]);
793 let tanh2 = pool.pow(tanh_x, pool.integer(2_i32));
794 let one = pool.integer(1_i32);
795 let neg_one = pool.integer(-1_i32);
796 let sech2 = pool.add(vec![one, pool.mul(vec![neg_one, tanh2])]);
797 Some(pool.mul(vec![sech2, dx]))
798 }
799
800 fn diff_reverse(
801 &self,
802 args: &[ExprId],
803 cotan: ExprId,
804 pool: &ExprPool,
805 ) -> Option<Vec<ExprId>> {
806 let x = args[0];
807 let tanh_x = pool.func("tanh", vec![x]);
808 let tanh2 = pool.pow(tanh_x, pool.integer(2_i32));
809 let one = pool.integer(1_i32);
810 let neg_one = pool.integer(-1_i32);
811 let sech2 = pool.add(vec![one, pool.mul(vec![neg_one, tanh2])]);
812 Some(vec![pool.mul(vec![cotan, sech2])])
813 }
814
815 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
816 Some(args[0].tanh())
817 }
818
819 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
820 Some(args[0].tanh())
821 }
822
823 fn lean_theorem(&self) -> Option<&'static str> {
824 Some("Real.tanh_deriv")
825 }
826 }
827
828 pub struct AsinPrimitive;
831
832 impl Primitive for AsinPrimitive {
833 fn name(&self) -> &'static str {
834 "asin"
835 }
836
837 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
838 format!("asin({})", pool.display(args[0]))
839 }
840
841 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
842 let x = args[0];
844 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
845 let x2 = pool.pow(x, pool.integer(2_i32));
846 let one = pool.integer(1_i32);
847 let neg_one = pool.integer(-1_i32);
848 let one_minus_x2 = pool.add(vec![one, pool.mul(vec![neg_one, x2])]);
849 let denom = pool.func("sqrt", vec![one_minus_x2]);
850 Some(pool.mul(vec![dx, pool.pow(denom, pool.integer(-1_i32))]))
851 }
852
853 fn diff_reverse(
854 &self,
855 args: &[ExprId],
856 cotan: ExprId,
857 pool: &ExprPool,
858 ) -> Option<Vec<ExprId>> {
859 let x = args[0];
860 let x2 = pool.pow(x, pool.integer(2_i32));
861 let one = pool.integer(1_i32);
862 let neg_one = pool.integer(-1_i32);
863 let one_minus_x2 = pool.add(vec![one, pool.mul(vec![neg_one, x2])]);
864 let denom = pool.func("sqrt", vec![one_minus_x2]);
865 Some(vec![
866 pool.mul(vec![cotan, pool.pow(denom, pool.integer(-1_i32))])
867 ])
868 }
869
870 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
871 Some(args[0].asin())
872 }
873
874 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
875 args[0].asin()
876 }
877 }
878
879 pub struct AcosPrimitive;
882
883 impl Primitive for AcosPrimitive {
884 fn name(&self) -> &'static str {
885 "acos"
886 }
887
888 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
889 format!("acos({})", pool.display(args[0]))
890 }
891
892 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
893 let x = args[0];
895 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
896 let x2 = pool.pow(x, pool.integer(2_i32));
897 let one = pool.integer(1_i32);
898 let neg_one = pool.integer(-1_i32);
899 let one_minus_x2 = pool.add(vec![one, pool.mul(vec![neg_one, x2])]);
900 let denom = pool.func("sqrt", vec![one_minus_x2]);
901 Some(pool.mul(vec![neg_one, dx, pool.pow(denom, pool.integer(-1_i32))]))
902 }
903
904 fn diff_reverse(
905 &self,
906 args: &[ExprId],
907 cotan: ExprId,
908 pool: &ExprPool,
909 ) -> Option<Vec<ExprId>> {
910 let x = args[0];
911 let x2 = pool.pow(x, pool.integer(2_i32));
912 let one = pool.integer(1_i32);
913 let neg_one = pool.integer(-1_i32);
914 let one_minus_x2 = pool.add(vec![one, pool.mul(vec![neg_one, x2])]);
915 let denom = pool.func("sqrt", vec![one_minus_x2]);
916 Some(vec![pool.mul(vec![
917 cotan,
918 neg_one,
919 pool.pow(denom, pool.integer(-1_i32)),
920 ])])
921 }
922
923 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
924 Some(args[0].acos())
925 }
926
927 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
928 args[0].acos()
929 }
930 }
931
932 pub struct AtanPrimitive;
935
936 impl Primitive for AtanPrimitive {
937 fn name(&self) -> &'static str {
938 "atan"
939 }
940
941 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
942 format!("atan({})", pool.display(args[0]))
943 }
944
945 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
946 let x = args[0];
948 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
949 let x2 = pool.pow(x, pool.integer(2_i32));
950 let one = pool.integer(1_i32);
951 let denom = pool.add(vec![one, x2]);
952 Some(pool.mul(vec![dx, pool.pow(denom, pool.integer(-1_i32))]))
953 }
954
955 fn diff_reverse(
956 &self,
957 args: &[ExprId],
958 cotan: ExprId,
959 pool: &ExprPool,
960 ) -> Option<Vec<ExprId>> {
961 let x = args[0];
962 let x2 = pool.pow(x, pool.integer(2_i32));
963 let one = pool.integer(1_i32);
964 let denom = pool.add(vec![one, x2]);
965 Some(vec![
966 pool.mul(vec![cotan, pool.pow(denom, pool.integer(-1_i32))])
967 ])
968 }
969
970 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
971 Some(args[0].atan())
972 }
973
974 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
975 Some(args[0].atan())
976 }
977
978 fn lean_theorem(&self) -> Option<&'static str> {
979 Some("Real.arctan_deriv")
980 }
981 }
982
983 pub struct ErfPrimitive;
986
987 impl Primitive for ErfPrimitive {
988 fn name(&self) -> &'static str {
989 "erf"
990 }
991
992 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
993 format!("erf({})", pool.display(args[0]))
994 }
995
996 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
997 let x = args[0];
999 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
1000 let x2 = pool.pow(x, pool.integer(2_i32));
1001 let neg_x2 = pool.mul(vec![pool.integer(-1_i32), x2]);
1002 let exp_neg_x2 = pool.func("exp", vec![neg_x2]);
1003 let coeff = pool.float(2.0 / std::f64::consts::PI.sqrt(), 53);
1004 Some(pool.mul(vec![coeff, exp_neg_x2, dx]))
1005 }
1006
1007 fn diff_reverse(
1008 &self,
1009 args: &[ExprId],
1010 cotan: ExprId,
1011 pool: &ExprPool,
1012 ) -> Option<Vec<ExprId>> {
1013 let x = args[0];
1014 let x2 = pool.pow(x, pool.integer(2_i32));
1015 let neg_x2 = pool.mul(vec![pool.integer(-1_i32), x2]);
1016 let exp_neg_x2 = pool.func("exp", vec![neg_x2]);
1017 let coeff = pool.float(2.0 / std::f64::consts::PI.sqrt(), 53);
1018 Some(vec![pool.mul(vec![cotan, coeff, exp_neg_x2])])
1019 }
1020
1021 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1022 Some(libm_erf(args[0]))
1023 }
1024
1025 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
1026 Some(args[0].erf())
1027 }
1028 }
1029
1030 pub struct ErfcPrimitive;
1033
1034 impl Primitive for ErfcPrimitive {
1035 fn name(&self) -> &'static str {
1036 "erfc"
1037 }
1038
1039 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1040 format!("erfc({})", pool.display(args[0]))
1041 }
1042
1043 fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
1044 let x = args[0];
1045 let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
1046 let x2 = pool.pow(x, pool.integer(2_i32));
1047 let neg_x2 = pool.mul(vec![pool.integer(-1_i32), x2]);
1048 let exp_neg_x2 = pool.func("exp", vec![neg_x2]);
1049 let coeff = pool.float(-2.0 / std::f64::consts::PI.sqrt(), 53);
1050 Some(pool.mul(vec![coeff, exp_neg_x2, dx]))
1051 }
1052
1053 fn diff_reverse(
1054 &self,
1055 args: &[ExprId],
1056 cotan: ExprId,
1057 pool: &ExprPool,
1058 ) -> Option<Vec<ExprId>> {
1059 let x = args[0];
1060 let x2 = pool.pow(x, pool.integer(2_i32));
1061 let neg_x2 = pool.mul(vec![pool.integer(-1_i32), x2]);
1062 let exp_neg_x2 = pool.func("exp", vec![neg_x2]);
1063 let coeff = pool.float(-2.0 / std::f64::consts::PI.sqrt(), 53);
1064 Some(vec![pool.mul(vec![cotan, coeff, exp_neg_x2])])
1065 }
1066
1067 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1068 Some(1.0 - libm_erf(args[0]))
1069 }
1070
1071 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
1072 Some(args[0].erfc())
1073 }
1074 }
1075
1076 pub struct AbsPrimitive;
1079
1080 impl Primitive for AbsPrimitive {
1081 fn name(&self) -> &'static str {
1082 "abs"
1083 }
1084
1085 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1086 format!("|{}|", pool.display(args[0]))
1087 }
1088
1089 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1090 Some(args[0].abs())
1091 }
1092
1093 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
1094 Some(args[0].abs_ball())
1095 }
1096 }
1097
1098 pub struct SignPrimitive;
1101
1102 impl Primitive for SignPrimitive {
1103 fn name(&self) -> &'static str {
1104 "sign"
1105 }
1106
1107 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1108 format!("sign({})", pool.display(args[0]))
1109 }
1110
1111 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1112 Some(if args[0] > 0.0 {
1113 1.0
1114 } else if args[0] < 0.0 {
1115 -1.0
1116 } else {
1117 0.0
1118 })
1119 }
1120 }
1121
1122 pub struct FloorPrimitive;
1125
1126 impl Primitive for FloorPrimitive {
1127 fn name(&self) -> &'static str {
1128 "floor"
1129 }
1130
1131 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1132 format!("floor({})", pool.display(args[0]))
1133 }
1134
1135 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1136 Some(args[0].floor())
1137 }
1138
1139 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
1140 Some(args[0].floor_ball())
1141 }
1142 }
1143
1144 pub struct CeilPrimitive;
1147
1148 impl Primitive for CeilPrimitive {
1149 fn name(&self) -> &'static str {
1150 "ceil"
1151 }
1152
1153 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1154 format!("ceil({})", pool.display(args[0]))
1155 }
1156
1157 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1158 Some(args[0].ceil())
1159 }
1160
1161 fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
1162 Some(args[0].ceil_ball())
1163 }
1164 }
1165
1166 pub struct RoundPrimitive;
1169
1170 impl Primitive for RoundPrimitive {
1171 fn name(&self) -> &'static str {
1172 "round"
1173 }
1174
1175 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1176 format!("round({})", pool.display(args[0]))
1177 }
1178
1179 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1180 Some(args[0].round())
1181 }
1182 }
1183
1184 pub struct Atan2Primitive;
1187
1188 impl Primitive for Atan2Primitive {
1189 fn name(&self) -> &'static str {
1190 "atan2"
1191 }
1192
1193 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1194 format!(
1195 "atan2({}, {})",
1196 pool.display(args[0]),
1197 pool.display(args[1])
1198 )
1199 }
1200
1201 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1202 if args.len() == 2 {
1203 Some(args[0].atan2(args[1]))
1204 } else {
1205 None
1206 }
1207 }
1208
1209 fn lean_theorem(&self) -> Option<&'static str> {
1210 Some("Real.arctan2")
1211 }
1212 }
1213
1214 pub struct GammaPrimitive;
1217
1218 impl Primitive for GammaPrimitive {
1219 fn name(&self) -> &'static str {
1220 "gamma"
1221 }
1222
1223 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1224 format!("Γ({})", pool.display(args[0]))
1225 }
1226
1227 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1228 Some(libm_gamma(args[0]))
1229 }
1230
1231 fn lean_theorem(&self) -> Option<&'static str> {
1232 Some("Real.Gamma")
1233 }
1234 }
1235
1236 pub struct MinPrimitive;
1239
1240 impl Primitive for MinPrimitive {
1241 fn name(&self) -> &'static str {
1242 "min"
1243 }
1244
1245 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1246 format!("min({}, {})", pool.display(args[0]), pool.display(args[1]))
1247 }
1248
1249 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1250 if args.len() == 2 {
1251 Some(args[0].min(args[1]))
1252 } else {
1253 None
1254 }
1255 }
1256 }
1257
1258 pub struct MaxPrimitive;
1261
1262 impl Primitive for MaxPrimitive {
1263 fn name(&self) -> &'static str {
1264 "max"
1265 }
1266
1267 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1268 format!("max({}, {})", pool.display(args[0]), pool.display(args[1]))
1269 }
1270
1271 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1272 if args.len() == 2 {
1273 Some(args[0].max(args[1]))
1274 } else {
1275 None
1276 }
1277 }
1278 }
1279
1280 fn libm_gamma(x: f64) -> f64 {
1286 const G: f64 = 7.0;
1287 const P: [f64; 9] = [
1288 0.999_999_999_999_809_9,
1289 676.520_368_121_885_1,
1290 -1_259.139_216_722_402_8,
1291 771.323_428_777_653_1,
1292 -176.615_029_162_140_6,
1293 12.507_343_278_686_905,
1294 -0.138_571_095_265_720_12,
1295 9.984_369_578_019_572e-6,
1296 1.505_632_735_149_311_6e-7,
1297 ];
1298 if x < 0.5 {
1299 std::f64::consts::PI / ((std::f64::consts::PI * x).sin() * libm_gamma(1.0 - x))
1301 } else {
1302 let xm = x - 1.0;
1303 let mut a = P[0];
1304 for (i, p) in P.iter().enumerate().skip(1) {
1305 a += p / (xm + i as f64);
1306 }
1307 let t = xm + G + 0.5;
1308 (2.0 * std::f64::consts::PI).sqrt() * t.powf(xm + 0.5) * (-t).exp() * a
1309 }
1310 }
1311
1312 fn libm_erf(x: f64) -> f64 {
1318 let t = 1.0 / (1.0 + 0.3275911 * x.abs());
1319 let poly = t
1320 * (0.254829592
1321 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
1322 let sign = if x < 0.0 { -1.0 } else { 1.0 };
1323 sign * (1.0 - poly * (-x * x).exp())
1324 }
1325}
1326
1327#[cfg(test)]
1332mod tests {
1333 use super::*;
1334 use crate::kernel::{Domain, ExprPool};
1335
1336 #[test]
1337 fn default_registry_has_builtins() {
1338 let reg = PrimitiveRegistry::default_registry();
1339 for name in &["sin", "cos", "exp", "log", "sqrt"] {
1340 assert!(reg.is_registered(name), "{name} not registered");
1341 let caps = reg.capabilities(name);
1342 assert!(
1343 caps.contains(Capabilities::NUMERIC_F64),
1344 "{name} missing NUMERIC_F64"
1345 );
1346 assert!(
1347 caps.contains(Capabilities::DIFF_FORWARD),
1348 "{name} missing DIFF_FORWARD"
1349 );
1350 assert!(
1351 caps.contains(Capabilities::DIFF_REVERSE),
1352 "{name} missing DIFF_REVERSE"
1353 );
1354 assert!(
1355 caps.contains(Capabilities::NUMERIC_BALL),
1356 "{name} missing NUMERIC_BALL"
1357 );
1358 }
1359 }
1360
1361 #[test]
1362 fn numeric_f64_correct() {
1363 let reg = PrimitiveRegistry::default_registry();
1364 let cases: &[(&str, f64, f64)] = &[
1365 ("sin", 0.0, 0.0),
1366 ("cos", 0.0, 1.0),
1367 ("exp", 0.0, 1.0),
1368 ("log", 1.0, 0.0),
1369 ("sqrt", 4.0, 2.0),
1370 ];
1371 for (name, input, expected) in cases {
1372 let got = reg.numeric_f64(name, &[*input]).unwrap();
1373 assert!(
1374 (got - expected).abs() < 1e-12,
1375 "{name}({input}) = {got}, expected {expected}"
1376 );
1377 }
1378 }
1379
1380 #[test]
1381 fn diff_forward_sin() {
1382 let reg = PrimitiveRegistry::default_registry();
1383 let pool = ExprPool::new();
1384 let x = pool.symbol("x", Domain::Real);
1385 let result = reg.diff_forward("sin", &[x], x, &pool);
1386 assert!(result.is_some(), "sin diff_forward returned None");
1387 }
1388
1389 #[test]
1390 fn coverage_report_markdown() {
1391 let reg = PrimitiveRegistry::default_registry();
1392 let report = reg.coverage_report();
1393 let md = report.to_markdown();
1394 assert!(md.contains("sin"), "coverage report missing sin");
1395 assert!(md.contains("✓"), "coverage report has no ticks");
1396 }
1397
1398 #[test]
1399 fn custom_primitive_registration() {
1400 struct TanhPrimitive;
1401 impl Primitive for TanhPrimitive {
1402 fn name(&self) -> &'static str {
1403 "tanh"
1404 }
1405 fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1406 format!("tanh({})", pool.display(args[0]))
1407 }
1408 fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1409 Some(args[0].tanh())
1410 }
1411 }
1412
1413 let mut reg = PrimitiveRegistry::new();
1414 reg.register(Box::new(TanhPrimitive));
1415 assert!(reg.is_registered("tanh"));
1416 let caps = reg.capabilities("tanh");
1417 assert!(caps.contains(Capabilities::NUMERIC_F64));
1418 assert!(!caps.contains(Capabilities::DIFF_FORWARD));
1419
1420 let got = reg.numeric_f64("tanh", &[0.0]).unwrap();
1421 assert!((got - 0.0).abs() < 1e-12);
1422 }
1423}