rust-lstm 0.8.0

A complete LSTM neural network library with training capabilities, multiple optimizers, and peephole variants.
Documentation
use ndarray::Array2;
use rust_lstm::models::lstm_network::LSTMNetwork;
use rust_lstm::layers::linear::LinearLayer;
use rust_lstm::text::{TextVocabulary, CharacterEmbedding, sample_with_temperature};
use rust_lstm::training::LSTMTrainer;
use rust_lstm::loss::CrossEntropyLoss;
use rust_lstm::optimizers::Adam;
use std::collections::HashMap;

struct CharacterLSTM {
    vocab: TextVocabulary,
    embedding: CharacterEmbedding,
    network: LSTMNetwork,
    output_layer: LinearLayer,
    trainer: Option<LSTMTrainer<CrossEntropyLoss, Adam>>,
    hidden_size: usize,
    sequence_length: usize,
}

impl CharacterLSTM {
    fn new(text: &str, sequence_length: usize, hidden_size: usize, embed_dim: usize) -> Self {
        let vocab = TextVocabulary::from_text(text);
        let embedding = CharacterEmbedding::new(vocab.size(), embed_dim);
        let network = LSTMNetwork::new(embed_dim, hidden_size, 1);
        let output_layer = LinearLayer::new(hidden_size, vocab.size());

        println!("Vocabulary size: {}", vocab.size());
        println!("Network: embed({}) -> LSTM({}) -> Linear({})", embed_dim, hidden_size, vocab.size());

        Self {
            vocab,
            embedding,
            network,
            output_layer,
            trainer: None,
            hidden_size,
            sequence_length,
        }
    }

    fn create_sequences(&self, text: &str) -> Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)> {
        let chars: Vec<char> = text.chars().collect();
        let mut sequences = Vec::new();

        for i in 0..chars.len().saturating_sub(self.sequence_length + 1) {
            let mut inputs = Vec::new();
            let mut targets = Vec::new();

            for j in 0..self.sequence_length {
                let char_idx = self.vocab.char_to_index(chars[i + j]).unwrap_or(0);
                let next_idx = self.vocab.char_to_index(chars[i + j + 1]).unwrap_or(0);

                let emb = self.embedding.lookup(char_idx);
                let input = Array2::from_shape_vec((emb.len(), 1), emb.to_vec()).unwrap();
                inputs.push(input);

                let mut target = Array2::zeros((self.hidden_size, 1));
                target[[next_idx % self.hidden_size, 0]] = 1.0;
                targets.push(target);
            }

            sequences.push((inputs, targets));
        }

        sequences
    }

    fn train(&mut self, text: &str, epochs: usize) {
        println!("Creating sequences...");
        let sequences = self.create_sequences(text);

        if sequences.is_empty() {
            println!("No sequences created!");
            return;
        }

        let split = (sequences.len() as f64 * 0.9) as usize;
        let (train, val) = sequences.split_at(split);

        println!("Training on {} sequences, validating on {}", train.len(), val.len());

        let loss_fn = CrossEntropyLoss;
        let optimizer = Adam::new(0.002);
        let mut trainer = LSTMTrainer::new(self.network.clone(), loss_fn, optimizer);

        let mut config = rust_lstm::training::TrainingConfig::default();
        config.epochs = epochs;
        config.print_every = epochs / 5;
        config.clip_gradient = Some(5.0);
        trainer = trainer.with_config(config);

        trainer.train(train, if val.is_empty() { None } else { Some(val) });
        self.trainer = Some(trainer);

        println!("Training complete!");
    }

    fn generate(&mut self, seed: &str, length: usize, temperature: f64) -> String {
        if self.trainer.is_none() {
            return seed.to_string();
        }

        let mut generated = seed.to_string();
        let mut chars: Vec<char> = seed.chars().collect();

        while chars.len() < self.sequence_length {
            chars.insert(0, ' ');
        }

        let network = &self.trainer.as_ref().unwrap().network;
        let mut inference_net = network.clone();
        inference_net.eval();

        for _ in 0..length {
            let start = chars.len().saturating_sub(self.sequence_length);
            let window: Vec<char> = chars[start..].to_vec();

            let mut h = Array2::zeros((self.hidden_size, 1));
            let mut c = Array2::zeros((self.hidden_size, 1));

            for ch in &window {
                let idx = self.vocab.char_to_index(*ch).unwrap_or(0);
                let emb = self.embedding.lookup(idx);
                let input = Array2::from_shape_vec((emb.len(), 1), emb.to_vec()).unwrap();
                let (new_h, new_c) = inference_net.forward(&input, &h, &c);
                h = new_h;
                c = new_c;
            }

            let logits_2d = self.output_layer.forward(&h);
            let logits = logits_2d.column(0).to_owned();

            let next_idx = sample_with_temperature(&logits, temperature);
            if let Some(next_char) = self.vocab.index_to_char(next_idx) {
                generated.push(next_char);
                chars.push(next_char);
            }
        }

        generated
    }
}

fn get_sample_texts() -> HashMap<&'static str, &'static str> {
    let mut texts = HashMap::new();

    texts.insert("poetry",
        "The woods are lovely, dark and deep, But I have promises to keep, \
         And miles to go before I sleep, And miles to go before I sleep. \
         Two roads diverged in a yellow wood, And sorry I could not travel both.");

    texts.insert("code",
        "fn main() { println!(\"Hello, world!\"); let x = 42; \
         if x > 10 { println!(\"x is greater\"); } for i in 0..5 { println!(\"i = {}\", i); } }");

    texts.insert("prose",
        "In a hole in the ground there lived a hobbit. Not a nasty, dirty, wet hole, \
         filled with the ends of worms and an oozy smell, nor yet a dry, bare, sandy hole.");

    texts
}

fn main() {
    println!("Character-Level LSTM with Text Utilities");
    println!("=========================================\n");

    let texts = get_sample_texts();

    for (domain, text) in &texts {
        println!("Domain: {}", domain);
        println!("Text: {}...\n", &text[..text.len().min(60)]);

        let mut model = CharacterLSTM::new(text, 8, 32, 16);
        model.train(text, 10);

        println!("\nGenerated samples:");
        for temp in [0.5, 1.0, 1.5] {
            let seed: String = text.chars().take(5).collect();
            let output = model.generate(&seed, 50, temp);
            println!("  temp={:.1}: {}", temp, output);
        }

        println!("\n{}\n", "=".repeat(50));
    }
}