use ndarray::arr2;
use rust_lstm::layers::linear::LinearLayer;
use rust_lstm::optimizers::{SGD, Adam};
use rust_lstm::models::lstm_network::LSTMNetwork;
fn basic_classification_example() {
println!("=== Basic Classification Example ===");
let mut linear = LinearLayer::new(4, 3);
let mut optimizer = SGD::new(0.1);
let input = arr2(&[
[1.0, 0.5], [0.8, -0.2], [1.2, 0.9], [-0.1, 0.3] ]);
let targets = arr2(&[
[1.0, 0.0], [0.0, 1.0], [0.0, 0.0] ]);
println!("Input shape: {:?}", input.shape());
println!("Target shape: {:?}", targets.shape());
for epoch in 0..10 {
let output = linear.forward(&input);
let loss = (&output - &targets).map(|x| x * x).sum() / (output.len() as f64);
let grad_output = 2.0 * (&output - &targets) / (output.len() as f64);
let (gradients, _input_grad) = linear.backward(&grad_output);
linear.update_parameters(&gradients, &mut optimizer, "classifier");
if epoch % 2 == 0 {
println!("Epoch {}: Loss = {:.4}", epoch, loss);
}
}
let final_output = linear.forward(&input);
println!("Final output:\n{:.3}", final_output);
println!("Target:\n{:.3}", targets);
println!();
}
fn lstm_with_linear_example() {
println!("=== LSTM + LinearLayer Example ===");
let mut lstm = LSTMNetwork::new(5, 8, 1);
let mut classifier = LinearLayer::new(8, 3);
let mut optimizer = Adam::new(0.001);
let sequence = vec![
arr2(&[[1.0], [0.5], [0.2], [0.8], [0.1]]), arr2(&[[0.9], [0.6], [0.3], [0.7], [0.2]]), arr2(&[[0.8], [0.7], [0.4], [0.6], [0.3]]), arr2(&[[0.7], [0.8], [0.5], [0.5], [0.4]]), ];
let target = arr2(&[[0.0], [1.0], [0.0]]);
println!("Sequence length: {}", sequence.len());
println!("Input features: {}", sequence[0].nrows());
println!("LSTM hidden size: {}", 8);
println!("Output classes: {}", target.nrows());
for epoch in 0..20 {
let (lstm_outputs, _) = lstm.forward_sequence_with_cache(&sequence);
let last_hidden = &lstm_outputs.last().unwrap().0;
let class_logits = classifier.forward(last_hidden);
let loss = (&class_logits - &target).map(|x| x * x).sum() / (class_logits.len() as f64);
let grad_output = 2.0 * (&class_logits - &target) / (class_logits.len() as f64);
let (linear_grads, _lstm_grad) = classifier.backward(&grad_output);
classifier.update_parameters(&linear_grads, &mut optimizer, "classifier");
if epoch % 5 == 0 {
println!("Epoch {}: Loss = {:.4}", epoch, loss);
}
}
let (final_lstm_outputs, _) = lstm.forward_sequence_with_cache(&sequence);
let final_hidden = &final_lstm_outputs.last().unwrap().0;
let final_prediction = classifier.forward(final_hidden);
println!("Final prediction: [{:.3}, {:.3}, {:.3}]",
final_prediction[[0, 0]], final_prediction[[1, 0]], final_prediction[[2, 0]]);
println!("Target: [{:.3}, {:.3}, {:.3}]",
target[[0, 0]], target[[1, 0]], target[[2, 0]]);
println!();
}
fn multilayer_perceptron_example() {
println!("=== Multi-Layer Perceptron Example ===");
let mut layer1 = LinearLayer::new(2, 4);
let mut layer2 = LinearLayer::new(4, 4);
let mut layer3 = LinearLayer::new(4, 1);
let mut optimizer = Adam::new(0.01);
let inputs = arr2(&[
[0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0] ]);
let targets = arr2(&[[0.0, 1.0, 1.0, 0.0]]);
println!("Training MLP on XOR problem...");
println!("Input shape: {:?}", inputs.shape());
println!("Target shape: {:?}", targets.shape());
for epoch in 0..100 {
let h1 = layer1.forward(&inputs);
let h1_relu = h1.map(|&x| if x > 0.0 { x } else { 0.0 });
let h2 = layer2.forward(&h1_relu);
let h2_relu = h2.map(|&x| if x > 0.0 { x } else { 0.0 });
let output = layer3.forward(&h2_relu);
let loss = (&output - &targets).map(|x| x * x).sum() / (output.len() as f64);
let grad_output = 2.0 * (&output - &targets) / (output.len() as f64);
let (grad3, grad_h2) = layer3.backward(&grad_output);
let grad_h2_relu = &grad_h2 * &h2.map(|&x| if x > 0.0 { 1.0 } else { 0.0 });
let (grad2, grad_h1) = layer2.backward(&grad_h2_relu);
let grad_h1_relu = &grad_h1 * &h1.map(|&x| if x > 0.0 { 1.0 } else { 0.0 });
let (grad1, _) = layer1.backward(&grad_h1_relu);
layer1.update_parameters(&grad1, &mut optimizer, "layer1");
layer2.update_parameters(&grad2, &mut optimizer, "layer2");
layer3.update_parameters(&grad3, &mut optimizer, "layer3");
if epoch % 20 == 0 {
println!("Epoch {}: Loss = {:.4}", epoch, loss);
}
}
let h1 = layer1.forward(&inputs);
let h1_relu = h1.map(|&x| if x > 0.0 { x } else { 0.0 });
let h2 = layer2.forward(&h1_relu);
let h2_relu = h2.map(|&x| if x > 0.0 { x } else { 0.0 });
let final_output = layer3.forward(&h2_relu);
println!("Final predictions:");
for i in 0..4 {
let input_vals = (inputs[[0, i]], inputs[[1, i]]);
let prediction = final_output[[0, i]];
let target_val = targets[[0, i]];
println!(" {:?} -> {:.3} (target: {:.1})", input_vals, prediction, target_val);
}
println!();
}
fn initialization_example() {
println!("=== Initialization Methods Example ===");
let layer_random = LinearLayer::new(3, 2);
println!("Random initialization:");
println!(" Weight range: [{:.3}, {:.3}]",
layer_random.weight.iter().min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(),
layer_random.weight.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap());
let layer_zeros = LinearLayer::new_zeros(3, 2);
println!("Zero initialization:");
println!(" All weights: {}", layer_zeros.weight.iter().all(|&x| x == 0.0));
let custom_weights = arr2(&[[1.0, 0.5, -0.2], [0.8, -0.1, 0.3]]);
let custom_bias = arr2(&[[0.1], [-0.05]]);
let layer_custom = LinearLayer::from_weights(custom_weights.clone(), custom_bias.clone());
println!("Custom initialization:");
println!(" Custom weights shape: {:?}", layer_custom.weight.shape());
println!(" Custom bias shape: {:?}", layer_custom.bias.shape());
println!("Layer dimensions: {:?}", layer_custom.dimensions());
println!("Number of parameters: {}", layer_custom.num_parameters());
println!();
}
fn main() {
println!("LinearLayer Examples");
println!("===================\n");
basic_classification_example();
lstm_with_linear_example();
multilayer_perceptron_example();
initialization_example();
println!("All examples completed successfully! 🎉");
println!("\nKey takeaways:");
println!("- LinearLayer enables standard neural network architectures");
println!("- Works seamlessly with LSTM networks for classification");
println!("- Supports multiple initialization methods");
println!("- Integrates with all existing optimizers");
println!("- Essential for text generation and classification tasks");
}