Tensor

Struct Tensor 

Source
pub struct Tensor {
    pub shape: Shape,
    pub data: Data,
}
Expand description

Basic Tensor struct.

§Fields

  • shape - The Shape of the Tensor.
  • data - The Data of the Tensor.

Fields§

§shape: Shape§data: Data

Implementations§

Source§

impl Tensor

Source

pub fn zeros(shape: Shape) -> Self

Creates a new Tensor with the given shape, filled with zeros.

§Arguments
  • shape - The shape of the Tensor.
§Returns

A new Tensor with the given shape filled with zeros.

Source

pub fn ones(shape: Shape) -> Self

Creates a new Tensor with the given shape, filled with ones.

§Arguments
  • shape - The shape of the Tensor.
§Returns

A new Tensor with the given shape filled with ones.

Source

pub fn random(shape: Shape, min: f32, max: f32) -> Self

Creates a new Tensor with the given shape, filled with random values.

§Arguments
  • shape - The shape of the Tensor.
  • min - The minimum value of the random values.
  • max - The maximum value of the random values.
§Returns

A new Tensor with the given shape filled with random values.

Examples found in repository?
examples/convolution.rs (line 29)
5fn main() {
6    let mut network = network::Network::new(tensor::Shape::Triple(1, 24, 24));
7
8    network.convolution(
9        5,
10        (3, 3),
11        (1, 1),
12        (0, 0),
13        (1, 1),
14        activation::Activation::ReLU,
15        Some(0.1),
16    );
17    network.convolution(
18        1,
19        (3, 3),
20        (1, 1),
21        (0, 0),
22        (1, 1),
23        activation::Activation::ReLU,
24        Some(0.1),
25    );
26
27    println!("{}", network);
28
29    let x = tensor::Tensor::random(tensor::Shape::Triple(1, 24, 24), 0.0, 1.0);
30    println!("x: {}", &x.shape);
31
32    let (pre, post, _, _) = network.forward(&x);
33    println!("pre-activation: {}", &pre[pre.len() - 1].shape);
34    println!("post-activation: {}", &post[post.len() - 1].shape);
35
36    plot::heatmap(&x, "Input", "./output/convolution-input.png");
37    plot::heatmap(
38        &pre[pre.len() - 1],
39        "Pre-activation",
40        "./output/convolution-pre.png",
41    );
42    plot::heatmap(
43        &post[post.len() - 1],
44        "Post-activation",
45        "./output/convolution-post.png",
46    );
47}
Source

pub fn one_hot(value: usize, max: usize) -> Self

One-hot encodes a value into a Shape::Vector.

Examples found in repository?
examples/benchmark.rs (line 53)
43fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
44    let mut reader = BufReader::new(File::open(file_path)?);
45    let _magic_number = read(&mut reader)?;
46    let num_labels = read(&mut reader)?;
47
48    let mut _labels = vec![0; num_labels as usize];
49    reader.read_exact(&mut _labels)?;
50
51    Ok(_labels
52        .iter()
53        .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
54        .collect())
55}
More examples
Hide additional examples
examples/compare/fashion-1.rs (line 88)
78fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
79    let mut reader = BufReader::new(File::open(file_path)?);
80    let _magic_number = read(&mut reader)?;
81    let num_labels = read(&mut reader)?;
82
83    let mut _labels = vec![0; num_labels as usize];
84    reader.read_exact(&mut _labels)?;
85
86    Ok(_labels
87        .iter()
88        .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
89        .collect())
90}
examples/compare/fashion-2.rs (line 88)
78fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
79    let mut reader = BufReader::new(File::open(file_path)?);
80    let _magic_number = read(&mut reader)?;
81    let num_labels = read(&mut reader)?;
82
83    let mut _labels = vec![0; num_labels as usize];
84    reader.read_exact(&mut _labels)?;
85
86    Ok(_labels
87        .iter()
88        .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
89        .collect())
90}
examples/compare/mnist.rs (line 88)
78fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
79    let mut reader = BufReader::new(File::open(file_path)?);
80    let _magic_number = read(&mut reader)?;
81    let num_labels = read(&mut reader)?;
82
83    let mut _labels = vec![0; num_labels as usize];
84    reader.read_exact(&mut _labels)?;
85
86    Ok(_labels
87        .iter()
88        .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
89        .collect())
90}
examples/mnist-fashion/feedback.rs (line 50)
40fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
41    let mut reader = BufReader::new(File::open(file_path)?);
42    let _magic_number = read(&mut reader)?;
43    let num_labels = read(&mut reader)?;
44
45    let mut _labels = vec![0; num_labels as usize];
46    reader.read_exact(&mut _labels)?;
47
48    Ok(_labels
49        .iter()
50        .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
51        .collect())
52}
examples/mnist-fashion/looping.rs (line 51)
41fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
42    let mut reader = BufReader::new(File::open(file_path)?);
43    let _magic_number = read(&mut reader)?;
44    let num_labels = read(&mut reader)?;
45
46    let mut _labels = vec![0; num_labels as usize];
47    reader.read_exact(&mut _labels)?;
48
49    Ok(_labels
50        .iter()
51        .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
52        .collect())
53}
Source

pub fn single(data: Vec<f32>) -> Self

Creates a new Tensor from the given vector.

Examples found in repository?
examples/timing/ftir-mlp.rs (line 39)
17fn data(
18    path: &str,
19) -> (
20    Vec<tensor::Tensor>,
21    Vec<tensor::Tensor>,
22    Vec<tensor::Tensor>,
23) {
24    let reader = BufReader::new(File::open(&path).unwrap());
25
26    let mut x: Vec<tensor::Tensor> = Vec::new();
27    let mut y: Vec<tensor::Tensor> = Vec::new();
28    let mut c: Vec<tensor::Tensor> = Vec::new();
29
30    for line in reader.lines().skip(1) {
31        let line = line.unwrap();
32        let record: Vec<&str> = line.split(',').collect();
33
34        let mut data: Vec<f32> = Vec::new();
35        for i in 0..571 {
36            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
37        }
38
39        x.push(tensor::Tensor::single(data));
40        y.push(tensor::Tensor::single(vec![record
41            .get(571)
42            .unwrap()
43            .parse::<f32>()
44            .unwrap()]));
45        c.push(tensor::Tensor::one_hot(
46            record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
47            28,
48        ));
49    }
50    (x, y, c)
51}
More examples
Hide additional examples
examples/timing/ftir-cnn.rs (lines 40-44)
17fn data(
18    path: &str,
19) -> (
20    Vec<tensor::Tensor>,
21    Vec<tensor::Tensor>,
22    Vec<tensor::Tensor>,
23) {
24    let reader = BufReader::new(File::open(&path).unwrap());
25
26    let mut x: Vec<tensor::Tensor> = Vec::new();
27    let mut y: Vec<tensor::Tensor> = Vec::new();
28    let mut c: Vec<tensor::Tensor> = Vec::new();
29
30    for line in reader.lines().skip(1) {
31        let line = line.unwrap();
32        let record: Vec<&str> = line.split(',').collect();
33
34        let mut data: Vec<f32> = Vec::new();
35        for i in 0..571 {
36            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
37        }
38        let data: Vec<Vec<Vec<f32>>> = vec![vec![data]];
39        x.push(tensor::Tensor::triple(data));
40        y.push(tensor::Tensor::single(vec![record
41            .get(571)
42            .unwrap()
43            .parse::<f32>()
44            .unwrap()]));
45        c.push(tensor::Tensor::one_hot(
46            record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
47            28,
48        ));
49    }
50
51    (x, y, c)
52}
examples/bike/feedback.rs (line 24)
10fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
11    let reader = BufReader::new(File::open(&path).unwrap());
12
13    let mut x: Vec<tensor::Tensor> = Vec::new();
14    let mut y: Vec<tensor::Tensor> = Vec::new();
15
16    for line in reader.lines().skip(1) {
17        let line = line.unwrap();
18        let record: Vec<&str> = line.split(',').collect();
19
20        let mut data: Vec<f32> = Vec::new();
21        for i in 2..14 {
22            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
23        }
24        x.push(tensor::Tensor::single(data));
25
26        y.push(tensor::Tensor::single(vec![record
27            .get(16)
28            .unwrap()
29            .parse::<f32>()
30            .unwrap()]));
31    }
32
33    let mut generator = random::Generator::create(12345);
34    let mut indices: Vec<usize> = (0..x.len()).collect();
35    generator.shuffle(&mut indices);
36
37    let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
38    let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
39
40    (x, y)
41}
examples/bike/looping.rs (line 25)
11fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
12    let reader = BufReader::new(File::open(&path).unwrap());
13
14    let mut x: Vec<tensor::Tensor> = Vec::new();
15    let mut y: Vec<tensor::Tensor> = Vec::new();
16
17    for line in reader.lines().skip(1) {
18        let line = line.unwrap();
19        let record: Vec<&str> = line.split(',').collect();
20
21        let mut data: Vec<f32> = Vec::new();
22        for i in 2..14 {
23            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
24        }
25        x.push(tensor::Tensor::single(data));
26
27        y.push(tensor::Tensor::single(vec![record
28            .get(16)
29            .unwrap()
30            .parse::<f32>()
31            .unwrap()]));
32    }
33
34    let mut generator = random::Generator::create(12345);
35    let mut indices: Vec<usize> = (0..x.len()).collect();
36    generator.shuffle(&mut indices);
37
38    let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
39    let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
40
41    (x, y)
42}
examples/bike/plain.rs (line 24)
10fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
11    let reader = BufReader::new(File::open(&path).unwrap());
12
13    let mut x: Vec<tensor::Tensor> = Vec::new();
14    let mut y: Vec<tensor::Tensor> = Vec::new();
15
16    for line in reader.lines().skip(1) {
17        let line = line.unwrap();
18        let record: Vec<&str> = line.split(',').collect();
19
20        let mut data: Vec<f32> = Vec::new();
21        for i in 2..14 {
22            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
23        }
24        x.push(tensor::Tensor::single(data));
25
26        y.push(tensor::Tensor::single(vec![record
27            .get(16)
28            .unwrap()
29            .parse::<f32>()
30            .unwrap()]));
31    }
32
33    let mut generator = random::Generator::create(12345);
34    let mut indices: Vec<usize> = (0..x.len()).collect();
35    generator.shuffle(&mut indices);
36
37    let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
38    let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
39
40    (x, y)
41}
examples/compare/bike.rs (line 62)
48fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
49    let reader = BufReader::new(File::open(&path).unwrap());
50
51    let mut x: Vec<tensor::Tensor> = Vec::new();
52    let mut y: Vec<tensor::Tensor> = Vec::new();
53
54    for line in reader.lines().skip(1) {
55        let line = line.unwrap();
56        let record: Vec<&str> = line.split(',').collect();
57
58        let mut data: Vec<f32> = Vec::new();
59        for i in 2..14 {
60            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
61        }
62        x.push(tensor::Tensor::single(data));
63
64        y.push(tensor::Tensor::single(vec![record
65            .get(16)
66            .unwrap()
67            .parse::<f32>()
68            .unwrap()]));
69    }
70
71    let mut generator = random::Generator::create(12345);
72    let mut indices: Vec<usize> = (0..x.len()).collect();
73    generator.shuffle(&mut indices);
74
75    let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
76    let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
77
78    (x, y)
79}
Source

pub fn double(data: Vec<Vec<f32>>) -> Self

Creates a new Tensor from the given two-dimensional vector.

Source

pub fn triple(data: Vec<Vec<Vec<f32>>>) -> Self

Creates a new Tensor from the given three-dimensional vector.

Examples found in repository?
examples/compare/mnist.rs (line 72)
52fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
53    let mut reader = BufReader::new(File::open(path)?);
54    let mut images: Vec<tensor::Tensor> = Vec::new();
55
56    let _magic_number = read(&mut reader)?;
57    let num_images = read(&mut reader)?;
58    let num_rows = read(&mut reader)?;
59    let num_cols = read(&mut reader)?;
60
61    for _ in 0..num_images {
62        let mut image: Vec<Vec<f32>> = Vec::new();
63        for _ in 0..num_rows {
64            let mut row: Vec<f32> = Vec::new();
65            for _ in 0..num_cols {
66                let mut pixel = [0];
67                reader.read_exact(&mut pixel)?;
68                row.push(pixel[0] as f32 / 255.0);
69            }
70            image.push(row);
71        }
72        images.push(tensor::Tensor::triple(vec![image]));
73    }
74
75    Ok(images)
76}
More examples
Hide additional examples
examples/timing/mnist.rs (line 43)
23fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
24    let mut reader = BufReader::new(File::open(path)?);
25    let mut images: Vec<tensor::Tensor> = Vec::new();
26
27    let _magic_number = read(&mut reader)?;
28    let num_images = read(&mut reader)?;
29    let num_rows = read(&mut reader)?;
30    let num_cols = read(&mut reader)?;
31
32    for _ in 0..num_images {
33        let mut image: Vec<Vec<f32>> = Vec::new();
34        for _ in 0..num_rows {
35            let mut row: Vec<f32> = Vec::new();
36            for _ in 0..num_cols {
37                let mut pixel = [0];
38                reader.read_exact(&mut pixel)?;
39                row.push(pixel[0] as f32 / 255.0);
40            }
41            image.push(row);
42        }
43        images.push(tensor::Tensor::triple(vec![image]));
44    }
45
46    Ok(images)
47}
examples/compare/fashion-1.rs (line 72)
52fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
53    let mut reader = BufReader::new(File::open(path)?);
54    let mut images: Vec<tensor::Tensor> = Vec::new();
55
56    let _magic_number = read(&mut reader)?;
57    let num_images = read(&mut reader)?;
58    let num_rows = read(&mut reader)?;
59    let num_cols = read(&mut reader)?;
60
61    for _ in 0..num_images {
62        let mut image: Vec<Vec<f32>> = Vec::new();
63        for _ in 0..num_rows {
64            let mut row: Vec<f32> = Vec::new();
65            for _ in 0..num_cols {
66                let mut pixel = [0];
67                reader.read_exact(&mut pixel)?;
68                row.push(pixel[0] as f32 / 255.0);
69            }
70            image.push(row);
71        }
72        images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
73    }
74
75    Ok(images)
76}
examples/compare/fashion-2.rs (line 72)
52fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
53    let mut reader = BufReader::new(File::open(path)?);
54    let mut images: Vec<tensor::Tensor> = Vec::new();
55
56    let _magic_number = read(&mut reader)?;
57    let num_images = read(&mut reader)?;
58    let num_rows = read(&mut reader)?;
59    let num_cols = read(&mut reader)?;
60
61    for _ in 0..num_images {
62        let mut image: Vec<Vec<f32>> = Vec::new();
63        for _ in 0..num_rows {
64            let mut row: Vec<f32> = Vec::new();
65            for _ in 0..num_cols {
66                let mut pixel = [0];
67                reader.read_exact(&mut pixel)?;
68                row.push(pixel[0] as f32 / 255.0);
69            }
70            image.push(row);
71        }
72        images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
73    }
74
75    Ok(images)
76}
examples/mnist-fashion/feedback.rs (line 34)
14fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
15    let mut reader = BufReader::new(File::open(path)?);
16    let mut images: Vec<tensor::Tensor> = Vec::new();
17
18    let _magic_number = read(&mut reader)?;
19    let num_images = read(&mut reader)?;
20    let num_rows = read(&mut reader)?;
21    let num_cols = read(&mut reader)?;
22
23    for _ in 0..num_images {
24        let mut image: Vec<Vec<f32>> = Vec::new();
25        for _ in 0..num_rows {
26            let mut row: Vec<f32> = Vec::new();
27            for _ in 0..num_cols {
28                let mut pixel = [0];
29                reader.read_exact(&mut pixel)?;
30                row.push(pixel[0] as f32 / 255.0);
31            }
32            image.push(row);
33        }
34        images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
35    }
36
37    Ok(images)
38}
examples/mnist-fashion/looping.rs (line 35)
15fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
16    let mut reader = BufReader::new(File::open(path)?);
17    let mut images: Vec<tensor::Tensor> = Vec::new();
18
19    let _magic_number = read(&mut reader)?;
20    let num_images = read(&mut reader)?;
21    let num_rows = read(&mut reader)?;
22    let num_cols = read(&mut reader)?;
23
24    for _ in 0..num_images {
25        let mut image: Vec<Vec<f32>> = Vec::new();
26        for _ in 0..num_rows {
27            let mut row: Vec<f32> = Vec::new();
28            for _ in 0..num_cols {
29                let mut pixel = [0];
30                reader.read_exact(&mut pixel)?;
31                row.push(pixel[0] as f32 / 255.0);
32            }
33            image.push(row);
34        }
35        images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
36    }
37
38    Ok(images)
39}
Source

pub fn quadruple(data: Vec<Vec<Vec<Vec<f32>>>>) -> Self

Creates a new Tensor from the given four-dimensional vector.

Source

pub fn quintuple(data: Vec<Vec<Vec<Vec<Vec<(usize, usize)>>>>>) -> Self

Creates a new Tensor from the given five-dimensional vector.

Source

pub fn nested(tensors: Vec<Tensor>) -> Self

Convert a vector of Tensors into a single nested Tensor.

Source

pub fn nestedoptional(tensors: Vec<Option<Tensor>>) -> Self

Convert a vector of Tensors into a single nested Tensor.

Source

pub fn unnested(&self) -> Vec<Tensor>

Convert a single nested Tensor into a vector of Tensors.

Source

pub fn unnestedoptional(&self) -> Vec<Option<Tensor>>

Convert a single nested Tensor into a vector of Tensors.

Source

pub fn flatten(&self) -> Self

Flatten the Tensor’s data. Returns a new Tensor with the updated shape and data.

§Returns

A new Tensor with the same data but in a vector format.

Source

pub fn get_flat(&self) -> Vec<f32>

Flatten the Tensor’s data into a vector.

§Returns

A vector.

Source

pub fn as_triple(&self) -> &Vec<Vec<Vec<f32>>>

Get the data of the Tensor as a vector.

§Returns

A reference to the data of the Tensor as a vector.

Source

pub fn get_triple(&self, outputs: &Shape) -> Vec<Vec<Vec<f32>>>

Get the data of the Tensor as a vector.

§Arguments
  • outputs - The output shape of the data.
§Returns

The data of the Tensor as a four-dimensional vector.

§Notes

If the data is a vector, the output shape must be provided. The reason for this is to reshape the vector into the correct shape.

Source

pub fn quadruple_to_vec_triple(&self) -> Vec<Tensor>

Get the data of the Shape::Quadruple Tensor as a vector of Shape::Triple Tensors.

Source

pub fn reshape(self, shape: Shape) -> Self

Reshape a Tensor into the given shape.

§Arguments
  • shape - The new shape of the Tensor.
§Returns

A new Tensor with the given shape.

Source

pub fn resize(&self, shape: Shape) -> Self

Resize the Tensor to the given shape. Average pooling is used to resize the Tensor.

Examples found in repository?
examples/compare/fashion-1.rs (line 72)
52fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
53    let mut reader = BufReader::new(File::open(path)?);
54    let mut images: Vec<tensor::Tensor> = Vec::new();
55
56    let _magic_number = read(&mut reader)?;
57    let num_images = read(&mut reader)?;
58    let num_rows = read(&mut reader)?;
59    let num_cols = read(&mut reader)?;
60
61    for _ in 0..num_images {
62        let mut image: Vec<Vec<f32>> = Vec::new();
63        for _ in 0..num_rows {
64            let mut row: Vec<f32> = Vec::new();
65            for _ in 0..num_cols {
66                let mut pixel = [0];
67                reader.read_exact(&mut pixel)?;
68                row.push(pixel[0] as f32 / 255.0);
69            }
70            image.push(row);
71        }
72        images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
73    }
74
75    Ok(images)
76}
More examples
Hide additional examples
examples/compare/fashion-2.rs (line 72)
52fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
53    let mut reader = BufReader::new(File::open(path)?);
54    let mut images: Vec<tensor::Tensor> = Vec::new();
55
56    let _magic_number = read(&mut reader)?;
57    let num_images = read(&mut reader)?;
58    let num_rows = read(&mut reader)?;
59    let num_cols = read(&mut reader)?;
60
61    for _ in 0..num_images {
62        let mut image: Vec<Vec<f32>> = Vec::new();
63        for _ in 0..num_rows {
64            let mut row: Vec<f32> = Vec::new();
65            for _ in 0..num_cols {
66                let mut pixel = [0];
67                reader.read_exact(&mut pixel)?;
68                row.push(pixel[0] as f32 / 255.0);
69            }
70            image.push(row);
71        }
72        images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
73    }
74
75    Ok(images)
76}
examples/mnist-fashion/feedback.rs (line 34)
14fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
15    let mut reader = BufReader::new(File::open(path)?);
16    let mut images: Vec<tensor::Tensor> = Vec::new();
17
18    let _magic_number = read(&mut reader)?;
19    let num_images = read(&mut reader)?;
20    let num_rows = read(&mut reader)?;
21    let num_cols = read(&mut reader)?;
22
23    for _ in 0..num_images {
24        let mut image: Vec<Vec<f32>> = Vec::new();
25        for _ in 0..num_rows {
26            let mut row: Vec<f32> = Vec::new();
27            for _ in 0..num_cols {
28                let mut pixel = [0];
29                reader.read_exact(&mut pixel)?;
30                row.push(pixel[0] as f32 / 255.0);
31            }
32            image.push(row);
33        }
34        images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
35    }
36
37    Ok(images)
38}
examples/mnist-fashion/looping.rs (line 35)
15fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
16    let mut reader = BufReader::new(File::open(path)?);
17    let mut images: Vec<tensor::Tensor> = Vec::new();
18
19    let _magic_number = read(&mut reader)?;
20    let num_images = read(&mut reader)?;
21    let num_rows = read(&mut reader)?;
22    let num_cols = read(&mut reader)?;
23
24    for _ in 0..num_images {
25        let mut image: Vec<Vec<f32>> = Vec::new();
26        for _ in 0..num_rows {
27            let mut row: Vec<f32> = Vec::new();
28            for _ in 0..num_cols {
29                let mut pixel = [0];
30                reader.read_exact(&mut pixel)?;
31                row.push(pixel[0] as f32 / 255.0);
32            }
33            image.push(row);
34        }
35        images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
36    }
37
38    Ok(images)
39}
examples/mnist-fashion/plain.rs (line 35)
15fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
16    let mut reader = BufReader::new(File::open(path)?);
17    let mut images: Vec<tensor::Tensor> = Vec::new();
18
19    let _magic_number = read(&mut reader)?;
20    let num_images = read(&mut reader)?;
21    let num_rows = read(&mut reader)?;
22    let num_cols = read(&mut reader)?;
23
24    for _ in 0..num_images {
25        let mut image: Vec<Vec<f32>> = Vec::new();
26        for _ in 0..num_rows {
27            let mut row: Vec<f32> = Vec::new();
28            for _ in 0..num_cols {
29                let mut pixel = [0];
30                reader.read_exact(&mut pixel)?;
31                row.push(pixel[0] as f32 / 255.0);
32            }
33            image.push(row);
34        }
35        images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
36    }
37
38    Ok(images)
39}
examples/mnist-fashion/skip.rs (line 34)
14fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
15    let mut reader = BufReader::new(File::open(path)?);
16    let mut images: Vec<tensor::Tensor> = Vec::new();
17
18    let _magic_number = read(&mut reader)?;
19    let num_images = read(&mut reader)?;
20    let num_rows = read(&mut reader)?;
21    let num_cols = read(&mut reader)?;
22
23    for _ in 0..num_images {
24        let mut image: Vec<Vec<f32>> = Vec::new();
25        for _ in 0..num_rows {
26            let mut row: Vec<f32> = Vec::new();
27            for _ in 0..num_cols {
28                let mut pixel = [0];
29                reader.read_exact(&mut pixel)?;
30                row.push(pixel[0] as f32 / 255.0);
31            }
32            image.push(row);
33        }
34        images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
35    }
36
37    Ok(images)
38}
Source

pub fn argmax(&self) -> usize

Get the index of the maximum value in the Tensor.

Examples found in repository?
examples/iris.rs (line 120)
49fn main() {
50    // Load the iris dataset
51    let (x, y) = data("./examples/datasets/iris.csv");
52
53    let split = (x.len() as f32 * 0.8) as usize;
54    let x = x.split_at(split);
55    let y = y.split_at(split);
56
57    let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
58    let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
59    let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
60    let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
61
62    println!(
63        "Train data {}x{}: {} => {}",
64        x_train.len(),
65        x_train[0].shape,
66        x_train[0].data,
67        y_train[0].data
68    );
69    println!(
70        "Test data {}x{}: {} => {}",
71        x_test.len(),
72        x_test[0].shape,
73        x_test[0].data,
74        y_test[0].data
75    );
76
77    // Create the network
78    let mut network = network::Network::new(tensor::Shape::Single(4));
79
80    network.dense(50, activation::Activation::ReLU, false, None);
81    network.dense(50, activation::Activation::ReLU, false, None);
82    network.dense(3, activation::Activation::Softmax, false, None);
83
84    network.set_optimizer(optimizer::RMSprop::create(
85        0.0001,     // Learning rate
86        0.0,        // Alpha
87        1e-8,       // Epsilon
88        Some(0.01), // Decay
89        Some(0.01), // Momentum
90        true,       // Centered
91    ));
92    network.set_objective(
93        objective::Objective::CrossEntropy, // Objective function
94        Some((-1f32, 1f32)),                // Gradient clipping
95    );
96
97    // Train the network
98    let (_train_loss, _val_loss, _val_acc) = network.learn(
99        &x_train,
100        &y_train,
101        Some((&x_test, &y_test, 5)),
102        1,
103        5,
104        Some(1),
105    );
106
107    // Validate the network
108    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
109    println!(
110        "Final validation accuracy: {:.2} % and loss: {:.5}",
111        val_acc * 100.0,
112        val_loss
113    );
114
115    // Use the network
116    let prediction = network.predict(x_test.get(0).unwrap());
117    println!(
118        "Prediction on input: {}. Target: {}. Output: {}.",
119        x_test[0].data,
120        y_test[0].argmax(),
121        prediction.argmax()
122    );
123}
More examples
Hide additional examples
examples/mnist-fashion/skip.rs (line 146)
54fn main() {
55    let x_train = load_mnist("./examples/datasets/mnist-fashion/train-images-idx3-ubyte").unwrap();
56    let y_train = load_labels(
57        "./examples/datasets/mnist-fashion/train-labels-idx1-ubyte",
58        10,
59    )
60    .unwrap();
61    let x_test = load_mnist("./examples/datasets/mnist-fashion/t10k-images-idx3-ubyte").unwrap();
62    let y_test = load_labels(
63        "./examples/datasets/mnist-fashion/t10k-labels-idx1-ubyte",
64        10,
65    )
66    .unwrap();
67    println!(
68        "Train: {} images, Test: {} images",
69        x_train.len(),
70        x_test.len()
71    );
72
73    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
74    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
75    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
76    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
77
78    let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
79
80    network.convolution(
81        1,
82        (3, 3),
83        (1, 1),
84        (1, 1),
85        (1, 1),
86        activation::Activation::ReLU,
87        None,
88    );
89    network.convolution(
90        1,
91        (3, 3),
92        (1, 1),
93        (1, 1),
94        (1, 1),
95        activation::Activation::ReLU,
96        None,
97    );
98    network.convolution(
99        1,
100        (3, 3),
101        (1, 1),
102        (1, 1),
103        (1, 1),
104        activation::Activation::ReLU,
105        None,
106    );
107    network.maxpool((2, 2), (2, 2));
108    network.dense(10, activation::Activation::Softmax, true, None);
109
110    network.connect(0, 2);
111
112    network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
113    network.set_objective(objective::Objective::CrossEntropy, None);
114
115    println!("{}", network);
116
117    // Train the network
118    let (train_loss, val_loss, val_acc) = network.learn(
119        &x_train,
120        &y_train,
121        Some((&x_test, &y_test, 10)),
122        32,
123        25,
124        Some(5),
125    );
126    plot::loss(
127        &train_loss,
128        &val_loss,
129        &val_acc,
130        "SKIP : Fashion-MNIST",
131        "./output/mnist-fashion/skip.png",
132    );
133
134    // Validate the network
135    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
136    println!(
137        "Final validation accuracy: {:.2} % and loss: {:.5}",
138        val_acc * 100.0,
139        val_loss
140    );
141
142    // Use the network
143    let prediction = network.predict(x_test.get(0).unwrap());
144    println!(
145        "Prediction on input: Target: {}. Output: {}.",
146        y_test[0].argmax(),
147        prediction.argmax()
148    );
149
150    // let x = x_test.get(5).unwrap();
151    // let y = y_test.get(5).unwrap();
152    // Plot the pre- and post-activation heatmaps for each (image) layer.
153    // let (pre, post, _) = network.forward(x);
154    // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
155    //     let pre_title = format!("layer_{}_pre", i);
156    //     let post_title = format!("layer_{}_post", i);
157    //     let pre_file = format!("layer_{}_pre.png", i);
158    //     let post_file = format!("layer_{}_post.png", i);
159    //     plot::heatmap(&i_pre, &pre_title, &pre_file);
160    //     plot::heatmap(&i_post, &post_title, &post_file);
161    // }
162}
examples/mnist/plain.rs (line 139)
54fn main() {
55    let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
56    let y_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
57    let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
58    let y_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
59    println!(
60        "Train: {} images, Test: {} images",
61        x_train.len(),
62        x_test.len()
63    );
64
65    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
66    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
67    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
68    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
69
70    let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
71
72    network.convolution(
73        1,
74        (3, 3),
75        (1, 1),
76        (1, 1),
77        (1, 1),
78        activation::Activation::ReLU,
79        None,
80    );
81    network.convolution(
82        1,
83        (3, 3),
84        (1, 1),
85        (1, 1),
86        (1, 1),
87        activation::Activation::ReLU,
88        None,
89    );
90    network.convolution(
91        1,
92        (3, 3),
93        (1, 1),
94        (1, 1),
95        (1, 1),
96        activation::Activation::ReLU,
97        None,
98    );
99    network.maxpool((2, 2), (2, 2));
100    network.dense(10, activation::Activation::Softmax, true, None);
101
102    network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
103    network.set_objective(
104        objective::Objective::CrossEntropy, // Objective function
105        None,                               // Gradient clipping
106    );
107
108    println!("{}", network);
109
110    // Train the network
111    let (train_loss, val_loss, val_acc) = network.learn(
112        &x_train,
113        &y_train,
114        Some((&x_test, &y_test, 10)),
115        32,
116        25,
117        Some(5),
118    );
119    plot::loss(
120        &train_loss,
121        &val_loss,
122        &val_acc,
123        "PLAIN : MNIST",
124        "./output/mnist/plain.png",
125    );
126
127    // Validate the network
128    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
129    println!(
130        "Final validation accuracy: {:.2} % and loss: {:.5}",
131        val_acc * 100.0,
132        val_loss
133    );
134
135    // Use the network
136    let prediction = network.predict(x_test.get(0).unwrap());
137    println!(
138        "Prediction on input: Target: {}. Output: {}.",
139        y_test[0].argmax(),
140        prediction.argmax()
141    );
142
143    // let x = x_test.get(5).unwrap();
144    // let y = y_test.get(5).unwrap();
145    // plot::heatmap(
146    //     &x,
147    //     &format!("Target: {}", y.argmax()),
148    //     "./output/mnist/input.png",
149    // );
150
151    // Plot the pre- and post-activation heatmaps for each (image) layer.
152    // let (pre, post, _) = network.forward(x);
153    // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
154    //     let pre_title = format!("layer_{}_pre", i);
155    //     let post_title = format!("layer_{}_post", i);
156    //     let pre_file = format!("layer_{}_pre.png", i);
157    //     let post_file = format!("layer_{}_post.png", i);
158    //     plot::heatmap(&i_pre, &pre_title, &pre_file);
159    //     plot::heatmap(&i_post, &post_title, &post_file);
160    // }
161}
examples/mnist/skip.rs (line 141)
54fn main() {
55    let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
56    let y_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
57    let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
58    let y_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
59    println!(
60        "Train: {} images, Test: {} images",
61        x_train.len(),
62        x_test.len()
63    );
64
65    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
66    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
67    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
68    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
69
70    let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
71
72    network.convolution(
73        1,
74        (3, 3),
75        (1, 1),
76        (1, 1),
77        (1, 1),
78        activation::Activation::ReLU,
79        None,
80    );
81    network.convolution(
82        1,
83        (3, 3),
84        (1, 1),
85        (1, 1),
86        (1, 1),
87        activation::Activation::ReLU,
88        None,
89    );
90    network.convolution(
91        1,
92        (3, 3),
93        (1, 1),
94        (1, 1),
95        (1, 1),
96        activation::Activation::ReLU,
97        None,
98    );
99    network.maxpool((2, 2), (2, 2));
100    network.dense(10, activation::Activation::Softmax, true, None);
101
102    network.connect(0, 3);
103
104    network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
105    network.set_objective(
106        objective::Objective::CrossEntropy, // Objective function
107        None,                               // Gradient clipping
108    );
109
110    println!("{}", network);
111
112    // Train the network
113    let (train_loss, val_loss, val_acc) = network.learn(
114        &x_train,
115        &y_train,
116        Some((&x_test, &y_test, 10)),
117        32,
118        25,
119        Some(5),
120    );
121    plot::loss(
122        &train_loss,
123        &val_loss,
124        &val_acc,
125        "SKIP : MNIST",
126        "./output/mnist/skip.png",
127    );
128
129    // Validate the network
130    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
131    println!(
132        "Final validation accuracy: {:.2} % and loss: {:.5}",
133        val_acc * 100.0,
134        val_loss
135    );
136
137    // Use the network
138    let prediction = network.predict(x_test.get(0).unwrap());
139    println!(
140        "Prediction on input: Target: {}. Output: {}.",
141        y_test[0].argmax(),
142        prediction.argmax()
143    );
144
145    // let x = x_test.get(5).unwrap();
146    // let y = y_test.get(5).unwrap();
147    // plot::heatmap(
148    //     &x,
149    //     &format!("Target: {}", y.argmax()),
150    //     "./output/mnist/input.png",
151    // );
152
153    // Plot the pre- and post-activation heatmaps for each (image) layer.
154    // let (pre, post, _) = network.forward(x);
155    // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
156    //     let pre_title = format!("layer_{}_pre", i);
157    //     let post_title = format!("layer_{}_post", i);
158    //     let pre_file = format!("layer_{}_pre.png", i);
159    //     let post_file = format!("layer_{}_post.png", i);
160    //     plot::heatmap(&i_pre, &pre_title, &pre_file);
161    //     plot::heatmap(&i_post, &post_title, &post_file);
162    // }
163}
examples/mnist/deconvolution.rs (line 139)
54fn main() {
55    let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
56    let y_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
57    let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
58    let y_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
59    println!(
60        "Train: {} images, Test: {} images",
61        x_train.len(),
62        x_test.len()
63    );
64
65    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
66    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
67    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
68    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
69
70    let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
71
72    network.convolution(
73        1,
74        (3, 3),
75        (1, 1),
76        (0, 0),
77        (1, 1),
78        activation::Activation::ReLU,
79        None,
80    );
81    network.maxpool((2, 2), (2, 2));
82    network.convolution(
83        4,
84        (3, 3),
85        (1, 1),
86        (0, 0),
87        (1, 1),
88        activation::Activation::ReLU,
89        None,
90    );
91    network.deconvolution(
92        4,
93        (3, 3),
94        (1, 1),
95        (0, 0),
96        activation::Activation::ReLU,
97        None,
98    );
99    network.maxpool((2, 2), (2, 2));
100    network.dense(10, activation::Activation::Softmax, true, None);
101
102    network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
103    network.set_objective(
104        objective::Objective::CrossEntropy, // Objective function
105        None,                               // Gradient clipping
106    );
107
108    println!("{}", network);
109
110    // Train the network
111    let (train_loss, val_loss, val_acc) = network.learn(
112        &x_train,
113        &y_train,
114        Some((&x_test, &y_test, 10)),
115        32,
116        25,
117        Some(5),
118    );
119    plot::loss(
120        &train_loss,
121        &val_loss,
122        &val_acc,
123        "DECONV : MNIST",
124        "./output/mnist/deconvolution.png",
125    );
126
127    // Validate the network
128    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
129    println!(
130        "Final validation accuracy: {:.2} % and loss: {:.5}",
131        val_acc * 100.0,
132        val_loss
133    );
134
135    // Use the network
136    let prediction = network.predict(x_test.get(0).unwrap());
137    println!(
138        "Prediction on input: Target: {}. Output: {}.",
139        y_test[0].argmax(),
140        prediction.argmax()
141    );
142
143    // let x = x_test.get(5).unwrap();
144    // let y = y_test.get(5).unwrap();
145    // plot::heatmap(
146    //     &x,
147    //     &format!("Target: {}", y.argmax()),
148    //     "./output/mnist/input.png",
149    // );
150
151    // Plot the pre- and post-activation heatmaps for each (image) layer.
152    // let (pre, post, _) = network.forward(x);
153    // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
154    //     let pre_title = format!("layer_{}_pre", i);
155    //     let post_title = format!("layer_{}_post", i);
156    //     let pre_file = format!("layer_{}_pre.png", i);
157    //     let post_file = format!("layer_{}_post.png", i);
158    //     plot::heatmap(&i_pre, &pre_title, &pre_file);
159    //     plot::heatmap(&i_post, &post_title, &post_file);
160    // }
161}
examples/mnist-fashion/feedback.rs (line 154)
54fn main() {
55    let x_train = load_mnist("./examples/datasets/mnist-fashion/train-images-idx3-ubyte").unwrap();
56    let y_train = load_labels(
57        "./examples/datasets/mnist-fashion/train-labels-idx1-ubyte",
58        10,
59    )
60    .unwrap();
61    let x_test = load_mnist("./examples/datasets/mnist-fashion/t10k-images-idx3-ubyte").unwrap();
62    let y_test = load_labels(
63        "./examples/datasets/mnist-fashion/t10k-labels-idx1-ubyte",
64        10,
65    )
66    .unwrap();
67    println!(
68        "Train: {} images, Test: {} images",
69        x_train.len(),
70        x_test.len()
71    );
72
73    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
74    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
75    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
76    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
77
78    let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
79
80    network.convolution(
81        1,
82        (3, 3),
83        (1, 1),
84        (1, 1),
85        (1, 1),
86        activation::Activation::ReLU,
87        None,
88    );
89    network.feedback(
90        vec![feedback::Layer::Convolution(
91            1,
92            activation::Activation::ReLU,
93            (3, 3),
94            (1, 1),
95            (1, 1),
96            (1, 1),
97            None,
98        )],
99        3,
100        false,
101        false,
102        feedback::Accumulation::Mean,
103    );
104    network.convolution(
105        1,
106        (3, 3),
107        (1, 1),
108        (1, 1),
109        (1, 1),
110        activation::Activation::ReLU,
111        None,
112    );
113    network.maxpool((2, 2), (2, 2));
114    network.dense(10, activation::Activation::Softmax, true, None);
115
116    // Include skip connection bypassing the feedback block
117    // network.connect(1, 2);
118    // network.set_accumulation(feedback::Accumulation::Add);
119
120    network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
121    network.set_objective(objective::Objective::CrossEntropy, None);
122
123    println!("{}", network);
124
125    // Train the network
126    let (train_loss, val_loss, val_acc) = network.learn(
127        &x_train,
128        &y_train,
129        Some((&x_test, &y_test, 10)),
130        32,
131        25,
132        Some(5),
133    );
134    plot::loss(
135        &train_loss,
136        &val_loss,
137        &val_acc,
138        "FEEDBACK : Fashion-MNIST",
139        "./output/mnist-fashion/feedback.png",
140    );
141
142    // Validate the network
143    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
144    println!(
145        "Final validation accuracy: {:.2} % and loss: {:.5}",
146        val_acc * 100.0,
147        val_loss
148    );
149
150    // Use the network
151    let prediction = network.predict(x_test.get(0).unwrap());
152    println!(
153        "Prediction on input: Target: {}. Output: {}.",
154        y_test[0].argmax(),
155        prediction.argmax()
156    );
157
158    // let x = x_test.get(5).unwrap();
159    // let y = y_test.get(5).unwrap();
160    // Plot the pre- and post-activation heatmaps for each (image) layer.
161    // let (pre, post, _) = network.forward(x);
162    // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
163    //     let pre_title = format!("layer_{}_pre", i);
164    //     let post_title = format!("layer_{}_post", i);
165    //     let pre_file = format!("layer_{}_pre.png", i);
166    //     let post_file = format!("layer_{}_post.png", i);
167    //     plot::heatmap(&i_pre, &pre_title, &pre_file);
168    //     plot::heatmap(&i_post, &post_title, &post_file);
169    // }
170}
Source

pub fn add_inplace(&mut self, other: &Tensor)

Inplace element-wise addition of two Tensors. Validates their shapes beforehand.

Source

pub fn sub_inplace(&mut self, other: &Tensor)

Inplace element-wise subtraction of two Tensors. Validates their shapes beforehand.

Source

pub fn mul_inplace(&mut self, other: &Tensor)

Inplace element-wise multiplication of two Tensors. Validates their shapes beforehand.

Source

pub fn div_scalar_inplace(&mut self, scalar: f32)

Source

pub fn mean_inplace(&mut self, others: &Vec<&Tensor>)

Inplace element-wise mean of two Tensors. Validates their shapes beforehand.

Source

pub fn hadamard(&mut self, other: &Tensor, scalar: f32)

Multiply the i == j elements of two Tensors of the same shape together. Validates their shapes beforehand.

§Arguments
  • other - The Tensor to multiply with.
  • scalar - A scalar value to multiply the result by (e.g., 1.0 / self.loops).
Source

pub fn product(&self, other: &Tensor) -> Self

Outer product of two Tensors.

Source

pub fn dot(&self, other: &Tensor) -> Self

Dot product of two Tensors. This Tensor must be Shape::Double and the other Tensor must be Shape::Single. This Tensors columns must be equal to the other Tensors rows.

Source

pub fn dropout(&mut self, dropout: f32)

Randomly set elements of the Tensor to zero with a given probability dropout.

Source

pub fn clamp(self, min: f32, max: f32) -> Self

Clamp the values of the Tensor to a given interval [min, max].

Source

pub fn transpose(&self) -> Self

Transpose the Tensor.

Source

pub fn extend(&mut self, other: &Tensor)

Trait Implementations§

Source§

impl Clone for Tensor

Source§

fn clone(&self) -> Tensor

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for Tensor

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Display for Tensor

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more

Auto Trait Implementations§

§

impl Freeze for Tensor

§

impl RefUnwindSafe for Tensor

§

impl Send for Tensor

§

impl Sync for Tensor

§

impl Unpin for Tensor

§

impl UnwindSafe for Tensor

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T> ToString for T
where T: Display + ?Sized,

Source§

fn to_string(&self) -> String

Converts the given value to a String. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.