use candle_core::Device;
use intellichip_rs::{TRMConfig, data::{NumpyDataset, NumpyDataLoader}};
use intellichip_rs::training::{Trainer, TrainingConfig};
fn main() -> anyhow::Result<()> {
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
log::info!("=== TinyRecursiveModel - Sudoku Parity Training ===");
log::info!("Goal: Match Python TRM performance (75-87% accuracy)");
let device = if candle_core::utils::cuda_is_available() {
Device::new_cuda(0)?
} else {
Device::Cpu
};
log::info!("Using device: {:?}", device);
let data_path = "E:\\repos\\TinyRecursiveModels\\data\\sudoku-extreme-100-aug-1000\\train\\";
log::info!("Loading Sudoku dataset from: {}", data_path);
let dataset = NumpyDataset::from_directory(data_path)?;
log::info!("Dataset loaded:");
log::info!(" - Total examples: {}", dataset.len());
log::info!(" - Vocab size: {}", dataset.vocab_size());
log::info!(" - Sequence length: {}", dataset.seq_len());
log::info!(" - Description: {}", dataset.metadata().description);
let batch_size = if device.is_cuda() { 16 } else { 16 };
let mut dataloader = NumpyDataLoader::new(dataset, batch_size, true);
log::info!("Data loader created:");
log::info!(" - Batch size: {}", batch_size);
log::info!(" - Num batches: {}", dataloader.num_batches());
let model_config = TRMConfig {
vocab_size: dataloader.dataset().vocab_size(), num_outputs: dataloader.dataset().vocab_size(), hidden_size: 512, h_cycles: 2, l_cycles: 4, l_layers: 2, num_heads: 8, expansion: 4.0, pos_encodings: "rope".to_string(),
mlp_t: false,
halt_max_steps: 10,
dropout: 0.0, };
log::info!("Model configuration (OPTIMIZED - H=2, L=4): {:#?}", model_config);
let embed_params = model_config.vocab_size * model_config.hidden_size;
let layer_params = model_config.hidden_size * model_config.hidden_size * 4 * model_config.l_layers;
let head_params = model_config.hidden_size * model_config.num_outputs;
let total_params = embed_params + layer_params + head_params;
log::info!("Approximate parameters: ~{:.2}M", total_params as f64 / 1_000_000.0);
let num_batches = dataloader.num_batches();
let num_epochs = 10; let total_steps = num_batches * num_epochs;
let training_config = TrainingConfig {
num_epochs,
batch_size,
learning_rate: 1e-4, lr_min: 1e-4, warmup_steps: 2000, total_steps,
weight_decay: 0.1, grad_clip: Some(1.0), ema_decay: 0.999, save_every: 10000, eval_every: 1000,
checkpoint_dir: "checkpoints_sudoku".to_string(),
};
log::info!("Training configuration (Python-matched):");
log::info!(" - Epochs: {}", training_config.num_epochs);
log::info!(" - Batch size: {}", training_config.batch_size);
log::info!(" - Learning rate: {:.6}", training_config.learning_rate);
log::info!(" - LR min: {:.6}", training_config.lr_min);
log::info!(" - Warmup steps: {}", training_config.warmup_steps);
log::info!(" - Total steps: {}", training_config.total_steps);
log::info!(" - Weight decay: {}", training_config.weight_decay);
log::info!(" - EMA decay: {}", training_config.ema_decay);
log::info!("\n=== Expected Performance ===");
log::info!("Target accuracy: 75-87% (Python baseline)");
log::info!("Initial loss: ~2.4 (ln(11) for random init)");
log::info!("Training time: 8-12 hours on CPU (estimated)");
log::info!("Checkpoint directory: {}", training_config.checkpoint_dir);
log::info!("\nInitializing trainer...");
let mut trainer = Trainer::new(model_config, training_config, device)?;
log::info!("\nStarting training...");
log::info!("This will take a while. Monitor loss convergence.");
log::info!("Press Ctrl+C to stop (checkpoints are saved every 5000 steps)");
trainer.train(&mut dataloader)?;
log::info!("\n=== Training Complete ===");
log::info!("Check final checkpoint for accuracy evaluation");
log::info!("Compare training curve with Python results");
Ok(())
}