pub struct Tensor {
pub shape: Shape,
pub data: Data,
}Expand description
Fields§
§shape: Shape§data: DataImplementations§
Source§impl Tensor
impl Tensor
Sourcepub fn random(shape: Shape, min: f32, max: f32) -> Self
pub fn random(shape: Shape, min: f32, max: f32) -> Self
Creates a new Tensor with the given shape, filled with random values.
§Arguments
shape- The shape of the Tensor.min- The minimum value of the random values.max- The maximum value of the random values.
§Returns
A new Tensor with the given shape filled with random values.
Examples found in repository?
5fn main() {
6 let mut network = network::Network::new(tensor::Shape::Triple(1, 24, 24));
7
8 network.convolution(
9 5,
10 (3, 3),
11 (1, 1),
12 (0, 0),
13 (1, 1),
14 activation::Activation::ReLU,
15 Some(0.1),
16 );
17 network.convolution(
18 1,
19 (3, 3),
20 (1, 1),
21 (0, 0),
22 (1, 1),
23 activation::Activation::ReLU,
24 Some(0.1),
25 );
26
27 println!("{}", network);
28
29 let x = tensor::Tensor::random(tensor::Shape::Triple(1, 24, 24), 0.0, 1.0);
30 println!("x: {}", &x.shape);
31
32 let (pre, post, _, _) = network.forward(&x);
33 println!("pre-activation: {}", &pre[pre.len() - 1].shape);
34 println!("post-activation: {}", &post[post.len() - 1].shape);
35
36 plot::heatmap(&x, "Input", "./output/convolution-input.png");
37 plot::heatmap(
38 &pre[pre.len() - 1],
39 "Pre-activation",
40 "./output/convolution-pre.png",
41 );
42 plot::heatmap(
43 &post[post.len() - 1],
44 "Post-activation",
45 "./output/convolution-post.png",
46 );
47}Sourcepub fn one_hot(value: usize, max: usize) -> Self
pub fn one_hot(value: usize, max: usize) -> Self
One-hot encodes a value into a Shape::Vector.
Examples found in repository?
43fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
44 let mut reader = BufReader::new(File::open(file_path)?);
45 let _magic_number = read(&mut reader)?;
46 let num_labels = read(&mut reader)?;
47
48 let mut _labels = vec![0; num_labels as usize];
49 reader.read_exact(&mut _labels)?;
50
51 Ok(_labels
52 .iter()
53 .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
54 .collect())
55}More examples
78fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
79 let mut reader = BufReader::new(File::open(file_path)?);
80 let _magic_number = read(&mut reader)?;
81 let num_labels = read(&mut reader)?;
82
83 let mut _labels = vec![0; num_labels as usize];
84 reader.read_exact(&mut _labels)?;
85
86 Ok(_labels
87 .iter()
88 .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
89 .collect())
90}78fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
79 let mut reader = BufReader::new(File::open(file_path)?);
80 let _magic_number = read(&mut reader)?;
81 let num_labels = read(&mut reader)?;
82
83 let mut _labels = vec![0; num_labels as usize];
84 reader.read_exact(&mut _labels)?;
85
86 Ok(_labels
87 .iter()
88 .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
89 .collect())
90}78fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
79 let mut reader = BufReader::new(File::open(file_path)?);
80 let _magic_number = read(&mut reader)?;
81 let num_labels = read(&mut reader)?;
82
83 let mut _labels = vec![0; num_labels as usize];
84 reader.read_exact(&mut _labels)?;
85
86 Ok(_labels
87 .iter()
88 .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
89 .collect())
90}40fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
41 let mut reader = BufReader::new(File::open(file_path)?);
42 let _magic_number = read(&mut reader)?;
43 let num_labels = read(&mut reader)?;
44
45 let mut _labels = vec![0; num_labels as usize];
46 reader.read_exact(&mut _labels)?;
47
48 Ok(_labels
49 .iter()
50 .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
51 .collect())
52}41fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
42 let mut reader = BufReader::new(File::open(file_path)?);
43 let _magic_number = read(&mut reader)?;
44 let num_labels = read(&mut reader)?;
45
46 let mut _labels = vec![0; num_labels as usize];
47 reader.read_exact(&mut _labels)?;
48
49 Ok(_labels
50 .iter()
51 .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
52 .collect())
53}- examples/mnist-fashion/plain.rs
- examples/mnist-fashion/skip.rs
- examples/mnist/deconvolution.rs
- examples/mnist/feedback.rs
- examples/mnist/looping.rs
- examples/mnist/plain.rs
- examples/mnist/skip.rs
- examples/timing/mnist.rs
- examples/cifar10/feedback.rs
- examples/cifar10/looping.rs
- examples/cifar10/plain.rs
- examples/timing/ftir-mlp.rs
- examples/timing/ftir-cnn.rs
- examples/compare/iris.rs
- examples/iris.rs
- examples/timing/iris.rs
- examples/compare/ftir-mlp.rs
- examples/ftir/mlp/feedback.rs
- examples/ftir/mlp/looping.rs
- examples/ftir/mlp/plain.rs
- examples/ftir/mlp/skip.rs
- examples/compare/ftir-cnn.rs
- examples/ftir/cnn/feedback.rs
- examples/ftir/cnn/looping.rs
- examples/ftir/cnn/plain.rs
- examples/ftir/cnn/skip.rs
Sourcepub fn single(data: Vec<f32>) -> Self
pub fn single(data: Vec<f32>) -> Self
Creates a new Tensor from the given vector.
Examples found in repository?
17fn data(
18 path: &str,
19) -> (
20 Vec<tensor::Tensor>,
21 Vec<tensor::Tensor>,
22 Vec<tensor::Tensor>,
23) {
24 let reader = BufReader::new(File::open(&path).unwrap());
25
26 let mut x: Vec<tensor::Tensor> = Vec::new();
27 let mut y: Vec<tensor::Tensor> = Vec::new();
28 let mut c: Vec<tensor::Tensor> = Vec::new();
29
30 for line in reader.lines().skip(1) {
31 let line = line.unwrap();
32 let record: Vec<&str> = line.split(',').collect();
33
34 let mut data: Vec<f32> = Vec::new();
35 for i in 0..571 {
36 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
37 }
38
39 x.push(tensor::Tensor::single(data));
40 y.push(tensor::Tensor::single(vec![record
41 .get(571)
42 .unwrap()
43 .parse::<f32>()
44 .unwrap()]));
45 c.push(tensor::Tensor::one_hot(
46 record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
47 28,
48 ));
49 }
50 (x, y, c)
51}More examples
17fn data(
18 path: &str,
19) -> (
20 Vec<tensor::Tensor>,
21 Vec<tensor::Tensor>,
22 Vec<tensor::Tensor>,
23) {
24 let reader = BufReader::new(File::open(&path).unwrap());
25
26 let mut x: Vec<tensor::Tensor> = Vec::new();
27 let mut y: Vec<tensor::Tensor> = Vec::new();
28 let mut c: Vec<tensor::Tensor> = Vec::new();
29
30 for line in reader.lines().skip(1) {
31 let line = line.unwrap();
32 let record: Vec<&str> = line.split(',').collect();
33
34 let mut data: Vec<f32> = Vec::new();
35 for i in 0..571 {
36 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
37 }
38 let data: Vec<Vec<Vec<f32>>> = vec![vec![data]];
39 x.push(tensor::Tensor::triple(data));
40 y.push(tensor::Tensor::single(vec![record
41 .get(571)
42 .unwrap()
43 .parse::<f32>()
44 .unwrap()]));
45 c.push(tensor::Tensor::one_hot(
46 record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
47 28,
48 ));
49 }
50
51 (x, y, c)
52}10fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
11 let reader = BufReader::new(File::open(&path).unwrap());
12
13 let mut x: Vec<tensor::Tensor> = Vec::new();
14 let mut y: Vec<tensor::Tensor> = Vec::new();
15
16 for line in reader.lines().skip(1) {
17 let line = line.unwrap();
18 let record: Vec<&str> = line.split(',').collect();
19
20 let mut data: Vec<f32> = Vec::new();
21 for i in 2..14 {
22 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
23 }
24 x.push(tensor::Tensor::single(data));
25
26 y.push(tensor::Tensor::single(vec![record
27 .get(16)
28 .unwrap()
29 .parse::<f32>()
30 .unwrap()]));
31 }
32
33 let mut generator = random::Generator::create(12345);
34 let mut indices: Vec<usize> = (0..x.len()).collect();
35 generator.shuffle(&mut indices);
36
37 let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
38 let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
39
40 (x, y)
41}11fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
12 let reader = BufReader::new(File::open(&path).unwrap());
13
14 let mut x: Vec<tensor::Tensor> = Vec::new();
15 let mut y: Vec<tensor::Tensor> = Vec::new();
16
17 for line in reader.lines().skip(1) {
18 let line = line.unwrap();
19 let record: Vec<&str> = line.split(',').collect();
20
21 let mut data: Vec<f32> = Vec::new();
22 for i in 2..14 {
23 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
24 }
25 x.push(tensor::Tensor::single(data));
26
27 y.push(tensor::Tensor::single(vec![record
28 .get(16)
29 .unwrap()
30 .parse::<f32>()
31 .unwrap()]));
32 }
33
34 let mut generator = random::Generator::create(12345);
35 let mut indices: Vec<usize> = (0..x.len()).collect();
36 generator.shuffle(&mut indices);
37
38 let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
39 let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
40
41 (x, y)
42}10fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
11 let reader = BufReader::new(File::open(&path).unwrap());
12
13 let mut x: Vec<tensor::Tensor> = Vec::new();
14 let mut y: Vec<tensor::Tensor> = Vec::new();
15
16 for line in reader.lines().skip(1) {
17 let line = line.unwrap();
18 let record: Vec<&str> = line.split(',').collect();
19
20 let mut data: Vec<f32> = Vec::new();
21 for i in 2..14 {
22 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
23 }
24 x.push(tensor::Tensor::single(data));
25
26 y.push(tensor::Tensor::single(vec![record
27 .get(16)
28 .unwrap()
29 .parse::<f32>()
30 .unwrap()]));
31 }
32
33 let mut generator = random::Generator::create(12345);
34 let mut indices: Vec<usize> = (0..x.len()).collect();
35 generator.shuffle(&mut indices);
36
37 let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
38 let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
39
40 (x, y)
41}48fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
49 let reader = BufReader::new(File::open(&path).unwrap());
50
51 let mut x: Vec<tensor::Tensor> = Vec::new();
52 let mut y: Vec<tensor::Tensor> = Vec::new();
53
54 for line in reader.lines().skip(1) {
55 let line = line.unwrap();
56 let record: Vec<&str> = line.split(',').collect();
57
58 let mut data: Vec<f32> = Vec::new();
59 for i in 2..14 {
60 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
61 }
62 x.push(tensor::Tensor::single(data));
63
64 y.push(tensor::Tensor::single(vec![record
65 .get(16)
66 .unwrap()
67 .parse::<f32>()
68 .unwrap()]));
69 }
70
71 let mut generator = random::Generator::create(12345);
72 let mut indices: Vec<usize> = (0..x.len()).collect();
73 generator.shuffle(&mut indices);
74
75 let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
76 let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
77
78 (x, y)
79}- examples/timing/bike.rs
- examples/compare/iris.rs
- examples/iris.rs
- examples/timing/iris.rs
- examples/xor.rs
- examples/compare/ftir-mlp.rs
- examples/ftir/mlp/feedback.rs
- examples/ftir/mlp/looping.rs
- examples/ftir/mlp/plain.rs
- examples/ftir/mlp/skip.rs
- examples/compare/ftir-cnn.rs
- examples/ftir/cnn/feedback.rs
- examples/ftir/cnn/looping.rs
- examples/ftir/cnn/plain.rs
- examples/ftir/cnn/skip.rs
Sourcepub fn double(data: Vec<Vec<f32>>) -> Self
pub fn double(data: Vec<Vec<f32>>) -> Self
Creates a new Tensor from the given two-dimensional vector.
Sourcepub fn triple(data: Vec<Vec<Vec<f32>>>) -> Self
pub fn triple(data: Vec<Vec<Vec<f32>>>) -> Self
Creates a new Tensor from the given three-dimensional vector.
Examples found in repository?
52fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
53 let mut reader = BufReader::new(File::open(path)?);
54 let mut images: Vec<tensor::Tensor> = Vec::new();
55
56 let _magic_number = read(&mut reader)?;
57 let num_images = read(&mut reader)?;
58 let num_rows = read(&mut reader)?;
59 let num_cols = read(&mut reader)?;
60
61 for _ in 0..num_images {
62 let mut image: Vec<Vec<f32>> = Vec::new();
63 for _ in 0..num_rows {
64 let mut row: Vec<f32> = Vec::new();
65 for _ in 0..num_cols {
66 let mut pixel = [0];
67 reader.read_exact(&mut pixel)?;
68 row.push(pixel[0] as f32 / 255.0);
69 }
70 image.push(row);
71 }
72 images.push(tensor::Tensor::triple(vec![image]));
73 }
74
75 Ok(images)
76}More examples
23fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
24 let mut reader = BufReader::new(File::open(path)?);
25 let mut images: Vec<tensor::Tensor> = Vec::new();
26
27 let _magic_number = read(&mut reader)?;
28 let num_images = read(&mut reader)?;
29 let num_rows = read(&mut reader)?;
30 let num_cols = read(&mut reader)?;
31
32 for _ in 0..num_images {
33 let mut image: Vec<Vec<f32>> = Vec::new();
34 for _ in 0..num_rows {
35 let mut row: Vec<f32> = Vec::new();
36 for _ in 0..num_cols {
37 let mut pixel = [0];
38 reader.read_exact(&mut pixel)?;
39 row.push(pixel[0] as f32 / 255.0);
40 }
41 image.push(row);
42 }
43 images.push(tensor::Tensor::triple(vec![image]));
44 }
45
46 Ok(images)
47}52fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
53 let mut reader = BufReader::new(File::open(path)?);
54 let mut images: Vec<tensor::Tensor> = Vec::new();
55
56 let _magic_number = read(&mut reader)?;
57 let num_images = read(&mut reader)?;
58 let num_rows = read(&mut reader)?;
59 let num_cols = read(&mut reader)?;
60
61 for _ in 0..num_images {
62 let mut image: Vec<Vec<f32>> = Vec::new();
63 for _ in 0..num_rows {
64 let mut row: Vec<f32> = Vec::new();
65 for _ in 0..num_cols {
66 let mut pixel = [0];
67 reader.read_exact(&mut pixel)?;
68 row.push(pixel[0] as f32 / 255.0);
69 }
70 image.push(row);
71 }
72 images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
73 }
74
75 Ok(images)
76}52fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
53 let mut reader = BufReader::new(File::open(path)?);
54 let mut images: Vec<tensor::Tensor> = Vec::new();
55
56 let _magic_number = read(&mut reader)?;
57 let num_images = read(&mut reader)?;
58 let num_rows = read(&mut reader)?;
59 let num_cols = read(&mut reader)?;
60
61 for _ in 0..num_images {
62 let mut image: Vec<Vec<f32>> = Vec::new();
63 for _ in 0..num_rows {
64 let mut row: Vec<f32> = Vec::new();
65 for _ in 0..num_cols {
66 let mut pixel = [0];
67 reader.read_exact(&mut pixel)?;
68 row.push(pixel[0] as f32 / 255.0);
69 }
70 image.push(row);
71 }
72 images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
73 }
74
75 Ok(images)
76}14fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
15 let mut reader = BufReader::new(File::open(path)?);
16 let mut images: Vec<tensor::Tensor> = Vec::new();
17
18 let _magic_number = read(&mut reader)?;
19 let num_images = read(&mut reader)?;
20 let num_rows = read(&mut reader)?;
21 let num_cols = read(&mut reader)?;
22
23 for _ in 0..num_images {
24 let mut image: Vec<Vec<f32>> = Vec::new();
25 for _ in 0..num_rows {
26 let mut row: Vec<f32> = Vec::new();
27 for _ in 0..num_cols {
28 let mut pixel = [0];
29 reader.read_exact(&mut pixel)?;
30 row.push(pixel[0] as f32 / 255.0);
31 }
32 image.push(row);
33 }
34 images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
35 }
36
37 Ok(images)
38}15fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
16 let mut reader = BufReader::new(File::open(path)?);
17 let mut images: Vec<tensor::Tensor> = Vec::new();
18
19 let _magic_number = read(&mut reader)?;
20 let num_images = read(&mut reader)?;
21 let num_rows = read(&mut reader)?;
22 let num_cols = read(&mut reader)?;
23
24 for _ in 0..num_images {
25 let mut image: Vec<Vec<f32>> = Vec::new();
26 for _ in 0..num_rows {
27 let mut row: Vec<f32> = Vec::new();
28 for _ in 0..num_cols {
29 let mut pixel = [0];
30 reader.read_exact(&mut pixel)?;
31 row.push(pixel[0] as f32 / 255.0);
32 }
33 image.push(row);
34 }
35 images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
36 }
37
38 Ok(images)
39}- examples/mnist-fashion/plain.rs
- examples/mnist-fashion/skip.rs
- examples/mnist/deconvolution.rs
- examples/mnist/feedback.rs
- examples/mnist/looping.rs
- examples/mnist/plain.rs
- examples/mnist/skip.rs
- examples/benchmark.rs
- examples/cifar10/feedback.rs
- examples/cifar10/looping.rs
- examples/cifar10/plain.rs
- examples/timing/ftir-cnn.rs
- examples/compare/ftir-cnn.rs
- examples/ftir/cnn/feedback.rs
- examples/ftir/cnn/looping.rs
- examples/ftir/cnn/plain.rs
- examples/ftir/cnn/skip.rs
Sourcepub fn quadruple(data: Vec<Vec<Vec<Vec<f32>>>>) -> Self
pub fn quadruple(data: Vec<Vec<Vec<Vec<f32>>>>) -> Self
Creates a new Tensor from the given four-dimensional vector.
Sourcepub fn quintuple(data: Vec<Vec<Vec<Vec<Vec<(usize, usize)>>>>>) -> Self
pub fn quintuple(data: Vec<Vec<Vec<Vec<Vec<(usize, usize)>>>>>) -> Self
Creates a new Tensor from the given five-dimensional vector.
Sourcepub fn nested(tensors: Vec<Tensor>) -> Self
pub fn nested(tensors: Vec<Tensor>) -> Self
Convert a vector of Tensors into a single nested Tensor.
Sourcepub fn nestedoptional(tensors: Vec<Option<Tensor>>) -> Self
pub fn nestedoptional(tensors: Vec<Option<Tensor>>) -> Self
Convert a vector of Tensors into a single nested Tensor.
Sourcepub fn unnestedoptional(&self) -> Vec<Option<Tensor>>
pub fn unnestedoptional(&self) -> Vec<Option<Tensor>>
Convert a single nested Tensor into a vector of Tensors.
Sourcepub fn flatten(&self) -> Self
pub fn flatten(&self) -> Self
Flatten the Tensor’s data. Returns a new Tensor with the updated shape and data.
§Returns
A new Tensor with the same data but in a vector format.
Sourcepub fn quadruple_to_vec_triple(&self) -> Vec<Tensor>
pub fn quadruple_to_vec_triple(&self) -> Vec<Tensor>
Get the data of the Shape::Quadruple Tensor as a vector of Shape::Triple Tensors.
Sourcepub fn resize(&self, shape: Shape) -> Self
pub fn resize(&self, shape: Shape) -> Self
Resize the Tensor to the given shape.
Average pooling is used to resize the Tensor.
Examples found in repository?
52fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
53 let mut reader = BufReader::new(File::open(path)?);
54 let mut images: Vec<tensor::Tensor> = Vec::new();
55
56 let _magic_number = read(&mut reader)?;
57 let num_images = read(&mut reader)?;
58 let num_rows = read(&mut reader)?;
59 let num_cols = read(&mut reader)?;
60
61 for _ in 0..num_images {
62 let mut image: Vec<Vec<f32>> = Vec::new();
63 for _ in 0..num_rows {
64 let mut row: Vec<f32> = Vec::new();
65 for _ in 0..num_cols {
66 let mut pixel = [0];
67 reader.read_exact(&mut pixel)?;
68 row.push(pixel[0] as f32 / 255.0);
69 }
70 image.push(row);
71 }
72 images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
73 }
74
75 Ok(images)
76}More examples
52fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
53 let mut reader = BufReader::new(File::open(path)?);
54 let mut images: Vec<tensor::Tensor> = Vec::new();
55
56 let _magic_number = read(&mut reader)?;
57 let num_images = read(&mut reader)?;
58 let num_rows = read(&mut reader)?;
59 let num_cols = read(&mut reader)?;
60
61 for _ in 0..num_images {
62 let mut image: Vec<Vec<f32>> = Vec::new();
63 for _ in 0..num_rows {
64 let mut row: Vec<f32> = Vec::new();
65 for _ in 0..num_cols {
66 let mut pixel = [0];
67 reader.read_exact(&mut pixel)?;
68 row.push(pixel[0] as f32 / 255.0);
69 }
70 image.push(row);
71 }
72 images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
73 }
74
75 Ok(images)
76}14fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
15 let mut reader = BufReader::new(File::open(path)?);
16 let mut images: Vec<tensor::Tensor> = Vec::new();
17
18 let _magic_number = read(&mut reader)?;
19 let num_images = read(&mut reader)?;
20 let num_rows = read(&mut reader)?;
21 let num_cols = read(&mut reader)?;
22
23 for _ in 0..num_images {
24 let mut image: Vec<Vec<f32>> = Vec::new();
25 for _ in 0..num_rows {
26 let mut row: Vec<f32> = Vec::new();
27 for _ in 0..num_cols {
28 let mut pixel = [0];
29 reader.read_exact(&mut pixel)?;
30 row.push(pixel[0] as f32 / 255.0);
31 }
32 image.push(row);
33 }
34 images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
35 }
36
37 Ok(images)
38}15fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
16 let mut reader = BufReader::new(File::open(path)?);
17 let mut images: Vec<tensor::Tensor> = Vec::new();
18
19 let _magic_number = read(&mut reader)?;
20 let num_images = read(&mut reader)?;
21 let num_rows = read(&mut reader)?;
22 let num_cols = read(&mut reader)?;
23
24 for _ in 0..num_images {
25 let mut image: Vec<Vec<f32>> = Vec::new();
26 for _ in 0..num_rows {
27 let mut row: Vec<f32> = Vec::new();
28 for _ in 0..num_cols {
29 let mut pixel = [0];
30 reader.read_exact(&mut pixel)?;
31 row.push(pixel[0] as f32 / 255.0);
32 }
33 image.push(row);
34 }
35 images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
36 }
37
38 Ok(images)
39}15fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
16 let mut reader = BufReader::new(File::open(path)?);
17 let mut images: Vec<tensor::Tensor> = Vec::new();
18
19 let _magic_number = read(&mut reader)?;
20 let num_images = read(&mut reader)?;
21 let num_rows = read(&mut reader)?;
22 let num_cols = read(&mut reader)?;
23
24 for _ in 0..num_images {
25 let mut image: Vec<Vec<f32>> = Vec::new();
26 for _ in 0..num_rows {
27 let mut row: Vec<f32> = Vec::new();
28 for _ in 0..num_cols {
29 let mut pixel = [0];
30 reader.read_exact(&mut pixel)?;
31 row.push(pixel[0] as f32 / 255.0);
32 }
33 image.push(row);
34 }
35 images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
36 }
37
38 Ok(images)
39}14fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
15 let mut reader = BufReader::new(File::open(path)?);
16 let mut images: Vec<tensor::Tensor> = Vec::new();
17
18 let _magic_number = read(&mut reader)?;
19 let num_images = read(&mut reader)?;
20 let num_rows = read(&mut reader)?;
21 let num_cols = read(&mut reader)?;
22
23 for _ in 0..num_images {
24 let mut image: Vec<Vec<f32>> = Vec::new();
25 for _ in 0..num_rows {
26 let mut row: Vec<f32> = Vec::new();
27 for _ in 0..num_cols {
28 let mut pixel = [0];
29 reader.read_exact(&mut pixel)?;
30 row.push(pixel[0] as f32 / 255.0);
31 }
32 image.push(row);
33 }
34 images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
35 }
36
37 Ok(images)
38}Sourcepub fn argmax(&self) -> usize
pub fn argmax(&self) -> usize
Get the index of the maximum value in the Tensor.
Examples found in repository?
49fn main() {
50 // Load the iris dataset
51 let (x, y) = data("./examples/datasets/iris.csv");
52
53 let split = (x.len() as f32 * 0.8) as usize;
54 let x = x.split_at(split);
55 let y = y.split_at(split);
56
57 let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
58 let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
59 let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
60 let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
61
62 println!(
63 "Train data {}x{}: {} => {}",
64 x_train.len(),
65 x_train[0].shape,
66 x_train[0].data,
67 y_train[0].data
68 );
69 println!(
70 "Test data {}x{}: {} => {}",
71 x_test.len(),
72 x_test[0].shape,
73 x_test[0].data,
74 y_test[0].data
75 );
76
77 // Create the network
78 let mut network = network::Network::new(tensor::Shape::Single(4));
79
80 network.dense(50, activation::Activation::ReLU, false, None);
81 network.dense(50, activation::Activation::ReLU, false, None);
82 network.dense(3, activation::Activation::Softmax, false, None);
83
84 network.set_optimizer(optimizer::RMSprop::create(
85 0.0001, // Learning rate
86 0.0, // Alpha
87 1e-8, // Epsilon
88 Some(0.01), // Decay
89 Some(0.01), // Momentum
90 true, // Centered
91 ));
92 network.set_objective(
93 objective::Objective::CrossEntropy, // Objective function
94 Some((-1f32, 1f32)), // Gradient clipping
95 );
96
97 // Train the network
98 let (_train_loss, _val_loss, _val_acc) = network.learn(
99 &x_train,
100 &y_train,
101 Some((&x_test, &y_test, 5)),
102 1,
103 5,
104 Some(1),
105 );
106
107 // Validate the network
108 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
109 println!(
110 "Final validation accuracy: {:.2} % and loss: {:.5}",
111 val_acc * 100.0,
112 val_loss
113 );
114
115 // Use the network
116 let prediction = network.predict(x_test.get(0).unwrap());
117 println!(
118 "Prediction on input: {}. Target: {}. Output: {}.",
119 x_test[0].data,
120 y_test[0].argmax(),
121 prediction.argmax()
122 );
123}More examples
54fn main() {
55 let x_train = load_mnist("./examples/datasets/mnist-fashion/train-images-idx3-ubyte").unwrap();
56 let y_train = load_labels(
57 "./examples/datasets/mnist-fashion/train-labels-idx1-ubyte",
58 10,
59 )
60 .unwrap();
61 let x_test = load_mnist("./examples/datasets/mnist-fashion/t10k-images-idx3-ubyte").unwrap();
62 let y_test = load_labels(
63 "./examples/datasets/mnist-fashion/t10k-labels-idx1-ubyte",
64 10,
65 )
66 .unwrap();
67 println!(
68 "Train: {} images, Test: {} images",
69 x_train.len(),
70 x_test.len()
71 );
72
73 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
74 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
75 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
76 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
77
78 let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
79
80 network.convolution(
81 1,
82 (3, 3),
83 (1, 1),
84 (1, 1),
85 (1, 1),
86 activation::Activation::ReLU,
87 None,
88 );
89 network.convolution(
90 1,
91 (3, 3),
92 (1, 1),
93 (1, 1),
94 (1, 1),
95 activation::Activation::ReLU,
96 None,
97 );
98 network.convolution(
99 1,
100 (3, 3),
101 (1, 1),
102 (1, 1),
103 (1, 1),
104 activation::Activation::ReLU,
105 None,
106 );
107 network.maxpool((2, 2), (2, 2));
108 network.dense(10, activation::Activation::Softmax, true, None);
109
110 network.connect(0, 2);
111
112 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
113 network.set_objective(objective::Objective::CrossEntropy, None);
114
115 println!("{}", network);
116
117 // Train the network
118 let (train_loss, val_loss, val_acc) = network.learn(
119 &x_train,
120 &y_train,
121 Some((&x_test, &y_test, 10)),
122 32,
123 25,
124 Some(5),
125 );
126 plot::loss(
127 &train_loss,
128 &val_loss,
129 &val_acc,
130 "SKIP : Fashion-MNIST",
131 "./output/mnist-fashion/skip.png",
132 );
133
134 // Validate the network
135 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
136 println!(
137 "Final validation accuracy: {:.2} % and loss: {:.5}",
138 val_acc * 100.0,
139 val_loss
140 );
141
142 // Use the network
143 let prediction = network.predict(x_test.get(0).unwrap());
144 println!(
145 "Prediction on input: Target: {}. Output: {}.",
146 y_test[0].argmax(),
147 prediction.argmax()
148 );
149
150 // let x = x_test.get(5).unwrap();
151 // let y = y_test.get(5).unwrap();
152 // Plot the pre- and post-activation heatmaps for each (image) layer.
153 // let (pre, post, _) = network.forward(x);
154 // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
155 // let pre_title = format!("layer_{}_pre", i);
156 // let post_title = format!("layer_{}_post", i);
157 // let pre_file = format!("layer_{}_pre.png", i);
158 // let post_file = format!("layer_{}_post.png", i);
159 // plot::heatmap(&i_pre, &pre_title, &pre_file);
160 // plot::heatmap(&i_post, &post_title, &post_file);
161 // }
162}54fn main() {
55 let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
56 let y_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
57 let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
58 let y_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
59 println!(
60 "Train: {} images, Test: {} images",
61 x_train.len(),
62 x_test.len()
63 );
64
65 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
66 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
67 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
68 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
69
70 let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
71
72 network.convolution(
73 1,
74 (3, 3),
75 (1, 1),
76 (1, 1),
77 (1, 1),
78 activation::Activation::ReLU,
79 None,
80 );
81 network.convolution(
82 1,
83 (3, 3),
84 (1, 1),
85 (1, 1),
86 (1, 1),
87 activation::Activation::ReLU,
88 None,
89 );
90 network.convolution(
91 1,
92 (3, 3),
93 (1, 1),
94 (1, 1),
95 (1, 1),
96 activation::Activation::ReLU,
97 None,
98 );
99 network.maxpool((2, 2), (2, 2));
100 network.dense(10, activation::Activation::Softmax, true, None);
101
102 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
103 network.set_objective(
104 objective::Objective::CrossEntropy, // Objective function
105 None, // Gradient clipping
106 );
107
108 println!("{}", network);
109
110 // Train the network
111 let (train_loss, val_loss, val_acc) = network.learn(
112 &x_train,
113 &y_train,
114 Some((&x_test, &y_test, 10)),
115 32,
116 25,
117 Some(5),
118 );
119 plot::loss(
120 &train_loss,
121 &val_loss,
122 &val_acc,
123 "PLAIN : MNIST",
124 "./output/mnist/plain.png",
125 );
126
127 // Validate the network
128 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
129 println!(
130 "Final validation accuracy: {:.2} % and loss: {:.5}",
131 val_acc * 100.0,
132 val_loss
133 );
134
135 // Use the network
136 let prediction = network.predict(x_test.get(0).unwrap());
137 println!(
138 "Prediction on input: Target: {}. Output: {}.",
139 y_test[0].argmax(),
140 prediction.argmax()
141 );
142
143 // let x = x_test.get(5).unwrap();
144 // let y = y_test.get(5).unwrap();
145 // plot::heatmap(
146 // &x,
147 // &format!("Target: {}", y.argmax()),
148 // "./output/mnist/input.png",
149 // );
150
151 // Plot the pre- and post-activation heatmaps for each (image) layer.
152 // let (pre, post, _) = network.forward(x);
153 // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
154 // let pre_title = format!("layer_{}_pre", i);
155 // let post_title = format!("layer_{}_post", i);
156 // let pre_file = format!("layer_{}_pre.png", i);
157 // let post_file = format!("layer_{}_post.png", i);
158 // plot::heatmap(&i_pre, &pre_title, &pre_file);
159 // plot::heatmap(&i_post, &post_title, &post_file);
160 // }
161}54fn main() {
55 let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
56 let y_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
57 let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
58 let y_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
59 println!(
60 "Train: {} images, Test: {} images",
61 x_train.len(),
62 x_test.len()
63 );
64
65 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
66 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
67 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
68 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
69
70 let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
71
72 network.convolution(
73 1,
74 (3, 3),
75 (1, 1),
76 (1, 1),
77 (1, 1),
78 activation::Activation::ReLU,
79 None,
80 );
81 network.convolution(
82 1,
83 (3, 3),
84 (1, 1),
85 (1, 1),
86 (1, 1),
87 activation::Activation::ReLU,
88 None,
89 );
90 network.convolution(
91 1,
92 (3, 3),
93 (1, 1),
94 (1, 1),
95 (1, 1),
96 activation::Activation::ReLU,
97 None,
98 );
99 network.maxpool((2, 2), (2, 2));
100 network.dense(10, activation::Activation::Softmax, true, None);
101
102 network.connect(0, 3);
103
104 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
105 network.set_objective(
106 objective::Objective::CrossEntropy, // Objective function
107 None, // Gradient clipping
108 );
109
110 println!("{}", network);
111
112 // Train the network
113 let (train_loss, val_loss, val_acc) = network.learn(
114 &x_train,
115 &y_train,
116 Some((&x_test, &y_test, 10)),
117 32,
118 25,
119 Some(5),
120 );
121 plot::loss(
122 &train_loss,
123 &val_loss,
124 &val_acc,
125 "SKIP : MNIST",
126 "./output/mnist/skip.png",
127 );
128
129 // Validate the network
130 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
131 println!(
132 "Final validation accuracy: {:.2} % and loss: {:.5}",
133 val_acc * 100.0,
134 val_loss
135 );
136
137 // Use the network
138 let prediction = network.predict(x_test.get(0).unwrap());
139 println!(
140 "Prediction on input: Target: {}. Output: {}.",
141 y_test[0].argmax(),
142 prediction.argmax()
143 );
144
145 // let x = x_test.get(5).unwrap();
146 // let y = y_test.get(5).unwrap();
147 // plot::heatmap(
148 // &x,
149 // &format!("Target: {}", y.argmax()),
150 // "./output/mnist/input.png",
151 // );
152
153 // Plot the pre- and post-activation heatmaps for each (image) layer.
154 // let (pre, post, _) = network.forward(x);
155 // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
156 // let pre_title = format!("layer_{}_pre", i);
157 // let post_title = format!("layer_{}_post", i);
158 // let pre_file = format!("layer_{}_pre.png", i);
159 // let post_file = format!("layer_{}_post.png", i);
160 // plot::heatmap(&i_pre, &pre_title, &pre_file);
161 // plot::heatmap(&i_post, &post_title, &post_file);
162 // }
163}54fn main() {
55 let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
56 let y_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
57 let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
58 let y_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
59 println!(
60 "Train: {} images, Test: {} images",
61 x_train.len(),
62 x_test.len()
63 );
64
65 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
66 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
67 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
68 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
69
70 let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
71
72 network.convolution(
73 1,
74 (3, 3),
75 (1, 1),
76 (0, 0),
77 (1, 1),
78 activation::Activation::ReLU,
79 None,
80 );
81 network.maxpool((2, 2), (2, 2));
82 network.convolution(
83 4,
84 (3, 3),
85 (1, 1),
86 (0, 0),
87 (1, 1),
88 activation::Activation::ReLU,
89 None,
90 );
91 network.deconvolution(
92 4,
93 (3, 3),
94 (1, 1),
95 (0, 0),
96 activation::Activation::ReLU,
97 None,
98 );
99 network.maxpool((2, 2), (2, 2));
100 network.dense(10, activation::Activation::Softmax, true, None);
101
102 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
103 network.set_objective(
104 objective::Objective::CrossEntropy, // Objective function
105 None, // Gradient clipping
106 );
107
108 println!("{}", network);
109
110 // Train the network
111 let (train_loss, val_loss, val_acc) = network.learn(
112 &x_train,
113 &y_train,
114 Some((&x_test, &y_test, 10)),
115 32,
116 25,
117 Some(5),
118 );
119 plot::loss(
120 &train_loss,
121 &val_loss,
122 &val_acc,
123 "DECONV : MNIST",
124 "./output/mnist/deconvolution.png",
125 );
126
127 // Validate the network
128 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
129 println!(
130 "Final validation accuracy: {:.2} % and loss: {:.5}",
131 val_acc * 100.0,
132 val_loss
133 );
134
135 // Use the network
136 let prediction = network.predict(x_test.get(0).unwrap());
137 println!(
138 "Prediction on input: Target: {}. Output: {}.",
139 y_test[0].argmax(),
140 prediction.argmax()
141 );
142
143 // let x = x_test.get(5).unwrap();
144 // let y = y_test.get(5).unwrap();
145 // plot::heatmap(
146 // &x,
147 // &format!("Target: {}", y.argmax()),
148 // "./output/mnist/input.png",
149 // );
150
151 // Plot the pre- and post-activation heatmaps for each (image) layer.
152 // let (pre, post, _) = network.forward(x);
153 // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
154 // let pre_title = format!("layer_{}_pre", i);
155 // let post_title = format!("layer_{}_post", i);
156 // let pre_file = format!("layer_{}_pre.png", i);
157 // let post_file = format!("layer_{}_post.png", i);
158 // plot::heatmap(&i_pre, &pre_title, &pre_file);
159 // plot::heatmap(&i_post, &post_title, &post_file);
160 // }
161}54fn main() {
55 let x_train = load_mnist("./examples/datasets/mnist-fashion/train-images-idx3-ubyte").unwrap();
56 let y_train = load_labels(
57 "./examples/datasets/mnist-fashion/train-labels-idx1-ubyte",
58 10,
59 )
60 .unwrap();
61 let x_test = load_mnist("./examples/datasets/mnist-fashion/t10k-images-idx3-ubyte").unwrap();
62 let y_test = load_labels(
63 "./examples/datasets/mnist-fashion/t10k-labels-idx1-ubyte",
64 10,
65 )
66 .unwrap();
67 println!(
68 "Train: {} images, Test: {} images",
69 x_train.len(),
70 x_test.len()
71 );
72
73 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
74 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
75 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
76 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
77
78 let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
79
80 network.convolution(
81 1,
82 (3, 3),
83 (1, 1),
84 (1, 1),
85 (1, 1),
86 activation::Activation::ReLU,
87 None,
88 );
89 network.feedback(
90 vec![feedback::Layer::Convolution(
91 1,
92 activation::Activation::ReLU,
93 (3, 3),
94 (1, 1),
95 (1, 1),
96 (1, 1),
97 None,
98 )],
99 3,
100 false,
101 false,
102 feedback::Accumulation::Mean,
103 );
104 network.convolution(
105 1,
106 (3, 3),
107 (1, 1),
108 (1, 1),
109 (1, 1),
110 activation::Activation::ReLU,
111 None,
112 );
113 network.maxpool((2, 2), (2, 2));
114 network.dense(10, activation::Activation::Softmax, true, None);
115
116 // Include skip connection bypassing the feedback block
117 // network.connect(1, 2);
118 // network.set_accumulation(feedback::Accumulation::Add);
119
120 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
121 network.set_objective(objective::Objective::CrossEntropy, None);
122
123 println!("{}", network);
124
125 // Train the network
126 let (train_loss, val_loss, val_acc) = network.learn(
127 &x_train,
128 &y_train,
129 Some((&x_test, &y_test, 10)),
130 32,
131 25,
132 Some(5),
133 );
134 plot::loss(
135 &train_loss,
136 &val_loss,
137 &val_acc,
138 "FEEDBACK : Fashion-MNIST",
139 "./output/mnist-fashion/feedback.png",
140 );
141
142 // Validate the network
143 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
144 println!(
145 "Final validation accuracy: {:.2} % and loss: {:.5}",
146 val_acc * 100.0,
147 val_loss
148 );
149
150 // Use the network
151 let prediction = network.predict(x_test.get(0).unwrap());
152 println!(
153 "Prediction on input: Target: {}. Output: {}.",
154 y_test[0].argmax(),
155 prediction.argmax()
156 );
157
158 // let x = x_test.get(5).unwrap();
159 // let y = y_test.get(5).unwrap();
160 // Plot the pre- and post-activation heatmaps for each (image) layer.
161 // let (pre, post, _) = network.forward(x);
162 // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
163 // let pre_title = format!("layer_{}_pre", i);
164 // let post_title = format!("layer_{}_post", i);
165 // let pre_file = format!("layer_{}_pre.png", i);
166 // let post_file = format!("layer_{}_post.png", i);
167 // plot::heatmap(&i_pre, &pre_title, &pre_file);
168 // plot::heatmap(&i_post, &post_title, &post_file);
169 // }
170}- examples/mnist-fashion/plain.rs
- examples/mnist/feedback.rs
- examples/cifar10/plain.rs
- examples/mnist-fashion/looping.rs
- examples/cifar10/looping.rs
- examples/mnist/looping.rs
- examples/ftir/mlp/plain.rs
- examples/cifar10/feedback.rs
- examples/ftir/cnn/plain.rs
- examples/ftir/mlp/skip.rs
- examples/ftir/mlp/looping.rs
- examples/ftir/cnn/skip.rs
- examples/ftir/cnn/looping.rs
- examples/ftir/cnn/feedback.rs
- examples/ftir/mlp/feedback.rs
Sourcepub fn add_inplace(&mut self, other: &Tensor)
pub fn add_inplace(&mut self, other: &Tensor)
Inplace element-wise addition of two Tensors.
Validates their shapes beforehand.
Sourcepub fn sub_inplace(&mut self, other: &Tensor)
pub fn sub_inplace(&mut self, other: &Tensor)
Inplace element-wise subtraction of two Tensors.
Validates their shapes beforehand.
Sourcepub fn mul_inplace(&mut self, other: &Tensor)
pub fn mul_inplace(&mut self, other: &Tensor)
Inplace element-wise multiplication of two Tensors.
Validates their shapes beforehand.
pub fn div_scalar_inplace(&mut self, scalar: f32)
Sourcepub fn mean_inplace(&mut self, others: &Vec<&Tensor>)
pub fn mean_inplace(&mut self, others: &Vec<&Tensor>)
Inplace element-wise mean of two Tensors.
Validates their shapes beforehand.
Sourcepub fn hadamard(&mut self, other: &Tensor, scalar: f32)
pub fn hadamard(&mut self, other: &Tensor, scalar: f32)
Multiply the i == j elements of two Tensors of the same shape together.
Validates their shapes beforehand.
§Arguments
other- TheTensorto multiply with.scalar- A scalar value to multiply the result by (e.g.,1.0 / self.loops).
Sourcepub fn dot(&self, other: &Tensor) -> Self
pub fn dot(&self, other: &Tensor) -> Self
Dot product of two Tensors.
This Tensor must be Shape::Double and the other Tensor must be Shape::Single.
This Tensors columns must be equal to the other Tensors rows.
Sourcepub fn dropout(&mut self, dropout: f32)
pub fn dropout(&mut self, dropout: f32)
Randomly set elements of the Tensor to zero with a given probability dropout.
Sourcepub fn clamp(self, min: f32, max: f32) -> Self
pub fn clamp(self, min: f32, max: f32) -> Self
Clamp the values of the Tensor to a given interval [min, max].
pub fn extend(&mut self, other: &Tensor)
Trait Implementations§
Auto Trait Implementations§
impl Freeze for Tensor
impl RefUnwindSafe for Tensor
impl Send for Tensor
impl Sync for Tensor
impl Unpin for Tensor
impl UnwindSafe for Tensor
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more