Skip to main content

alkahest_cas/primitive/
mod.rs

1//! Primitive registry — central table mapping function names to their
2//! full capability bundles.
3//!
4//! # Motivation
5//!
6//! Before this registry existed, adding a new primitive (e.g. `erf`) required
7//! editing match arms in `simplify/rules.rs`, `diff/forward.rs`,
8//! `diff/reverse.rs`, `jit/mod.rs`, `ball/mod.rs`, and `horner.rs`
9//! independently.  The registry collapses that to one call:
10//! `registry.register(Box::new(ErfPrimitive))`.
11//!
12//! # Design
13//!
14//! Each primitive implements the [`Primitive`] trait.  Every method except
15//! [`Primitive::name`] and [`Primitive::pretty`] is optional — returning
16//! `None` means "not implemented yet".  Callers fall back gracefully
17//! (e.g. `diff_forward` returns a `Derivative(...)` placeholder if the
18//! registry returns `None`).
19//!
20//! The [`Capabilities`] bitfield lets tooling and agents ask
21//! "can I JIT this expression?" without attempting the operation.
22//!
23//! # Example
24//!
25//! ```rust
26//! use alkahest_cas::primitive::{PrimitiveRegistry, Capabilities};
27//!
28//! let reg = PrimitiveRegistry::default_registry();
29//! let caps = reg.capabilities("sin");
30//! assert!(caps.contains(Capabilities::NUMERIC_F64));
31//! assert!(caps.contains(Capabilities::DIFF_FORWARD));
32//!
33//! let report = reg.coverage_report();
34//! // Every built-in has at least NUMERIC_F64 coverage.
35//! for row in &report.rows {
36//!     assert!(row.caps.contains(Capabilities::NUMERIC_F64));
37//! }
38//! ```
39
40use crate::ball::ArbBall;
41use crate::kernel::{ExprId, ExprPool};
42use std::collections::HashMap;
43use std::fmt;
44
45// ---------------------------------------------------------------------------
46// Capability flags
47// ---------------------------------------------------------------------------
48
49bitflags::bitflags! {
50    /// Bit-field recording which capability bundle slots a primitive has filled.
51    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52    pub struct Capabilities: u32 {
53        const SIMPLIFY      = 1 << 0;
54        const DIFF_FORWARD  = 1 << 1;
55        const DIFF_REVERSE  = 1 << 2;
56        const NUMERIC_F64   = 1 << 3;
57        const NUMERIC_BALL  = 1 << 4;
58        const LOWER_LLVM    = 1 << 5;
59        const LEAN_THEOREM  = 1 << 6;
60    }
61}
62
63impl fmt::Display for Capabilities {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        let names = [
66            (Capabilities::SIMPLIFY, "simplify"),
67            (Capabilities::DIFF_FORWARD, "diff_fwd"),
68            (Capabilities::DIFF_REVERSE, "diff_rev"),
69            (Capabilities::NUMERIC_F64, "numeric_f64"),
70            (Capabilities::NUMERIC_BALL, "numeric_ball"),
71            (Capabilities::LOWER_LLVM, "lower_llvm"),
72            (Capabilities::LEAN_THEOREM, "lean"),
73        ];
74        let present: Vec<&str> = names
75            .iter()
76            .filter(|(flag, _)| self.contains(*flag))
77            .map(|(_, name)| *name)
78            .collect();
79        write!(f, "[{}]", present.join(", "))
80    }
81}
82
83// ---------------------------------------------------------------------------
84// Primitive trait
85// ---------------------------------------------------------------------------
86
87/// A primitive function that can be registered in [`PrimitiveRegistry`].
88///
89/// Only [`name`](Primitive::name) and [`pretty`](Primitive::pretty) are
90/// required; every other method is optional and defaults to returning `None`.
91pub trait Primitive: 'static + Send + Sync {
92    /// The canonical name used in `ExprData::Func { name, .. }`.
93    fn name(&self) -> &'static str;
94
95    /// Human-readable display.  Called by `ExprDisplay` for `Func` nodes.
96    fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String;
97
98    // ── Optional capability bundle ──────────────────────────────────────────
99
100    /// Algebraic simplification: try to reduce `self(args)` to a simpler form.
101    fn simplify(&self, _args: &[ExprId], _pool: &ExprPool) -> Option<ExprId> {
102        None
103    }
104
105    /// Forward-mode differentiation: `d/d_wrt (self(args))`.
106    /// Returns `None` if not implemented (caller should return a placeholder).
107    fn diff_forward(&self, _args: &[ExprId], _wrt: ExprId, _pool: &ExprPool) -> Option<ExprId> {
108        None
109    }
110
111    /// Reverse-mode differentiation: cotangent propagation.
112    /// Returns one cotangent per argument, or `None` if not implemented.
113    fn diff_reverse(
114        &self,
115        _args: &[ExprId],
116        _cotan: ExprId,
117        _pool: &ExprPool,
118    ) -> Option<Vec<ExprId>> {
119        None
120    }
121
122    /// Numerical evaluation at `f64` precision.
123    fn numeric_f64(&self, _args: &[f64]) -> Option<f64> {
124        None
125    }
126
127    /// Rigorous ball-arithmetic evaluation.
128    fn numeric_ball(&self, _args: &[ArbBall]) -> Option<ArbBall> {
129        None
130    }
131
132    /// Name of the Lean 4 / Mathlib theorem that certifies this primitive.
133    fn lean_theorem(&self) -> Option<&'static str> {
134        None
135    }
136}
137
138// ---------------------------------------------------------------------------
139// Coverage report
140// ---------------------------------------------------------------------------
141
142/// One row in the coverage report table.
143#[derive(Debug, Clone)]
144pub struct CoverageRow {
145    pub name: String,
146    pub caps: Capabilities,
147}
148
149/// Human-readable and machine-readable coverage table for all registered
150/// primitives.  Returned by [`PrimitiveRegistry::coverage_report`].
151#[derive(Debug, Clone)]
152pub struct CoverageReport {
153    pub rows: Vec<CoverageRow>,
154}
155
156impl CoverageReport {
157    /// Render as a Markdown table (suitable for CI PR comments or docs).
158    pub fn to_markdown(&self) -> String {
159        let header = "| Primitive | simplify | diff_fwd | diff_rev | numeric_f64 | numeric_ball | lower_llvm | lean |\n\
160                      |---|---|---|---|---|---|---|---|";
161        let rows: Vec<String> = self
162            .rows
163            .iter()
164            .map(|r| {
165                let tick = |flag: Capabilities| {
166                    if r.caps.contains(flag) {
167                        "✓"
168                    } else {
169                        "✗"
170                    }
171                };
172                format!(
173                    "| {} | {} | {} | {} | {} | {} | {} | {} |",
174                    r.name,
175                    tick(Capabilities::SIMPLIFY),
176                    tick(Capabilities::DIFF_FORWARD),
177                    tick(Capabilities::DIFF_REVERSE),
178                    tick(Capabilities::NUMERIC_F64),
179                    tick(Capabilities::NUMERIC_BALL),
180                    tick(Capabilities::LOWER_LLVM),
181                    tick(Capabilities::LEAN_THEOREM),
182                )
183            })
184            .collect();
185        format!("{}\n{}", header, rows.join("\n"))
186    }
187}
188
189// ---------------------------------------------------------------------------
190// Registry entry (stores the primitive + probed capabilities)
191// ---------------------------------------------------------------------------
192
193struct Entry {
194    primitive: Box<dyn Primitive>,
195    caps: Capabilities,
196}
197
198// ---------------------------------------------------------------------------
199// PrimitiveRegistry
200// ---------------------------------------------------------------------------
201
202/// Central registry mapping function names to their [`Primitive`]
203/// implementations.
204///
205/// Use [`default_registry`](PrimitiveRegistry::default_registry) to get a
206/// registry pre-populated with Alkahest's built-in functions.
207pub struct PrimitiveRegistry {
208    map: HashMap<&'static str, Entry>,
209}
210
211impl PrimitiveRegistry {
212    /// Create an empty registry.
213    pub fn new() -> Self {
214        PrimitiveRegistry {
215            map: HashMap::new(),
216        }
217    }
218
219    /// Register a primitive.  Probes capabilities by calling each optional
220    /// method with a canonical zero-element input and recording which ones
221    /// return `Some`.
222    pub fn register(&mut self, p: Box<dyn Primitive>) {
223        let caps = probe_caps(&*p);
224        let name = p.name();
225        self.map.insert(name, Entry { primitive: p, caps });
226    }
227
228    /// Look up a primitive by name.
229    pub fn get(&self, name: &str) -> Option<&dyn Primitive> {
230        self.map.get(name).map(|e| &*e.primitive)
231    }
232
233    /// Return the [`Capabilities`] bitfield for a named primitive.
234    /// Returns `Capabilities::empty()` if the primitive is not registered.
235    pub fn capabilities(&self, name: &str) -> Capabilities {
236        self.map
237            .get(name)
238            .map(|e| e.caps)
239            .unwrap_or(Capabilities::empty())
240    }
241
242    /// Generate a coverage table for all registered primitives, sorted by
243    /// name.
244    pub fn coverage_report(&self) -> CoverageReport {
245        let mut rows: Vec<CoverageRow> = self
246            .map
247            .iter()
248            .map(|(name, e)| CoverageRow {
249                name: name.to_string(),
250                caps: e.caps,
251            })
252            .collect();
253        rows.sort_by(|a, b| a.name.cmp(&b.name));
254        CoverageReport { rows }
255    }
256
257    /// Call `diff_forward` on a registered primitive.
258    /// Returns `None` if the primitive is unknown or lacks `DIFF_FORWARD`.
259    pub fn diff_forward(
260        &self,
261        name: &str,
262        args: &[ExprId],
263        wrt: ExprId,
264        pool: &ExprPool,
265    ) -> Option<ExprId> {
266        let entry = self.map.get(name)?;
267        entry.primitive.diff_forward(args, wrt, pool)
268    }
269
270    /// Call `diff_reverse` on a registered primitive.
271    pub fn diff_reverse(
272        &self,
273        name: &str,
274        args: &[ExprId],
275        cotan: ExprId,
276        pool: &ExprPool,
277    ) -> Option<Vec<ExprId>> {
278        let entry = self.map.get(name)?;
279        entry.primitive.diff_reverse(args, cotan, pool)
280    }
281
282    /// Call `numeric_f64` on a registered primitive.
283    pub fn numeric_f64(&self, name: &str, args: &[f64]) -> Option<f64> {
284        let entry = self.map.get(name)?;
285        entry.primitive.numeric_f64(args)
286    }
287
288    /// Call `numeric_ball` on a registered primitive.
289    pub fn numeric_ball(&self, name: &str, args: &[ArbBall]) -> Option<ArbBall> {
290        let entry = self.map.get(name)?;
291        entry.primitive.numeric_ball(args)
292    }
293
294    /// Return a registry pre-populated with Alkahest's built-in primitives.
295    pub fn default_registry() -> Self {
296        let mut reg = Self::new();
297        reg.register(Box::new(builtins::SinPrimitive));
298        reg.register(Box::new(builtins::CosPrimitive));
299        reg.register(Box::new(builtins::ExpPrimitive));
300        reg.register(Box::new(builtins::LogPrimitive));
301        reg.register(Box::new(builtins::SqrtPrimitive));
302        // V1-12: expanded registry
303        reg.register(Box::new(builtins::TanPrimitive));
304        reg.register(Box::new(builtins::SinhPrimitive));
305        reg.register(Box::new(builtins::CoshPrimitive));
306        reg.register(Box::new(builtins::TanhPrimitive));
307        reg.register(Box::new(builtins::AsinPrimitive));
308        reg.register(Box::new(builtins::AcosPrimitive));
309        reg.register(Box::new(builtins::AtanPrimitive));
310        reg.register(Box::new(builtins::ErfPrimitive));
311        reg.register(Box::new(builtins::ErfcPrimitive));
312        reg.register(Box::new(builtins::AbsPrimitive));
313        reg.register(Box::new(builtins::SignPrimitive));
314        reg.register(Box::new(builtins::FloorPrimitive));
315        reg.register(Box::new(builtins::CeilPrimitive));
316        reg.register(Box::new(builtins::RoundPrimitive));
317        reg.register(Box::new(builtins::Atan2Primitive));
318        reg.register(Box::new(builtins::GammaPrimitive));
319        reg.register(Box::new(builtins::MinPrimitive));
320        reg.register(Box::new(builtins::MaxPrimitive));
321        reg
322    }
323
324    /// Returns true if a primitive with this name is registered.
325    pub fn is_registered(&self, name: &str) -> bool {
326        self.map.contains_key(name)
327    }
328
329    /// Iterate over all registered (name, capabilities) pairs.
330    pub fn iter(&self) -> impl Iterator<Item = (&str, Capabilities)> {
331        self.map.iter().map(|(k, e)| (*k, e.caps))
332    }
333}
334
335impl Default for PrimitiveRegistry {
336    fn default() -> Self {
337        Self::default_registry()
338    }
339}
340
341// ---------------------------------------------------------------------------
342// Capability probing
343// ---------------------------------------------------------------------------
344
345/// Probe which optional methods a primitive has implemented by calling them
346/// with a single-element `f64 = 1.0` / `ExprId` argument.  We use an
347/// independent `ExprPool` so that probing is side-effect free.
348fn probe_caps(p: &dyn Primitive) -> Capabilities {
349    let mut caps = Capabilities::empty();
350
351    // Probe with both unary and binary argument lists so n-ary primitives
352    // (e.g. atan2, min, max) register their capabilities.
353    let probe_f64_sets: [&[f64]; 2] = [&[1.0], &[1.0, 2.0]];
354    for args in probe_f64_sets {
355        if p.numeric_f64(args).is_some() {
356            caps |= Capabilities::NUMERIC_F64;
357            break;
358        }
359    }
360
361    let ball1 = [ArbBall::from_f64(1.0, 128)];
362    let ball2 = [ArbBall::from_f64(1.0, 128), ArbBall::from_f64(2.0, 128)];
363    if p.numeric_ball(&ball1).is_some() || p.numeric_ball(&ball2).is_some() {
364        caps |= Capabilities::NUMERIC_BALL;
365    }
366
367    // diff_forward / diff_reverse / simplify: probe with a fresh pool
368    let pool = ExprPool::new();
369    let x = pool.symbol("_probe", crate::kernel::Domain::Real);
370    let y = pool.symbol("_probe_y", crate::kernel::Domain::Real);
371    let probe_id_sets: [Vec<ExprId>; 2] = [vec![x], vec![x, y]];
372
373    for args in &probe_id_sets {
374        if p.diff_forward(args, x, &pool).is_some() {
375            caps |= Capabilities::DIFF_FORWARD;
376            break;
377        }
378    }
379    for args in &probe_id_sets {
380        if p.diff_reverse(args, x, &pool).is_some() {
381            caps |= Capabilities::DIFF_REVERSE;
382            break;
383        }
384    }
385    for args in &probe_id_sets {
386        if p.simplify(args, &pool).is_some() {
387            caps |= Capabilities::SIMPLIFY;
388            break;
389        }
390    }
391    if p.lean_theorem().is_some() {
392        caps |= Capabilities::LEAN_THEOREM;
393    }
394    caps
395}
396
397// ---------------------------------------------------------------------------
398// Built-in primitives
399// ---------------------------------------------------------------------------
400
401pub mod builtins {
402    use super::Primitive;
403    use crate::ball::ArbBall;
404    use crate::kernel::{ExprId, ExprPool};
405
406    // ── sin ──────────────────────────────────────────────────────────────────
407
408    pub struct SinPrimitive;
409
410    impl Primitive for SinPrimitive {
411        fn name(&self) -> &'static str {
412            "sin"
413        }
414
415        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
416            format!("sin({})", pool.display(args[0]))
417        }
418
419        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
420            let x = args[0];
421            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
422            let cos_x = pool.func("cos", vec![x]);
423            Some(pool.mul(vec![cos_x, dx]))
424        }
425
426        fn diff_reverse(
427            &self,
428            args: &[ExprId],
429            cotan: ExprId,
430            pool: &ExprPool,
431        ) -> Option<Vec<ExprId>> {
432            let x = args[0];
433            let cos_x = pool.func("cos", vec![x]);
434            Some(vec![pool.mul(vec![cotan, cos_x])])
435        }
436
437        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
438            Some(args[0].sin())
439        }
440
441        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
442            Some(args[0].sin())
443        }
444
445        fn lean_theorem(&self) -> Option<&'static str> {
446            Some("Real.sin_deriv")
447        }
448    }
449
450    // ── cos ──────────────────────────────────────────────────────────────────
451
452    pub struct CosPrimitive;
453
454    impl Primitive for CosPrimitive {
455        fn name(&self) -> &'static str {
456            "cos"
457        }
458
459        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
460            format!("cos({})", pool.display(args[0]))
461        }
462
463        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
464            let x = args[0];
465            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
466            let neg_one = pool.integer(-1_i32);
467            let sin_x = pool.func("sin", vec![x]);
468            Some(pool.mul(vec![neg_one, sin_x, dx]))
469        }
470
471        fn diff_reverse(
472            &self,
473            args: &[ExprId],
474            cotan: ExprId,
475            pool: &ExprPool,
476        ) -> Option<Vec<ExprId>> {
477            let x = args[0];
478            let neg_one = pool.integer(-1_i32);
479            let sin_x = pool.func("sin", vec![x]);
480            Some(vec![pool.mul(vec![cotan, neg_one, sin_x])])
481        }
482
483        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
484            Some(args[0].cos())
485        }
486
487        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
488            Some(args[0].cos())
489        }
490
491        fn lean_theorem(&self) -> Option<&'static str> {
492            Some("Real.cos_deriv")
493        }
494    }
495
496    // ── exp ──────────────────────────────────────────────────────────────────
497
498    pub struct ExpPrimitive;
499
500    impl Primitive for ExpPrimitive {
501        fn name(&self) -> &'static str {
502            "exp"
503        }
504
505        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
506            format!("exp({})", pool.display(args[0]))
507        }
508
509        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
510            let x = args[0];
511            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
512            let exp_x = pool.func("exp", vec![x]);
513            Some(pool.mul(vec![exp_x, dx]))
514        }
515
516        fn diff_reverse(
517            &self,
518            args: &[ExprId],
519            cotan: ExprId,
520            pool: &ExprPool,
521        ) -> Option<Vec<ExprId>> {
522            let x = args[0];
523            let exp_x = pool.func("exp", vec![x]);
524            Some(vec![pool.mul(vec![cotan, exp_x])])
525        }
526
527        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
528            Some(args[0].exp())
529        }
530
531        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
532            Some(args[0].exp())
533        }
534
535        fn lean_theorem(&self) -> Option<&'static str> {
536            Some("Real.exp_deriv")
537        }
538    }
539
540    // ── log ──────────────────────────────────────────────────────────────────
541
542    pub struct LogPrimitive;
543
544    impl Primitive for LogPrimitive {
545        fn name(&self) -> &'static str {
546            "log"
547        }
548
549        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
550            format!("log({})", pool.display(args[0]))
551        }
552
553        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
554            let x = args[0];
555            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
556            // d/dwrt log(x) = dx / x
557            let x_inv = pool.pow(x, pool.integer(-1_i32));
558            Some(pool.mul(vec![x_inv, dx]))
559        }
560
561        fn diff_reverse(
562            &self,
563            args: &[ExprId],
564            cotan: ExprId,
565            pool: &ExprPool,
566        ) -> Option<Vec<ExprId>> {
567            let x = args[0];
568            let x_inv = pool.pow(x, pool.integer(-1_i32));
569            Some(vec![pool.mul(vec![cotan, x_inv])])
570        }
571
572        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
573            Some(args[0].ln())
574        }
575
576        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
577            args[0].log()
578        }
579
580        fn lean_theorem(&self) -> Option<&'static str> {
581            Some("Real.log_deriv")
582        }
583    }
584
585    // ── sqrt ─────────────────────────────────────────────────────────────────
586
587    pub struct SqrtPrimitive;
588
589    impl Primitive for SqrtPrimitive {
590        fn name(&self) -> &'static str {
591            "sqrt"
592        }
593
594        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
595            format!("sqrt({})", pool.display(args[0]))
596        }
597
598        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
599            let x = args[0];
600            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
601            // d/dwrt sqrt(x) = dx / (2 * sqrt(x))
602            let sqrt_x = pool.func("sqrt", vec![x]);
603            let two = pool.integer(2_i32);
604            let denom = pool.mul(vec![two, sqrt_x]);
605            let denom_inv = pool.pow(denom, pool.integer(-1_i32));
606            Some(pool.mul(vec![dx, denom_inv]))
607        }
608
609        fn diff_reverse(
610            &self,
611            args: &[ExprId],
612            cotan: ExprId,
613            pool: &ExprPool,
614        ) -> Option<Vec<ExprId>> {
615            let x = args[0];
616            let sqrt_x = pool.func("sqrt", vec![x]);
617            let two = pool.integer(2_i32);
618            let denom = pool.mul(vec![two, sqrt_x]);
619            let denom_inv = pool.pow(denom, pool.integer(-1_i32));
620            Some(vec![pool.mul(vec![cotan, denom_inv])])
621        }
622
623        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
624            Some(args[0].sqrt())
625        }
626
627        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
628            args[0].sqrt()
629        }
630
631        fn lean_theorem(&self) -> Option<&'static str> {
632            Some("Real.sqrt_deriv")
633        }
634    }
635
636    // ── tan ──────────────────────────────────────────────────────────────────
637
638    pub struct TanPrimitive;
639
640    impl Primitive for TanPrimitive {
641        fn name(&self) -> &'static str {
642            "tan"
643        }
644
645        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
646            format!("tan({})", pool.display(args[0]))
647        }
648
649        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
650            // d/dx tan(x) = dx / cos²(x) = dx * (1 + tan²(x))
651            let x = args[0];
652            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
653            let tan_x = pool.func("tan", vec![x]);
654            let tan2 = pool.pow(tan_x, pool.integer(2_i32));
655            let one = pool.integer(1_i32);
656            let sec2 = pool.add(vec![one, tan2]);
657            Some(pool.mul(vec![sec2, dx]))
658        }
659
660        fn diff_reverse(
661            &self,
662            args: &[ExprId],
663            cotan: ExprId,
664            pool: &ExprPool,
665        ) -> Option<Vec<ExprId>> {
666            let x = args[0];
667            let tan_x = pool.func("tan", vec![x]);
668            let tan2 = pool.pow(tan_x, pool.integer(2_i32));
669            let one = pool.integer(1_i32);
670            let sec2 = pool.add(vec![one, tan2]);
671            Some(vec![pool.mul(vec![cotan, sec2])])
672        }
673
674        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
675            Some(args[0].tan())
676        }
677
678        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
679            args[0].tan()
680        }
681
682        fn lean_theorem(&self) -> Option<&'static str> {
683            Some("Real.tan_deriv")
684        }
685    }
686
687    // ── sinh ─────────────────────────────────────────────────────────────────
688
689    pub struct SinhPrimitive;
690
691    impl Primitive for SinhPrimitive {
692        fn name(&self) -> &'static str {
693            "sinh"
694        }
695
696        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
697            format!("sinh({})", pool.display(args[0]))
698        }
699
700        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
701            let x = args[0];
702            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
703            let cosh_x = pool.func("cosh", vec![x]);
704            Some(pool.mul(vec![cosh_x, dx]))
705        }
706
707        fn diff_reverse(
708            &self,
709            args: &[ExprId],
710            cotan: ExprId,
711            pool: &ExprPool,
712        ) -> Option<Vec<ExprId>> {
713            let x = args[0];
714            let cosh_x = pool.func("cosh", vec![x]);
715            Some(vec![pool.mul(vec![cotan, cosh_x])])
716        }
717
718        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
719            Some(args[0].sinh())
720        }
721
722        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
723            Some(args[0].sinh())
724        }
725
726        fn lean_theorem(&self) -> Option<&'static str> {
727            Some("Real.sinh_deriv")
728        }
729    }
730
731    // ── cosh ─────────────────────────────────────────────────────────────────
732
733    pub struct CoshPrimitive;
734
735    impl Primitive for CoshPrimitive {
736        fn name(&self) -> &'static str {
737            "cosh"
738        }
739
740        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
741            format!("cosh({})", pool.display(args[0]))
742        }
743
744        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
745            let x = args[0];
746            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
747            let sinh_x = pool.func("sinh", vec![x]);
748            Some(pool.mul(vec![sinh_x, dx]))
749        }
750
751        fn diff_reverse(
752            &self,
753            args: &[ExprId],
754            cotan: ExprId,
755            pool: &ExprPool,
756        ) -> Option<Vec<ExprId>> {
757            let x = args[0];
758            let sinh_x = pool.func("sinh", vec![x]);
759            Some(vec![pool.mul(vec![cotan, sinh_x])])
760        }
761
762        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
763            Some(args[0].cosh())
764        }
765
766        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
767            Some(args[0].cosh())
768        }
769
770        fn lean_theorem(&self) -> Option<&'static str> {
771            Some("Real.cosh_deriv")
772        }
773    }
774
775    // ── tanh ─────────────────────────────────────────────────────────────────
776
777    pub struct TanhPrimitive;
778
779    impl Primitive for TanhPrimitive {
780        fn name(&self) -> &'static str {
781            "tanh"
782        }
783
784        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
785            format!("tanh({})", pool.display(args[0]))
786        }
787
788        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
789            // d/dx tanh(x) = dx * (1 - tanh²(x))
790            let x = args[0];
791            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
792            let tanh_x = pool.func("tanh", vec![x]);
793            let tanh2 = pool.pow(tanh_x, pool.integer(2_i32));
794            let one = pool.integer(1_i32);
795            let neg_one = pool.integer(-1_i32);
796            let sech2 = pool.add(vec![one, pool.mul(vec![neg_one, tanh2])]);
797            Some(pool.mul(vec![sech2, dx]))
798        }
799
800        fn diff_reverse(
801            &self,
802            args: &[ExprId],
803            cotan: ExprId,
804            pool: &ExprPool,
805        ) -> Option<Vec<ExprId>> {
806            let x = args[0];
807            let tanh_x = pool.func("tanh", vec![x]);
808            let tanh2 = pool.pow(tanh_x, pool.integer(2_i32));
809            let one = pool.integer(1_i32);
810            let neg_one = pool.integer(-1_i32);
811            let sech2 = pool.add(vec![one, pool.mul(vec![neg_one, tanh2])]);
812            Some(vec![pool.mul(vec![cotan, sech2])])
813        }
814
815        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
816            Some(args[0].tanh())
817        }
818
819        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
820            Some(args[0].tanh())
821        }
822
823        fn lean_theorem(&self) -> Option<&'static str> {
824            Some("Real.tanh_deriv")
825        }
826    }
827
828    // ── asin ─────────────────────────────────────────────────────────────────
829
830    pub struct AsinPrimitive;
831
832    impl Primitive for AsinPrimitive {
833        fn name(&self) -> &'static str {
834            "asin"
835        }
836
837        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
838            format!("asin({})", pool.display(args[0]))
839        }
840
841        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
842            // d/dx asin(x) = dx / sqrt(1 - x²)
843            let x = args[0];
844            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
845            let x2 = pool.pow(x, pool.integer(2_i32));
846            let one = pool.integer(1_i32);
847            let neg_one = pool.integer(-1_i32);
848            let one_minus_x2 = pool.add(vec![one, pool.mul(vec![neg_one, x2])]);
849            let denom = pool.func("sqrt", vec![one_minus_x2]);
850            Some(pool.mul(vec![dx, pool.pow(denom, pool.integer(-1_i32))]))
851        }
852
853        fn diff_reverse(
854            &self,
855            args: &[ExprId],
856            cotan: ExprId,
857            pool: &ExprPool,
858        ) -> Option<Vec<ExprId>> {
859            let x = args[0];
860            let x2 = pool.pow(x, pool.integer(2_i32));
861            let one = pool.integer(1_i32);
862            let neg_one = pool.integer(-1_i32);
863            let one_minus_x2 = pool.add(vec![one, pool.mul(vec![neg_one, x2])]);
864            let denom = pool.func("sqrt", vec![one_minus_x2]);
865            Some(vec![
866                pool.mul(vec![cotan, pool.pow(denom, pool.integer(-1_i32))])
867            ])
868        }
869
870        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
871            Some(args[0].asin())
872        }
873
874        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
875            args[0].asin()
876        }
877    }
878
879    // ── acos ─────────────────────────────────────────────────────────────────
880
881    pub struct AcosPrimitive;
882
883    impl Primitive for AcosPrimitive {
884        fn name(&self) -> &'static str {
885            "acos"
886        }
887
888        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
889            format!("acos({})", pool.display(args[0]))
890        }
891
892        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
893            // d/dx acos(x) = -dx / sqrt(1 - x²)
894            let x = args[0];
895            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
896            let x2 = pool.pow(x, pool.integer(2_i32));
897            let one = pool.integer(1_i32);
898            let neg_one = pool.integer(-1_i32);
899            let one_minus_x2 = pool.add(vec![one, pool.mul(vec![neg_one, x2])]);
900            let denom = pool.func("sqrt", vec![one_minus_x2]);
901            Some(pool.mul(vec![neg_one, dx, pool.pow(denom, pool.integer(-1_i32))]))
902        }
903
904        fn diff_reverse(
905            &self,
906            args: &[ExprId],
907            cotan: ExprId,
908            pool: &ExprPool,
909        ) -> Option<Vec<ExprId>> {
910            let x = args[0];
911            let x2 = pool.pow(x, pool.integer(2_i32));
912            let one = pool.integer(1_i32);
913            let neg_one = pool.integer(-1_i32);
914            let one_minus_x2 = pool.add(vec![one, pool.mul(vec![neg_one, x2])]);
915            let denom = pool.func("sqrt", vec![one_minus_x2]);
916            Some(vec![pool.mul(vec![
917                cotan,
918                neg_one,
919                pool.pow(denom, pool.integer(-1_i32)),
920            ])])
921        }
922
923        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
924            Some(args[0].acos())
925        }
926
927        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
928            args[0].acos()
929        }
930    }
931
932    // ── atan ─────────────────────────────────────────────────────────────────
933
934    pub struct AtanPrimitive;
935
936    impl Primitive for AtanPrimitive {
937        fn name(&self) -> &'static str {
938            "atan"
939        }
940
941        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
942            format!("atan({})", pool.display(args[0]))
943        }
944
945        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
946            // d/dx atan(x) = dx / (1 + x²)
947            let x = args[0];
948            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
949            let x2 = pool.pow(x, pool.integer(2_i32));
950            let one = pool.integer(1_i32);
951            let denom = pool.add(vec![one, x2]);
952            Some(pool.mul(vec![dx, pool.pow(denom, pool.integer(-1_i32))]))
953        }
954
955        fn diff_reverse(
956            &self,
957            args: &[ExprId],
958            cotan: ExprId,
959            pool: &ExprPool,
960        ) -> Option<Vec<ExprId>> {
961            let x = args[0];
962            let x2 = pool.pow(x, pool.integer(2_i32));
963            let one = pool.integer(1_i32);
964            let denom = pool.add(vec![one, x2]);
965            Some(vec![
966                pool.mul(vec![cotan, pool.pow(denom, pool.integer(-1_i32))])
967            ])
968        }
969
970        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
971            Some(args[0].atan())
972        }
973
974        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
975            Some(args[0].atan())
976        }
977
978        fn lean_theorem(&self) -> Option<&'static str> {
979            Some("Real.arctan_deriv")
980        }
981    }
982
983    // ── erf ──────────────────────────────────────────────────────────────────
984
985    pub struct ErfPrimitive;
986
987    impl Primitive for ErfPrimitive {
988        fn name(&self) -> &'static str {
989            "erf"
990        }
991
992        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
993            format!("erf({})", pool.display(args[0]))
994        }
995
996        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
997            // d/dx erf(x) = (2/sqrt(π)) * exp(-x²) * dx
998            let x = args[0];
999            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
1000            let x2 = pool.pow(x, pool.integer(2_i32));
1001            let neg_x2 = pool.mul(vec![pool.integer(-1_i32), x2]);
1002            let exp_neg_x2 = pool.func("exp", vec![neg_x2]);
1003            let coeff = pool.float(2.0 / std::f64::consts::PI.sqrt(), 53);
1004            Some(pool.mul(vec![coeff, exp_neg_x2, dx]))
1005        }
1006
1007        fn diff_reverse(
1008            &self,
1009            args: &[ExprId],
1010            cotan: ExprId,
1011            pool: &ExprPool,
1012        ) -> Option<Vec<ExprId>> {
1013            let x = args[0];
1014            let x2 = pool.pow(x, pool.integer(2_i32));
1015            let neg_x2 = pool.mul(vec![pool.integer(-1_i32), x2]);
1016            let exp_neg_x2 = pool.func("exp", vec![neg_x2]);
1017            let coeff = pool.float(2.0 / std::f64::consts::PI.sqrt(), 53);
1018            Some(vec![pool.mul(vec![cotan, coeff, exp_neg_x2])])
1019        }
1020
1021        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1022            Some(libm_erf(args[0]))
1023        }
1024
1025        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
1026            Some(args[0].erf())
1027        }
1028    }
1029
1030    // ── erfc ─────────────────────────────────────────────────────────────────
1031
1032    pub struct ErfcPrimitive;
1033
1034    impl Primitive for ErfcPrimitive {
1035        fn name(&self) -> &'static str {
1036            "erfc"
1037        }
1038
1039        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1040            format!("erfc({})", pool.display(args[0]))
1041        }
1042
1043        fn diff_forward(&self, args: &[ExprId], wrt: ExprId, pool: &ExprPool) -> Option<ExprId> {
1044            let x = args[0];
1045            let dx = crate::diff::diff(x, wrt, pool).ok()?.value;
1046            let x2 = pool.pow(x, pool.integer(2_i32));
1047            let neg_x2 = pool.mul(vec![pool.integer(-1_i32), x2]);
1048            let exp_neg_x2 = pool.func("exp", vec![neg_x2]);
1049            let coeff = pool.float(-2.0 / std::f64::consts::PI.sqrt(), 53);
1050            Some(pool.mul(vec![coeff, exp_neg_x2, dx]))
1051        }
1052
1053        fn diff_reverse(
1054            &self,
1055            args: &[ExprId],
1056            cotan: ExprId,
1057            pool: &ExprPool,
1058        ) -> Option<Vec<ExprId>> {
1059            let x = args[0];
1060            let x2 = pool.pow(x, pool.integer(2_i32));
1061            let neg_x2 = pool.mul(vec![pool.integer(-1_i32), x2]);
1062            let exp_neg_x2 = pool.func("exp", vec![neg_x2]);
1063            let coeff = pool.float(-2.0 / std::f64::consts::PI.sqrt(), 53);
1064            Some(vec![pool.mul(vec![cotan, coeff, exp_neg_x2])])
1065        }
1066
1067        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1068            Some(1.0 - libm_erf(args[0]))
1069        }
1070
1071        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
1072            Some(args[0].erfc())
1073        }
1074    }
1075
1076    // ── abs ──────────────────────────────────────────────────────────────────
1077
1078    pub struct AbsPrimitive;
1079
1080    impl Primitive for AbsPrimitive {
1081        fn name(&self) -> &'static str {
1082            "abs"
1083        }
1084
1085        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1086            format!("|{}|", pool.display(args[0]))
1087        }
1088
1089        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1090            Some(args[0].abs())
1091        }
1092
1093        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
1094            Some(args[0].abs_ball())
1095        }
1096    }
1097
1098    // ── sign ─────────────────────────────────────────────────────────────────
1099
1100    pub struct SignPrimitive;
1101
1102    impl Primitive for SignPrimitive {
1103        fn name(&self) -> &'static str {
1104            "sign"
1105        }
1106
1107        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1108            format!("sign({})", pool.display(args[0]))
1109        }
1110
1111        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1112            Some(if args[0] > 0.0 {
1113                1.0
1114            } else if args[0] < 0.0 {
1115                -1.0
1116            } else {
1117                0.0
1118            })
1119        }
1120    }
1121
1122    // ── floor ────────────────────────────────────────────────────────────────
1123
1124    pub struct FloorPrimitive;
1125
1126    impl Primitive for FloorPrimitive {
1127        fn name(&self) -> &'static str {
1128            "floor"
1129        }
1130
1131        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1132            format!("floor({})", pool.display(args[0]))
1133        }
1134
1135        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1136            Some(args[0].floor())
1137        }
1138
1139        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
1140            Some(args[0].floor_ball())
1141        }
1142    }
1143
1144    // ── ceil ─────────────────────────────────────────────────────────────────
1145
1146    pub struct CeilPrimitive;
1147
1148    impl Primitive for CeilPrimitive {
1149        fn name(&self) -> &'static str {
1150            "ceil"
1151        }
1152
1153        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1154            format!("ceil({})", pool.display(args[0]))
1155        }
1156
1157        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1158            Some(args[0].ceil())
1159        }
1160
1161        fn numeric_ball(&self, args: &[ArbBall]) -> Option<ArbBall> {
1162            Some(args[0].ceil_ball())
1163        }
1164    }
1165
1166    // ── round ────────────────────────────────────────────────────────────────
1167
1168    pub struct RoundPrimitive;
1169
1170    impl Primitive for RoundPrimitive {
1171        fn name(&self) -> &'static str {
1172            "round"
1173        }
1174
1175        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1176            format!("round({})", pool.display(args[0]))
1177        }
1178
1179        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1180            Some(args[0].round())
1181        }
1182    }
1183
1184    // ── atan2 ────────────────────────────────────────────────────────────────
1185
1186    pub struct Atan2Primitive;
1187
1188    impl Primitive for Atan2Primitive {
1189        fn name(&self) -> &'static str {
1190            "atan2"
1191        }
1192
1193        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1194            format!(
1195                "atan2({}, {})",
1196                pool.display(args[0]),
1197                pool.display(args[1])
1198            )
1199        }
1200
1201        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1202            if args.len() == 2 {
1203                Some(args[0].atan2(args[1]))
1204            } else {
1205                None
1206            }
1207        }
1208
1209        fn lean_theorem(&self) -> Option<&'static str> {
1210            Some("Real.arctan2")
1211        }
1212    }
1213
1214    // ── gamma ────────────────────────────────────────────────────────────────
1215
1216    pub struct GammaPrimitive;
1217
1218    impl Primitive for GammaPrimitive {
1219        fn name(&self) -> &'static str {
1220            "gamma"
1221        }
1222
1223        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1224            format!("Γ({})", pool.display(args[0]))
1225        }
1226
1227        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1228            Some(libm_gamma(args[0]))
1229        }
1230
1231        fn lean_theorem(&self) -> Option<&'static str> {
1232            Some("Real.Gamma")
1233        }
1234    }
1235
1236    // ── min ──────────────────────────────────────────────────────────────────
1237
1238    pub struct MinPrimitive;
1239
1240    impl Primitive for MinPrimitive {
1241        fn name(&self) -> &'static str {
1242            "min"
1243        }
1244
1245        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1246            format!("min({}, {})", pool.display(args[0]), pool.display(args[1]))
1247        }
1248
1249        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1250            if args.len() == 2 {
1251                Some(args[0].min(args[1]))
1252            } else {
1253                None
1254            }
1255        }
1256    }
1257
1258    // ── max ──────────────────────────────────────────────────────────────────
1259
1260    pub struct MaxPrimitive;
1261
1262    impl Primitive for MaxPrimitive {
1263        fn name(&self) -> &'static str {
1264            "max"
1265        }
1266
1267        fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1268            format!("max({}, {})", pool.display(args[0]), pool.display(args[1]))
1269        }
1270
1271        fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1272            if args.len() == 2 {
1273                Some(args[0].max(args[1]))
1274            } else {
1275                None
1276            }
1277        }
1278    }
1279
1280    // ---------------------------------------------------------------------------
1281    // Helper: Lanczos approximation for Γ(x), x ∈ ℝ.  Accurate to ~15 digits.
1282    // Coefficients from Cephes (g = 7, n = 9).
1283    // ---------------------------------------------------------------------------
1284
1285    fn libm_gamma(x: f64) -> f64 {
1286        const G: f64 = 7.0;
1287        const P: [f64; 9] = [
1288            0.999_999_999_999_809_9,
1289            676.520_368_121_885_1,
1290            -1_259.139_216_722_402_8,
1291            771.323_428_777_653_1,
1292            -176.615_029_162_140_6,
1293            12.507_343_278_686_905,
1294            -0.138_571_095_265_720_12,
1295            9.984_369_578_019_572e-6,
1296            1.505_632_735_149_311_6e-7,
1297        ];
1298        if x < 0.5 {
1299            // Reflection: Γ(x)Γ(1-x) = π / sin(πx)
1300            std::f64::consts::PI / ((std::f64::consts::PI * x).sin() * libm_gamma(1.0 - x))
1301        } else {
1302            let xm = x - 1.0;
1303            let mut a = P[0];
1304            for (i, p) in P.iter().enumerate().skip(1) {
1305                a += p / (xm + i as f64);
1306            }
1307            let t = xm + G + 0.5;
1308            (2.0 * std::f64::consts::PI).sqrt() * t.powf(xm + 0.5) * (-t).exp() * a
1309        }
1310    }
1311
1312    // ---------------------------------------------------------------------------
1313    // Helper: erf via polynomial approximation (Horner-form, max error ≤ 1.5e-7)
1314    // Abramowitz & Stegun 7.1.26
1315    // ---------------------------------------------------------------------------
1316
1317    fn libm_erf(x: f64) -> f64 {
1318        let t = 1.0 / (1.0 + 0.3275911 * x.abs());
1319        let poly = t
1320            * (0.254829592
1321                + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
1322        let sign = if x < 0.0 { -1.0 } else { 1.0 };
1323        sign * (1.0 - poly * (-x * x).exp())
1324    }
1325}
1326
1327// ---------------------------------------------------------------------------
1328// Tests
1329// ---------------------------------------------------------------------------
1330
1331#[cfg(test)]
1332mod tests {
1333    use super::*;
1334    use crate::kernel::{Domain, ExprPool};
1335
1336    #[test]
1337    fn default_registry_has_builtins() {
1338        let reg = PrimitiveRegistry::default_registry();
1339        for name in &["sin", "cos", "exp", "log", "sqrt"] {
1340            assert!(reg.is_registered(name), "{name} not registered");
1341            let caps = reg.capabilities(name);
1342            assert!(
1343                caps.contains(Capabilities::NUMERIC_F64),
1344                "{name} missing NUMERIC_F64"
1345            );
1346            assert!(
1347                caps.contains(Capabilities::DIFF_FORWARD),
1348                "{name} missing DIFF_FORWARD"
1349            );
1350            assert!(
1351                caps.contains(Capabilities::DIFF_REVERSE),
1352                "{name} missing DIFF_REVERSE"
1353            );
1354            assert!(
1355                caps.contains(Capabilities::NUMERIC_BALL),
1356                "{name} missing NUMERIC_BALL"
1357            );
1358        }
1359    }
1360
1361    #[test]
1362    fn numeric_f64_correct() {
1363        let reg = PrimitiveRegistry::default_registry();
1364        let cases: &[(&str, f64, f64)] = &[
1365            ("sin", 0.0, 0.0),
1366            ("cos", 0.0, 1.0),
1367            ("exp", 0.0, 1.0),
1368            ("log", 1.0, 0.0),
1369            ("sqrt", 4.0, 2.0),
1370        ];
1371        for (name, input, expected) in cases {
1372            let got = reg.numeric_f64(name, &[*input]).unwrap();
1373            assert!(
1374                (got - expected).abs() < 1e-12,
1375                "{name}({input}) = {got}, expected {expected}"
1376            );
1377        }
1378    }
1379
1380    #[test]
1381    fn diff_forward_sin() {
1382        let reg = PrimitiveRegistry::default_registry();
1383        let pool = ExprPool::new();
1384        let x = pool.symbol("x", Domain::Real);
1385        let result = reg.diff_forward("sin", &[x], x, &pool);
1386        assert!(result.is_some(), "sin diff_forward returned None");
1387    }
1388
1389    #[test]
1390    fn coverage_report_markdown() {
1391        let reg = PrimitiveRegistry::default_registry();
1392        let report = reg.coverage_report();
1393        let md = report.to_markdown();
1394        assert!(md.contains("sin"), "coverage report missing sin");
1395        assert!(md.contains("✓"), "coverage report has no ticks");
1396    }
1397
1398    #[test]
1399    fn custom_primitive_registration() {
1400        struct TanhPrimitive;
1401        impl Primitive for TanhPrimitive {
1402            fn name(&self) -> &'static str {
1403                "tanh"
1404            }
1405            fn pretty(&self, args: &[ExprId], pool: &ExprPool) -> String {
1406                format!("tanh({})", pool.display(args[0]))
1407            }
1408            fn numeric_f64(&self, args: &[f64]) -> Option<f64> {
1409                Some(args[0].tanh())
1410            }
1411        }
1412
1413        let mut reg = PrimitiveRegistry::new();
1414        reg.register(Box::new(TanhPrimitive));
1415        assert!(reg.is_registered("tanh"));
1416        let caps = reg.capabilities("tanh");
1417        assert!(caps.contains(Capabilities::NUMERIC_F64));
1418        assert!(!caps.contains(Capabilities::DIFF_FORWARD));
1419
1420        let got = reg.numeric_f64("tanh", &[0.0]).unwrap();
1421        assert!((got - 0.0).abs() < 1e-12);
1422    }
1423}