dfdx 0.13.0

Ergonomic auto differentiation in Rust, with pytorch like apis.
Documentation
//! Demonstrates how to select or gather sub tensors (index) from tensors

use dfdx::{
    shapes::Rank3,
    tensor::{AsArray, AutoDevice, Tensor, TensorFrom},
    tensor_ops::{GatherTo, SelectTo},
};

fn main() {
    let dev = AutoDevice::default();

    let a: Tensor<Rank3<4, 2, 3>, f32, _> = dev.tensor([
        [[0.00, 0.01, 0.02], [0.10, 0.11, 0.12]],
        [[1.00, 1.01, 1.02], [1.10, 1.11, 1.12]],
        [[2.00, 2.01, 2.02], [2.10, 2.11, 2.12]],
        [[3.00, 3.01, 3.02], [3.10, 3.11, 3.12]],
    ]);

    // the easiest thing to do is to `select` a single value from a given axis.
    // to do that, you need indices with a shape up to the axis you are select from.
    // for example, given shape (M, N, O), here are the index shapes for each axis:
    // - Axis 0: index shape ()
    // - Axis 1: index shape (M, )
    // - Axis 2: index shape (M, N)
    // here we select from axis 0 so we just need 1 value.
    let b = a.clone().select(dev.tensor(0));
    assert_eq!(b.array(), a.array()[0]);

    // to `select` from axis 1, we use a tensor with shape (4,)
    let d = a.clone().select(dev.tensor([0, 1, 0, 1]));
    assert_eq!(
        d.array(),
        [
            [0.00, 0.01, 0.02],
            [1.10, 1.11, 1.12],
            [2.00, 2.01, 2.02],
            [3.10, 3.11, 3.12]
        ]
    );

    // We can also `gather` multiple elements from each axis. This lets you grab
    // the same elements multiple times! This requires an index with shape similar to select,
    // but with an extra dimension at the end that says how many elements to gather from the axis:
    // - Axis 0: index shape (Z, )
    // - Axis 1: index shape (M, Z)
    // - Axis 2: index shape (M, N, Z)
    // here, we `gather` from axis 0 because we have a 1d tensor. the new size will be (6, 2, 3)!
    let c = a.clone().gather(dev.tensor([0, 0, 1, 1, 2, 2]));
    dbg!(c.array());

    // and similarly, we can `gather` from axis 1 with a 2d tensor. the new size will be (4, 6, 3)!
    let e = a.gather(dev.tensor([[0; 6], [1; 6], [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 1, 1]]));
    dbg!(e.array());
}