1use std::collections::HashMap;
22
23use rlx_ir::{
24 Graph, NodeId, find_param_node as ir_find_param_node, find_param_nodes as ir_find_param_nodes,
25};
26use rlx_opt::rlx_autodiff::grad_with_loss;
27use rlx_optim::{Adam, Optimizer};
28use rlx_runtime::{CompiledGraph, Device, Session};
29use serde::{Deserialize, Serialize};
30
31pub fn find_param_node(g: &Graph, name: &str) -> Option<NodeId> {
33 ir_find_param_node(g, name)
34}
35
36pub fn find_param_nodes(g: &Graph, names: &[&str]) -> Result<Vec<NodeId>, GraphOptError> {
38 ir_find_param_nodes(g, names).map_err(GraphOptError::ParamNotFound)
39}
40
41#[derive(Clone, Debug, Serialize, Deserialize)]
42pub struct GraphOptConfig {
43 pub steps: u32,
44 pub lr: f32,
46 pub relative_lr: bool,
50 pub lr_floor: f32,
51 pub beta1: f32,
52 pub beta2: f32,
53}
54
55impl Default for GraphOptConfig {
56 fn default() -> Self {
57 Self {
58 steps: 128,
59 lr: 0.02,
60 relative_lr: true,
61 lr_floor: 1e-12,
62 beta1: 0.9,
63 beta2: 0.999,
64 }
65 }
66}
67
68impl GraphOptConfig {
69 #[must_use]
70 pub fn from_steps(steps: u32) -> Self {
71 Self {
72 steps,
73 ..Self::default()
74 }
75 }
76}
77
78#[derive(Clone, Debug, Serialize, Deserialize)]
79pub struct GraphOptResult {
80 pub params: HashMap<String, f32>,
82 pub final_loss: f32,
83 pub history: Vec<f32>,
84 pub final_grads: HashMap<String, f32>,
85}
86
87#[derive(Clone, Debug, PartialEq)]
88pub enum GraphOptError {
89 ParamNotFound(String),
90 OptimizeEmpty,
91 GradcheckMismatch { param: String, ad: f32, fd: f32 },
92}
93
94impl std::fmt::Display for GraphOptError {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 match self {
97 Self::ParamNotFound(n) => write!(f, "param not found in graph: {n}"),
98 Self::OptimizeEmpty => write!(f, "optimize list is empty"),
99 Self::GradcheckMismatch { param, ad, fd } => {
100 write!(f, "gradcheck mismatch at {param}: AD={ad:.6e} FD={fd:.6e}")
101 }
102 }
103 }
104}
105
106impl std::error::Error for GraphOptError {}
107
108pub struct GraphOptSpec<'a> {
110 pub optimize: &'a [&'a str],
112 pub values: HashMap<String, f32>,
114 pub bounds: HashMap<String, (f32, f32)>,
116 pub inputs: &'a [(&'a str, &'a [f32])],
118}
119
120pub fn adam_opt_graph(
122 fwd: &Graph,
123 spec: &GraphOptSpec<'_>,
124 cfg: &GraphOptConfig,
125 device: Device,
126) -> Result<GraphOptResult, GraphOptError> {
127 if spec.optimize.is_empty() {
128 return Err(GraphOptError::OptimizeEmpty);
129 }
130
131 let param_ids = find_param_nodes(fwd, spec.optimize)?;
132 let bwd = grad_with_loss(fwd, ¶m_ids);
133 let session = Session::new(device);
134 let mut compiled = session.compile(bwd);
135
136 let mut opt_values: Vec<f32> = spec
137 .optimize
138 .iter()
139 .map(|n| {
140 spec.values
141 .get(*n)
142 .copied()
143 .ok_or_else(|| GraphOptError::ParamNotFound((*n).into()))
144 })
145 .collect::<Result<_, _>>()?;
146
147 let mut opt = Adam::new(cfg.lr).with_betas(cfg.beta1, cfg.beta2);
148 let mut history = Vec::with_capacity(cfg.steps as usize);
149 let mut last_grads: HashMap<String, f32> = HashMap::new();
150 let mut last_loss = f32::MAX;
151
152 for _ in 0..cfg.steps {
153 apply_all_params(&mut compiled, &spec.values, spec.optimize, &opt_values);
154
155 let mut run_in: Vec<(&str, &[f32])> = spec.inputs.to_vec();
156 run_in.push(("d_output", &[1.0]));
157 let outs = compiled.run(&run_in);
158 last_loss = outs[0][0];
159 history.push(last_loss);
160
161 let mut scaled_grads = Vec::with_capacity(opt_values.len());
162 for (i, gout) in outs[1..].iter().enumerate() {
163 let g = gout[0];
164 let name = spec.optimize[i];
165 last_grads.insert(name.to_string(), g);
166 let scale = if cfg.relative_lr {
167 opt_values[i].abs().max(cfg.lr_floor)
168 } else {
169 1.0
170 };
171 scaled_grads.push(g * scale);
172 }
173
174 opt.lr = cfg.lr;
175 opt.step(
176 "params",
177 &[opt_values.len()],
178 &mut opt_values,
179 &scaled_grads,
180 );
181 opt.end_iteration();
182
183 for (i, name) in spec.optimize.iter().enumerate() {
184 if let Some(&(lo, hi)) = spec.bounds.get(*name) {
185 opt_values[i] = opt_values[i].clamp(lo, hi);
186 }
187 }
188 }
189
190 let mut params = spec.values.clone();
191 for (name, val) in spec.optimize.iter().zip(opt_values.iter()) {
192 params.insert((*name).to_string(), *val);
193 }
194
195 Ok(GraphOptResult {
196 params,
197 final_loss: last_loss,
198 history,
199 final_grads: last_grads,
200 })
201}
202
203pub(crate) fn apply_all_params(
204 compiled: &mut CompiledGraph,
205 all: &HashMap<String, f32>,
206 optimize: &[&str],
207 opt_values: &[f32],
208) {
209 for (name, val) in all {
210 if !optimize.contains(&name.as_str()) {
211 compiled.set_param(name, &[*val]);
212 }
213 }
214 for (name, val) in optimize.iter().zip(opt_values.iter()) {
215 compiled.set_param(name, &[*val]);
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use rlx_ir::{DType, Graph, Op, Shape, op::BinaryOp};
222
223 use super::*;
224
225 fn quadratic_loss_graph() -> (Graph, &'static str) {
226 let mut g = Graph::new("quad");
227 let s = Shape::new(&[1], DType::F32);
228 let x = g.param("x", s.clone());
229 let target = g.add_node(
230 Op::Constant {
231 data: 2.0f32.to_le_bytes().to_vec(),
232 },
233 vec![],
234 s.clone(),
235 );
236 let err = g.binary(BinaryOp::Sub, x, target, s.clone());
237 let loss = g.binary(BinaryOp::Mul, err, err, s);
238 g.set_outputs(vec![loss]);
239 (g, "x")
240 }
241
242 #[test]
243 fn parabolic_1d_converges() {
244 let (fwd, pname) = quadratic_loss_graph();
245 let values = HashMap::from([(pname.to_string(), 0.0f32)]);
246 let bounds = HashMap::from([(pname.to_string(), (-10.0, 10.0))]);
247 let spec = GraphOptSpec {
248 optimize: &[pname],
249 values,
250 bounds,
251 inputs: &[],
252 };
253 let cfg = GraphOptConfig {
254 steps: 96,
255 lr: 0.15,
256 relative_lr: false,
257 ..Default::default()
258 };
259 let r = adam_opt_graph(&fwd, &spec, &cfg, Device::Cpu).unwrap();
260 assert!(
261 r.final_loss < 0.01,
262 "loss={} x={}",
263 r.final_loss,
264 r.params[pname]
265 );
266 assert!((r.params[pname] - 2.0).abs() < 0.08);
267 }
268}