Skip to main content

rlx_bbo/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15//! Black-box optimization + FMQ/QGBS search (domain-agnostic `f64` objectives).
16//!
17//! For compiled flow-map **policies** and RLX-graph FMQ training, use [`rlx-rl`](../rlx-rl/).
18
19pub mod acquisition;
20pub mod bo;
21mod cmaes;
22mod flow_map;
23pub mod gp;
24mod gradcheck;
25mod gradient_descent;
26mod graph_opt;
27mod q_guidance;
28pub mod sampling;
29mod surrogate;
30pub mod tpe;
31mod trajectory;
32mod twin;
33
34pub use bo::{Acquisition, BoConfig, bo};
35pub use cmaes::{CmaesConfig, cmaes};
36pub use flow_map::{
37    LinearFlowMap, fmq_surrogate_step, load_flow_map, save_flow_map, train_from_jsonl,
38};
39pub use gp::{GpPosterior, Kernel, cholesky};
40pub use gradcheck::gradcheck_graph;
41pub use gradient_descent::{AdamOptConfig, AdamOptResult, adam_opt_1d, adam_opt_nd};
42pub use graph_opt::{
43    GraphOptConfig, GraphOptError, GraphOptResult, GraphOptSpec, adam_opt_graph, find_param_node,
44    find_param_nodes,
45};
46pub use q_guidance::{
47    DEFAULT_KAPPA, QSteerConfig, QgbsConfig, eta_eff_twin, finite_diff_grad, q_guided_beam_search,
48    q_steered_search, q_steered_search_with_grad, search_by_method, trust_region_q_step,
49};
50pub use surrogate::{
51    LinearSurrogate, fit_from_trajectory_jsonl, fit_linear_surrogate, load_surrogate,
52    save_surrogate,
53};
54pub use trajectory::{TrajectoryRecord, append_jsonl, diagonal_flow_pairs, load_jsonl};
55pub use twin::q_steered_search_twin;
56
57use rand::distributions::Distribution;
58use rand::rngs::StdRng;
59use rand::{Rng, SeedableRng};
60use rand_distr::Normal;
61
62#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
63pub struct BboSolution {
64    pub x: Vec<f64>,
65    pub value: f64,
66    pub trace: Vec<f64>,
67    pub n_evals: usize,
68}
69
70#[derive(Clone, Debug)]
71pub struct Bbox {
72    pub bounds: Vec<(f64, f64)>,
73}
74
75impl Bbox {
76    pub fn new(bounds: Vec<(f64, f64)>) -> Self {
77        Self { bounds }
78    }
79    pub fn dim(&self) -> usize {
80        self.bounds.len()
81    }
82    pub fn sample(&self, rng: &mut StdRng) -> Vec<f64> {
83        self.bounds
84            .iter()
85            .map(|&(lo, hi)| rng.gen_range(lo..=hi))
86            .collect()
87    }
88    pub fn clip(&self, x: &mut [f64]) {
89        for (xi, &(lo, hi)) in x.iter_mut().zip(self.bounds.iter()) {
90            if *xi < lo {
91                *xi = lo;
92            }
93            if *xi > hi {
94                *xi = hi;
95            }
96        }
97    }
98    pub fn width(&self, i: usize) -> f64 {
99        self.bounds[i].1 - self.bounds[i].0
100    }
101}
102
103pub fn random_search<F>(bbox: &Bbox, n_evals: usize, seed: u64, mut f: F) -> BboSolution
104where
105    F: FnMut(&[f64]) -> f64,
106{
107    let mut rng = StdRng::seed_from_u64(seed);
108    let mut best_x = bbox.sample(&mut rng);
109    let mut best_v = f(&best_x);
110    let mut trace = Vec::with_capacity(n_evals);
111    trace.push(best_v);
112    for _ in 1..n_evals {
113        let x = bbox.sample(&mut rng);
114        let v = f(&x);
115        if v < best_v {
116            best_v = v;
117            best_x = x;
118        }
119        trace.push(best_v);
120    }
121    BboSolution {
122        x: best_x,
123        value: best_v,
124        trace,
125        n_evals,
126    }
127}
128
129#[derive(Clone, Debug)]
130pub struct PsoConfig {
131    pub n_particles: usize,
132    pub n_iters: usize,
133    pub w: f64,
134    pub c1: f64,
135    pub c2: f64,
136}
137
138impl Default for PsoConfig {
139    fn default() -> Self {
140        Self {
141            n_particles: 30,
142            n_iters: 100,
143            w: 0.729,
144            c1: 1.494,
145            c2: 1.494,
146        }
147    }
148}
149
150pub fn pso<F>(bbox: &Bbox, cfg: &PsoConfig, seed: u64, mut f: F) -> BboSolution
151where
152    F: FnMut(&[f64]) -> f64,
153{
154    let n = bbox.dim();
155    let mut rng = StdRng::seed_from_u64(seed);
156    let mut positions: Vec<Vec<f64>> = (0..cfg.n_particles)
157        .map(|_| bbox.sample(&mut rng))
158        .collect();
159    let mut velocities: Vec<Vec<f64>> = (0..cfg.n_particles)
160        .map(|_| {
161            (0..n)
162                .map(|i| rng.gen_range(-bbox.width(i) / 4.0..=bbox.width(i) / 4.0))
163                .collect()
164        })
165        .collect();
166    let mut pbests = positions.clone();
167    let mut pbest_vals: Vec<f64> = positions.iter().map(|p| f(p)).collect();
168    let (gbest_i, gbest_v) = argmin_with_value(&pbest_vals).expect("pso");
169    let mut gbest = pbests[gbest_i].clone();
170    let mut gbest_v = *gbest_v;
171    let mut n_evals = cfg.n_particles;
172    let mut trace = vec![gbest_v];
173    for _ in 0..cfg.n_iters {
174        for p_idx in 0..cfg.n_particles {
175            for d in 0..n {
176                let r1: f64 = rng.gen_range(0.0..1.0);
177                let r2: f64 = rng.gen_range(0.0..1.0);
178                velocities[p_idx][d] = cfg.w * velocities[p_idx][d]
179                    + cfg.c1 * r1 * (pbests[p_idx][d] - positions[p_idx][d])
180                    + cfg.c2 * r2 * (gbest[d] - positions[p_idx][d]);
181                positions[p_idx][d] += velocities[p_idx][d];
182            }
183            bbox.clip(&mut positions[p_idx]);
184            let v = f(&positions[p_idx]);
185            n_evals += 1;
186            if v < pbest_vals[p_idx] {
187                pbest_vals[p_idx] = v;
188                pbests[p_idx] = positions[p_idx].clone();
189                if v < gbest_v {
190                    gbest_v = v;
191                    gbest = positions[p_idx].clone();
192                }
193            }
194        }
195        trace.push(gbest_v);
196    }
197    BboSolution {
198        x: gbest,
199        value: gbest_v,
200        trace,
201        n_evals,
202    }
203}
204
205fn argmin_with_value(v: &[f64]) -> Option<(usize, &f64)> {
206    let mut it = v.iter().enumerate();
207    let (mut bi, mut bv) = it.next()?;
208    for (i, val) in it {
209        if val < bv {
210            bi = i;
211            bv = val;
212        }
213    }
214    Some((bi, bv))
215}
216
217#[derive(Clone, Debug)]
218pub struct EsConfig {
219    pub n_iters: usize,
220    pub sigma0_frac: f64,
221    pub adapt_window: usize,
222}
223
224impl Default for EsConfig {
225    fn default() -> Self {
226        Self {
227            n_iters: 200,
228            sigma0_frac: 0.1,
229            adapt_window: 10,
230        }
231    }
232}
233
234pub fn one_plus_one_es<F>(bbox: &Bbox, cfg: &EsConfig, seed: u64, mut f: F) -> BboSolution
235where
236    F: FnMut(&[f64]) -> f64,
237{
238    let n = bbox.dim();
239    let mut rng = StdRng::seed_from_u64(seed);
240    let mut x = bbox.sample(&mut rng);
241    let mut best_v = f(&x);
242    let mut trace = vec![best_v];
243    let mut sigmas: Vec<f64> = (0..n).map(|i| bbox.width(i) * cfg.sigma0_frac).collect();
244    let mut window_successes = 0usize;
245    let mut n_evals = 1usize;
246    for k in 0..cfg.n_iters {
247        let mut candidate = x.clone();
248        for d in 0..n {
249            let normal = Normal::new(0.0, sigmas[d]).unwrap();
250            candidate[d] += normal.sample(&mut rng);
251        }
252        bbox.clip(&mut candidate);
253        let v = f(&candidate);
254        n_evals += 1;
255        if v < best_v {
256            best_v = v;
257            x = candidate;
258            window_successes += 1;
259        }
260        trace.push(best_v);
261        if (k + 1) % cfg.adapt_window == 0 {
262            let success_rate = window_successes as f64 / cfg.adapt_window as f64;
263            let scale = if success_rate > 0.2 { 1.22 } else { 1.0 / 1.22 };
264            for s in sigmas.iter_mut() {
265                *s *= scale;
266            }
267            window_successes = 0;
268        }
269    }
270    BboSolution {
271        x,
272        value: best_v,
273        trace,
274        n_evals,
275    }
276}