Skip to main content

deep_delta_learning/
delta_operator.rs

1use burn::prelude::*;
2
3use crate::spectral::{DeltaRegime, SpectralInfo};
4use crate::utils::{mean_of_slice, tensor_to_vec};
5
6#[derive(Debug, Clone)]
7pub struct DeltaOperator<B: Backend> {
8    k: Tensor<B, 2>,
9    beta: Tensor<B, 1>,
10}
11
12impl<B: Backend> DeltaOperator<B> {
13    pub fn new(k: Tensor<B, 2>, beta: Tensor<B, 1>) -> Self {
14        Self { k, beta }
15    }
16
17    pub fn apply(&self, x: &Tensor<B, 3>) -> Tensor<B, 3> {
18        let k_col = self.k.clone().unsqueeze_dim::<3>(2);
19        let k_row = self.k.clone().unsqueeze_dim::<3>(1);
20        let proj = k_row.matmul(x.clone());
21        let beta = self
22            .beta
23            .clone()
24            .unsqueeze_dim::<2>(1)
25            .unsqueeze_dim::<3>(2);
26        x.clone() - k_col.matmul(proj) * beta
27    }
28
29    pub fn apply_vector(&self, x: &Tensor<B, 2>) -> Tensor<B, 2> {
30        let proj = (self.k.clone() * x.clone()).sum_dim(1);
31        let beta = self.beta.clone().unsqueeze_dim::<2>(1);
32        x.clone() - self.k.clone() * proj * beta
33    }
34
35    pub fn k_eigenvalue(&self) -> Tensor<B, 1> {
36        self.beta.clone().ones_like() - self.beta.clone()
37    }
38
39    pub fn determinant(&self) -> Tensor<B, 1> {
40        self.k_eigenvalue()
41    }
42
43    pub fn spectral_info(&self) -> SpectralInfo {
44        self.spectral_info_with_lift(1)
45    }
46
47    pub fn spectral_info_with_lift(&self, d_value: usize) -> SpectralInfo {
48        let beta_values = tensor_to_vec(self.beta.clone());
49        let determinant_values = beta_values
50            .iter()
51            .map(|beta| 1.0 - beta)
52            .collect::<Vec<_>>();
53        let beta_mean = mean_of_slice(&beta_values);
54        let determinant_mean = mean_of_slice(&determinant_values);
55        let lifted_determinant_mean = determinant_values
56            .iter()
57            .map(|determinant| determinant.powi(d_value as i32))
58            .sum::<f32>()
59            / determinant_values.len() as f32;
60
61        SpectralInfo {
62            beta_mean,
63            k_eigenvalue_mean: 1.0 - beta_mean,
64            determinant_mean,
65            lifted_determinant_mean,
66            regime: DeltaRegime::from_beta(beta_mean),
67        }
68    }
69}