1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/// Second-order optimization methods
///
/// This module has been refactored into separate submodules for better organization
/// and maintainability. Each optimizer now has its own dedicated module.
pub mod lbfgs;
// Import from original file for remaining optimizers that need to be extracted
use anyhow::Result;
use std::collections::HashMap;
use trustformers_core::tensor::Tensor;
// Re-export the LBFGS components
pub use lbfgs::{LBFGS, LineSearchMethod};
// The following structs would be moved to their own modules in a complete refactor:
/// Kronecker-Factored Approximate Curvature (K-FAC) optimizer.
///
/// K-FAC approximates the Fisher Information Matrix using Kronecker products,
/// making second-order optimization feasible for neural networks.
///
/// Note: This implementation would be moved to its own module (kfac.rs) in a complete refactor.
#[derive(Debug)]
pub struct KFAC {
pub learning_rate: f32,
pub momentum: f32,
pub damping: f32,
pub weight_decay: f32,
pub update_freq: usize,
pub eps: f32,
// Internal state
pub step: usize,
pub momentum_buffer: HashMap<String, Vec<f32>>,
pub cov_ata: HashMap<String, Vec<Vec<f32>>>, // A^T A covariance matrices
pub cov_ggt: HashMap<String, Vec<Vec<f32>>>, // G G^T covariance matrices
pub inv_cov_ata: HashMap<String, Vec<Vec<f32>>>, // Inverse of A^T A
pub inv_cov_ggt: HashMap<String, Vec<Vec<f32>>>, // Inverse of G G^T
}
impl Default for KFAC {
fn default() -> Self {
Self {
learning_rate: 1e-3,
momentum: 0.9,
damping: 1e-3,
weight_decay: 0.0,
update_freq: 10,
eps: 1e-10,
step: 0,
momentum_buffer: HashMap::new(),
cov_ata: HashMap::new(),
cov_ggt: HashMap::new(),
inv_cov_ata: HashMap::new(),
inv_cov_ggt: HashMap::new(),
}
}
}
impl KFAC {
pub fn new(learning_rate: f32) -> Self {
Self {
learning_rate,
..Default::default()
}
}
}
/// Shampoo optimizer with adaptive preconditioning
///
/// Note: This implementation would be moved to its own module (shampoo.rs) in a complete refactor.
#[derive(Debug)]
pub struct Shampoo {
pub learning_rate: f32,
pub eps: f32,
pub weight_decay: f32,
pub momentum: f32,
pub update_freq: usize,
// Internal state
pub step: usize,
pub momentum_buffer: HashMap<String, Vec<f32>>,
pub h_matrices: HashMap<String, Vec<Vec<f32>>>,
}
impl Default for Shampoo {
fn default() -> Self {
Self {
learning_rate: 1e-3,
eps: 1e-4,
weight_decay: 0.0,
momentum: 0.0,
update_freq: 10,
step: 0,
momentum_buffer: HashMap::new(),
h_matrices: HashMap::new(),
}
}
}
impl Shampoo {
pub fn new(learning_rate: f32) -> Self {
Self {
learning_rate,
..Default::default()
}
}
}
/// Natural Gradient Descent optimizer
///
/// Note: This implementation would be moved to its own module (natural_gradient.rs) in a complete refactor.
#[derive(Debug)]
pub struct NaturalGradient {
pub learning_rate: f32,
pub damping: f32,
pub update_freq: usize,
// Internal state
pub step: usize,
pub fisher_info: HashMap<String, Vec<Vec<f32>>>,
pub inv_fisher: HashMap<String, Vec<Vec<f32>>>,
}
impl Default for NaturalGradient {
fn default() -> Self {
Self {
learning_rate: 1e-3,
damping: 1e-3,
update_freq: 10,
step: 0,
fisher_info: HashMap::new(),
inv_fisher: HashMap::new(),
}
}
}
impl NaturalGradient {
pub fn new(learning_rate: f32) -> Self {
Self {
learning_rate,
..Default::default()
}
}
}