optirs_core/second_order/kfac/
utils.rs

1// Utility functions for K-FAC optimization
2//
3// This module contains helper functions and utilities used throughout
4// the K-FAC implementation, including layer-specific operations and
5// mathematical utilities.
6
7use crate::error::Result;
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::numeric::Float;
10use std::fmt::Debug;
11
12/// K-FAC utilities for layer-specific operations
13pub struct KFACUtils;
14
15impl KFACUtils {
16    /// Compute K-FAC update for convolutional layers
17    pub fn conv_kfac_update<T: Float + 'static>(
18        input_patches: &Array2<T>,
19        output_gradients: &Array2<T>,
20        kernel_size: (usize, usize),
21        stride: (usize, usize),
22        padding: (usize, usize),
23    ) -> Result<Array2<T>> {
24        // Simplified convolution K-FAC update
25        // In practice, this would involve more complex patch extraction and reshaping
26        let batch_size = input_patches.nrows();
27        let input_dim = input_patches.ncols();
28        let output_dim = output_gradients.ncols();
29
30        // Create a placeholder update matrix
31        let mut update = Array2::zeros((kernel_size.0 * kernel_size.1, output_dim));
32
33        // Simple averaging across batch
34        if batch_size > 0 {
35            let scale = T::one() / T::from(batch_size).unwrap_or_else(|| T::zero());
36            for i in 0..update.nrows() {
37                for j in 0..update.ncols() {
38                    let input_idx = i % input_dim;
39                    let output_idx = j % output_dim;
40
41                    let mut sum = T::zero();
42                    for b in 0..batch_size {
43                        if input_idx < input_dim && output_idx < output_dim {
44                            sum = sum
45                                + input_patches[[b, input_idx]] * output_gradients[[b, output_idx]];
46                        }
47                    }
48                    update[[i, j]] = sum * scale;
49                }
50            }
51        }
52
53        Ok(update)
54    }
55
56    /// Compute batch normalization statistics for K-FAC
57    pub fn batchnorm_statistics<T: Float + scirs2_core::numeric::FromPrimitive>(
58        input: &Array2<T>,
59        eps: T,
60    ) -> Result<(Array1<T>, Array1<T>)> {
61        let batch_size = input.nrows();
62        let num_features = input.ncols();
63
64        if batch_size == 0 {
65            return Ok((Array1::zeros(num_features), Array1::ones(num_features)));
66        }
67
68        let batch_size_t = T::from(batch_size).unwrap_or_else(|| T::zero());
69
70        // Compute mean
71        let mean = input.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
72
73        // Compute variance
74        let mut var = Array1::zeros(num_features);
75        for i in 0..num_features {
76            let mut sum_sq_diff = T::zero();
77            for j in 0..batch_size {
78                let diff = input[[j, i]] - mean[i];
79                sum_sq_diff = sum_sq_diff + diff * diff;
80            }
81            var[i] = sum_sq_diff / batch_size_t + eps;
82        }
83
84        Ok((mean, var))
85    }
86
87    /// Compute K-FAC update for grouped convolution layers
88    pub fn grouped_conv_kfac<T: Float + scirs2_core::ndarray::ScalarOperand>(
89        input: &Array2<T>,
90        gradients: &Array2<T>,
91        num_groups: usize,
92    ) -> Result<Array2<T>> {
93        let batch_size = input.nrows();
94        let input_channels = input.ncols();
95        let output_channels = gradients.ncols();
96
97        if num_groups == 0 {
98            return Err(crate::error::OptimError::InvalidParameter(
99                "Number of groups must be positive".to_string(),
100            ));
101        }
102
103        let input_per_group = input_channels / num_groups;
104        let output_per_group = output_channels / num_groups;
105
106        let mut result = Array2::zeros((input_channels, output_channels));
107
108        // Process each group separately
109        for group in 0..num_groups {
110            let input_start = group * input_per_group;
111            let input_end = input_start + input_per_group;
112            let output_start = group * output_per_group;
113            let output_end = output_start + output_per_group;
114
115            // Extract group data
116            let group_input = input.slice(scirs2_core::ndarray::s![.., input_start..input_end]);
117            let group_gradients =
118                gradients.slice(scirs2_core::ndarray::s![.., output_start..output_end]);
119
120            // Compute group covariance
121            let group_update = group_input.t().dot(&group_gradients);
122
123            // Place back in result
124            result
125                .slice_mut(scirs2_core::ndarray::s![
126                    input_start..input_end,
127                    output_start..output_end
128                ])
129                .assign(&group_update);
130        }
131
132        // Normalize by batch size
133        if batch_size > 0 {
134            let scale = T::one() / T::from(batch_size).unwrap_or_else(|| T::zero());
135            result = result * scale;
136        }
137
138        Ok(result)
139    }
140
141    /// Compute eigenvalue-based regularization
142    pub fn eigenvalue_regularization<T: Float + Debug + Send + Sync + 'static>(
143        matrix: &Array2<T>,
144        min_eigenvalue: T,
145    ) -> Array2<T> {
146        let n = matrix.nrows();
147        let mut regularized = matrix.clone();
148
149        // Simple diagonal regularization (in practice, would use proper eigendecomposition)
150        for i in 0..n {
151            if regularized[[i, i]] < min_eigenvalue {
152                regularized[[i, i]] = min_eigenvalue;
153            }
154        }
155
156        regularized
157    }
158
159    /// Compute Kronecker product approximation for two matrices
160    pub fn kronecker_product_approx<T: Float + Debug + Send + Sync + 'static>(
161        a: &Array2<T>,
162        b: &Array2<T>,
163    ) -> Array2<T> {
164        let (a_rows, a_cols) = a.dim();
165        let (b_rows, b_cols) = b.dim();
166
167        let mut result = Array2::zeros((a_rows * b_rows, a_cols * b_cols));
168
169        for i in 0..a_rows {
170            for j in 0..a_cols {
171                let a_val = a[[i, j]];
172                for k in 0..b_rows {
173                    for l in 0..b_cols {
174                        result[[i * b_rows + k, j * b_cols + l]] = a_val * b[[k, l]];
175                    }
176                }
177            }
178        }
179
180        result
181    }
182
183    /// Compute trace of a matrix
184    pub fn trace<T: Float + Debug + Send + Sync + 'static>(matrix: &Array2<T>) -> T {
185        let n = matrix.nrows().min(matrix.ncols());
186        let mut trace = T::zero();
187
188        for i in 0..n {
189            trace = trace + matrix[[i, i]];
190        }
191
192        trace
193    }
194
195    /// Compute Frobenius norm of a matrix
196    pub fn frobenius_norm<T: Float + std::iter::Sum>(matrix: &Array2<T>) -> T {
197        matrix.iter().map(|&x| x * x).sum::<T>().sqrt()
198    }
199
200    /// Check if two matrices are approximately equal
201    pub fn matrices_approx_equal<T: Float + Debug + Send + Sync + 'static>(
202        a: &Array2<T>,
203        b: &Array2<T>,
204        tolerance: T,
205    ) -> bool {
206        if a.dim() != b.dim() {
207            return false;
208        }
209
210        for (a_val, b_val) in a.iter().zip(b.iter()) {
211            if (*a_val - *b_val).abs() > tolerance {
212                return false;
213            }
214        }
215
216        true
217    }
218
219    /// Compute running average with exponential decay
220    pub fn exponential_moving_average<T: Float + Debug + Send + Sync + 'static>(
221        current_value: T,
222        new_value: T,
223        decay: T,
224    ) -> T {
225        decay * current_value + (T::one() - decay) * new_value
226    }
227
228    /// Clamp eigenvalues to prevent numerical instability
229    pub fn clamp_eigenvalues<T: Float + Debug + Send + Sync + 'static>(
230        eigenvalues: &mut Array1<T>,
231        min_val: T,
232        max_val: T,
233    ) {
234        for eigenval in eigenvalues.iter_mut() {
235            *eigenval = (*eigenval).max(min_val).min(max_val);
236        }
237    }
238
239    /// Compute condition number using singular values (approximation)
240    pub fn condition_number_svd_approx<T: Float + Debug + Send + Sync + 'static>(
241        matrix: &Array2<T>,
242    ) -> T {
243        // Simple approximation using diagonal elements
244        let diag = matrix.diag();
245        let max_diag = diag
246            .iter()
247            .fold(T::neg_infinity(), |acc, &x| acc.max(x.abs()));
248        let min_diag = diag.iter().fold(T::infinity(), |acc, &x| acc.min(x.abs()));
249
250        if min_diag > T::zero() {
251            max_diag / min_diag
252        } else {
253            T::infinity()
254        }
255    }
256
257    /// Extract diagonal elements and create diagonal matrix
258    pub fn diag_matrix<T: Float + Clone>(diagonal: &Array1<T>) -> Array2<T> {
259        let n = diagonal.len();
260        let mut matrix = Array2::zeros((n, n));
261
262        for i in 0..n {
263            matrix[[i, i]] = diagonal[i];
264        }
265
266        matrix
267    }
268
269    /// Symmetrize a matrix: (A + A^T) / 2
270    pub fn symmetrize<T: Float + Debug + Send + Sync + 'static>(matrix: &Array2<T>) -> Array2<T> {
271        let n = matrix.nrows();
272        let mut result = Array2::zeros((n, n));
273
274        for i in 0..n {
275            for j in 0..n {
276                result[[i, j]] =
277                    (matrix[[i, j]] + matrix[[j, i]]) / T::from(2.0).unwrap_or_else(|| T::zero());
278            }
279        }
280
281        result
282    }
283}
284
285/// Ordered float wrapper for comparison operations
286#[derive(Debug, Clone, Copy)]
287pub struct OrderedFloat<T: Float + Debug + Send + Sync + 'static>(pub T);
288
289impl<T: Float + Debug + Send + Sync + 'static> PartialEq for OrderedFloat<T> {
290    fn eq(&self, other: &Self) -> bool {
291        self.0 == other.0 || (self.0.is_nan() && other.0.is_nan())
292    }
293}
294
295impl<T: Float + Debug + Send + Sync + 'static> Eq for OrderedFloat<T> {}
296
297impl<T: Float + Debug + Send + Sync + 'static> Ord for OrderedFloat<T> {
298    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
299        self.0
300            .partial_cmp(&other.0)
301            .unwrap_or(std::cmp::Ordering::Equal)
302    }
303}
304
305impl<T: Float + Debug + Send + Sync + 'static> PartialOrd for OrderedFloat<T> {
306    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
307        Some(self.cmp(other))
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[test]
316    fn test_trace_computation() {
317        let matrix =
318            Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
319                .unwrap();
320        let trace = KFACUtils::trace(&matrix);
321        assert!((trace - 15.0).abs() < 1e-10); // 1 + 5 + 9 = 15
322    }
323
324    #[test]
325    fn test_frobenius_norm() {
326        let matrix = Array2::from_shape_vec((2, 2), vec![3.0, 4.0, 0.0, 0.0]).unwrap();
327        let norm = KFACUtils::frobenius_norm(&matrix);
328        assert!((norm - 5.0).abs() < 1e-10); // sqrt(9 + 16) = 5
329    }
330
331    #[test]
332    fn test_exponential_moving_average() {
333        let current = 10.0;
334        let new_val = 20.0;
335        let decay = 0.9;
336
337        let result = KFACUtils::exponential_moving_average(current, new_val, decay);
338        let expected = 0.9 * 10.0 + 0.1 * 20.0; // 9.0 + 2.0 = 11.0
339        assert!((result - expected).abs() < 1e-10);
340    }
341
342    #[test]
343    fn test_matrices_approx_equal() {
344        let a = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
345        let b = Array2::from_shape_vec((2, 2), vec![1.001, 2.001, 3.001, 4.001]).unwrap();
346
347        assert!(KFACUtils::matrices_approx_equal(&a, &b, 0.01));
348        assert!(!KFACUtils::matrices_approx_equal(&a, &b, 0.0001));
349    }
350
351    #[test]
352    fn test_symmetrize() {
353        let matrix = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
354        let symmetric = KFACUtils::symmetrize(&matrix);
355
356        assert!((symmetric[[0, 0]] - 1.0).abs() < 1e-10);
357        assert!((symmetric[[0, 1]] - 2.5).abs() < 1e-10); // (2 + 3) / 2
358        assert!((symmetric[[1, 0]] - 2.5).abs() < 1e-10); // (3 + 2) / 2
359        assert!((symmetric[[1, 1]] - 4.0).abs() < 1e-10);
360    }
361
362    #[test]
363    fn test_diag_matrix() {
364        let diagonal = Array1::from_vec(vec![1.0, 2.0, 3.0]);
365        let matrix = KFACUtils::diag_matrix(&diagonal);
366
367        assert_eq!(matrix.dim(), (3, 3));
368        assert!((matrix[[0, 0]] - 1.0).abs() < 1e-10);
369        assert!((matrix[[1, 1]] - 2.0).abs() < 1e-10);
370        assert!((matrix[[2, 2]] - 3.0).abs() < 1e-10);
371        assert!((matrix[[0, 1]]).abs() < 1e-10); // Off-diagonal should be zero
372    }
373
374    #[test]
375    fn test_ordered_float() {
376        let a = OrderedFloat(1.5);
377        let b = OrderedFloat(2.5);
378        let c = OrderedFloat(1.5);
379
380        assert!(a < b);
381        assert!(a == c);
382        assert!(b > a);
383    }
384
385    #[test]
386    fn test_batchnorm_statistics() {
387        let input =
388            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
389
390        let (mean, var) = KFACUtils::batchnorm_statistics(&input, 1e-8).unwrap();
391
392        // Expected mean: [4.0, 5.0] (column-wise average)
393        assert!((mean[0] - 4.0).abs() < 1e-6);
394        assert!((mean[1] - 5.0).abs() < 1e-6);
395
396        // Variance should be positive
397        assert!(var[0] > 0.0);
398        assert!(var[1] > 0.0);
399    }
400}