Skip to main content

phop_core/
config.rs

1//! Discovery configuration.
2
3use serde::{Deserialize, Serialize};
4
5/// Temperature annealing schedule for the Gumbel-Softmax topology relaxation.
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
7pub enum TempSchedule {
8    /// Linear interpolation from `tau_start` to `tau_end`.
9    Linear,
10    /// Cosine interpolation from `tau_start` to `tau_end` (default).
11    #[default]
12    Cosine,
13}
14
15/// Compute backend for the expensive numeric inner loops (constant fitting).
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
17pub enum Backend {
18    /// CPU (always available; exact `f64`).
19    #[default]
20    Cpu,
21    /// NVIDIA CUDA GPU (requires the `gpu-cuda` feature and a device at runtime; falls back to
22    /// CPU otherwise). Single-precision coarse fit, refined to `f64` by the CPU LM polish.
23    Cuda,
24    /// Apple Metal GPU (requires the `gpu-metal` feature and a Metal device at runtime; falls back
25    /// to CPU otherwise, macOS only). Single-precision coarse forward; exact `f64` stays on the CPU.
26    Metal,
27}
28
29/// Configuration for a [`crate::Discoverer`] run.
30///
31/// Construct with [`Config::default`] and adjust via the builder methods, e.g.
32/// ```
33/// use phop_core::Config;
34/// let cfg = Config::default().population(256).max_depth(10).max_epochs(2_000);
35/// assert_eq!(cfg.population, 256);
36/// ```
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct Config {
39    /// Number of candidate trees evaluated jointly.
40    pub population: usize,
41    /// Maximum depth of candidate EML trees.
42    pub max_depth: usize,
43    /// Maximum number of optimization epochs.
44    pub max_epochs: usize,
45    /// Adam learning rate.
46    pub learning_rate: f64,
47    /// Weight on the complexity penalty in the multi-objective loss.
48    pub lambda_complexity: f64,
49    /// Weight on the sparsity penalty (pressure toward the constant `1`).
50    pub lambda_sparsity: f64,
51    /// Weight on the parsimony (depth) penalty.
52    pub lambda_parsimony: f64,
53    /// Initial Gumbel-Softmax temperature.
54    pub tau_start: f64,
55    /// Final Gumbel-Softmax temperature.
56    pub tau_end: f64,
57    /// Temperature annealing schedule.
58    pub temp_schedule: TempSchedule,
59    /// RNG seed for reproducibility.
60    pub seed: u64,
61    /// Number of solutions to keep on the Pareto front.
62    pub top_k: usize,
63    /// Compute backend for constant fitting (CPU by default; CUDA when built and available).
64    pub backend: Backend,
65}
66
67impl Default for Config {
68    fn default() -> Self {
69        Self {
70            population: 256,
71            max_depth: 10,
72            max_epochs: 2_000,
73            learning_rate: 0.05,
74            lambda_complexity: 1e-3,
75            lambda_sparsity: 1e-3,
76            lambda_parsimony: 1e-3,
77            tau_start: 2.0,
78            tau_end: 0.1,
79            temp_schedule: TempSchedule::Cosine,
80            seed: 0,
81            top_k: 5,
82            backend: Backend::Cpu,
83        }
84    }
85}
86
87impl Config {
88    /// Set the population size.
89    #[must_use]
90    pub fn population(mut self, p: usize) -> Self {
91        self.population = p;
92        self
93    }
94
95    /// Set the maximum tree depth.
96    #[must_use]
97    pub fn max_depth(mut self, d: usize) -> Self {
98        self.max_depth = d;
99        self
100    }
101
102    /// Set the maximum number of epochs.
103    #[must_use]
104    pub fn max_epochs(mut self, n: usize) -> Self {
105        self.max_epochs = n;
106        self
107    }
108
109    /// Set the Adam learning rate.
110    #[must_use]
111    pub fn learning_rate(mut self, lr: f64) -> Self {
112        self.learning_rate = lr;
113        self
114    }
115
116    /// Set the RNG seed.
117    #[must_use]
118    pub fn seed(mut self, s: u64) -> Self {
119        self.seed = s;
120        self
121    }
122
123    /// Set the number of Pareto solutions to keep.
124    #[must_use]
125    pub fn top_k(mut self, k: usize) -> Self {
126        self.top_k = k;
127        self
128    }
129
130    /// Set the compute backend for constant fitting.
131    #[must_use]
132    pub fn backend(mut self, backend: Backend) -> Self {
133        self.backend = backend;
134        self
135    }
136
137    /// Temperature at training progress `t in [0, 1]` under the configured schedule.
138    #[must_use]
139    pub fn temperature(&self, t: f64) -> f64 {
140        let t = t.clamp(0.0, 1.0);
141        match self.temp_schedule {
142            TempSchedule::Linear => self.tau_start + (self.tau_end - self.tau_start) * t,
143            TempSchedule::Cosine => {
144                let c = 0.5 * (1.0 + (std::f64::consts::PI * t).cos());
145                self.tau_end + (self.tau_start - self.tau_end) * c
146            }
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn backend_defaults_to_cpu_and_builder_sets_cuda() {
157        assert_eq!(Config::default().backend, Backend::Cpu);
158        assert_eq!(
159            Config::default().backend(Backend::Cuda).backend,
160            Backend::Cuda
161        );
162    }
163}