1use crate::twiddle_stability::{
22 apply_twiddle_update, clip_twiddle_grad, project_twiddles_unit_circle,
23};
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
26pub enum TwiddleOptimizer {
27 Sgd,
28 Adam,
30 DiagPrecond,
32}
33
34impl TwiddleOptimizer {
35 pub fn label(self) -> &'static str {
36 match self {
37 Self::Sgd => "sgd",
38 Self::Adam => "adam",
39 Self::DiagPrecond => "diag_precond",
40 }
41 }
42
43 pub fn parse(s: &str) -> anyhow::Result<Self> {
44 match s.to_ascii_lowercase().as_str() {
45 "sgd" => Ok(Self::Sgd),
46 "adam" => Ok(Self::Adam),
47 "diag" | "diag_precond" | "precond" => Ok(Self::DiagPrecond),
48 other => anyhow::bail!("unknown optimizer {other} (sgd, adam, diag_precond)"),
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
54pub struct TwiddleOptState {
55 pub optimizer: TwiddleOptimizer,
56 pub beta1: f32,
57 pub beta2: f32,
58 pub eps: f32,
59 m_enc: Vec<f32>,
61 m_dec: Vec<f32>,
62 v_enc: Vec<f32>,
63 v_dec: Vec<f32>,
64 step: usize,
65}
66
67impl TwiddleOptState {
68 pub fn new(optimizer: TwiddleOptimizer, enc_len: usize, dec_len: usize) -> Self {
69 let enc_state = state_len(optimizer, enc_len);
70 let dec_state = state_len(optimizer, dec_len);
71 Self {
72 optimizer,
73 beta1: 0.9,
74 beta2: 0.999,
75 eps: 1e-8,
76 m_enc: vec![0f32; enc_state],
77 m_dec: vec![0f32; dec_state],
78 v_enc: vec![0f32; enc_state],
79 v_dec: vec![0f32; dec_state],
80 step: 0,
81 }
82 }
83
84 pub fn step_pair(
85 &mut self,
86 encoder: &mut [f32],
87 decoder: &mut [f32],
88 enc_grad: &[f32],
89 dec_grad: &[f32],
90 lr: f32,
91 grad_clip: f32,
92 project: bool,
93 ) {
94 self.step += 1;
95 match self.optimizer {
96 TwiddleOptimizer::Sgd => {
97 apply_twiddle_update(encoder, enc_grad, lr, grad_clip, project);
98 apply_twiddle_update(decoder, dec_grad, lr, grad_clip, project);
99 }
100 TwiddleOptimizer::Adam => {
101 adam_angle_update(
102 encoder,
103 enc_grad,
104 &mut self.m_enc,
105 &mut self.v_enc,
106 self.step,
107 lr,
108 grad_clip,
109 self.beta1,
110 self.beta2,
111 self.eps,
112 );
113 adam_angle_update(
114 decoder,
115 dec_grad,
116 &mut self.m_dec,
117 &mut self.v_dec,
118 self.step,
119 lr,
120 grad_clip,
121 self.beta1,
122 self.beta2,
123 self.eps,
124 );
125 let _ = project;
126 }
127 TwiddleOptimizer::DiagPrecond => {
128 diag_precond_angle_update(
129 encoder,
130 enc_grad,
131 &mut self.v_enc,
132 lr,
133 grad_clip,
134 self.beta2,
135 self.eps,
136 );
137 diag_precond_angle_update(
138 decoder,
139 dec_grad,
140 &mut self.v_dec,
141 lr,
142 grad_clip,
143 self.beta2,
144 self.eps,
145 );
146 let _ = project;
147 }
148 }
149 }
150}
151
152fn state_len(optimizer: TwiddleOptimizer, flat_len: usize) -> usize {
153 match optimizer {
154 TwiddleOptimizer::Sgd => flat_len,
155 TwiddleOptimizer::Adam | TwiddleOptimizer::DiagPrecond => flat_len / 2,
156 }
157}
158
159fn cartesian_grad_to_angle(tw: &[f32], grad: &[f32]) -> Vec<f32> {
161 debug_assert_eq!(tw.len(), grad.len());
162 let mut out = Vec::with_capacity(tw.len() / 2);
163 for (w, g) in tw.chunks(2).zip(grad.chunks(2)) {
164 let re = w[0];
165 let im = w[1];
166 let mag = (re * re + im * im).sqrt().max(1e-12);
167 let ur = re / mag;
168 let ui = im / mag;
169 out.push(-g[0] * ui + g[1] * ur);
170 }
171 out
172}
173
174fn apply_angle_deltas(tw: &mut [f32], deltas: &[f32]) {
175 for (chunk, &delta) in tw.chunks_mut(2).zip(deltas) {
176 let re = chunk[0];
177 let im = chunk[1];
178 let mag = (re * re + im * im).sqrt().max(1e-12);
179 let ur = re / mag;
180 let ui = im / mag;
181 let (s, c) = delta.sin_cos();
182 chunk[0] = ur * c + ui * s;
184 chunk[1] = ui * c - ur * s;
185 }
186}
187
188fn adam_angle_update(
189 tw: &mut [f32],
190 grad: &[f32],
191 m: &mut [f32],
192 v: &mut [f32],
193 step: usize,
194 lr: f32,
195 grad_clip: f32,
196 beta1: f32,
197 beta2: f32,
198 eps: f32,
199) {
200 let mut angle_grad = cartesian_grad_to_angle(tw, grad);
201 clip_twiddle_grad(&mut angle_grad, grad_clip);
202 let bc1 = 1.0 - beta1.powi(step as i32);
203 let bc2 = 1.0 - beta2.powi(step as i32);
204 let mut deltas = vec![0f32; angle_grad.len()];
205 for i in 0..angle_grad.len() {
206 m[i] = beta1 * m[i] + (1.0 - beta1) * angle_grad[i];
207 v[i] = beta2 * v[i] + (1.0 - beta2) * angle_grad[i] * angle_grad[i];
208 let m_hat = m[i] / bc1;
209 let v_hat = v[i] / bc2;
210 deltas[i] = lr * m_hat / (v_hat.sqrt() + eps);
211 }
212 apply_angle_deltas(tw, &deltas);
213}
214
215fn diag_precond_angle_update(
216 tw: &mut [f32],
217 grad: &[f32],
218 v: &mut [f32],
219 lr: f32,
220 grad_clip: f32,
221 beta2: f32,
222 eps: f32,
223) {
224 let mut angle_grad = cartesian_grad_to_angle(tw, grad);
225 clip_twiddle_grad(&mut angle_grad, grad_clip);
226 let mut deltas = vec![0f32; angle_grad.len()];
227 for i in 0..angle_grad.len() {
228 v[i] = beta2 * v[i] + (1.0 - beta2) * angle_grad[i] * angle_grad[i];
229 deltas[i] = lr * angle_grad[i] / (v[i].sqrt() + eps);
230 }
231 apply_angle_deltas(tw, &deltas);
232}
233
234pub fn hvp_twiddles_finite_diff<F>(
236 tw: &[f32],
237 direction: &[f32],
238 mut loss_and_grad: F,
239 eps: f32,
240) -> anyhow::Result<Vec<f32>>
241where
242 F: FnMut(&[f32]) -> anyhow::Result<(f32, Vec<f32>)>,
243{
244 anyhow::ensure!(tw.len() == direction.len());
245 let mut plus = tw.to_vec();
246 let mut minus = tw.to_vec();
247 for i in 0..tw.len() {
248 plus[i] += eps * direction[i];
249 minus[i] -= eps * direction[i];
250 }
251 let (_, g_plus) = loss_and_grad(&plus)?;
252 let (_, g_minus) = loss_and_grad(&minus)?;
253 Ok(g_plus
254 .iter()
255 .zip(g_minus.iter())
256 .map(|(a, b)| (a - b) / (2.0 * eps))
257 .collect())
258}
259
260pub fn diag_gn_step(tw: &mut [f32], grad: &[f32], lr: f32, damping: f32, project: bool) {
262 let angle_grad = cartesian_grad_to_angle(tw, grad);
263 let deltas: Vec<f32> = angle_grad
264 .iter()
265 .map(|g| lr * g / (g * g + damping))
266 .collect();
267 apply_angle_deltas(tw, &deltas);
268 if project {
269 project_twiddles_unit_circle(tw);
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn angle_delta_descent_on_quadratic() {
279 let mut tw = vec![0.6f32, 0.8];
280 for _ in 0..100 {
281 let g_theta = 2.0 * tw[1].atan2(tw[0]);
282 apply_angle_deltas(&mut tw, &[0.05 * g_theta]);
283 }
284 let theta = tw[1].atan2(tw[0]);
285 assert!(theta.abs() < 0.2, "theta={theta}");
286 }
287
288 #[test]
289 fn angle_adam_reduces_quadratic_angle() {
290 let mut tw = vec![0.6f32, 0.8];
291 let mut m = [0.0];
292 let mut v = [0.0];
293 for step in 1..=400 {
294 let g_theta = 2.0 * tw[1].atan2(tw[0]);
295 let mut angle_grad = vec![g_theta];
296 clip_twiddle_grad(&mut angle_grad, 0.0);
297 let bc1 = 1.0 - 0.9f32.powi(step);
298 let bc2 = 1.0 - 0.999f32.powi(step);
299 let mut deltas = vec![0.0];
300 m[0] = 0.9 * m[0] + 0.1 * angle_grad[0];
301 v[0] = 0.999 * v[0] + 0.001 * angle_grad[0] * angle_grad[0];
302 deltas[0] = 0.15 * (m[0] / bc1) / ((v[0] / bc2).sqrt() + 1e-8);
303 apply_angle_deltas(&mut tw, &deltas);
304 }
305 let theta = tw[1].atan2(tw[0]);
306 assert!(theta.abs() < 0.2, "theta={theta}");
307 }
308
309 #[test]
310 fn cartesian_to_angle_chain_rule() {
311 let tw = vec![1.0, 0.0];
312 let grad = vec![0.0, 1.0];
313 let dtheta = cartesian_grad_to_angle(&tw, &grad)[0];
314 assert!((dtheta - 1.0).abs() < 1e-5);
315 }
316}