optirs_core/second_order/kfac/
utils.rs1use crate::error::Result;
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::numeric::Float;
10use std::fmt::Debug;
11
12pub struct KFACUtils;
14
15impl KFACUtils {
16 pub fn conv_kfac_update<T: Float + 'static>(
18 input_patches: &Array2<T>,
19 output_gradients: &Array2<T>,
20 kernel_size: (usize, usize),
21 stride: (usize, usize),
22 padding: (usize, usize),
23 ) -> Result<Array2<T>> {
24 let batch_size = input_patches.nrows();
27 let input_dim = input_patches.ncols();
28 let output_dim = output_gradients.ncols();
29
30 let mut update = Array2::zeros((kernel_size.0 * kernel_size.1, output_dim));
32
33 if batch_size > 0 {
35 let scale = T::one() / T::from(batch_size).unwrap_or_else(|| T::zero());
36 for i in 0..update.nrows() {
37 for j in 0..update.ncols() {
38 let input_idx = i % input_dim;
39 let output_idx = j % output_dim;
40
41 let mut sum = T::zero();
42 for b in 0..batch_size {
43 if input_idx < input_dim && output_idx < output_dim {
44 sum = sum
45 + input_patches[[b, input_idx]] * output_gradients[[b, output_idx]];
46 }
47 }
48 update[[i, j]] = sum * scale;
49 }
50 }
51 }
52
53 Ok(update)
54 }
55
56 pub fn batchnorm_statistics<T: Float + scirs2_core::numeric::FromPrimitive>(
58 input: &Array2<T>,
59 eps: T,
60 ) -> Result<(Array1<T>, Array1<T>)> {
61 let batch_size = input.nrows();
62 let num_features = input.ncols();
63
64 if batch_size == 0 {
65 return Ok((Array1::zeros(num_features), Array1::ones(num_features)));
66 }
67
68 let batch_size_t = T::from(batch_size).unwrap_or_else(|| T::zero());
69
70 let mean = input
72 .mean_axis(scirs2_core::ndarray::Axis(0))
73 .expect("unwrap failed");
74
75 let mut var = Array1::zeros(num_features);
77 for i in 0..num_features {
78 let mut sum_sq_diff = T::zero();
79 for j in 0..batch_size {
80 let diff = input[[j, i]] - mean[i];
81 sum_sq_diff = sum_sq_diff + diff * diff;
82 }
83 var[i] = sum_sq_diff / batch_size_t + eps;
84 }
85
86 Ok((mean, var))
87 }
88
89 pub fn grouped_conv_kfac<T: Float + scirs2_core::ndarray::ScalarOperand>(
91 input: &Array2<T>,
92 gradients: &Array2<T>,
93 num_groups: usize,
94 ) -> Result<Array2<T>> {
95 let batch_size = input.nrows();
96 let input_channels = input.ncols();
97 let output_channels = gradients.ncols();
98
99 if num_groups == 0 {
100 return Err(crate::error::OptimError::InvalidParameter(
101 "Number of groups must be positive".to_string(),
102 ));
103 }
104
105 let input_per_group = input_channels / num_groups;
106 let output_per_group = output_channels / num_groups;
107
108 let mut result = Array2::zeros((input_channels, output_channels));
109
110 for group in 0..num_groups {
112 let input_start = group * input_per_group;
113 let input_end = input_start + input_per_group;
114 let output_start = group * output_per_group;
115 let output_end = output_start + output_per_group;
116
117 let group_input = input.slice(scirs2_core::ndarray::s![.., input_start..input_end]);
119 let group_gradients =
120 gradients.slice(scirs2_core::ndarray::s![.., output_start..output_end]);
121
122 let group_update = group_input.t().dot(&group_gradients);
124
125 result
127 .slice_mut(scirs2_core::ndarray::s![
128 input_start..input_end,
129 output_start..output_end
130 ])
131 .assign(&group_update);
132 }
133
134 if batch_size > 0 {
136 let scale = T::one() / T::from(batch_size).unwrap_or_else(|| T::zero());
137 result = result * scale;
138 }
139
140 Ok(result)
141 }
142
143 pub fn eigenvalue_regularization<T: Float + Debug + Send + Sync + 'static>(
145 matrix: &Array2<T>,
146 min_eigenvalue: T,
147 ) -> Array2<T> {
148 let n = matrix.nrows();
149 let mut regularized = matrix.clone();
150
151 for i in 0..n {
153 if regularized[[i, i]] < min_eigenvalue {
154 regularized[[i, i]] = min_eigenvalue;
155 }
156 }
157
158 regularized
159 }
160
161 pub fn kronecker_product_approx<T: Float + Debug + Send + Sync + 'static>(
163 a: &Array2<T>,
164 b: &Array2<T>,
165 ) -> Array2<T> {
166 let (a_rows, a_cols) = a.dim();
167 let (b_rows, b_cols) = b.dim();
168
169 let mut result = Array2::zeros((a_rows * b_rows, a_cols * b_cols));
170
171 for i in 0..a_rows {
172 for j in 0..a_cols {
173 let a_val = a[[i, j]];
174 for k in 0..b_rows {
175 for l in 0..b_cols {
176 result[[i * b_rows + k, j * b_cols + l]] = a_val * b[[k, l]];
177 }
178 }
179 }
180 }
181
182 result
183 }
184
185 pub fn trace<T: Float + Debug + Send + Sync + 'static>(matrix: &Array2<T>) -> T {
187 let n = matrix.nrows().min(matrix.ncols());
188 let mut trace = T::zero();
189
190 for i in 0..n {
191 trace = trace + matrix[[i, i]];
192 }
193
194 trace
195 }
196
197 pub fn frobenius_norm<T: Float + std::iter::Sum>(matrix: &Array2<T>) -> T {
199 matrix.iter().map(|&x| x * x).sum::<T>().sqrt()
200 }
201
202 pub fn matrices_approx_equal<T: Float + Debug + Send + Sync + 'static>(
204 a: &Array2<T>,
205 b: &Array2<T>,
206 tolerance: T,
207 ) -> bool {
208 if a.dim() != b.dim() {
209 return false;
210 }
211
212 for (a_val, b_val) in a.iter().zip(b.iter()) {
213 if (*a_val - *b_val).abs() > tolerance {
214 return false;
215 }
216 }
217
218 true
219 }
220
221 pub fn exponential_moving_average<T: Float + Debug + Send + Sync + 'static>(
223 current_value: T,
224 new_value: T,
225 decay: T,
226 ) -> T {
227 decay * current_value + (T::one() - decay) * new_value
228 }
229
230 pub fn clamp_eigenvalues<T: Float + Debug + Send + Sync + 'static>(
232 eigenvalues: &mut Array1<T>,
233 min_val: T,
234 max_val: T,
235 ) {
236 for eigenval in eigenvalues.iter_mut() {
237 *eigenval = (*eigenval).max(min_val).min(max_val);
238 }
239 }
240
241 pub fn condition_number_svd_approx<T: Float + Debug + Send + Sync + 'static>(
243 matrix: &Array2<T>,
244 ) -> T {
245 let diag = matrix.diag();
247 let max_diag = diag
248 .iter()
249 .fold(T::neg_infinity(), |acc, &x| acc.max(x.abs()));
250 let min_diag = diag.iter().fold(T::infinity(), |acc, &x| acc.min(x.abs()));
251
252 if min_diag > T::zero() {
253 max_diag / min_diag
254 } else {
255 T::infinity()
256 }
257 }
258
259 pub fn diag_matrix<T: Float + Clone>(diagonal: &Array1<T>) -> Array2<T> {
261 let n = diagonal.len();
262 let mut matrix = Array2::zeros((n, n));
263
264 for i in 0..n {
265 matrix[[i, i]] = diagonal[i];
266 }
267
268 matrix
269 }
270
271 pub fn symmetrize<T: Float + Debug + Send + Sync + 'static>(matrix: &Array2<T>) -> Array2<T> {
273 let n = matrix.nrows();
274 let mut result = Array2::zeros((n, n));
275
276 for i in 0..n {
277 for j in 0..n {
278 result[[i, j]] =
279 (matrix[[i, j]] + matrix[[j, i]]) / T::from(2.0).unwrap_or_else(|| T::zero());
280 }
281 }
282
283 result
284 }
285}
286
287#[derive(Debug, Clone, Copy)]
289pub struct OrderedFloat<T: Float + Debug + Send + Sync + 'static>(pub T);
290
291impl<T: Float + Debug + Send + Sync + 'static> PartialEq for OrderedFloat<T> {
292 fn eq(&self, other: &Self) -> bool {
293 self.0 == other.0 || (self.0.is_nan() && other.0.is_nan())
294 }
295}
296
297impl<T: Float + Debug + Send + Sync + 'static> Eq for OrderedFloat<T> {}
298
299impl<T: Float + Debug + Send + Sync + 'static> Ord for OrderedFloat<T> {
300 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
301 self.0
302 .partial_cmp(&other.0)
303 .unwrap_or(std::cmp::Ordering::Equal)
304 }
305}
306
307impl<T: Float + Debug + Send + Sync + 'static> PartialOrd for OrderedFloat<T> {
308 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
309 Some(self.cmp(other))
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_trace_computation() {
319 let matrix =
320 Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
321 .expect("unwrap failed");
322 let trace = KFACUtils::trace(&matrix);
323 assert!((trace - 15.0).abs() < 1e-10); }
325
326 #[test]
327 fn test_frobenius_norm() {
328 let matrix =
329 Array2::from_shape_vec((2, 2), vec![3.0, 4.0, 0.0, 0.0]).expect("unwrap failed");
330 let norm = KFACUtils::frobenius_norm(&matrix);
331 assert!((norm - 5.0).abs() < 1e-10); }
333
334 #[test]
335 fn test_exponential_moving_average() {
336 let current = 10.0;
337 let new_val = 20.0;
338 let decay = 0.9;
339
340 let result = KFACUtils::exponential_moving_average(current, new_val, decay);
341 let expected = 0.9 * 10.0 + 0.1 * 20.0; assert!((result - expected).abs() < 1e-10);
343 }
344
345 #[test]
346 fn test_matrices_approx_equal() {
347 let a = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("unwrap failed");
348 let b = Array2::from_shape_vec((2, 2), vec![1.001, 2.001, 3.001, 4.001])
349 .expect("unwrap failed");
350
351 assert!(KFACUtils::matrices_approx_equal(&a, &b, 0.01));
352 assert!(!KFACUtils::matrices_approx_equal(&a, &b, 0.0001));
353 }
354
355 #[test]
356 fn test_symmetrize() {
357 let matrix =
358 Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("unwrap failed");
359 let symmetric = KFACUtils::symmetrize(&matrix);
360
361 assert!((symmetric[[0, 0]] - 1.0).abs() < 1e-10);
362 assert!((symmetric[[0, 1]] - 2.5).abs() < 1e-10); assert!((symmetric[[1, 0]] - 2.5).abs() < 1e-10); assert!((symmetric[[1, 1]] - 4.0).abs() < 1e-10);
365 }
366
367 #[test]
368 fn test_diag_matrix() {
369 let diagonal = Array1::from_vec(vec![1.0, 2.0, 3.0]);
370 let matrix = KFACUtils::diag_matrix(&diagonal);
371
372 assert_eq!(matrix.dim(), (3, 3));
373 assert!((matrix[[0, 0]] - 1.0).abs() < 1e-10);
374 assert!((matrix[[1, 1]] - 2.0).abs() < 1e-10);
375 assert!((matrix[[2, 2]] - 3.0).abs() < 1e-10);
376 assert!((matrix[[0, 1]]).abs() < 1e-10); }
378
379 #[test]
380 fn test_ordered_float() {
381 let a = OrderedFloat(1.5);
382 let b = OrderedFloat(2.5);
383 let c = OrderedFloat(1.5);
384
385 assert!(a < b);
386 assert!(a == c);
387 assert!(b > a);
388 }
389
390 #[test]
391 fn test_batchnorm_statistics() {
392 let input = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
393 .expect("unwrap failed");
394
395 let (mean, var) = KFACUtils::batchnorm_statistics(&input, 1e-8).expect("unwrap failed");
396
397 assert!((mean[0] - 4.0).abs() < 1e-6);
399 assert!((mean[1] - 5.0).abs() < 1e-6);
400
401 assert!(var[0] > 0.0);
403 assert!(var[1] > 0.0);
404 }
405}