burn_tensor/tensor/linalg/vector_norm.rs
1use burn_backend::tensor::Ordered;
2
3use crate::backend::Backend;
4use crate::tensor::{BasicOps, Tensor};
5use crate::{ElementConversion, Numeric};
6#[allow(unused_imports)]
7use num_traits::float::Float;
8/// Specifies the type of norm to compute.
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub enum Norm {
11 /// L0 norm (count of non-zero elements)
12 L0,
13
14 /// L1 norm (sum of absolute values)
15 L1,
16
17 /// L2 norm (Euclidean norm)
18 L2,
19
20 /// L:INFINITY norm (maximum absolute value)
21 LInf,
22
23 /// L:NEG_INFINITY norm (minimum absolute value)
24 LNegInf,
25
26 /// Lp norm (generalized norm)
27 Lp(f64),
28}
29
30impl Norm {
31 /// Get the exponent of the norm.
32 pub fn to_exponent(self) -> f64 {
33 use Norm::*;
34 match self {
35 L0 => 0.0,
36 L1 => 1.0,
37 L2 => 2.0,
38 LInf => f64::INFINITY,
39 LNegInf => f64::NEG_INFINITY,
40 Lp(p) => p,
41 }
42 }
43}
44
45impl From<u32> for Norm {
46 fn from(value: u32) -> Self {
47 use Norm::*;
48 match value {
49 0 => L0,
50 1 => L1,
51 2 => L2,
52 u32::MAX => LInf,
53 _ => Lp(value as f64),
54 }
55 }
56}
57
58impl From<i32> for Norm {
59 fn from(value: i32) -> Self {
60 use Norm::*;
61 match value {
62 0 => L0,
63 1 => L1,
64 2 => L2,
65 i32::MAX => LInf,
66 i32::MIN => LNegInf,
67 _ => Lp(value as f64),
68 }
69 }
70}
71
72impl From<f32> for Norm {
73 fn from(value: f32) -> Self {
74 use Norm::*;
75 match value {
76 0.0 => L0,
77 1.0 => L1,
78 2.0 => L2,
79 f32::INFINITY => LInf,
80 f32::NEG_INFINITY => LNegInf,
81 _ => Lp(value as f64),
82 }
83 }
84}
85
86impl From<f64> for Norm {
87 fn from(value: f64) -> Self {
88 use Norm::*;
89 match value {
90 0.0 => L0,
91 1.0 => L1,
92 2.0 => L2,
93 f64::INFINITY => LInf,
94 f64::NEG_INFINITY => LNegInf,
95 _ => Lp(value),
96 }
97 }
98}
99
100/// Computes the vector norm of a tensor along a specified dimension.
101///
102/// Generic dispatch wrapper over specialized / optimized norms.
103///
104/// See:
105/// - [torch.linalg.vector_norm](https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html)
106/// - [numpy.linalg.vector_norm](https://numpy.org/doc/stable/reference/generated/numpy.linalg.vector_norm.html)
107///
108/// # Arguments
109///
110/// * `x` - The input tensor.
111/// * `norm` - The selected norm.
112/// * `dim` - The dimension to compute the norm over.
113///
114/// # Returns
115///
116/// The vector norm of the input tensor.
117pub fn vector_norm<B: Backend, const D: usize>(
118 x: Tensor<B, D>,
119 norm: impl Into<Norm>,
120 dim: usize,
121) -> Tensor<B, D> {
122 lp_norm(x, norm.into().to_exponent(), dim)
123}
124
125/// Computes the general ``L(p)`` norm of a tensor along a specified dimension.
126///
127/// Uses the specialized implementations for:
128/// * 0.0
129/// * 1.0
130/// * 2.0
131/// * 2 * N for integral N,
132/// * f64::INFINITY,
133/// * f64::NEG_INFINITY,
134///
135/// # Arguments
136///
137/// * `x` - The input tensor.
138/// * `p` - The exponent of the Lp norm.
139/// * `dim` - The dimension to compute the norm over.
140///
141/// # Returns
142///
143/// The ``L(p)`` norm of the input tensor.
144pub fn lp_norm<B: Backend, const D: usize>(x: Tensor<B, D>, p: f64, dim: usize) -> Tensor<B, D> {
145 match p {
146 0.0 => l0_norm(x, dim),
147 1.0 => l1_norm(x, dim),
148 2.0 => l2_norm(x, dim),
149 p if is_even_integer(p) => lp_signed_norm(x, p as u32, dim),
150 f64::INFINITY => max_abs_norm(x, dim),
151 f64::NEG_INFINITY => min_abs_norm(x, dim),
152 _ => lp_norm_base(x, p, dim),
153 }
154}
155
156/// Normalize a tensor versus its `vector_norm`.
157///
158/// Equivalent to ``x.clone() / vector_norm(x, norm, dim).clamp_min(eps)``.
159///
160/// # Arguments
161///
162/// * `x` - The input tensor.
163/// * `norm` - The selected norm.
164/// * `dim` - The dimension to compute the norm over.
165/// * `eps` - The epsilon for the norm.
166///
167/// # Returns
168///
169/// The normalized tensor.
170pub fn vector_normalize<B: Backend, const D: usize, E: ElementConversion>(
171 x: Tensor<B, D>,
172 norm: impl Into<Norm>,
173 dim: usize,
174 eps: E,
175) -> Tensor<B, D> {
176 let norm = vector_norm(x.clone(), norm, dim).clamp_min(eps);
177 x / norm
178}
179
180/// Computes the L0 norm of a tensor along a specified dimension.
181///
182/// # Arguments
183///
184/// * `x` - The input tensor.
185/// * `dim` - The dimension to compute the norm over.
186///
187/// # Returns
188///
189/// The L0 norm of the input tensor.
190pub fn l0_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>
191where
192 K: BasicOps<B> + Numeric<B>,
193{
194 x.zeros_like()
195 .mask_fill(x.not_equal_elem(0), 1)
196 .sum_dim(dim)
197}
198
199/// Computes the L1 norm of a tensor along a specified dimension.
200///
201/// This is a convenience function that wraps `vector_norm` with `p = 1.0`.
202///
203/// # Arguments
204///
205/// * `x` - The input tensor.
206/// * `dim` - The dimension to compute the norm over.
207///
208/// # Returns
209///
210/// The L1 norm of the input tensor.
211pub fn l1_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>
212where
213 K: BasicOps<B> + Numeric<B>,
214{
215 x.abs().sum_dim(dim)
216}
217
218/// Computes the L2 norm of a tensor along a specified dimension.
219///
220/// # Arguments
221///
222/// * `x` - The input tensor.
223/// * `dim` - The dimension to compute the norm over.
224///
225/// # Returns
226///
227/// The L2 norm of the input tensor.
228pub fn l2_norm<B: Backend, const D: usize>(x: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
229 x.square().sum_dim(dim).sqrt()
230}
231
232fn is_even_integer(x: f64) -> bool {
233 x.fract() == 0.0 && (x as i64) % 2 == 0
234}
235
236/// Computes ``L(2*n)`` for even integer ``n``.
237///
238/// This lets us skip the abs.
239fn lp_signed_norm<B: Backend, const D: usize>(x: Tensor<B, D>, p: u32, dim: usize) -> Tensor<B, D> {
240 x.powi_scalar(p).sum_dim(dim).powf_scalar(1. / (p as f64))
241}
242
243/// Computes the general ``L(p)`` using the generalized method.
244///
245/// This uses no specialized implementations and cannot handle:
246/// * 0.0
247/// * f64::INFINITY,
248/// * f64::NEG_INFINITY,
249fn lp_norm_base<B: Backend, const D: usize>(x: Tensor<B, D>, p: f64, dim: usize) -> Tensor<B, D> {
250 x.abs().powf_scalar(p).sum_dim(dim).powf_scalar(1. / p)
251}
252
253/// Computes the L:INFINITY norm of a tensor along a specified dimension.
254///
255/// # Arguments
256///
257/// * `x` - The input tensor.
258/// * `dim` - The dimension to compute the norm over.
259///
260/// # Returns
261///
262/// The L:INFINITY norm of the input tensor.
263pub fn max_abs_norm<B: Backend, const D: usize, K>(
264 x: Tensor<B, D, K>,
265 dim: usize,
266) -> Tensor<B, D, K>
267where
268 K: Ordered<B>,
269{
270 x.max_abs_dim(dim)
271}
272
273/// Computes the L:NEG_INFINITY norm of a tensor along a specified dimension.
274///
275/// # Arguments
276///
277/// * `x` - The input tensor.
278/// * `dim` - The dimension to compute the norm over.
279///
280/// # Returns
281///
282/// The L:NEG_INFINITY norm of the input tensor.
283pub fn min_abs_norm<B: Backend, const D: usize, K>(
284 x: Tensor<B, D, K>,
285 dim: usize,
286) -> Tensor<B, D, K>
287where
288 K: Ordered<B>,
289{
290 x.abs().min_dim(dim)
291}