use anyhow::Result;
use candle_core::Device;
use clap::{Parser, Subcommand};
use tracing::{Level, info, warn};
use tracing_subscriber::FmtSubscriber;
use hermes_llm::config::TrainingConfig;
use hermes_llm::data::{DataLoader, Dataset};
use hermes_llm::tokenizer::{BPETrainer, Tokenizer};
use hermes_llm::training::{TextGenerator, Trainer};
#[derive(Parser)]
#[command(name = "hermes-llm")]
#[command(about = "Train LLMs from scratch in Rust")]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
Train {
#[arg(short, long)]
data: Option<String>,
#[arg(short, long)]
tokenizer: String,
#[arg(short, long, default_value = "tiny")]
model: String,
#[arg(short, long, default_value = "checkpoints")]
output: String,
#[arg(short, long, default_value = "32")]
batch_size: usize,
#[arg(short, long, default_value = "1")]
epochs: usize,
#[arg(long, default_value = "256")]
seq_len: usize,
#[arg(long, default_value = "3e-4")]
lr: f64,
#[arg(long, default_value = "1")]
num_gpus: usize,
#[arg(long, default_value = "1")]
grad_accum: usize,
#[arg(long)]
checkpoint: Option<String>,
#[arg(long, default_value = "0")]
freeze_layers: usize,
#[arg(long)]
resume: bool,
#[arg(long, hide = true, default_value_t = usize::MAX)]
rank: usize,
#[arg(long, hide = true, default_value = "nccl_id.txt")]
comm_file: String,
},
TrainTokenizer {
#[arg(short, long, num_args = 1..)]
input: Vec<String>,
#[arg(short, long)]
output: String,
#[arg(short, long, default_value = "32000")]
vocab_size: usize,
},
Generate {
#[arg(short, long)]
checkpoint: String,
#[arg(long)]
config: String,
#[arg(short, long)]
tokenizer: String,
#[arg(short, long)]
prompt: String,
#[arg(short, long, default_value = "100")]
max_tokens: usize,
#[arg(long, default_value = "0.8")]
temperature: f64,
#[arg(long)]
top_k: Option<usize>,
#[arg(long, default_value = "true")]
gpu: bool,
},
Info {
#[arg(short, long, default_value = "gpt2-small")]
model: String,
},
Dpo {
#[arg(short, long)]
data: String,
#[arg(short, long)]
tokenizer: String,
#[arg(short, long)]
checkpoint: String,
#[arg(long)]
config: String,
#[arg(short, long, default_value = "checkpoints-dpo")]
output: String,
#[arg(short, long, default_value = "4")]
batch_size: usize,
#[arg(short, long, default_value = "1")]
epochs: usize,
#[arg(long, default_value = "5e-7")]
lr: f64,
#[arg(long, default_value = "0.1")]
beta: f64,
#[arg(long, default_value = "512")]
max_len: usize,
},
}
#[allow(unused_variables)]
fn get_device(use_gpu: bool, gpu_id: usize) -> Result<Device> {
if use_gpu {
#[cfg(feature = "metal")]
{
return Ok(Device::new_metal(gpu_id)?);
}
#[cfg(feature = "cuda")]
{
return Ok(Device::new_cuda(gpu_id)?);
}
#[cfg(not(any(feature = "metal", feature = "cuda")))]
{
tracing::warn!(
"No GPU feature enabled, using CPU. Build with --features metal or --features cuda"
);
return Ok(Device::Cpu);
}
}
Ok(Device::Cpu)
}
fn get_model_def(name: &str) -> Result<hermes_llm::ModelDef> {
if let Some(model_def) = hermes_llm::get_builtin_model(name) {
return Ok(model_def);
}
if std::path::Path::new(name).exists() {
match hermes_llm::parse_mal_file(name) {
Ok(model_def) => return Ok(model_def),
Err(e) => {
warn!("Failed to parse MAL file '{}': {}", name, e);
}
}
}
anyhow::bail!(
"Unknown model '{}'. Available: {:?}",
name,
hermes_llm::list_wellknown_models()
);
}
fn main() -> Result<()> {
let subscriber = FmtSubscriber::builder()
.with_max_level(Level::INFO)
.finish();
tracing::subscriber::set_global_default(subscriber)?;
let cli = Cli::parse();
match cli.command {
Commands::Train {
data,
tokenizer: tokenizer_path,
model,
output,
batch_size,
epochs,
seq_len,
lr,
num_gpus,
grad_accum,
checkpoint,
freeze_layers,
resume,
rank,
comm_file,
} => {
let children_handle = if num_gpus > 1 && rank == usize::MAX {
use std::process::{Command, Stdio};
unsafe { std::env::set_var("CUDA_VISIBLE_DEVICES", "0") };
let data_path = data
.as_ref()
.ok_or_else(|| anyhow::anyhow!("--data is required for multi-GPU training"))?;
info!("=== Distributed Training ===");
info!("GPUs: {}", num_gpus);
info!("Model: {}", model);
info!(
"Effective batch: {} ({} x {} x {})",
batch_size * grad_accum * num_gpus,
batch_size,
grad_accum,
num_gpus
);
let exe = std::env::current_exe()?;
let _ = std::fs::remove_file(&comm_file);
let mut children = Vec::new();
for r in 1..num_gpus {
info!("Launching rank {} on GPU {}...", r, r);
let mut child_cmd = Command::new(&exe);
child_cmd
.env("CUDA_VISIBLE_DEVICES", r.to_string())
.arg("train")
.arg("--data")
.arg(data_path)
.arg("--tokenizer")
.arg(&tokenizer_path)
.arg("--model")
.arg(&model)
.arg("--output")
.arg(&output)
.arg("--batch-size")
.arg(batch_size.to_string())
.arg("--epochs")
.arg(epochs.to_string())
.arg("--seq-len")
.arg(seq_len.to_string())
.arg("--lr")
.arg(lr.to_string())
.arg("--num-gpus")
.arg(num_gpus.to_string())
.arg("--grad-accum")
.arg(grad_accum.to_string())
.arg("--freeze-layers")
.arg(freeze_layers.to_string())
.arg("--rank")
.arg(r.to_string())
.arg("--comm-file")
.arg(&comm_file);
if let Some(ref ckpt) = checkpoint {
child_cmd.arg("--checkpoint").arg(ckpt);
}
let child = child_cmd
.current_dir(std::env::current_dir()?)
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()?;
children.push((r, child));
}
std::thread::sleep(std::time::Duration::from_secs(2));
Some(std::thread::spawn(move || {
let mut all_ok = true;
for (r, mut c) in children {
if !c.wait().map(|s| s.success()).unwrap_or(false) {
warn!("Rank {} failed", r);
all_ok = false;
}
}
all_ok
}))
} else {
None
};
let actual_rank = if rank == usize::MAX { 0 } else { rank };
let device = get_device(true, if num_gpus > 1 { 0 } else { actual_rank })?;
let dist_config = hermes_llm::DistributedConfig {
world_size: num_gpus,
rank: actual_rank,
comm_file,
};
if dist_config.is_distributed() {
info!(
"Rank {}/{} on GPU {}",
actual_rank + 1,
num_gpus,
actual_rank
);
}
info!("Using device: {:?}", device);
let tokenizer = if std::path::Path::new(&tokenizer_path).exists() {
info!("Loading tokenizer from {}", tokenizer_path);
Tokenizer::from_file(&tokenizer_path)?
} else {
let data_path = data
.as_ref()
.ok_or_else(|| anyhow::anyhow!("--data required to train tokenizer"))?;
info!("Training new tokenizer...");
BPETrainer::new(32000).train_from_files(&[data_path.as_str()], &tokenizer_path)?
};
info!("Tokenizer vocab size: {}", tokenizer.vocab_size());
let mut config = get_model_def(&model)?;
config.vocab_size = tokenizer.vocab_size();
info!("Model: {}", config.name);
let dataset = match &data {
Some(path) => {
info!("Loading dataset from {}", path);
Dataset::from_file(path, &tokenizer, seq_len)?
}
None => {
info!("Loading dataset from stdin...");
Dataset::from_stdin(&tokenizer, seq_len)?
}
};
info!("Dataset size: {} tokens", dataset.tokens().len());
let mut train_loader = if num_gpus > 1 {
DataLoader::new_distributed(dataset, batch_size, true, actual_rank, num_gpus)
} else {
DataLoader::new(dataset, batch_size, true)
};
info!("Number of batches: {}", train_loader.num_batches());
let training_config = TrainingConfig {
learning_rate: lr,
batch_size,
epochs,
seq_len,
gradient_accumulation_steps: grad_accum,
..Default::default()
};
std::fs::create_dir_all(&output)?;
#[cfg(feature = "nccl")]
let comm = if dist_config.is_distributed() {
Some(hermes_llm::NcclCommunicator::new(&dist_config)?)
} else {
None
};
#[cfg(not(feature = "nccl"))]
let comm: Option<hermes_llm::NcclCommunicator> = None;
if dist_config.is_distributed() && comm.is_none() {
anyhow::bail!("Distributed training requires --features nccl");
}
if let Some(ref c) = comm {
info!("Waiting for all ranks to synchronize...");
c.barrier()?;
info!("All ranks synchronized");
}
let mut trainer = Trainer::new(config.clone(), training_config, device)?;
let resume_state = if resume {
let state = trainer.load_training_state(&output)?;
if state.global_step > 0 {
info!(
"Resuming from epoch {}, step {}, batch {}",
state.epoch + 1,
state.global_step,
state.batch_position
);
Some(state)
} else {
None
}
} else if let Some(ref ckpt_path) = checkpoint {
info!("Loading checkpoint: {}", ckpt_path);
trainer.load_checkpoint(ckpt_path)?;
None
} else {
None
};
if freeze_layers > 0 {
info!("Freezing {} layers", freeze_layers);
trainer.freeze_layers(freeze_layers)?;
}
if let Some(ref c) = comm {
info!("Broadcasting model weights...");
hermes_llm::distributed::sync_model(trainer.var_map(), c)?;
}
if dist_config.is_main_process() {
config.save_json(&format!("{}/config.json", output))?;
info!("Saved config to {}/config.json", output);
}
let completed = trainer.train_resumable(
&mut train_loader,
None,
Some(&output),
comm.as_ref(),
resume_state,
)?;
let is_worker = children_handle.is_none() && dist_config.is_distributed();
if let Some(c) = comm {
c.barrier()?;
c.finalize()?;
}
if is_worker {
std::process::exit(0);
}
if let Some(handle) = children_handle {
let all_ok = handle.join().unwrap_or(false);
let _ = std::fs::remove_file(&dist_config.comm_file);
if all_ok {
if completed {
println!("\n=== Training complete ===");
} else {
println!("\n=== Training interrupted, checkpoint saved ===");
println!("Resume with: hermes-llm train --resume --output {}", output);
}
} else {
anyhow::bail!("Some worker processes failed");
}
} else if dist_config.is_main_process() {
if completed {
info!("Training complete!");
} else {
info!("Training interrupted, checkpoint saved to {}", output);
info!("Resume with: hermes-llm train --resume --output {}", output);
}
}
}
Commands::TrainTokenizer {
input,
output,
vocab_size,
} => {
info!("Training BPE tokenizer with vocab size {}", vocab_size);
let trainer = BPETrainer::new(vocab_size);
let files: Vec<&str> = input.iter().map(|s| s.as_str()).collect();
let tokenizer = trainer.train_from_files(&files, &output)?;
info!(
"Tokenizer trained and saved to {} (vocab size: {})",
output,
tokenizer.vocab_size()
);
}
Commands::Generate {
checkpoint,
config: config_path,
tokenizer: tokenizer_path,
prompt,
max_tokens,
temperature,
top_k,
gpu,
} => {
let device = get_device(gpu, 0)?;
info!("Using device: {:?}", device);
let config = hermes_llm::ModelDef::from_json(&config_path)?;
let tokenizer = Tokenizer::from_file(&tokenizer_path)?;
let mut var_map = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&var_map, candle_core::DType::F32, &device);
let model = hermes_llm::Transformer::new(&config, vb)?;
var_map.load(&checkpoint)?;
info!("Loaded model from {}", checkpoint);
let prompt_tokens = tokenizer.encode(&prompt, false)?;
info!("Prompt tokens: {:?}", prompt_tokens);
let generator = TextGenerator::new(&model, &device);
let output_tokens =
generator.generate(&prompt_tokens, max_tokens, temperature, top_k)?;
let output_text = tokenizer.decode(&output_tokens, true)?;
println!("\n{}", output_text);
}
Commands::Info { model } => {
let model_def = get_model_def(&model)?;
print!("{}", model_def);
}
Commands::Dpo {
data,
tokenizer: tokenizer_path,
checkpoint,
config: config_path,
output,
batch_size,
epochs,
lr,
beta,
max_len,
} => {
let device = get_device(true, 0)?;
info!("Using device: {:?}", device);
let config = hermes_llm::ModelDef::from_json(&config_path)?;
let tokenizer = Tokenizer::from_file(&tokenizer_path)?;
info!("Loading preference dataset from {}", data);
let dataset = hermes_llm::dpo::PreferenceDataset::from_file(&data)?;
info!("Initializing DPO trainer...");
let mut trainer =
hermes_llm::dpo::DpoTrainer::new(config, &checkpoint, device, lr, beta, max_len)?;
std::fs::create_dir_all(&output)?;
trainer.train(&dataset, &tokenizer, epochs, batch_size, Some(&output))?;
info!("DPO training complete!");
}
}
Ok(())
}