use rust_lstm::{LSTMNetwork, ModelMetadata, PersistentModel};
use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("🔍 LSTM Model Internal Structure Inspection");
println!("============================================\n");
println!("🏗️ Creating a simple LSTM model...");
let input_size = 3; let hidden_size = 4; let num_layers = 2;
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
let metadata = ModelMetadata {
model_name: "Language Model Example".to_string(),
version: "1.0.0".to_string(),
created_at: chrono::Utc::now().to_rfc3339(),
input_size,
hidden_size,
num_layers,
total_epochs: 0,
final_loss: None,
description: Some("Example LSTM for language modeling inspection".to_string()),
};
std::fs::create_dir_all("models")?;
network.save("models/inspection_model.json", metadata.clone())?;
println!("📊 WHAT'S STORED IN AN LSTM MODEL:");
println!("=================================\n");
inspect_architecture(&network);
inspect_parameters(&network);
inspect_file_contents()?;
explain_llm_context();
calculate_parameter_count(&network);
Ok(())
}
fn inspect_architecture(network: &LSTMNetwork) {
println!("🏗️ NETWORK ARCHITECTURE:");
println!(" 📐 Input Size: {} (size of input vectors, e.g., word embeddings)", network.input_size);
println!(" 🧠 Hidden Size: {} (memory capacity per layer)", network.hidden_size);
println!(" 📚 Number of Layers: {} (depth of the network)", network.num_layers);
println!(" 🔄 Total Cells: {} LSTM cells", network.num_layers);
println!();
}
fn inspect_parameters(network: &LSTMNetwork) {
println!("⚙️ PARAMETERS STORED FOR EACH LSTM CELL:");
println!("==========================================");
for (layer_idx, cell) in network.get_cells().iter().enumerate() {
println!("📦 Layer {} LSTM Cell:", layer_idx);
println!(" 🔢 w_ih (Input→Hidden weights): {}×{} = {} parameters",
cell.w_ih.shape()[0], cell.w_ih.shape()[1],
cell.w_ih.shape()[0] * cell.w_ih.shape()[1]);
println!(" • Controls how input affects all 4 gates (i,f,g,o)");
println!(" 🔄 w_hh (Hidden→Hidden weights): {}×{} = {} parameters",
cell.w_hh.shape()[0], cell.w_hh.shape()[1],
cell.w_hh.shape()[0] * cell.w_hh.shape()[1]);
println!(" • Controls how previous hidden state affects gates");
println!(" ➕ b_ih (Input biases): {}×{} = {} parameters",
cell.b_ih.shape()[0], cell.b_ih.shape()[1],
cell.b_ih.shape()[0] * cell.b_ih.shape()[1]);
println!(" • Bias terms for input transformations");
println!(" ➕ b_hh (Hidden biases): {}×{} = {} parameters",
cell.b_hh.shape()[0], cell.b_hh.shape()[1],
cell.b_hh.shape()[0] * cell.b_hh.shape()[1]);
println!(" • Bias terms for hidden transformations");
println!(" 📏 Hidden Size: {}", cell.hidden_size);
println!();
}
println!("🧮 WHAT THESE PARAMETERS REPRESENT:");
println!(" 🚪 4 Gates per cell (each gets 1/4 of the weights):");
println!(" • Input Gate (i): Decides what new information to store");
println!(" • Forget Gate (f): Decides what to discard from memory");
println!(" • Cell Gate (g): Creates candidate values to add");
println!(" • Output Gate (o): Decides what parts of memory to output");
println!();
}
fn inspect_file_contents() -> Result<(), Box<dyn std::error::Error>> {
println!("📄 ACTUAL FILE CONTENTS:");
println!("========================");
let file_size = fs::metadata("models/inspection_model.json")?.len();
println!(" 📊 File size: {} bytes", file_size);
let content = fs::read_to_string("models/inspection_model.json")?;
let lines: Vec<&str> = content.lines().collect();
println!(" 📋 JSON Structure (first 30 lines):");
for (i, line) in lines.iter().enumerate().take(30) {
println!(" {}: {}", i + 1, line);
}
if lines.len() > 30 {
println!(" ... ({} more lines)", lines.len() - 30);
}
println!();
Ok(())
}
fn explain_llm_context() {
println!("🤖 FOR LARGE LANGUAGE MODELS (LLMs):");
println!("=====================================");
println!("If you trained an LSTM-based LLM, the model would store:");
println!();
println!("📚 LEARNED LANGUAGE PATTERNS:");
println!(" • Word relationships and dependencies");
println!(" • Grammar and syntax patterns");
println!(" • Semantic associations between concepts");
println!(" • Long-term dependencies in text");
println!();
println!("🧠 HOW PATTERNS ARE ENCODED:");
println!(" • w_ih weights: How input words influence gates");
println!(" - Input gate: Which words trigger memory updates");
println!(" - Forget gate: Which contexts cause forgetting");
println!(" - Cell gate: What new information words contribute");
println!(" - Output gate: What words trigger specific outputs");
println!();
println!(" • w_hh weights: How previous context influences current processing");
println!(" - Encodes how past words affect current decisions");
println!(" - Captures sequential dependencies and patterns");
println!(" - Stores grammatical and syntactic relationships");
println!();
println!(" • Biases: Default tendencies and adjustments");
println!(" - Language-specific biases (e.g., word order preferences)");
println!(" - Default gate behaviors for the language");
println!();
println!("📖 EXAMPLE: Training on 'The cat sat on the mat':");
println!(" • Weights learn that 'The' often precedes nouns");
println!(" • 'cat' + 'sat' pattern gets encoded in w_hh");
println!(" • 'on the' creates strong forward dependency");
println!(" • Sentence boundaries trigger memory resets");
println!();
println!("💾 METADATA ALSO STORED:");
println!(" • Training information (epochs, loss, date)");
println!(" • Architecture details (vocabulary size, embedding dim)");
println!(" • Model version and description");
println!(" • Performance metrics");
println!();
}
fn calculate_parameter_count(network: &LSTMNetwork) {
println!("🔢 TOTAL PARAMETER COUNT:");
println!("=========================");
let mut total_params = 0;
for (layer_idx, cell) in network.get_cells().iter().enumerate() {
let w_ih_params = cell.w_ih.len();
let w_hh_params = cell.w_hh.len();
let b_ih_params = cell.b_ih.len();
let b_hh_params = cell.b_hh.len();
let layer_params = w_ih_params + w_hh_params + b_ih_params + b_hh_params;
total_params += layer_params;
println!(" Layer {}: {} parameters", layer_idx, layer_params);
}
println!(" 🎯 Total: {} trainable parameters", total_params);
println!(" 💾 Memory: ~{:.1} KB (f64 precision)", total_params as f64 * 8.0 / 1024.0);
println!();
println!("📊 COMPARISON TO MODERN LLMs:");
println!(" • GPT-3: ~175 billion parameters");
println!(" • This LSTM: {} parameters", total_params);
println!(" • Ratio: This model is {:.0}x smaller", 175_000_000_000.0 / total_params as f64);
println!();
println!("💡 SCALING FOR LLMs:");
println!(" For a production LSTM LLM you might use:");
println!(" • Input size: 512-1024 (embedding dimension)");
println!(" • Hidden size: 1024-4096 (memory capacity)");
println!(" • Layers: 6-12 (depth for complex patterns)");
println!(" • Vocabulary: 50,000-100,000 tokens");
let llm_input = 512;
let llm_hidden = 2048;
let llm_layers = 8;
let llm_params_per_layer = (llm_input * llm_hidden * 4) + (llm_hidden * llm_hidden * 4) + (llm_hidden * 4 * 2);
let llm_total_params = llm_params_per_layer * llm_layers;
println!(" Example LSTM LLM ({} layers, {} hidden): ~{:.1}M parameters",
llm_layers, llm_hidden, llm_total_params as f64 / 1_000_000.0);
}