pub struct NeuralNetwork<L, C>where
L: Layer,{
pub layers: Vec<L>,
pub cost: C,
}Fields§
§layers: Vec<L>§cost: CImplementations§
Source§impl<L, C, D> NeuralNetwork<L, C>where
L: Layer<Input = D, Output = D> + Serialize + for<'de> Deserialize<'de>,
C: Cost<D>,
D: Dimension + RemoveAxis,
impl<L, C, D> NeuralNetwork<L, C>where
L: Layer<Input = D, Output = D> + Serialize + for<'de> Deserialize<'de>,
C: Cost<D>,
D: Dimension + RemoveAxis,
Sourcepub fn new(layers: Vec<L>, cost: C) -> Self
pub fn new(layers: Vec<L>, cost: C) -> Self
Examples found in repository?
examples/mnist.rs (line 75)
61fn main() {
62 let model_path = "mnist_model.bin";
63
64 let mut nn = if PathBuf::from(model_path).exists() {
65 println!("Loading existing model...");
66 NeuralNetwork::load(model_path, MSE).expect("Failed to load model")
67 } else {
68 println!("Creating new model...");
69 let output_size = 10;
70 let input_size = 28 * 28;
71
72 let dense_layer_1 = DenseLayer::new(input_size, 128, Sigmoid);
73 let dense_layer_2 = DenseLayer::new(128, output_size, Sigmoid);
74
75 NeuralNetwork::new(vec![dense_layer_1, dense_layer_2], MSE)
76 };
77
78 let train_images_path = PathBuf::from("./train-images.idx3-ubyte");
79 let train_labels_path = PathBuf::from("./train-labels.idx1-ubyte");
80
81 let (images, labels) = match load_mnist_data(train_images_path, train_labels_path) {
82 Ok(data) => data,
83 Err(e) => {
84 eprintln!("Error loading MNIST data: {}", e);
85 return;
86 }
87 };
88
89 println!("Loaded {} training images", images.shape()[0]);
90
91 let learning_rate = 0.01;
92 let num_epochs = 10;
93 let batch_size = 32;
94
95 println!("\nTraining with batch size {}...", batch_size);
96 nn.train(&images, &labels, learning_rate, num_epochs, batch_size);
97
98 println!("\nSaving model to {}...", model_path);
99 nn.save(model_path).expect("Failed to save model");
100
101 let test_images_path = PathBuf::from("./t10k-images.idx3-ubyte");
102 let test_labels_path = PathBuf::from("./t10k-labels.idx1-ubyte");
103
104 let (test_images, test_labels) = match load_mnist_data(test_images_path, test_labels_path) {
105 Ok(data) => data,
106 Err(e) => {
107 eprintln!("Error loading test data: {}", e);
108 return;
109 }
110 };
111
112 let accuracy = nn.accuracy(&test_images, &test_labels);
113 println!("\nTest accuracy: {:.2}%", accuracy * 100.0);
114}Sourcepub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()>
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()>
Examples found in repository?
examples/mnist.rs (line 99)
61fn main() {
62 let model_path = "mnist_model.bin";
63
64 let mut nn = if PathBuf::from(model_path).exists() {
65 println!("Loading existing model...");
66 NeuralNetwork::load(model_path, MSE).expect("Failed to load model")
67 } else {
68 println!("Creating new model...");
69 let output_size = 10;
70 let input_size = 28 * 28;
71
72 let dense_layer_1 = DenseLayer::new(input_size, 128, Sigmoid);
73 let dense_layer_2 = DenseLayer::new(128, output_size, Sigmoid);
74
75 NeuralNetwork::new(vec![dense_layer_1, dense_layer_2], MSE)
76 };
77
78 let train_images_path = PathBuf::from("./train-images.idx3-ubyte");
79 let train_labels_path = PathBuf::from("./train-labels.idx1-ubyte");
80
81 let (images, labels) = match load_mnist_data(train_images_path, train_labels_path) {
82 Ok(data) => data,
83 Err(e) => {
84 eprintln!("Error loading MNIST data: {}", e);
85 return;
86 }
87 };
88
89 println!("Loaded {} training images", images.shape()[0]);
90
91 let learning_rate = 0.01;
92 let num_epochs = 10;
93 let batch_size = 32;
94
95 println!("\nTraining with batch size {}...", batch_size);
96 nn.train(&images, &labels, learning_rate, num_epochs, batch_size);
97
98 println!("\nSaving model to {}...", model_path);
99 nn.save(model_path).expect("Failed to save model");
100
101 let test_images_path = PathBuf::from("./t10k-images.idx3-ubyte");
102 let test_labels_path = PathBuf::from("./t10k-labels.idx1-ubyte");
103
104 let (test_images, test_labels) = match load_mnist_data(test_images_path, test_labels_path) {
105 Ok(data) => data,
106 Err(e) => {
107 eprintln!("Error loading test data: {}", e);
108 return;
109 }
110 };
111
112 let accuracy = nn.accuracy(&test_images, &test_labels);
113 println!("\nTest accuracy: {:.2}%", accuracy * 100.0);
114}Sourcepub fn load<P: AsRef<Path>>(path: P, cost: C) -> Result<Self>
pub fn load<P: AsRef<Path>>(path: P, cost: C) -> Result<Self>
Examples found in repository?
examples/mnist.rs (line 66)
61fn main() {
62 let model_path = "mnist_model.bin";
63
64 let mut nn = if PathBuf::from(model_path).exists() {
65 println!("Loading existing model...");
66 NeuralNetwork::load(model_path, MSE).expect("Failed to load model")
67 } else {
68 println!("Creating new model...");
69 let output_size = 10;
70 let input_size = 28 * 28;
71
72 let dense_layer_1 = DenseLayer::new(input_size, 128, Sigmoid);
73 let dense_layer_2 = DenseLayer::new(128, output_size, Sigmoid);
74
75 NeuralNetwork::new(vec![dense_layer_1, dense_layer_2], MSE)
76 };
77
78 let train_images_path = PathBuf::from("./train-images.idx3-ubyte");
79 let train_labels_path = PathBuf::from("./train-labels.idx1-ubyte");
80
81 let (images, labels) = match load_mnist_data(train_images_path, train_labels_path) {
82 Ok(data) => data,
83 Err(e) => {
84 eprintln!("Error loading MNIST data: {}", e);
85 return;
86 }
87 };
88
89 println!("Loaded {} training images", images.shape()[0]);
90
91 let learning_rate = 0.01;
92 let num_epochs = 10;
93 let batch_size = 32;
94
95 println!("\nTraining with batch size {}...", batch_size);
96 nn.train(&images, &labels, learning_rate, num_epochs, batch_size);
97
98 println!("\nSaving model to {}...", model_path);
99 nn.save(model_path).expect("Failed to save model");
100
101 let test_images_path = PathBuf::from("./t10k-images.idx3-ubyte");
102 let test_labels_path = PathBuf::from("./t10k-labels.idx1-ubyte");
103
104 let (test_images, test_labels) = match load_mnist_data(test_images_path, test_labels_path) {
105 Ok(data) => data,
106 Err(e) => {
107 eprintln!("Error loading test data: {}", e);
108 return;
109 }
110 };
111
112 let accuracy = nn.accuracy(&test_images, &test_labels);
113 println!("\nTest accuracy: {:.2}%", accuracy * 100.0);
114}pub fn forward( &mut self, input: &ArrayBase<OwnedRepr<f32>, D>, ) -> ArrayBase<OwnedRepr<f32>, D>
pub fn backward( &mut self, grad_output: &ArrayBase<OwnedRepr<f32>, D>, learning_rate: f32, )
Sourcepub fn train(
&mut self,
inputs: &ArrayBase<OwnedRepr<f32>, D>,
targets: &ArrayBase<OwnedRepr<f32>, D>,
learning_rate: f32,
epochs: usize,
batch_size: usize,
)where
D: RemoveAxis,
pub fn train(
&mut self,
inputs: &ArrayBase<OwnedRepr<f32>, D>,
targets: &ArrayBase<OwnedRepr<f32>, D>,
learning_rate: f32,
epochs: usize,
batch_size: usize,
)where
D: RemoveAxis,
Examples found in repository?
examples/mnist.rs (line 96)
61fn main() {
62 let model_path = "mnist_model.bin";
63
64 let mut nn = if PathBuf::from(model_path).exists() {
65 println!("Loading existing model...");
66 NeuralNetwork::load(model_path, MSE).expect("Failed to load model")
67 } else {
68 println!("Creating new model...");
69 let output_size = 10;
70 let input_size = 28 * 28;
71
72 let dense_layer_1 = DenseLayer::new(input_size, 128, Sigmoid);
73 let dense_layer_2 = DenseLayer::new(128, output_size, Sigmoid);
74
75 NeuralNetwork::new(vec![dense_layer_1, dense_layer_2], MSE)
76 };
77
78 let train_images_path = PathBuf::from("./train-images.idx3-ubyte");
79 let train_labels_path = PathBuf::from("./train-labels.idx1-ubyte");
80
81 let (images, labels) = match load_mnist_data(train_images_path, train_labels_path) {
82 Ok(data) => data,
83 Err(e) => {
84 eprintln!("Error loading MNIST data: {}", e);
85 return;
86 }
87 };
88
89 println!("Loaded {} training images", images.shape()[0]);
90
91 let learning_rate = 0.01;
92 let num_epochs = 10;
93 let batch_size = 32;
94
95 println!("\nTraining with batch size {}...", batch_size);
96 nn.train(&images, &labels, learning_rate, num_epochs, batch_size);
97
98 println!("\nSaving model to {}...", model_path);
99 nn.save(model_path).expect("Failed to save model");
100
101 let test_images_path = PathBuf::from("./t10k-images.idx3-ubyte");
102 let test_labels_path = PathBuf::from("./t10k-labels.idx1-ubyte");
103
104 let (test_images, test_labels) = match load_mnist_data(test_images_path, test_labels_path) {
105 Ok(data) => data,
106 Err(e) => {
107 eprintln!("Error loading test data: {}", e);
108 return;
109 }
110 };
111
112 let accuracy = nn.accuracy(&test_images, &test_labels);
113 println!("\nTest accuracy: {:.2}%", accuracy * 100.0);
114}Sourcepub fn accuracy(
&mut self,
test_data: &ArrayBase<OwnedRepr<f32>, D>,
test_labels: &ArrayBase<OwnedRepr<f32>, D>,
) -> f32where
D: RemoveAxis,
pub fn accuracy(
&mut self,
test_data: &ArrayBase<OwnedRepr<f32>, D>,
test_labels: &ArrayBase<OwnedRepr<f32>, D>,
) -> f32where
D: RemoveAxis,
Examples found in repository?
examples/mnist.rs (line 112)
61fn main() {
62 let model_path = "mnist_model.bin";
63
64 let mut nn = if PathBuf::from(model_path).exists() {
65 println!("Loading existing model...");
66 NeuralNetwork::load(model_path, MSE).expect("Failed to load model")
67 } else {
68 println!("Creating new model...");
69 let output_size = 10;
70 let input_size = 28 * 28;
71
72 let dense_layer_1 = DenseLayer::new(input_size, 128, Sigmoid);
73 let dense_layer_2 = DenseLayer::new(128, output_size, Sigmoid);
74
75 NeuralNetwork::new(vec![dense_layer_1, dense_layer_2], MSE)
76 };
77
78 let train_images_path = PathBuf::from("./train-images.idx3-ubyte");
79 let train_labels_path = PathBuf::from("./train-labels.idx1-ubyte");
80
81 let (images, labels) = match load_mnist_data(train_images_path, train_labels_path) {
82 Ok(data) => data,
83 Err(e) => {
84 eprintln!("Error loading MNIST data: {}", e);
85 return;
86 }
87 };
88
89 println!("Loaded {} training images", images.shape()[0]);
90
91 let learning_rate = 0.01;
92 let num_epochs = 10;
93 let batch_size = 32;
94
95 println!("\nTraining with batch size {}...", batch_size);
96 nn.train(&images, &labels, learning_rate, num_epochs, batch_size);
97
98 println!("\nSaving model to {}...", model_path);
99 nn.save(model_path).expect("Failed to save model");
100
101 let test_images_path = PathBuf::from("./t10k-images.idx3-ubyte");
102 let test_labels_path = PathBuf::from("./t10k-labels.idx1-ubyte");
103
104 let (test_images, test_labels) = match load_mnist_data(test_images_path, test_labels_path) {
105 Ok(data) => data,
106 Err(e) => {
107 eprintln!("Error loading test data: {}", e);
108 return;
109 }
110 };
111
112 let accuracy = nn.accuracy(&test_images, &test_labels);
113 println!("\nTest accuracy: {:.2}%", accuracy * 100.0);
114}Auto Trait Implementations§
impl<L, C> Freeze for NeuralNetwork<L, C>where
C: Freeze,
impl<L, C> RefUnwindSafe for NeuralNetwork<L, C>where
C: RefUnwindSafe,
L: RefUnwindSafe,
impl<L, C> Send for NeuralNetwork<L, C>
impl<L, C> Sync for NeuralNetwork<L, C>
impl<L, C> Unpin for NeuralNetwork<L, C>
impl<L, C> UnwindSafe for NeuralNetwork<L, C>where
C: UnwindSafe,
L: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more