1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
//! Optimizers ported from [`mlx/python/mlx/optimizers/optimizers.py`] +
//! [`mlx-swift/Source/MLXOptimizers/Optimizers.swift`].
//!
//! Ten gradient-descent optimizer families, each implementing the common
//! [`Optimizer`] trait:
//!
//! - [`SGD`] — stochastic gradient descent + (Nesterov) momentum + weight
//! decay + dampening.
//! - [`RMSprop`] — running-average-of-squared-gradients normalization.
//! - [`Adagrad`] — cumulative-squared-gradients normalization.
//! - [`AdaDelta`] — `(u/v)` running ratio with no global learning rate.
//! - [`Adam`] / [`AdamW`] / [`Adamax`] — bias-corrected adaptive moments
//! family, with [`AdamW`] adding decoupled weight decay and [`Adamax`]
//! using the `∞`-norm denominator.
//! - [`Lion`] — sign-of-momentum update (smaller compute / memory than Adam).
//! - [`Adafactor`] — sublinear-memory adaptive moments (row+col running
//! averages instead of full per-element `v`).
//! - [`Muon`] — momentum + Newton-Schulz orthogonalization on 2D+ updates.
//!
//! Plus [`MultiOptimizer`] for routing different parameter groups to
//! different optimizer instances, and the [`schedulers`] sub-module for
//! step-driven learning-rate schedules ([`schedulers::cosine_decay`],
//! [`schedulers::exponential_decay`], [`schedulers::step_decay`],
//! [`schedulers::linear_schedule`], [`schedulers::join_schedules`]).
//!
//! ## Trait shape (deviation from Python)
//!
//! Python keeps state in a nested `dict` keyed by the parameter tree path
//! (`tree_map(apply_single, gradients, parameters, state)`). The Rust port
//! flattens this to a `HashMap<String, ...>` — each optimizer owns its own
//! per-parameter state keyed by the parameter's *flat name* (the same flat
//! string keys [`crate::lm::load::Weights`] uses, e.g.
//! `"model.layers.0.self_attn.q_proj.weight"`). Reasons:
//!
//! - mlxrs's [`crate::lm::load::Weights`] is already a flat `HashMap<String,
//! Array>` (mirroring the safetensors / GGUF on-disk format), and the
//! training loop hands the optimizer a [`Weights`]-shaped tree of
//! gradients + parameters. The flat shape is the natural Rust idiom.
//! - The Python `tree_map` walks the per-parameter `state` dict in lock-step
//! with the parameter tree; a flat `HashMap` keyed by the same flat path
//! is the structural equivalent, just spelled differently.
//! - This follows the Rust-idiomatic API shape: ndarray-flavored
//! ergonomics over verbatim Python/Swift mirroring.
//!
//! ## Scope cuts
//!
//! - **Distributed training** (`mx.distributed.AllReduce` / `Group.barrier`)
//! is out of scope for v1; single-process training only. Can be added
//! later via the already-bound but unwrapped `mlxrs_sys::mlx_distributed_*`
//! symbols.
//! - **MultiOptimizer** ships the trait + a minimal predicate-routing impl;
//! the full per-parameter-tree Python complexity (`tree_merge`,
//! `_split_dictionary` with `tree_flatten`/`tree_unflatten` round-trip)
//! collapses naturally to flat-map filtering.
//! - **TensorBoard / W&B integrations** are out of scope; callers add their
//! own progress callback (see [`super::trainer::TrainingCallback`]).
//!
//! [`mlx/python/mlx/optimizers/optimizers.py`]: https://github.com/ml-explore/mlx/blob/main/python/mlx/optimizers/optimizers.py
//! [`mlx-swift/Source/MLXOptimizers/Optimizers.swift`]: https://github.com/ml-explore/mlx-swift/blob/main/Source/MLXOptimizers/Optimizers.swift
//! [`Weights`]: crate::lm::load::Weights
pub use AdaDelta;
pub use Adafactor;
pub use Adagrad;
pub use ;
pub use ;
pub use clip_grad_norm;
pub use Lion;
pub use MultiOptimizer;
pub use Muon;
pub use RMSprop;
pub use ;
pub use SGD;