burn_tensor/tensor/linalg/
mod.rs

1mod cosine_similarity;
2mod diag;
3mod lu_decomposition;
4mod outer;
5mod trace;
6mod vector_norm;
7
8pub use cosine_similarity::*;
9pub use diag::*;
10pub use lu_decomposition::*;
11pub use outer::*;
12pub use trace::*;
13pub use vector_norm::*;
14
15use crate::{BasicOps, SliceArg, Tensor, TensorKind, backend::Backend};
16
17/// Swaps two slices of a tensor.
18/// # Arguments
19/// * `tensor` - The input tensor.
20/// * `slices1` - The first slice to swap.
21/// * `slices2` - The second slice to swap.
22/// # Returns
23/// A new tensor with the specified slices swapped.
24/// # Notes
25/// This method will be useful for matrix factorization algorithms.
26fn swap_slices<B: Backend, const D: usize, K, S>(
27    tensor: Tensor<B, D, K>,
28    slices1: S,
29    slices2: S,
30) -> Tensor<B, D, K>
31where
32    S: SliceArg<D> + Clone,
33    K: TensorKind<B> + BasicOps<B>,
34{
35    let temporary = tensor.clone().slice(slices1.clone());
36    let tensor = tensor
37        .clone()
38        .slice_assign(slices1, tensor.slice(slices2.clone()));
39    tensor.slice_assign(slices2, temporary)
40}