phop_core/config.rs
1//! Discovery configuration.
2
3use serde::{Deserialize, Serialize};
4
5/// Temperature annealing schedule for the Gumbel-Softmax topology relaxation.
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
7pub enum TempSchedule {
8 /// Linear interpolation from `tau_start` to `tau_end`.
9 Linear,
10 /// Cosine interpolation from `tau_start` to `tau_end` (default).
11 #[default]
12 Cosine,
13}
14
15/// Compute backend for the expensive numeric inner loops (constant fitting).
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
17pub enum Backend {
18 /// CPU (always available; exact `f64`).
19 #[default]
20 Cpu,
21 /// NVIDIA CUDA GPU (requires the `gpu-cuda` feature and a device at runtime; falls back to
22 /// CPU otherwise). Single-precision coarse fit, refined to `f64` by the CPU LM polish.
23 Cuda,
24 /// Apple Metal GPU (requires the `gpu-metal` feature and a Metal device at runtime; falls back
25 /// to CPU otherwise, macOS only). Single-precision coarse forward; exact `f64` stays on the CPU.
26 Metal,
27}
28
29/// Configuration for a [`crate::Discoverer`] run.
30///
31/// Construct with [`Config::default`] and adjust via the builder methods, e.g.
32/// ```
33/// use phop_core::Config;
34/// let cfg = Config::default().population(256).max_depth(10).max_epochs(2_000);
35/// assert_eq!(cfg.population, 256);
36/// ```
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct Config {
39 /// Number of candidate trees evaluated jointly.
40 pub population: usize,
41 /// Maximum depth of candidate EML trees.
42 pub max_depth: usize,
43 /// Maximum number of optimization epochs.
44 pub max_epochs: usize,
45 /// Adam learning rate.
46 pub learning_rate: f64,
47 /// Weight on the complexity penalty in the multi-objective loss.
48 pub lambda_complexity: f64,
49 /// Weight on the sparsity penalty (pressure toward the constant `1`).
50 pub lambda_sparsity: f64,
51 /// Weight on the parsimony (depth) penalty.
52 pub lambda_parsimony: f64,
53 /// Initial Gumbel-Softmax temperature.
54 pub tau_start: f64,
55 /// Final Gumbel-Softmax temperature.
56 pub tau_end: f64,
57 /// Temperature annealing schedule.
58 pub temp_schedule: TempSchedule,
59 /// RNG seed for reproducibility.
60 pub seed: u64,
61 /// Number of solutions to keep on the Pareto front.
62 pub top_k: usize,
63 /// Compute backend for constant fitting (CPU by default; CUDA when built and available).
64 pub backend: Backend,
65}
66
67impl Default for Config {
68 fn default() -> Self {
69 Self {
70 population: 256,
71 max_depth: 10,
72 max_epochs: 2_000,
73 learning_rate: 0.05,
74 lambda_complexity: 1e-3,
75 lambda_sparsity: 1e-3,
76 lambda_parsimony: 1e-3,
77 tau_start: 2.0,
78 tau_end: 0.1,
79 temp_schedule: TempSchedule::Cosine,
80 seed: 0,
81 top_k: 5,
82 backend: Backend::Cpu,
83 }
84 }
85}
86
87impl Config {
88 /// Set the population size.
89 #[must_use]
90 pub fn population(mut self, p: usize) -> Self {
91 self.population = p;
92 self
93 }
94
95 /// Set the maximum tree depth.
96 #[must_use]
97 pub fn max_depth(mut self, d: usize) -> Self {
98 self.max_depth = d;
99 self
100 }
101
102 /// Set the maximum number of epochs.
103 #[must_use]
104 pub fn max_epochs(mut self, n: usize) -> Self {
105 self.max_epochs = n;
106 self
107 }
108
109 /// Set the Adam learning rate.
110 #[must_use]
111 pub fn learning_rate(mut self, lr: f64) -> Self {
112 self.learning_rate = lr;
113 self
114 }
115
116 /// Set the RNG seed.
117 #[must_use]
118 pub fn seed(mut self, s: u64) -> Self {
119 self.seed = s;
120 self
121 }
122
123 /// Set the number of Pareto solutions to keep.
124 #[must_use]
125 pub fn top_k(mut self, k: usize) -> Self {
126 self.top_k = k;
127 self
128 }
129
130 /// Set the compute backend for constant fitting.
131 #[must_use]
132 pub fn backend(mut self, backend: Backend) -> Self {
133 self.backend = backend;
134 self
135 }
136
137 /// Temperature at training progress `t in [0, 1]` under the configured schedule.
138 #[must_use]
139 pub fn temperature(&self, t: f64) -> f64 {
140 let t = t.clamp(0.0, 1.0);
141 match self.temp_schedule {
142 TempSchedule::Linear => self.tau_start + (self.tau_end - self.tau_start) * t,
143 TempSchedule::Cosine => {
144 let c = 0.5 * (1.0 + (std::f64::consts::PI * t).cos());
145 self.tau_end + (self.tau_start - self.tau_end) * c
146 }
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn backend_defaults_to_cpu_and_builder_sets_cuda() {
157 assert_eq!(Config::default().backend, Backend::Cpu);
158 assert_eq!(
159 Config::default().backend(Backend::Cuda).backend,
160 Backend::Cuda
161 );
162 }
163}