deep_delta_learning/
delta_operator.rs1use burn::prelude::*;
2
3use crate::spectral::{DeltaRegime, SpectralInfo};
4use crate::utils::{mean_of_slice, tensor_to_vec};
5
6#[derive(Debug, Clone)]
7pub struct DeltaOperator<B: Backend> {
8 k: Tensor<B, 2>,
9 beta: Tensor<B, 1>,
10}
11
12impl<B: Backend> DeltaOperator<B> {
13 pub fn new(k: Tensor<B, 2>, beta: Tensor<B, 1>) -> Self {
14 Self { k, beta }
15 }
16
17 pub fn apply(&self, x: &Tensor<B, 3>) -> Tensor<B, 3> {
18 let k_col = self.k.clone().unsqueeze_dim::<3>(2);
19 let k_row = self.k.clone().unsqueeze_dim::<3>(1);
20 let proj = k_row.matmul(x.clone());
21 let beta = self
22 .beta
23 .clone()
24 .unsqueeze_dim::<2>(1)
25 .unsqueeze_dim::<3>(2);
26 x.clone() - k_col.matmul(proj) * beta
27 }
28
29 pub fn apply_vector(&self, x: &Tensor<B, 2>) -> Tensor<B, 2> {
30 let proj = (self.k.clone() * x.clone()).sum_dim(1);
31 let beta = self.beta.clone().unsqueeze_dim::<2>(1);
32 x.clone() - self.k.clone() * proj * beta
33 }
34
35 pub fn k_eigenvalue(&self) -> Tensor<B, 1> {
36 self.beta.clone().ones_like() - self.beta.clone()
37 }
38
39 pub fn determinant(&self) -> Tensor<B, 1> {
40 self.k_eigenvalue()
41 }
42
43 pub fn spectral_info(&self) -> SpectralInfo {
44 self.spectral_info_with_lift(1)
45 }
46
47 pub fn spectral_info_with_lift(&self, d_value: usize) -> SpectralInfo {
48 let beta_values = tensor_to_vec(self.beta.clone());
49 let determinant_values = beta_values
50 .iter()
51 .map(|beta| 1.0 - beta)
52 .collect::<Vec<_>>();
53 let beta_mean = mean_of_slice(&beta_values);
54 let determinant_mean = mean_of_slice(&determinant_values);
55 let lifted_determinant_mean = determinant_values
56 .iter()
57 .map(|determinant| determinant.powi(d_value as i32))
58 .sum::<f32>()
59 / determinant_values.len() as f32;
60
61 SpectralInfo {
62 beta_mean,
63 k_eigenvalue_mean: 1.0 - beta_mean,
64 determinant_mean,
65 lifted_determinant_mean,
66 regime: DeltaRegime::from_beta(beta_mean),
67 }
68 }
69}