burn_tensor/tensor/grid/
mod.rs

1mod affine_grid;
2mod meshgrid;
3
4pub use meshgrid::*;
5
6pub use affine_grid::*;
7
8/// Enum to specify index cardinal layout.
9#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
10pub enum GridIndexing {
11    /// Dimensions are in the same order as the cardinality of the inputs.
12    /// Equivalent to "ij" indexing in NumPy and PyTorch.
13    #[default]
14    Matrix,
15
16    /// The same as Matrix, but the first two dimensions are swapped.
17    /// Equivalent to "xy" indexing in NumPy and PyTorch.
18    Cartesian,
19}
20
21/// Enum to specify grid sparsity mode.
22#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
23pub enum GridSparsity {
24    /// The grid is fully expanded to the full cartesian product shape.
25    #[default]
26    Dense,
27
28    /// The grid is sparse, expanded only at the cardinal dimensions.
29    Sparse,
30}
31
32/// Grid policy options.
33#[derive(new, Default, Debug, Copy, Clone)]
34pub struct GridOptions {
35    /// Indexing mode.
36    pub indexing: GridIndexing,
37
38    /// Sparsity mode.
39    pub sparsity: GridSparsity,
40}
41
42impl From<GridIndexing> for GridOptions {
43    fn from(value: GridIndexing) -> Self {
44        Self {
45            indexing: value,
46            ..Default::default()
47        }
48    }
49}
50impl From<GridSparsity> for GridOptions {
51    fn from(value: GridSparsity) -> Self {
52        Self {
53            sparsity: value,
54            ..Default::default()
55        }
56    }
57}
58
59/// Enum to specify the index dimension position.
60#[derive(Default, Debug, Copy, Clone)]
61pub enum IndexPos {
62    /// The index is in the first dimension.
63    #[default]
64    First,
65
66    /// The index is in the last dimension.
67    Last,
68}