Skip to main content

burn_tensor/tensor/linalg/
vector_norm.rs

1use burn_backend::tensor::Ordered;
2
3use crate::backend::Backend;
4use crate::tensor::{BasicOps, Tensor};
5use crate::{ElementConversion, Numeric};
6#[allow(unused_imports)]
7use num_traits::float::Float;
8/// Specifies the type of norm to compute.
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub enum Norm {
11    /// L0 norm (count of non-zero elements)
12    L0,
13
14    /// L1 norm (sum of absolute values)
15    L1,
16
17    /// L2 norm (Euclidean norm)
18    L2,
19
20    /// L:INFINITY norm (maximum absolute value)
21    LInf,
22
23    /// L:NEG_INFINITY norm (minimum absolute value)
24    LNegInf,
25
26    /// Lp norm (generalized norm)
27    Lp(f64),
28}
29
30impl Norm {
31    /// Get the exponent of the norm.
32    pub fn to_exponent(self) -> f64 {
33        use Norm::*;
34        match self {
35            L0 => 0.0,
36            L1 => 1.0,
37            L2 => 2.0,
38            LInf => f64::INFINITY,
39            LNegInf => f64::NEG_INFINITY,
40            Lp(p) => p,
41        }
42    }
43}
44
45impl From<u32> for Norm {
46    fn from(value: u32) -> Self {
47        use Norm::*;
48        match value {
49            0 => L0,
50            1 => L1,
51            2 => L2,
52            u32::MAX => LInf,
53            _ => Lp(value as f64),
54        }
55    }
56}
57
58impl From<i32> for Norm {
59    fn from(value: i32) -> Self {
60        use Norm::*;
61        match value {
62            0 => L0,
63            1 => L1,
64            2 => L2,
65            i32::MAX => LInf,
66            i32::MIN => LNegInf,
67            _ => Lp(value as f64),
68        }
69    }
70}
71
72impl From<f32> for Norm {
73    fn from(value: f32) -> Self {
74        use Norm::*;
75        match value {
76            0.0 => L0,
77            1.0 => L1,
78            2.0 => L2,
79            f32::INFINITY => LInf,
80            f32::NEG_INFINITY => LNegInf,
81            _ => Lp(value as f64),
82        }
83    }
84}
85
86impl From<f64> for Norm {
87    fn from(value: f64) -> Self {
88        use Norm::*;
89        match value {
90            0.0 => L0,
91            1.0 => L1,
92            2.0 => L2,
93            f64::INFINITY => LInf,
94            f64::NEG_INFINITY => LNegInf,
95            _ => Lp(value),
96        }
97    }
98}
99
100/// Computes the vector norm of a tensor along a specified dimension.
101///
102/// Generic dispatch wrapper over specialized / optimized norms.
103///
104/// See:
105/// - [torch.linalg.vector_norm](https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html)
106/// - [numpy.linalg.vector_norm](https://numpy.org/doc/stable/reference/generated/numpy.linalg.vector_norm.html)
107///
108/// # Arguments
109///
110/// * `x` - The input tensor.
111/// * `norm` - The selected norm.
112/// * `dim` - The dimension to compute the norm over.
113///
114/// # Returns
115///
116/// The vector norm of the input tensor.
117pub fn vector_norm<B: Backend, const D: usize>(
118    x: Tensor<B, D>,
119    norm: impl Into<Norm>,
120    dim: usize,
121) -> Tensor<B, D> {
122    lp_norm(x, norm.into().to_exponent(), dim)
123}
124
125/// Computes the general ``L(p)`` norm of a tensor along a specified dimension.
126///
127/// Uses the specialized implementations for:
128/// * 0.0
129/// * 1.0
130/// * 2.0
131/// * 2 * N for integral N,
132/// * f64::INFINITY,
133/// * f64::NEG_INFINITY,
134///
135/// # Arguments
136///
137/// * `x` - The input tensor.
138/// * `p` - The exponent of the Lp norm.
139/// * `dim` - The dimension to compute the norm over.
140///
141/// # Returns
142///
143/// The ``L(p)`` norm of the input tensor.
144pub fn lp_norm<B: Backend, const D: usize>(x: Tensor<B, D>, p: f64, dim: usize) -> Tensor<B, D> {
145    match p {
146        0.0 => l0_norm(x, dim),
147        1.0 => l1_norm(x, dim),
148        2.0 => l2_norm(x, dim),
149        p if is_even_integer(p) => lp_signed_norm(x, p as u32, dim),
150        f64::INFINITY => max_abs_norm(x, dim),
151        f64::NEG_INFINITY => min_abs_norm(x, dim),
152        _ => lp_norm_base(x, p, dim),
153    }
154}
155
156/// Normalize a tensor versus its `vector_norm`.
157///
158/// Equivalent to ``x.clone() / vector_norm(x, norm, dim).clamp_min(eps)``.
159///
160/// # Arguments
161///
162/// * `x` - The input tensor.
163/// * `norm` - The selected norm.
164/// * `dim` - The dimension to compute the norm over.
165/// * `eps` - The epsilon for the norm.
166///
167/// # Returns
168///
169/// The normalized tensor.
170pub fn vector_normalize<B: Backend, const D: usize, E: ElementConversion>(
171    x: Tensor<B, D>,
172    norm: impl Into<Norm>,
173    dim: usize,
174    eps: E,
175) -> Tensor<B, D> {
176    let norm = vector_norm(x.clone(), norm, dim).clamp_min(eps);
177    x / norm
178}
179
180/// Computes the L0 norm of a tensor along a specified dimension.
181///
182/// # Arguments
183///
184/// * `x` - The input tensor.
185/// * `dim` - The dimension to compute the norm over.
186///
187/// # Returns
188///
189/// The L0 norm of the input tensor.
190pub fn l0_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>
191where
192    K: BasicOps<B> + Numeric<B>,
193{
194    x.zeros_like()
195        .mask_fill(x.not_equal_elem(0), 1)
196        .sum_dim(dim)
197}
198
199/// Computes the L1 norm of a tensor along a specified dimension.
200///
201/// This is a convenience function that wraps `vector_norm` with `p = 1.0`.
202///
203/// # Arguments
204///
205/// * `x` - The input tensor.
206/// * `dim` - The dimension to compute the norm over.
207///
208/// # Returns
209///
210/// The L1 norm of the input tensor.
211pub fn l1_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>
212where
213    K: BasicOps<B> + Numeric<B>,
214{
215    x.abs().sum_dim(dim)
216}
217
218/// Computes the L2 norm of a tensor along a specified dimension.
219///
220/// # Arguments
221///
222/// * `x` - The input tensor.
223/// * `dim` - The dimension to compute the norm over.
224///
225/// # Returns
226///
227/// The L2 norm of the input tensor.
228pub fn l2_norm<B: Backend, const D: usize>(x: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
229    x.square().sum_dim(dim).sqrt()
230}
231
232fn is_even_integer(x: f64) -> bool {
233    x.fract() == 0.0 && (x as i64) % 2 == 0
234}
235
236/// Computes ``L(2*n)`` for even integer ``n``.
237///
238/// This lets us skip the abs.
239fn lp_signed_norm<B: Backend, const D: usize>(x: Tensor<B, D>, p: u32, dim: usize) -> Tensor<B, D> {
240    x.powi_scalar(p).sum_dim(dim).powf_scalar(1. / (p as f64))
241}
242
243/// Computes the general ``L(p)`` using the generalized method.
244///
245/// This uses no specialized implementations and cannot handle:
246/// * 0.0
247/// * f64::INFINITY,
248/// * f64::NEG_INFINITY,
249fn lp_norm_base<B: Backend, const D: usize>(x: Tensor<B, D>, p: f64, dim: usize) -> Tensor<B, D> {
250    x.abs().powf_scalar(p).sum_dim(dim).powf_scalar(1. / p)
251}
252
253/// Computes the L:INFINITY norm of a tensor along a specified dimension.
254///
255/// # Arguments
256///
257/// * `x` - The input tensor.
258/// * `dim` - The dimension to compute the norm over.
259///
260/// # Returns
261///
262/// The L:INFINITY norm of the input tensor.
263pub fn max_abs_norm<B: Backend, const D: usize, K>(
264    x: Tensor<B, D, K>,
265    dim: usize,
266) -> Tensor<B, D, K>
267where
268    K: Ordered<B>,
269{
270    x.max_abs_dim(dim)
271}
272
273/// Computes the L:NEG_INFINITY norm of a tensor along a specified dimension.
274///
275/// # Arguments
276///
277/// * `x` - The input tensor.
278/// * `dim` - The dimension to compute the norm over.
279///
280/// # Returns
281///
282/// The L:NEG_INFINITY norm of the input tensor.
283pub fn min_abs_norm<B: Backend, const D: usize, K>(
284    x: Tensor<B, D, K>,
285    dim: usize,
286) -> Tensor<B, D, K>
287where
288    K: Ordered<B>,
289{
290    x.abs().min_dim(dim)
291}