burn_tensor/tensor/linalg/
vector_norm.rs

1use crate::backend::Backend;
2use crate::tensor::{BasicOps, Tensor};
3use crate::{ElementConversion, Numeric};
4
5/// Specifies the type of norm to compute.
6#[derive(Debug, Clone, Copy)]
7pub enum Norm {
8    /// L0 norm (count of non-zero elements)
9    L0,
10
11    /// L1 norm (sum of absolute values)
12    L1,
13
14    /// L2 norm (Euclidean norm)
15    L2,
16
17    /// L:INFINITY norm (maximum absolute value)
18    LInf,
19
20    /// L:NEG_INFINITY norm (minimum absolute value)
21    LNegInf,
22
23    /// Lp norm (generalized norm)
24    Lp(f64),
25}
26
27impl From<i32> for Norm {
28    fn from(value: i32) -> Self {
29        match value {
30            0 => Norm::L0,
31            1 => Norm::L1,
32            2 => Norm::L2,
33            _ => Norm::Lp(value as f64),
34        }
35    }
36}
37
38impl From<f32> for Norm {
39    fn from(value: f32) -> Self {
40        match value {
41            0.0 => Norm::L0,
42            1.0 => Norm::L1,
43            2.0 => Norm::L2,
44            f32::INFINITY => Norm::LInf,
45            f32::NEG_INFINITY => Norm::LNegInf,
46            _ => Norm::Lp(value as f64),
47        }
48    }
49}
50
51impl From<f64> for Norm {
52    fn from(value: f64) -> Self {
53        match value {
54            0.0 => Norm::L0,
55            1.0 => Norm::L1,
56            2.0 => Norm::L2,
57            f64::INFINITY => Norm::LInf,
58            f64::NEG_INFINITY => Norm::LNegInf,
59            _ => Norm::Lp(value),
60        }
61    }
62}
63
64/// Computes the vector norm of a tensor along a specified dimension.
65///
66/// Generic dispatch wrapper over specialized / optimized norms.
67///
68/// See:
69/// - [torch.linalg.vector_norm](https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html)
70/// - [numpy.linalg.vector_norm](https://numpy.org/doc/stable/reference/generated/numpy.linalg.vector_norm.html)
71///
72/// # Arguments
73///
74/// * `x` - The input tensor.
75/// * `norm` - The selected norm.
76/// * `dim` - The dimension to compute the norm over.
77///
78/// # Returns
79///
80/// The vector norm of the input tensor.
81pub fn vector_norm<B: Backend, const D: usize>(
82    x: Tensor<B, D>,
83    norm: impl Into<Norm>,
84    dim: usize,
85) -> Tensor<B, D> {
86    let norm = norm.into();
87    match norm {
88        Norm::L0 => l0_norm(x, dim),
89        Norm::L1 => l1_norm(x, dim),
90        Norm::L2 => l2_norm(x, dim),
91        Norm::LInf => max_abs_norm(x, dim),
92        Norm::LNegInf => min_abs_norm(x, dim),
93        Norm::Lp(p) => lp_norm(x, p, dim),
94    }
95}
96
97/// Normalize a tensor versus its `vector_norm`.
98///
99/// Equivalent to ``x.clone() / vector_norm(x, norm, dim).clamp_min(eps)``.
100///
101/// # Arguments
102///
103/// * `x` - The input tensor.
104/// * `norm` - The selected norm.
105/// * `dim` - The dimension to compute the norm over.
106/// * `eps` - The epsilon for the norm.
107///
108/// # Returns
109///
110/// The normalized tensor.
111pub fn vector_normalize<B: Backend, const D: usize, E: ElementConversion>(
112    x: Tensor<B, D>,
113    norm: impl Into<Norm>,
114    dim: usize,
115    eps: E,
116) -> Tensor<B, D> {
117    let norm = vector_norm(x.clone(), norm, dim).clamp_min(eps);
118    x / norm
119}
120
121/// Computes the L0 norm of a tensor along a specified dimension.
122///
123/// # Arguments
124///
125/// * `x` - The input tensor.
126/// * `dim` - The dimension to compute the norm over.
127///
128/// # Returns
129///
130/// The L0 norm of the input tensor.
131pub fn l0_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>
132where
133    K: BasicOps<B> + Numeric<B>,
134{
135    x.zeros_like()
136        .mask_fill(x.not_equal_elem(0), 1)
137        .sum_dim(dim)
138}
139
140/// Computes the L1 norm of a tensor along a specified dimension.
141///
142/// This is a convenience function that wraps `vector_norm` with `p = 1.0`.
143///
144/// # Arguments
145///
146/// * `x` - The input tensor.
147/// * `dim` - The dimension to compute the norm over.
148///
149/// # Returns
150///
151/// The L1 norm of the input tensor.
152pub fn l1_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>
153where
154    K: BasicOps<B> + Numeric<B>,
155{
156    x.abs().sum_dim(dim)
157}
158
159/// Computes the L2 norm of a tensor along a specified dimension.
160///
161/// # Arguments
162///
163/// * `x` - The input tensor.
164/// * `dim` - The dimension to compute the norm over.
165///
166/// # Returns
167///
168/// The L2 norm of the input tensor.
169pub fn l2_norm<B: Backend, const D: usize>(x: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
170    x.abs().powi_scalar(2).sum_dim(dim).sqrt()
171}
172
173/// Computes the general ``L(p)`` norm of a tensor along a specified dimension.
174///
175/// # Arguments
176///
177/// * `x` - The input tensor.
178/// * `p` - The exponent of the Lp norm.
179/// * `dim` - The dimension to compute the norm over.
180///
181/// # Returns
182///
183/// The ``L(p)`` norm of the input tensor.
184pub fn lp_norm<B: Backend, const D: usize>(x: Tensor<B, D>, p: f64, dim: usize) -> Tensor<B, D> {
185    x.abs().powf_scalar(p).sum_dim(dim).powf_scalar(1. / p)
186}
187
188/// Computes the L:INFINITY norm of a tensor along a specified dimension.
189///
190/// # Arguments
191///
192/// * `x` - The input tensor.
193/// * `dim` - The dimension to compute the norm over.
194///
195/// # Returns
196///
197/// The L:INFINITY norm of the input tensor.
198pub fn max_abs_norm<B: Backend, const D: usize, K>(
199    x: Tensor<B, D, K>,
200    dim: usize,
201) -> Tensor<B, D, K>
202where
203    K: BasicOps<B> + Numeric<B>,
204{
205    x.max_abs_dim(dim)
206}
207
208/// Computes the L:NEG_INFINITY norm of a tensor along a specified dimension.
209///
210/// # Arguments
211///
212/// * `x` - The input tensor.
213/// * `dim` - The dimension to compute the norm over.
214///
215/// # Returns
216///
217/// The L:NEG_INFINITY norm of the input tensor.
218pub fn min_abs_norm<B: Backend, const D: usize, K>(
219    x: Tensor<B, D, K>,
220    dim: usize,
221) -> Tensor<B, D, K>
222where
223    K: BasicOps<B> + Numeric<B>,
224{
225    x.abs().min_dim(dim)
226}