Skip to main content

rlx_bbo/
graph_opt.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Adam on compiled RLX loss graphs via reverse-mode AD (`grad_with_loss`).
17//!
18//! Replaces hand-rolled Adam loops in domain crates (RF passives, LNA match,
19//! placement demos). Scalar FD [`adam_opt_nd`] remains for black-box objectives.
20
21use std::collections::HashMap;
22
23use rlx_ir::{
24    Graph, NodeId, find_param_node as ir_find_param_node, find_param_nodes as ir_find_param_nodes,
25};
26use rlx_opt::rlx_autodiff::grad_with_loss;
27use rlx_optim::{Adam, Optimizer};
28use rlx_runtime::{CompiledGraph, Device, Session};
29use serde::{Deserialize, Serialize};
30
31/// Resolve a single `Op::Param` node by name.
32pub fn find_param_node(g: &Graph, name: &str) -> Option<NodeId> {
33    ir_find_param_node(g, name)
34}
35
36/// Resolve param nodes in the same order as `names`.
37pub fn find_param_nodes(g: &Graph, names: &[&str]) -> Result<Vec<NodeId>, GraphOptError> {
38    ir_find_param_nodes(g, names).map_err(GraphOptError::ParamNotFound)
39}
40
41#[derive(Clone, Debug, Serialize, Deserialize)]
42pub struct GraphOptConfig {
43    pub steps: u32,
44    /// Base Adam learning rate (see [`relative_lr`]).
45    pub lr: f32,
46    /// When true, each optimized coordinate is scaled by `max(|x|, lr_floor)`
47    /// before the Adam update — useful when params span orders of magnitude
48    /// (e.g. Lg ≈ 17 nH vs gm ≈ 50 mS).
49    pub relative_lr: bool,
50    pub lr_floor: f32,
51    pub beta1: f32,
52    pub beta2: f32,
53}
54
55impl Default for GraphOptConfig {
56    fn default() -> Self {
57        Self {
58            steps: 128,
59            lr: 0.02,
60            relative_lr: true,
61            lr_floor: 1e-12,
62            beta1: 0.9,
63            beta2: 0.999,
64        }
65    }
66}
67
68impl GraphOptConfig {
69    #[must_use]
70    pub fn from_steps(steps: u32) -> Self {
71        Self {
72            steps,
73            ..Self::default()
74        }
75    }
76}
77
78#[derive(Clone, Debug, Serialize, Deserialize)]
79pub struct GraphOptResult {
80    /// Final value for every param present in the spec (optimized + frozen).
81    pub params: HashMap<String, f32>,
82    pub final_loss: f32,
83    pub history: Vec<f32>,
84    pub final_grads: HashMap<String, f32>,
85}
86
87#[derive(Clone, Debug, PartialEq)]
88pub enum GraphOptError {
89    ParamNotFound(String),
90    OptimizeEmpty,
91    GradcheckMismatch { param: String, ad: f32, fd: f32 },
92}
93
94impl std::fmt::Display for GraphOptError {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        match self {
97            Self::ParamNotFound(n) => write!(f, "param not found in graph: {n}"),
98            Self::OptimizeEmpty => write!(f, "optimize list is empty"),
99            Self::GradcheckMismatch { param, ad, fd } => {
100                write!(f, "gradcheck mismatch at {param}: AD={ad:.6e} FD={fd:.6e}")
101            }
102        }
103    }
104}
105
106impl std::error::Error for GraphOptError {}
107
108/// Parameter bundle for [`adam_opt_graph`].
109pub struct GraphOptSpec<'a> {
110    /// Names optimized by Adam (must be `Op::Param` in `fwd`).
111    pub optimize: &'a [&'a str],
112    /// Initial values for **all** params referenced by the graph (optimized + frozen).
113    pub values: HashMap<String, f32>,
114    /// Per-coordinate bounds (only required for optimized names).
115    pub bounds: HashMap<String, (f32, f32)>,
116    /// Forward inputs (e.g. `("freq_hz", &[2.4e9])`). `d_output` is injected automatically.
117    pub inputs: &'a [(&'a str, &'a [f32])],
118}
119
120/// Compile `grad_with_loss` on `fwd` and run Adam.
121pub fn adam_opt_graph(
122    fwd: &Graph,
123    spec: &GraphOptSpec<'_>,
124    cfg: &GraphOptConfig,
125    device: Device,
126) -> Result<GraphOptResult, GraphOptError> {
127    if spec.optimize.is_empty() {
128        return Err(GraphOptError::OptimizeEmpty);
129    }
130
131    let param_ids = find_param_nodes(fwd, spec.optimize)?;
132    let bwd = grad_with_loss(fwd, &param_ids);
133    let session = Session::new(device);
134    let mut compiled = session.compile(bwd);
135
136    let mut opt_values: Vec<f32> = spec
137        .optimize
138        .iter()
139        .map(|n| {
140            spec.values
141                .get(*n)
142                .copied()
143                .ok_or_else(|| GraphOptError::ParamNotFound((*n).into()))
144        })
145        .collect::<Result<_, _>>()?;
146
147    let mut opt = Adam::new(cfg.lr).with_betas(cfg.beta1, cfg.beta2);
148    let mut history = Vec::with_capacity(cfg.steps as usize);
149    let mut last_grads: HashMap<String, f32> = HashMap::new();
150    let mut last_loss = f32::MAX;
151
152    for _ in 0..cfg.steps {
153        apply_all_params(&mut compiled, &spec.values, spec.optimize, &opt_values);
154
155        let mut run_in: Vec<(&str, &[f32])> = spec.inputs.to_vec();
156        run_in.push(("d_output", &[1.0]));
157        let outs = compiled.run(&run_in);
158        last_loss = outs[0][0];
159        history.push(last_loss);
160
161        let mut scaled_grads = Vec::with_capacity(opt_values.len());
162        for (i, gout) in outs[1..].iter().enumerate() {
163            let g = gout[0];
164            let name = spec.optimize[i];
165            last_grads.insert(name.to_string(), g);
166            let scale = if cfg.relative_lr {
167                opt_values[i].abs().max(cfg.lr_floor)
168            } else {
169                1.0
170            };
171            scaled_grads.push(g * scale);
172        }
173
174        opt.lr = cfg.lr;
175        opt.step(
176            "params",
177            &[opt_values.len()],
178            &mut opt_values,
179            &scaled_grads,
180        );
181        opt.end_iteration();
182
183        for (i, name) in spec.optimize.iter().enumerate() {
184            if let Some(&(lo, hi)) = spec.bounds.get(*name) {
185                opt_values[i] = opt_values[i].clamp(lo, hi);
186            }
187        }
188    }
189
190    let mut params = spec.values.clone();
191    for (name, val) in spec.optimize.iter().zip(opt_values.iter()) {
192        params.insert((*name).to_string(), *val);
193    }
194
195    Ok(GraphOptResult {
196        params,
197        final_loss: last_loss,
198        history,
199        final_grads: last_grads,
200    })
201}
202
203pub(crate) fn apply_all_params(
204    compiled: &mut CompiledGraph,
205    all: &HashMap<String, f32>,
206    optimize: &[&str],
207    opt_values: &[f32],
208) {
209    for (name, val) in all {
210        if !optimize.contains(&name.as_str()) {
211            compiled.set_param(name, &[*val]);
212        }
213    }
214    for (name, val) in optimize.iter().zip(opt_values.iter()) {
215        compiled.set_param(name, &[*val]);
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use rlx_ir::{DType, Graph, Op, Shape, op::BinaryOp};
222
223    use super::*;
224
225    fn quadratic_loss_graph() -> (Graph, &'static str) {
226        let mut g = Graph::new("quad");
227        let s = Shape::new(&[1], DType::F32);
228        let x = g.param("x", s.clone());
229        let target = g.add_node(
230            Op::Constant {
231                data: 2.0f32.to_le_bytes().to_vec(),
232            },
233            vec![],
234            s.clone(),
235        );
236        let err = g.binary(BinaryOp::Sub, x, target, s.clone());
237        let loss = g.binary(BinaryOp::Mul, err, err, s);
238        g.set_outputs(vec![loss]);
239        (g, "x")
240    }
241
242    #[test]
243    fn parabolic_1d_converges() {
244        let (fwd, pname) = quadratic_loss_graph();
245        let values = HashMap::from([(pname.to_string(), 0.0f32)]);
246        let bounds = HashMap::from([(pname.to_string(), (-10.0, 10.0))]);
247        let spec = GraphOptSpec {
248            optimize: &[pname],
249            values,
250            bounds,
251            inputs: &[],
252        };
253        let cfg = GraphOptConfig {
254            steps: 96,
255            lr: 0.15,
256            relative_lr: false,
257            ..Default::default()
258        };
259        let r = adam_opt_graph(&fwd, &spec, &cfg, Device::Cpu).unwrap();
260        assert!(
261            r.final_loss < 0.01,
262            "loss={} x={}",
263            r.final_loss,
264            r.params[pname]
265        );
266        assert!((r.params[pname] - 2.0).abs() < 0.08);
267    }
268}