use ndarray::Array;
use rand::{Rng, RngExt};
use scirs2_neural::error::Result;
use scirs2_neural::{
models::architectures::{RNNCellType, Seq2Seq, Seq2SeqConfig},
prelude::*,
};
#[allow(dead_code)]
fn main() -> Result<()> {
println!("Sequence-to-Sequence (Seq2Seq) Model Example");
println!("--------------------------------------------");
let src_vocab_size = 10000; let tgt_vocab_size = 8000; let inputshape = [2, 10];
let mut input_seq = Array::<f32>::zeros(inputshape).into_dyn();
let mut rng = rand::rng();
for elem in input_seq.iter_mut() {
*elem = (rng.random_range(0.0..1.0) * (src_vocab_size as f32 - 1.0)).floor();
}
let targetshape = [2, 8];
let mut target_seq = Array::<f32>::zeros(targetshape).into_dyn();
for elem in target_seq.iter_mut() {
*elem = (rng.random_range(0.0..1.0) * (tgt_vocab_size as f32 - 1.0)).floor();
println!("\nCreating Basic Translation Model...");
let mut translation_model = Seq2Seq::create_translation_model(
src_vocab_size..tgt_vocab_size,
256, println!("Running forward pass with teacher forcing...");
let train_output = translation_model.forward_train(&input_seq, &target_seq)?;
println!("Training output shape: {:?}", train_output.shape());
println!("\nGenerating sequences...");
let generated = translation_model.generate(
&input_seq,
Some(15), 1, Some(2), println!("Generated sequence shape: {:?}", generated.shape());
println!("Generated sequences (token IDs):");
for b in 0..generated.shape()[0] {
print!(" Sequence {}: ", b);
for t in 0..generated.shape()[1] {
if generated[[b, t]] > 0.0 {
print!("{} ", generated[[b, t]]);
}
}
println!();
println!("\nCreating Custom Seq2Seq Model...");
let custom_config = Seq2SeqConfig {
input_vocab_size: src_vocab_size,
output_vocab_size: tgt_vocab_size,
embedding_dim: 128,
hidden_dim: 256,
num_layers: 2,
encoder_cell_type: RNNCellType::GRU,
decoder_cell_type: RNNCellType::LSTM, bidirectional_encoder: true,
use_attention: true,
dropout_rate: 0.2,
max_seq_len: 50,
};
let custom_model = Seq2Seq::<f32>::new(custom_config)?;
println!("Custom model created successfully.");
println!("\nCreating Small Seq2Seq Model...");
let small_model = Seq2Seq::create_small_model(src_vocab_size, tgt_vocab_size)?;
let small_generated = small_model.generate(&input_seq, Some(10), 1, Some(2))?;
println!(
"Small model generated sequence shape: {:?}",
small_generated.shape()
);
println!("\nDemonstrating Training/Inference Mode Switching:");
translation_model.set_training(true);
println!("Is in training mode: {}", translation_model.is_training());
translation_model.set_training(false);
"Is in training mode after switching: {}",
translation_model.is_training()
println!("\nModel Parameter Counts:");
"Translation model parameters: {}",
translation_model.params().len()
println!("Custom model parameters: {}", custom_model.params().len());
println!("Small model parameters: {}", small_model.params().len());
println!("\nSeq2Seq Example Completed Successfully!");
Ok(())
}