Skip to main content

optirs_core/second_order/kfac/
utils.rs

1// Utility functions for K-FAC optimization
2//
3// This module contains helper functions and utilities used throughout
4// the K-FAC implementation, including layer-specific operations and
5// mathematical utilities.
6
7use crate::error::Result;
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::numeric::Float;
10use std::fmt::Debug;
11
12/// K-FAC utilities for layer-specific operations
13pub struct KFACUtils;
14
15impl KFACUtils {
16    /// Compute K-FAC update for convolutional layers
17    pub fn conv_kfac_update<T: Float + 'static>(
18        input_patches: &Array2<T>,
19        output_gradients: &Array2<T>,
20        kernel_size: (usize, usize),
21        stride: (usize, usize),
22        padding: (usize, usize),
23    ) -> Result<Array2<T>> {
24        // Simplified convolution K-FAC update
25        // In practice, this would involve more complex patch extraction and reshaping
26        let batch_size = input_patches.nrows();
27        let input_dim = input_patches.ncols();
28        let output_dim = output_gradients.ncols();
29
30        // Create a placeholder update matrix
31        let mut update = Array2::zeros((kernel_size.0 * kernel_size.1, output_dim));
32
33        // Simple averaging across batch
34        if batch_size > 0 {
35            let scale = T::one() / T::from(batch_size).unwrap_or_else(|| T::zero());
36            for i in 0..update.nrows() {
37                for j in 0..update.ncols() {
38                    let input_idx = i % input_dim;
39                    let output_idx = j % output_dim;
40
41                    let mut sum = T::zero();
42                    for b in 0..batch_size {
43                        if input_idx < input_dim && output_idx < output_dim {
44                            sum = sum
45                                + input_patches[[b, input_idx]] * output_gradients[[b, output_idx]];
46                        }
47                    }
48                    update[[i, j]] = sum * scale;
49                }
50            }
51        }
52
53        Ok(update)
54    }
55
56    /// Compute batch normalization statistics for K-FAC
57    pub fn batchnorm_statistics<T: Float + scirs2_core::numeric::FromPrimitive>(
58        input: &Array2<T>,
59        eps: T,
60    ) -> Result<(Array1<T>, Array1<T>)> {
61        let batch_size = input.nrows();
62        let num_features = input.ncols();
63
64        if batch_size == 0 {
65            return Ok((Array1::zeros(num_features), Array1::ones(num_features)));
66        }
67
68        let batch_size_t = T::from(batch_size).unwrap_or_else(|| T::zero());
69
70        // Compute mean
71        let mean = input
72            .mean_axis(scirs2_core::ndarray::Axis(0))
73            .expect("unwrap failed");
74
75        // Compute variance
76        let mut var = Array1::zeros(num_features);
77        for i in 0..num_features {
78            let mut sum_sq_diff = T::zero();
79            for j in 0..batch_size {
80                let diff = input[[j, i]] - mean[i];
81                sum_sq_diff = sum_sq_diff + diff * diff;
82            }
83            var[i] = sum_sq_diff / batch_size_t + eps;
84        }
85
86        Ok((mean, var))
87    }
88
89    /// Compute K-FAC update for grouped convolution layers
90    pub fn grouped_conv_kfac<T: Float + scirs2_core::ndarray::ScalarOperand>(
91        input: &Array2<T>,
92        gradients: &Array2<T>,
93        num_groups: usize,
94    ) -> Result<Array2<T>> {
95        let batch_size = input.nrows();
96        let input_channels = input.ncols();
97        let output_channels = gradients.ncols();
98
99        if num_groups == 0 {
100            return Err(crate::error::OptimError::InvalidParameter(
101                "Number of groups must be positive".to_string(),
102            ));
103        }
104
105        let input_per_group = input_channels / num_groups;
106        let output_per_group = output_channels / num_groups;
107
108        let mut result = Array2::zeros((input_channels, output_channels));
109
110        // Process each group separately
111        for group in 0..num_groups {
112            let input_start = group * input_per_group;
113            let input_end = input_start + input_per_group;
114            let output_start = group * output_per_group;
115            let output_end = output_start + output_per_group;
116
117            // Extract group data
118            let group_input = input.slice(scirs2_core::ndarray::s![.., input_start..input_end]);
119            let group_gradients =
120                gradients.slice(scirs2_core::ndarray::s![.., output_start..output_end]);
121
122            // Compute group covariance
123            let group_update = group_input.t().dot(&group_gradients);
124
125            // Place back in result
126            result
127                .slice_mut(scirs2_core::ndarray::s![
128                    input_start..input_end,
129                    output_start..output_end
130                ])
131                .assign(&group_update);
132        }
133
134        // Normalize by batch size
135        if batch_size > 0 {
136            let scale = T::one() / T::from(batch_size).unwrap_or_else(|| T::zero());
137            result = result * scale;
138        }
139
140        Ok(result)
141    }
142
143    /// Compute eigenvalue-based regularization
144    pub fn eigenvalue_regularization<T: Float + Debug + Send + Sync + 'static>(
145        matrix: &Array2<T>,
146        min_eigenvalue: T,
147    ) -> Array2<T> {
148        let n = matrix.nrows();
149        let mut regularized = matrix.clone();
150
151        // Simple diagonal regularization (in practice, would use proper eigendecomposition)
152        for i in 0..n {
153            if regularized[[i, i]] < min_eigenvalue {
154                regularized[[i, i]] = min_eigenvalue;
155            }
156        }
157
158        regularized
159    }
160
161    /// Compute Kronecker product approximation for two matrices
162    pub fn kronecker_product_approx<T: Float + Debug + Send + Sync + 'static>(
163        a: &Array2<T>,
164        b: &Array2<T>,
165    ) -> Array2<T> {
166        let (a_rows, a_cols) = a.dim();
167        let (b_rows, b_cols) = b.dim();
168
169        let mut result = Array2::zeros((a_rows * b_rows, a_cols * b_cols));
170
171        for i in 0..a_rows {
172            for j in 0..a_cols {
173                let a_val = a[[i, j]];
174                for k in 0..b_rows {
175                    for l in 0..b_cols {
176                        result[[i * b_rows + k, j * b_cols + l]] = a_val * b[[k, l]];
177                    }
178                }
179            }
180        }
181
182        result
183    }
184
185    /// Compute trace of a matrix
186    pub fn trace<T: Float + Debug + Send + Sync + 'static>(matrix: &Array2<T>) -> T {
187        let n = matrix.nrows().min(matrix.ncols());
188        let mut trace = T::zero();
189
190        for i in 0..n {
191            trace = trace + matrix[[i, i]];
192        }
193
194        trace
195    }
196
197    /// Compute Frobenius norm of a matrix
198    pub fn frobenius_norm<T: Float + std::iter::Sum>(matrix: &Array2<T>) -> T {
199        matrix.iter().map(|&x| x * x).sum::<T>().sqrt()
200    }
201
202    /// Check if two matrices are approximately equal
203    pub fn matrices_approx_equal<T: Float + Debug + Send + Sync + 'static>(
204        a: &Array2<T>,
205        b: &Array2<T>,
206        tolerance: T,
207    ) -> bool {
208        if a.dim() != b.dim() {
209            return false;
210        }
211
212        for (a_val, b_val) in a.iter().zip(b.iter()) {
213            if (*a_val - *b_val).abs() > tolerance {
214                return false;
215            }
216        }
217
218        true
219    }
220
221    /// Compute running average with exponential decay
222    pub fn exponential_moving_average<T: Float + Debug + Send + Sync + 'static>(
223        current_value: T,
224        new_value: T,
225        decay: T,
226    ) -> T {
227        decay * current_value + (T::one() - decay) * new_value
228    }
229
230    /// Clamp eigenvalues to prevent numerical instability
231    pub fn clamp_eigenvalues<T: Float + Debug + Send + Sync + 'static>(
232        eigenvalues: &mut Array1<T>,
233        min_val: T,
234        max_val: T,
235    ) {
236        for eigenval in eigenvalues.iter_mut() {
237            *eigenval = (*eigenval).max(min_val).min(max_val);
238        }
239    }
240
241    /// Compute condition number using singular values (approximation)
242    pub fn condition_number_svd_approx<T: Float + Debug + Send + Sync + 'static>(
243        matrix: &Array2<T>,
244    ) -> T {
245        // Simple approximation using diagonal elements
246        let diag = matrix.diag();
247        let max_diag = diag
248            .iter()
249            .fold(T::neg_infinity(), |acc, &x| acc.max(x.abs()));
250        let min_diag = diag.iter().fold(T::infinity(), |acc, &x| acc.min(x.abs()));
251
252        if min_diag > T::zero() {
253            max_diag / min_diag
254        } else {
255            T::infinity()
256        }
257    }
258
259    /// Extract diagonal elements and create diagonal matrix
260    pub fn diag_matrix<T: Float + Clone>(diagonal: &Array1<T>) -> Array2<T> {
261        let n = diagonal.len();
262        let mut matrix = Array2::zeros((n, n));
263
264        for i in 0..n {
265            matrix[[i, i]] = diagonal[i];
266        }
267
268        matrix
269    }
270
271    /// Symmetrize a matrix: (A + A^T) / 2
272    pub fn symmetrize<T: Float + Debug + Send + Sync + 'static>(matrix: &Array2<T>) -> Array2<T> {
273        let n = matrix.nrows();
274        let mut result = Array2::zeros((n, n));
275
276        for i in 0..n {
277            for j in 0..n {
278                result[[i, j]] =
279                    (matrix[[i, j]] + matrix[[j, i]]) / T::from(2.0).unwrap_or_else(|| T::zero());
280            }
281        }
282
283        result
284    }
285}
286
287/// Ordered float wrapper for comparison operations
288#[derive(Debug, Clone, Copy)]
289pub struct OrderedFloat<T: Float + Debug + Send + Sync + 'static>(pub T);
290
291impl<T: Float + Debug + Send + Sync + 'static> PartialEq for OrderedFloat<T> {
292    fn eq(&self, other: &Self) -> bool {
293        self.0 == other.0 || (self.0.is_nan() && other.0.is_nan())
294    }
295}
296
297impl<T: Float + Debug + Send + Sync + 'static> Eq for OrderedFloat<T> {}
298
299impl<T: Float + Debug + Send + Sync + 'static> Ord for OrderedFloat<T> {
300    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
301        self.0
302            .partial_cmp(&other.0)
303            .unwrap_or(std::cmp::Ordering::Equal)
304    }
305}
306
307impl<T: Float + Debug + Send + Sync + 'static> PartialOrd for OrderedFloat<T> {
308    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
309        Some(self.cmp(other))
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_trace_computation() {
319        let matrix =
320            Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
321                .expect("unwrap failed");
322        let trace = KFACUtils::trace(&matrix);
323        assert!((trace - 15.0).abs() < 1e-10); // 1 + 5 + 9 = 15
324    }
325
326    #[test]
327    fn test_frobenius_norm() {
328        let matrix =
329            Array2::from_shape_vec((2, 2), vec![3.0, 4.0, 0.0, 0.0]).expect("unwrap failed");
330        let norm = KFACUtils::frobenius_norm(&matrix);
331        assert!((norm - 5.0).abs() < 1e-10); // sqrt(9 + 16) = 5
332    }
333
334    #[test]
335    fn test_exponential_moving_average() {
336        let current = 10.0;
337        let new_val = 20.0;
338        let decay = 0.9;
339
340        let result = KFACUtils::exponential_moving_average(current, new_val, decay);
341        let expected = 0.9 * 10.0 + 0.1 * 20.0; // 9.0 + 2.0 = 11.0
342        assert!((result - expected).abs() < 1e-10);
343    }
344
345    #[test]
346    fn test_matrices_approx_equal() {
347        let a = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("unwrap failed");
348        let b = Array2::from_shape_vec((2, 2), vec![1.001, 2.001, 3.001, 4.001])
349            .expect("unwrap failed");
350
351        assert!(KFACUtils::matrices_approx_equal(&a, &b, 0.01));
352        assert!(!KFACUtils::matrices_approx_equal(&a, &b, 0.0001));
353    }
354
355    #[test]
356    fn test_symmetrize() {
357        let matrix =
358            Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("unwrap failed");
359        let symmetric = KFACUtils::symmetrize(&matrix);
360
361        assert!((symmetric[[0, 0]] - 1.0).abs() < 1e-10);
362        assert!((symmetric[[0, 1]] - 2.5).abs() < 1e-10); // (2 + 3) / 2
363        assert!((symmetric[[1, 0]] - 2.5).abs() < 1e-10); // (3 + 2) / 2
364        assert!((symmetric[[1, 1]] - 4.0).abs() < 1e-10);
365    }
366
367    #[test]
368    fn test_diag_matrix() {
369        let diagonal = Array1::from_vec(vec![1.0, 2.0, 3.0]);
370        let matrix = KFACUtils::diag_matrix(&diagonal);
371
372        assert_eq!(matrix.dim(), (3, 3));
373        assert!((matrix[[0, 0]] - 1.0).abs() < 1e-10);
374        assert!((matrix[[1, 1]] - 2.0).abs() < 1e-10);
375        assert!((matrix[[2, 2]] - 3.0).abs() < 1e-10);
376        assert!((matrix[[0, 1]]).abs() < 1e-10); // Off-diagonal should be zero
377    }
378
379    #[test]
380    fn test_ordered_float() {
381        let a = OrderedFloat(1.5);
382        let b = OrderedFloat(2.5);
383        let c = OrderedFloat(1.5);
384
385        assert!(a < b);
386        assert!(a == c);
387        assert!(b > a);
388    }
389
390    #[test]
391    fn test_batchnorm_statistics() {
392        let input = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
393            .expect("unwrap failed");
394
395        let (mean, var) = KFACUtils::batchnorm_statistics(&input, 1e-8).expect("unwrap failed");
396
397        // Expected mean: [4.0, 5.0] (column-wise average)
398        assert!((mean[0] - 4.0).abs() < 1e-6);
399        assert!((mean[1] - 5.0).abs() < 1e-6);
400
401        // Variance should be positive
402        assert!(var[0] > 0.0);
403        assert!(var[1] > 0.0);
404    }
405}