1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//! Bridge between [`WeightStore`] and any [`rlx_optim::Optimizer`].
//!
//! The bundled [`crate::adam::AdamState`] is the host-side reference
//! Adam used by the parametric UMAP trainer. When you want a
//! different rule — AdamW, Lion, Muon, Sophia, … — you can drive any
//! [`rlx_optim::Optimizer`] over the same [`WeightStore`] / gradient
//! pair via [`step_weight_store`].
//!
//! Matrix-aware optimizers (Adafactor, SOAP, Muon, Kron-PSGD) need
//! the per-parameter shape. Pass it via the `shapes` map keyed by
//! the same name used in the `WeightStore`. Names with no entry in
//! `shapes` are treated as 1-D vectors of length `data.len()`.
//!
//! ```ignore
//! use rlx_optim::AdamW;
//! use rlx_umap::optim_adapter::step_weight_store;
//! use std::collections::HashMap;
//!
//! let mut opt = AdamW::new(3e-4).with_weight_decay(0.1);
//! let mut shapes: HashMap<String, Vec<usize>> = HashMap::new();
//! shapes.insert("fc.weight".to_string(), vec![128, 64]);
//! shapes.insert("fc.bias".to_string(), vec![128]);
//!
//! step_weight_store(&mut opt, &mut weights, &grads, &shapes);
//! opt.end_iteration();
//! ```
use HashMap;
use Optimizer;
use crateWeightStore;
/// Drive `opt` for every name in `grads` that also appears in
/// `weights`. Missing grads (no entry in `grads.0` for a name held
/// by `weights`) are silently skipped — useful for partial training
/// passes (e.g. freezing the projection head).
///
/// Returns the number of parameters actually stepped.
/// Convenience: drive `opt` with all parameters treated as flat 1-D
/// vectors. Equivalent to `step_weight_store(opt, weights, grads, &HashMap::new())`.
/// Matrix-aware optimizers fall back to their elementwise path in this case.