optirs_core/second_order/kfac/
utils.rs1use crate::error::Result;
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::numeric::Float;
10use std::fmt::Debug;
11
12pub struct KFACUtils;
14
15impl KFACUtils {
16 pub fn conv_kfac_update<T: Float + 'static>(
18 input_patches: &Array2<T>,
19 output_gradients: &Array2<T>,
20 kernel_size: (usize, usize),
21 stride: (usize, usize),
22 padding: (usize, usize),
23 ) -> Result<Array2<T>> {
24 let batch_size = input_patches.nrows();
27 let input_dim = input_patches.ncols();
28 let output_dim = output_gradients.ncols();
29
30 let mut update = Array2::zeros((kernel_size.0 * kernel_size.1, output_dim));
32
33 if batch_size > 0 {
35 let scale = T::one() / T::from(batch_size).unwrap_or_else(|| T::zero());
36 for i in 0..update.nrows() {
37 for j in 0..update.ncols() {
38 let input_idx = i % input_dim;
39 let output_idx = j % output_dim;
40
41 let mut sum = T::zero();
42 for b in 0..batch_size {
43 if input_idx < input_dim && output_idx < output_dim {
44 sum = sum
45 + input_patches[[b, input_idx]] * output_gradients[[b, output_idx]];
46 }
47 }
48 update[[i, j]] = sum * scale;
49 }
50 }
51 }
52
53 Ok(update)
54 }
55
56 pub fn batchnorm_statistics<T: Float + scirs2_core::numeric::FromPrimitive>(
58 input: &Array2<T>,
59 eps: T,
60 ) -> Result<(Array1<T>, Array1<T>)> {
61 let batch_size = input.nrows();
62 let num_features = input.ncols();
63
64 if batch_size == 0 {
65 return Ok((Array1::zeros(num_features), Array1::ones(num_features)));
66 }
67
68 let batch_size_t = T::from(batch_size).unwrap_or_else(|| T::zero());
69
70 let mean = input.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
72
73 let mut var = Array1::zeros(num_features);
75 for i in 0..num_features {
76 let mut sum_sq_diff = T::zero();
77 for j in 0..batch_size {
78 let diff = input[[j, i]] - mean[i];
79 sum_sq_diff = sum_sq_diff + diff * diff;
80 }
81 var[i] = sum_sq_diff / batch_size_t + eps;
82 }
83
84 Ok((mean, var))
85 }
86
87 pub fn grouped_conv_kfac<T: Float + scirs2_core::ndarray::ScalarOperand>(
89 input: &Array2<T>,
90 gradients: &Array2<T>,
91 num_groups: usize,
92 ) -> Result<Array2<T>> {
93 let batch_size = input.nrows();
94 let input_channels = input.ncols();
95 let output_channels = gradients.ncols();
96
97 if num_groups == 0 {
98 return Err(crate::error::OptimError::InvalidParameter(
99 "Number of groups must be positive".to_string(),
100 ));
101 }
102
103 let input_per_group = input_channels / num_groups;
104 let output_per_group = output_channels / num_groups;
105
106 let mut result = Array2::zeros((input_channels, output_channels));
107
108 for group in 0..num_groups {
110 let input_start = group * input_per_group;
111 let input_end = input_start + input_per_group;
112 let output_start = group * output_per_group;
113 let output_end = output_start + output_per_group;
114
115 let group_input = input.slice(scirs2_core::ndarray::s![.., input_start..input_end]);
117 let group_gradients =
118 gradients.slice(scirs2_core::ndarray::s![.., output_start..output_end]);
119
120 let group_update = group_input.t().dot(&group_gradients);
122
123 result
125 .slice_mut(scirs2_core::ndarray::s![
126 input_start..input_end,
127 output_start..output_end
128 ])
129 .assign(&group_update);
130 }
131
132 if batch_size > 0 {
134 let scale = T::one() / T::from(batch_size).unwrap_or_else(|| T::zero());
135 result = result * scale;
136 }
137
138 Ok(result)
139 }
140
141 pub fn eigenvalue_regularization<T: Float + Debug + Send + Sync + 'static>(
143 matrix: &Array2<T>,
144 min_eigenvalue: T,
145 ) -> Array2<T> {
146 let n = matrix.nrows();
147 let mut regularized = matrix.clone();
148
149 for i in 0..n {
151 if regularized[[i, i]] < min_eigenvalue {
152 regularized[[i, i]] = min_eigenvalue;
153 }
154 }
155
156 regularized
157 }
158
159 pub fn kronecker_product_approx<T: Float + Debug + Send + Sync + 'static>(
161 a: &Array2<T>,
162 b: &Array2<T>,
163 ) -> Array2<T> {
164 let (a_rows, a_cols) = a.dim();
165 let (b_rows, b_cols) = b.dim();
166
167 let mut result = Array2::zeros((a_rows * b_rows, a_cols * b_cols));
168
169 for i in 0..a_rows {
170 for j in 0..a_cols {
171 let a_val = a[[i, j]];
172 for k in 0..b_rows {
173 for l in 0..b_cols {
174 result[[i * b_rows + k, j * b_cols + l]] = a_val * b[[k, l]];
175 }
176 }
177 }
178 }
179
180 result
181 }
182
183 pub fn trace<T: Float + Debug + Send + Sync + 'static>(matrix: &Array2<T>) -> T {
185 let n = matrix.nrows().min(matrix.ncols());
186 let mut trace = T::zero();
187
188 for i in 0..n {
189 trace = trace + matrix[[i, i]];
190 }
191
192 trace
193 }
194
195 pub fn frobenius_norm<T: Float + std::iter::Sum>(matrix: &Array2<T>) -> T {
197 matrix.iter().map(|&x| x * x).sum::<T>().sqrt()
198 }
199
200 pub fn matrices_approx_equal<T: Float + Debug + Send + Sync + 'static>(
202 a: &Array2<T>,
203 b: &Array2<T>,
204 tolerance: T,
205 ) -> bool {
206 if a.dim() != b.dim() {
207 return false;
208 }
209
210 for (a_val, b_val) in a.iter().zip(b.iter()) {
211 if (*a_val - *b_val).abs() > tolerance {
212 return false;
213 }
214 }
215
216 true
217 }
218
219 pub fn exponential_moving_average<T: Float + Debug + Send + Sync + 'static>(
221 current_value: T,
222 new_value: T,
223 decay: T,
224 ) -> T {
225 decay * current_value + (T::one() - decay) * new_value
226 }
227
228 pub fn clamp_eigenvalues<T: Float + Debug + Send + Sync + 'static>(
230 eigenvalues: &mut Array1<T>,
231 min_val: T,
232 max_val: T,
233 ) {
234 for eigenval in eigenvalues.iter_mut() {
235 *eigenval = (*eigenval).max(min_val).min(max_val);
236 }
237 }
238
239 pub fn condition_number_svd_approx<T: Float + Debug + Send + Sync + 'static>(
241 matrix: &Array2<T>,
242 ) -> T {
243 let diag = matrix.diag();
245 let max_diag = diag
246 .iter()
247 .fold(T::neg_infinity(), |acc, &x| acc.max(x.abs()));
248 let min_diag = diag.iter().fold(T::infinity(), |acc, &x| acc.min(x.abs()));
249
250 if min_diag > T::zero() {
251 max_diag / min_diag
252 } else {
253 T::infinity()
254 }
255 }
256
257 pub fn diag_matrix<T: Float + Clone>(diagonal: &Array1<T>) -> Array2<T> {
259 let n = diagonal.len();
260 let mut matrix = Array2::zeros((n, n));
261
262 for i in 0..n {
263 matrix[[i, i]] = diagonal[i];
264 }
265
266 matrix
267 }
268
269 pub fn symmetrize<T: Float + Debug + Send + Sync + 'static>(matrix: &Array2<T>) -> Array2<T> {
271 let n = matrix.nrows();
272 let mut result = Array2::zeros((n, n));
273
274 for i in 0..n {
275 for j in 0..n {
276 result[[i, j]] =
277 (matrix[[i, j]] + matrix[[j, i]]) / T::from(2.0).unwrap_or_else(|| T::zero());
278 }
279 }
280
281 result
282 }
283}
284
285#[derive(Debug, Clone, Copy)]
287pub struct OrderedFloat<T: Float + Debug + Send + Sync + 'static>(pub T);
288
289impl<T: Float + Debug + Send + Sync + 'static> PartialEq for OrderedFloat<T> {
290 fn eq(&self, other: &Self) -> bool {
291 self.0 == other.0 || (self.0.is_nan() && other.0.is_nan())
292 }
293}
294
295impl<T: Float + Debug + Send + Sync + 'static> Eq for OrderedFloat<T> {}
296
297impl<T: Float + Debug + Send + Sync + 'static> Ord for OrderedFloat<T> {
298 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
299 self.0
300 .partial_cmp(&other.0)
301 .unwrap_or(std::cmp::Ordering::Equal)
302 }
303}
304
305impl<T: Float + Debug + Send + Sync + 'static> PartialOrd for OrderedFloat<T> {
306 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
307 Some(self.cmp(other))
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 #[test]
316 fn test_trace_computation() {
317 let matrix =
318 Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
319 .unwrap();
320 let trace = KFACUtils::trace(&matrix);
321 assert!((trace - 15.0).abs() < 1e-10); }
323
324 #[test]
325 fn test_frobenius_norm() {
326 let matrix = Array2::from_shape_vec((2, 2), vec![3.0, 4.0, 0.0, 0.0]).unwrap();
327 let norm = KFACUtils::frobenius_norm(&matrix);
328 assert!((norm - 5.0).abs() < 1e-10); }
330
331 #[test]
332 fn test_exponential_moving_average() {
333 let current = 10.0;
334 let new_val = 20.0;
335 let decay = 0.9;
336
337 let result = KFACUtils::exponential_moving_average(current, new_val, decay);
338 let expected = 0.9 * 10.0 + 0.1 * 20.0; assert!((result - expected).abs() < 1e-10);
340 }
341
342 #[test]
343 fn test_matrices_approx_equal() {
344 let a = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
345 let b = Array2::from_shape_vec((2, 2), vec![1.001, 2.001, 3.001, 4.001]).unwrap();
346
347 assert!(KFACUtils::matrices_approx_equal(&a, &b, 0.01));
348 assert!(!KFACUtils::matrices_approx_equal(&a, &b, 0.0001));
349 }
350
351 #[test]
352 fn test_symmetrize() {
353 let matrix = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
354 let symmetric = KFACUtils::symmetrize(&matrix);
355
356 assert!((symmetric[[0, 0]] - 1.0).abs() < 1e-10);
357 assert!((symmetric[[0, 1]] - 2.5).abs() < 1e-10); assert!((symmetric[[1, 0]] - 2.5).abs() < 1e-10); assert!((symmetric[[1, 1]] - 4.0).abs() < 1e-10);
360 }
361
362 #[test]
363 fn test_diag_matrix() {
364 let diagonal = Array1::from_vec(vec![1.0, 2.0, 3.0]);
365 let matrix = KFACUtils::diag_matrix(&diagonal);
366
367 assert_eq!(matrix.dim(), (3, 3));
368 assert!((matrix[[0, 0]] - 1.0).abs() < 1e-10);
369 assert!((matrix[[1, 1]] - 2.0).abs() < 1e-10);
370 assert!((matrix[[2, 2]] - 3.0).abs() < 1e-10);
371 assert!((matrix[[0, 1]]).abs() < 1e-10); }
373
374 #[test]
375 fn test_ordered_float() {
376 let a = OrderedFloat(1.5);
377 let b = OrderedFloat(2.5);
378 let c = OrderedFloat(1.5);
379
380 assert!(a < b);
381 assert!(a == c);
382 assert!(b > a);
383 }
384
385 #[test]
386 fn test_batchnorm_statistics() {
387 let input =
388 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
389
390 let (mean, var) = KFACUtils::batchnorm_statistics(&input, 1e-8).unwrap();
391
392 assert!((mean[0] - 4.0).abs() < 1e-6);
394 assert!((mean[1] - 5.0).abs() < 1e-6);
395
396 assert!(var[0] > 0.0);
398 assert!(var[1] > 0.0);
399 }
400}