Skip to main content

rlx_fft/
second_order.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Second-order / adaptive twiddle optimizers (Adam, diagonal preconditioning, HVP).
17//!
18//! Adaptive steps use **angle parameterization** on the unit circle so re/im are not
19//! updated independently (which breaks twiddle magnitude even with projection).
20
21use crate::twiddle_stability::{
22    apply_twiddle_update, clip_twiddle_grad, project_twiddles_unit_circle,
23};
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
26pub enum TwiddleOptimizer {
27    Sgd,
28    /// Adam on twiddle angles (unit-circle safe).
29    Adam,
30    /// Diagonal preconditioner on twiddle angles.
31    DiagPrecond,
32}
33
34impl TwiddleOptimizer {
35    pub fn label(self) -> &'static str {
36        match self {
37            Self::Sgd => "sgd",
38            Self::Adam => "adam",
39            Self::DiagPrecond => "diag_precond",
40        }
41    }
42
43    pub fn parse(s: &str) -> anyhow::Result<Self> {
44        match s.to_ascii_lowercase().as_str() {
45            "sgd" => Ok(Self::Sgd),
46            "adam" => Ok(Self::Adam),
47            "diag" | "diag_precond" | "precond" => Ok(Self::DiagPrecond),
48            other => anyhow::bail!("unknown optimizer {other} (sgd, adam, diag_precond)"),
49        }
50    }
51}
52
53#[derive(Debug, Clone)]
54pub struct TwiddleOptState {
55    pub optimizer: TwiddleOptimizer,
56    pub beta1: f32,
57    pub beta2: f32,
58    pub eps: f32,
59    /// Moment buffers — length `n_complex_twiddles` for adaptive opts.
60    m_enc: Vec<f32>,
61    m_dec: Vec<f32>,
62    v_enc: Vec<f32>,
63    v_dec: Vec<f32>,
64    step: usize,
65}
66
67impl TwiddleOptState {
68    pub fn new(optimizer: TwiddleOptimizer, enc_len: usize, dec_len: usize) -> Self {
69        let enc_state = state_len(optimizer, enc_len);
70        let dec_state = state_len(optimizer, dec_len);
71        Self {
72            optimizer,
73            beta1: 0.9,
74            beta2: 0.999,
75            eps: 1e-8,
76            m_enc: vec![0f32; enc_state],
77            m_dec: vec![0f32; dec_state],
78            v_enc: vec![0f32; enc_state],
79            v_dec: vec![0f32; dec_state],
80            step: 0,
81        }
82    }
83
84    pub fn step_pair(
85        &mut self,
86        encoder: &mut [f32],
87        decoder: &mut [f32],
88        enc_grad: &[f32],
89        dec_grad: &[f32],
90        lr: f32,
91        grad_clip: f32,
92        project: bool,
93    ) {
94        self.step += 1;
95        match self.optimizer {
96            TwiddleOptimizer::Sgd => {
97                apply_twiddle_update(encoder, enc_grad, lr, grad_clip, project);
98                apply_twiddle_update(decoder, dec_grad, lr, grad_clip, project);
99            }
100            TwiddleOptimizer::Adam => {
101                adam_angle_update(
102                    encoder,
103                    enc_grad,
104                    &mut self.m_enc,
105                    &mut self.v_enc,
106                    self.step,
107                    lr,
108                    grad_clip,
109                    self.beta1,
110                    self.beta2,
111                    self.eps,
112                );
113                adam_angle_update(
114                    decoder,
115                    dec_grad,
116                    &mut self.m_dec,
117                    &mut self.v_dec,
118                    self.step,
119                    lr,
120                    grad_clip,
121                    self.beta1,
122                    self.beta2,
123                    self.eps,
124                );
125                let _ = project;
126            }
127            TwiddleOptimizer::DiagPrecond => {
128                diag_precond_angle_update(
129                    encoder,
130                    enc_grad,
131                    &mut self.v_enc,
132                    lr,
133                    grad_clip,
134                    self.beta2,
135                    self.eps,
136                );
137                diag_precond_angle_update(
138                    decoder,
139                    dec_grad,
140                    &mut self.v_dec,
141                    lr,
142                    grad_clip,
143                    self.beta2,
144                    self.eps,
145                );
146                let _ = project;
147            }
148        }
149    }
150}
151
152fn state_len(optimizer: TwiddleOptimizer, flat_len: usize) -> usize {
153    match optimizer {
154        TwiddleOptimizer::Sgd => flat_len,
155        TwiddleOptimizer::Adam | TwiddleOptimizer::DiagPrecond => flat_len / 2,
156    }
157}
158
159/// Cartesian (re, im) gradient → scalar dL/dθ for w = e^{iθ} on the unit circle.
160fn cartesian_grad_to_angle(tw: &[f32], grad: &[f32]) -> Vec<f32> {
161    debug_assert_eq!(tw.len(), grad.len());
162    let mut out = Vec::with_capacity(tw.len() / 2);
163    for (w, g) in tw.chunks(2).zip(grad.chunks(2)) {
164        let re = w[0];
165        let im = w[1];
166        let mag = (re * re + im * im).sqrt().max(1e-12);
167        let ur = re / mag;
168        let ui = im / mag;
169        out.push(-g[0] * ui + g[1] * ur);
170    }
171    out
172}
173
174fn apply_angle_deltas(tw: &mut [f32], deltas: &[f32]) {
175    for (chunk, &delta) in tw.chunks_mut(2).zip(deltas) {
176        let re = chunk[0];
177        let im = chunk[1];
178        let mag = (re * re + im * im).sqrt().max(1e-12);
179        let ur = re / mag;
180        let ui = im / mag;
181        let (s, c) = delta.sin_cos();
182        // w_new = w * exp(-i*delta)
183        chunk[0] = ur * c + ui * s;
184        chunk[1] = ui * c - ur * s;
185    }
186}
187
188fn adam_angle_update(
189    tw: &mut [f32],
190    grad: &[f32],
191    m: &mut [f32],
192    v: &mut [f32],
193    step: usize,
194    lr: f32,
195    grad_clip: f32,
196    beta1: f32,
197    beta2: f32,
198    eps: f32,
199) {
200    let mut angle_grad = cartesian_grad_to_angle(tw, grad);
201    clip_twiddle_grad(&mut angle_grad, grad_clip);
202    let bc1 = 1.0 - beta1.powi(step as i32);
203    let bc2 = 1.0 - beta2.powi(step as i32);
204    let mut deltas = vec![0f32; angle_grad.len()];
205    for i in 0..angle_grad.len() {
206        m[i] = beta1 * m[i] + (1.0 - beta1) * angle_grad[i];
207        v[i] = beta2 * v[i] + (1.0 - beta2) * angle_grad[i] * angle_grad[i];
208        let m_hat = m[i] / bc1;
209        let v_hat = v[i] / bc2;
210        deltas[i] = lr * m_hat / (v_hat.sqrt() + eps);
211    }
212    apply_angle_deltas(tw, &deltas);
213}
214
215fn diag_precond_angle_update(
216    tw: &mut [f32],
217    grad: &[f32],
218    v: &mut [f32],
219    lr: f32,
220    grad_clip: f32,
221    beta2: f32,
222    eps: f32,
223) {
224    let mut angle_grad = cartesian_grad_to_angle(tw, grad);
225    clip_twiddle_grad(&mut angle_grad, grad_clip);
226    let mut deltas = vec![0f32; angle_grad.len()];
227    for i in 0..angle_grad.len() {
228        v[i] = beta2 * v[i] + (1.0 - beta2) * angle_grad[i] * angle_grad[i];
229        deltas[i] = lr * angle_grad[i] / (v[i].sqrt() + eps);
230    }
231    apply_angle_deltas(tw, &deltas);
232}
233
234/// Finite-difference Hessian–vector product on twiddle flat buffer (diagnostic / small n).
235pub fn hvp_twiddles_finite_diff<F>(
236    tw: &[f32],
237    direction: &[f32],
238    mut loss_and_grad: F,
239    eps: f32,
240) -> anyhow::Result<Vec<f32>>
241where
242    F: FnMut(&[f32]) -> anyhow::Result<(f32, Vec<f32>)>,
243{
244    anyhow::ensure!(tw.len() == direction.len());
245    let mut plus = tw.to_vec();
246    let mut minus = tw.to_vec();
247    for i in 0..tw.len() {
248        plus[i] += eps * direction[i];
249        minus[i] -= eps * direction[i];
250    }
251    let (_, g_plus) = loss_and_grad(&plus)?;
252    let (_, g_minus) = loss_and_grad(&minus)?;
253    Ok(g_plus
254        .iter()
255        .zip(g_minus.iter())
256        .map(|(a, b)| (a - b) / (2.0 * eps))
257        .collect())
258}
259
260/// One diagonal Gauss–Newton style step using grad² as curvature proxy (angle space).
261pub fn diag_gn_step(tw: &mut [f32], grad: &[f32], lr: f32, damping: f32, project: bool) {
262    let angle_grad = cartesian_grad_to_angle(tw, grad);
263    let deltas: Vec<f32> = angle_grad
264        .iter()
265        .map(|g| lr * g / (g * g + damping))
266        .collect();
267    apply_angle_deltas(tw, &deltas);
268    if project {
269        project_twiddles_unit_circle(tw);
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn angle_delta_descent_on_quadratic() {
279        let mut tw = vec![0.6f32, 0.8];
280        for _ in 0..100 {
281            let g_theta = 2.0 * tw[1].atan2(tw[0]);
282            apply_angle_deltas(&mut tw, &[0.05 * g_theta]);
283        }
284        let theta = tw[1].atan2(tw[0]);
285        assert!(theta.abs() < 0.2, "theta={theta}");
286    }
287
288    #[test]
289    fn angle_adam_reduces_quadratic_angle() {
290        let mut tw = vec![0.6f32, 0.8];
291        let mut m = [0.0];
292        let mut v = [0.0];
293        for step in 1..=400 {
294            let g_theta = 2.0 * tw[1].atan2(tw[0]);
295            let mut angle_grad = vec![g_theta];
296            clip_twiddle_grad(&mut angle_grad, 0.0);
297            let bc1 = 1.0 - 0.9f32.powi(step);
298            let bc2 = 1.0 - 0.999f32.powi(step);
299            let mut deltas = vec![0.0];
300            m[0] = 0.9 * m[0] + 0.1 * angle_grad[0];
301            v[0] = 0.999 * v[0] + 0.001 * angle_grad[0] * angle_grad[0];
302            deltas[0] = 0.15 * (m[0] / bc1) / ((v[0] / bc2).sqrt() + 1e-8);
303            apply_angle_deltas(&mut tw, &deltas);
304        }
305        let theta = tw[1].atan2(tw[0]);
306        assert!(theta.abs() < 0.2, "theta={theta}");
307    }
308
309    #[test]
310    fn cartesian_to_angle_chain_rule() {
311        let tw = vec![1.0, 0.0];
312        let grad = vec![0.0, 1.0];
313        let dtheta = cartesian_grad_to_angle(&tw, &grad)[0];
314        assert!((dtheta - 1.0).abs() < 1e-5);
315    }
316}