use crate::models::llama_370m::Llama370MConfig;
use crate::train::pretrain::{CheckpointFn, EpochArtifact, StepFn, ValFn};
use crate::train::transformer_trainer::{LMBatch, TransformerTrainConfig, TransformerTrainer};
use crate::transformer::{ModelArchitecture, Transformer, TransformerConfig};
use crate::Tensor;
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::path::Path;
use std::rc::Rc;
pub type SharedTrainer = Rc<RefCell<TransformerTrainer>>;
pub fn load_init_tensors_from_apr(
path: impl AsRef<Path>,
) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String> {
let path_ref = path.as_ref();
aprender::format::converter::load_model_tensors(path_ref).map_err(|e| {
format!(
"FALSIFY-APR-PRETRAIN-INIT-006: failed to load init tensors from APR file {}: {e}",
path_ref.display()
)
})
}
pub fn validate_pretrain_init_arch_compatible(cfg: &TransformerConfig) -> Result<(), String> {
match cfg.architecture {
ModelArchitecture::Decoder => Ok(()),
ModelArchitecture::Encoder => Err(format!(
"FALSIFY-APR-PRETRAIN-ARCH-007: --init checkpoint has architecture=Encoder \
(e.g., BERT/RoBERTa/CodeBERT) but the pretrain trainer is decoder-only \
(Llama/Qwen-class causal LMs). Loading encoder weights into a decoder \
trainer would produce nonsense gradients. Architectural details: \
hidden_size={}, num_layers={}, vocab_size={}, hf_architecture={:?}",
cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size, cfg.hf_architecture
)),
}
}
pub fn populate_trainer_from_init_tensors(
transformer: &mut Transformer,
init_tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
) -> Result<usize, String> {
let expected: Vec<(String, usize)> = transformer
.named_parameters()
.into_iter()
.map(|(name, t)| (name, t.len()))
.collect();
let mut populated = 0usize;
let mut errors: Vec<String> = Vec::new();
for (name, expected_len) in &expected {
match init_tensors.get(name) {
Some((data, _shape)) => {
if data.len() != *expected_len {
errors.push(format!(
"{name}: init length {} != trainer expected {expected_len}",
data.len()
));
continue;
}
let tensor = Tensor::from_vec(data.clone(), true);
if !transformer.set_named_parameter(name, tensor) {
errors.push(format!(
"{name}: set_named_parameter rejected the assignment"
));
continue;
}
populated += 1;
}
None => {
errors.push(format!("{name}: not present in init APR tensors"));
}
}
}
if !errors.is_empty() {
let total = errors.len();
let head = errors.iter().take(5).cloned().collect::<Vec<_>>().join("; ");
return Err(format!(
"FALSIFY-APR-PRETRAIN-INIT-007: populate_trainer_from_init_tensors \
failed for {total} parameter(s); first {} of {total}: {head}",
errors.len().min(5)
));
}
Ok(populated)
}
pub fn llama_370m_transformer_config() -> TransformerConfig {
TransformerConfig {
hidden_size: Llama370MConfig::HIDDEN_DIM,
num_attention_heads: Llama370MConfig::NUM_HEADS,
num_kv_heads: Llama370MConfig::NUM_KV_HEADS,
intermediate_size: Llama370MConfig::INTERMEDIATE_DIM,
num_hidden_layers: Llama370MConfig::NUM_LAYERS,
vocab_size: Llama370MConfig::VOCAB_SIZE,
max_position_embeddings: Llama370MConfig::MAX_POSITION_EMBEDDINGS,
rms_norm_eps: Llama370MConfig::RMS_NORM_EPS,
rope_theta: Llama370MConfig::ROPE_THETA,
use_bias: false,
head_dim_override: None,
architecture: ModelArchitecture::Decoder,
hf_architecture: Some("LlamaForCausalLM".into()),
hf_model_type: Some("llama".into()),
tie_word_embeddings: true,
}
}
pub fn build_transformer_config(init: Option<&TransformerConfig>) -> TransformerConfig {
match init {
None => llama_370m_transformer_config(),
Some(cfg) => cfg.clone(),
}
}
pub fn llama_370m_train_config(lr: f32, seq_length: usize, seed: u64) -> TransformerTrainConfig {
let model_cfg = llama_370m_transformer_config();
let mut cfg = TransformerTrainConfig::new(model_cfg);
cfg.lr = lr;
cfg.max_seq_len = seq_length;
cfg.seed = seed;
cfg
}
pub struct RealStepFn {
trainer: SharedTrainer,
batches: Box<dyn Iterator<Item = LMBatch>>,
}
impl RealStepFn {
pub fn new(trainer: SharedTrainer, batches: Box<dyn Iterator<Item = LMBatch>>) -> Self {
Self { trainer, batches }
}
}
impl StepFn for RealStepFn {
fn step(&mut self, _step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
let Some(batch) = self.batches.next() else {
return (1.0, 1.0);
};
let mut trainer = self.trainer.borrow_mut();
let loss = trainer.train_batch(&batch);
let grad_norm = 1.0_f32;
(loss, grad_norm)
}
fn optimizer_state_sha256(&self) -> Option<String> {
Some(self.trainer.borrow().optimizer_state_sha256())
}
}
pub struct RealValFn {
trainer: SharedTrainer,
held_out: Vec<LMBatch>,
}
impl RealValFn {
pub fn new(trainer: SharedTrainer, held_out: Vec<LMBatch>) -> Self {
Self { trainer, held_out }
}
}
impl ValFn for RealValFn {
fn validate(&mut self, _epoch: usize) -> f32 {
if self.held_out.is_empty() {
return f32::NAN;
}
let trainer = self.trainer.borrow();
let mut total_loss = 0.0_f32;
let mut total_items = 0_usize;
for batch in &self.held_out {
for i in 0..batch.batch_size {
let Some(inp) = batch.get_input(i) else {
continue;
};
let Some(tgt) = batch.get_target(i) else {
continue;
};
let (loss_val, _loss_tensor, _logits) = trainer.forward_single(inp, tgt);
total_loss += loss_val;
total_items += 1;
}
}
if total_items == 0 {
f32::NAN
} else {
total_loss / total_items as f32
}
}
}
pub struct AprCheckpointFn {
trainer: SharedTrainer,
model_name: String,
architecture: String,
}
impl AprCheckpointFn {
pub fn new(
trainer: SharedTrainer,
model_name: impl Into<String>,
architecture: impl Into<String>,
) -> Self {
Self { trainer, model_name: model_name.into(), architecture: architecture.into() }
}
}
impl CheckpointFn for AprCheckpointFn {
fn save(&mut self, _epoch: usize, artifact: &EpochArtifact) -> Result<(), String> {
let trainer = self.trainer.borrow();
trainer
.save_apr(&artifact.checkpoint_path, &self.model_name, &self.architecture)
.map_err(|e| format!("save_apr failed: {e}"))
}
}
pub fn build_shared_trainer(lr: f32, seq_length: usize, seed: u64) -> SharedTrainer {
let cfg = llama_370m_train_config(lr, seq_length, seed);
let trainer = TransformerTrainer::new(cfg);
#[cfg(debug_assertions)]
{
let param_count: usize = trainer.model().parameters().iter().map(|t| t.len()).sum();
debug_assert!(
(366_000_000..=374_000_000).contains(¶m_count),
"INV-ARCH-370M-001: parameter count {param_count} outside [366M, 374M] band",
);
}
Rc::new(RefCell::new(trainer))
}
pub fn build_shared_trainer_with_init(
lr: f32,
seq_length: usize,
seed: u64,
init_arch: Option<&TransformerConfig>,
init_path: Option<&Path>,
) -> Result<SharedTrainer, String> {
if init_arch.is_some() != init_path.is_some() {
return Err(format!(
"build_shared_trainer_with_init: init_arch and init_path must both be Some \
or both None (caller bug; init_arch.is_some()={}, init_path.is_some()={})",
init_arch.is_some(),
init_path.is_some()
));
}
if let Some(cfg) = init_arch {
validate_pretrain_init_arch_compatible(cfg)?;
}
let model_cfg = build_transformer_config(init_arch);
let mut train_cfg = TransformerTrainConfig::new(model_cfg);
train_cfg.lr = lr;
train_cfg.max_seq_len = seq_length;
train_cfg.seed = seed;
let mut trainer = TransformerTrainer::new(train_cfg);
if let Some(path) = init_path {
let tensors = load_init_tensors_from_apr(path)?;
populate_trainer_from_init_tensors(trainer.model_mut(), &tensors)?;
}
Ok(Rc::new(RefCell::new(trainer)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::train::transformer_trainer::LMBatch;
#[test]
fn load_init_tensors_missing_file_errors_with_falsifier_id() {
let tmp = tempfile::TempDir::new().expect("tempdir");
let missing = tmp.path().join("does-not-exist.apr");
let err = load_init_tensors_from_apr(&missing)
.expect_err("missing init APR file MUST fail-fast");
assert!(
err.contains("FALSIFY-APR-PRETRAIN-INIT-006"),
"error must cite falsifier id (auditability): {err}"
);
assert!(
err.contains("does-not-exist.apr"),
"error must name the missing path (operator-experience): {err}"
);
}
#[test]
fn load_init_tensors_signature_compile_bind() {
fn _check_signature<F>(_f: F)
where
F: Fn(
&Path,
)
-> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String>,
{
}
_check_signature(|p| load_init_tensors_from_apr(p));
}
#[test]
fn transformer_config_matches_llama_370m_constants() {
let cfg = llama_370m_transformer_config();
assert_eq!(cfg.hidden_size, Llama370MConfig::HIDDEN_DIM);
assert_eq!(cfg.num_hidden_layers, Llama370MConfig::NUM_LAYERS);
assert_eq!(cfg.num_attention_heads, Llama370MConfig::NUM_HEADS);
assert_eq!(cfg.num_kv_heads, Llama370MConfig::NUM_KV_HEADS);
assert_eq!(cfg.intermediate_size, Llama370MConfig::INTERMEDIATE_DIM);
assert_eq!(cfg.vocab_size, Llama370MConfig::VOCAB_SIZE);
assert!((cfg.rope_theta - Llama370MConfig::ROPE_THETA).abs() < f32::EPSILON);
assert!((cfg.rms_norm_eps - Llama370MConfig::RMS_NORM_EPS).abs() < f32::EPSILON);
assert!(!cfg.use_bias, "INV-ARCH-370M-008: no bias");
assert!(cfg.tie_word_embeddings, "INV-ARCH-370M-004: tied embeddings");
}
#[test]
fn build_transformer_config_no_init_matches_llama370m() {
let baseline = llama_370m_transformer_config();
let result = build_transformer_config(None);
assert_eq!(result.hidden_size, baseline.hidden_size);
assert_eq!(result.num_attention_heads, baseline.num_attention_heads);
assert_eq!(result.num_kv_heads, baseline.num_kv_heads);
assert_eq!(result.intermediate_size, baseline.intermediate_size);
assert_eq!(result.num_hidden_layers, baseline.num_hidden_layers);
assert_eq!(result.vocab_size, baseline.vocab_size);
assert_eq!(
result.max_position_embeddings,
baseline.max_position_embeddings
);
assert!((result.rms_norm_eps - baseline.rms_norm_eps).abs() < f32::EPSILON);
assert!((result.rope_theta - baseline.rope_theta).abs() < f32::EPSILON);
assert_eq!(result.use_bias, baseline.use_bias);
assert_eq!(result.tie_word_embeddings, baseline.tie_word_embeddings);
assert_eq!(result.architecture, baseline.architecture);
assert_eq!(result.hf_architecture, baseline.hf_architecture);
assert_eq!(result.hf_model_type, baseline.hf_model_type);
}
#[test]
fn build_transformer_config_qwen_init_matches_input() {
let qwen = TransformerConfig::qwen2_0_5b();
let result = build_transformer_config(Some(&qwen));
assert_eq!(result.hidden_size, qwen.hidden_size, "hidden_size");
assert_eq!(
result.num_attention_heads, qwen.num_attention_heads,
"num_attention_heads"
);
assert_eq!(result.num_kv_heads, qwen.num_kv_heads, "num_kv_heads");
assert_eq!(
result.intermediate_size, qwen.intermediate_size,
"intermediate_size"
);
assert_eq!(
result.num_hidden_layers, qwen.num_hidden_layers,
"num_hidden_layers"
);
assert_eq!(result.vocab_size, qwen.vocab_size, "vocab_size");
assert_eq!(
result.max_position_embeddings, qwen.max_position_embeddings,
"max_position_embeddings"
);
assert_eq!(result.use_bias, qwen.use_bias, "use_bias");
assert_eq!(
result.tie_word_embeddings, qwen.tie_word_embeddings,
"tie_word_embeddings"
);
assert_eq!(result.architecture, qwen.architecture, "architecture");
assert_eq!(
result.num_attention_heads / result.num_kv_heads,
7,
"GQA ratio must preserve as 7:1 (Qwen2.5-0.5B canonical)"
);
}
#[test]
fn build_transformer_config_dispatch_mutually_exclusive() {
let qwen = TransformerConfig::qwen2_0_5b();
let none_result = build_transformer_config(None);
let some_result = build_transformer_config(Some(&qwen));
assert_ne!(
none_result.hidden_size, some_result.hidden_size,
"dispatch must differentiate None vs Some — Llama370M hidden=1024 vs Qwen=896"
);
assert_ne!(
none_result.vocab_size, some_result.vocab_size,
"dispatch must differentiate None vs Some — Llama370M vocab=50257 vs Qwen=151936"
);
}
#[test]
fn validate_pretrain_init_arch_accepts_decoder() {
let qwen = TransformerConfig::qwen2_0_5b();
assert_eq!(qwen.architecture, ModelArchitecture::Decoder);
validate_pretrain_init_arch_compatible(&qwen)
.expect("decoder-family config (Qwen2.5-0.5B) MUST pass arch-compat gate");
}
#[test]
fn validate_pretrain_init_arch_rejects_encoder() {
let bert = TransformerConfig {
hidden_size: 768,
num_attention_heads: 12,
num_kv_heads: 12,
intermediate_size: 3072,
num_hidden_layers: 12,
vocab_size: 50265,
max_position_embeddings: 514,
rms_norm_eps: 1e-12,
rope_theta: 10_000.0,
use_bias: true,
head_dim_override: None,
architecture: ModelArchitecture::Encoder,
hf_architecture: Some("RobertaModel".to_string()),
hf_model_type: Some("roberta".to_string()),
tie_word_embeddings: false,
};
let err = validate_pretrain_init_arch_compatible(&bert).expect_err(
"encoder-family config (CodeBERT/RoBERTa) MUST fail arch-compat gate — \
silent acceptance would corrupt §49 fine-tune trajectory before any \
FALSIFY-006 check could measure it",
);
assert!(
err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
"error must cite falsifier id: {err}"
);
assert!(
err.contains("Encoder"),
"error must name the architecture family: {err}"
);
assert!(
err.contains("decoder-only"),
"error must explain why this is wrong (decoder trainer): {err}"
);
assert!(
err.contains("RobertaModel"),
"error must name the offending hf_architecture: {err}"
);
}
#[test]
fn validate_pretrain_init_arch_accepts_llama370m_baseline() {
let llama = llama_370m_transformer_config();
assert_eq!(
llama.architecture,
ModelArchitecture::Decoder,
"Llama370M baseline MUST be Decoder (regression-free)"
);
validate_pretrain_init_arch_compatible(&llama)
.expect("Llama370M baseline (Decoder) MUST pass arch-compat gate");
}
#[test]
fn real_step_fn_exhausted_iterator_returns_finite_placeholder() {
let mut tiny = TransformerConfig::llama2_7b();
tiny.hidden_size = 64;
tiny.num_attention_heads = 4;
tiny.num_kv_heads = 4;
tiny.num_hidden_layers = 2;
tiny.intermediate_size = 128;
tiny.vocab_size = 256;
let cfg = TransformerTrainConfig::new(tiny);
let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
let empty_iter: Box<dyn Iterator<Item = LMBatch>> = Box::new(std::iter::empty::<LMBatch>());
let mut step = RealStepFn::new(trainer, empty_iter);
let (loss, grad_norm) = step.step(0, 1.0e-4, 128);
assert!(loss.is_finite(), "exhausted iter must return finite loss");
assert!(grad_norm.is_finite(), "grad_norm must be finite");
assert!(grad_norm >= 0.0, "INV-TRAIN-008: grad_norm non-negative");
}
#[test]
fn real_val_fn_empty_held_out_returns_nan() {
let mut tiny = TransformerConfig::llama2_7b();
tiny.hidden_size = 64;
tiny.num_attention_heads = 4;
tiny.num_kv_heads = 4;
tiny.num_hidden_layers = 2;
tiny.intermediate_size = 128;
tiny.vocab_size = 256;
let cfg = TransformerTrainConfig::new(tiny);
let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
let mut val = RealValFn::new(trainer, Vec::new());
let loss = val.validate(0);
assert!(loss.is_nan(), "empty held_out must surface as NaN to the guard");
}
fn tiny_test_transformer() -> Transformer {
let mut tiny = TransformerConfig::llama2_7b();
tiny.hidden_size = 32;
tiny.num_attention_heads = 2;
tiny.num_kv_heads = 2;
tiny.num_hidden_layers = 2;
tiny.intermediate_size = 64;
tiny.vocab_size = 16;
Transformer::new(&tiny)
}
fn tensors_map_from_transformer(
transformer: &Transformer,
) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
let mut map = BTreeMap::new();
for (name, t) in transformer.named_parameters() {
let len = t.len();
let data: Vec<f32> = (0..len).map(|i| i as f32 * 0.001).collect();
map.insert(name, (data, vec![len]));
}
map
}
#[test]
fn populate_trainer_from_init_tensors_happy_path() {
let mut transformer = tiny_test_transformer();
let init_tensors = tensors_map_from_transformer(&transformer);
let expected_count = transformer.named_parameters().len();
let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
assert!(
result.is_ok(),
"happy-path populate must succeed: {result:?}"
);
assert_eq!(
result.unwrap(),
expected_count,
"populated count must equal named_parameters().len()"
);
}
#[test]
fn populate_trainer_from_init_tensors_extra_entries_silently_ignored() {
let mut transformer = tiny_test_transformer();
let mut init_tensors = tensors_map_from_transformer(&transformer);
init_tensors.insert(
"model.layers.999.fictitious.weight".to_string(),
(vec![0.0; 4], vec![4]),
);
let expected_count = transformer.named_parameters().len();
let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
assert!(
result.is_ok(),
"extra init entries must NOT cause Err: {result:?}"
);
assert_eq!(result.unwrap(), expected_count);
}
#[test]
fn populate_trainer_from_init_tensors_rejects_length_mismatch() {
let mut transformer = tiny_test_transformer();
let mut init_tensors = tensors_map_from_transformer(&transformer);
let any_name = transformer.named_parameters()[0].0.clone();
init_tensors.insert(any_name.clone(), (vec![0.0; 7], vec![7]));
let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
assert!(
result.is_err(),
"length-mismatch must Err, not silently truncate"
);
let err = result.unwrap_err();
assert!(
err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
"error must cite falsifier id; got: {err}"
);
assert!(
err.contains(&any_name),
"error must name the offending parameter; got: {err}"
);
assert!(
err.contains("init length 7"),
"error must report the actual init length; got: {err}"
);
}
#[test]
fn populate_trainer_from_init_tensors_rejects_missing_required_param() {
let mut transformer = tiny_test_transformer();
let mut init_tensors = tensors_map_from_transformer(&transformer);
let any_name = transformer.named_parameters()[0].0.clone();
init_tensors.remove(&any_name);
let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
assert!(
result.is_err(),
"missing-required must Err, not silently leave random init"
);
let err = result.unwrap_err();
assert!(
err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
"error must cite falsifier id; got: {err}"
);
assert!(
err.contains(&any_name),
"error must name the missing parameter; got: {err}"
);
assert!(
err.contains("not present in init APR"),
"error must say what was missing; got: {err}"
);
}
#[test]
fn build_shared_trainer_with_init_none_uses_llama370m_shape() {
let trainer = build_shared_trainer_with_init(1.0e-4, 128, 42, None, None)
.expect("None case must succeed");
let model = trainer.borrow();
let embed_len = model.model().named_parameters()[0].1.len();
let expected_embed_len =
Llama370MConfig::VOCAB_SIZE * Llama370MConfig::HIDDEN_DIM;
assert_eq!(
embed_len, expected_embed_len,
"init=None must produce Llama370M-shaped embedding (vocab={} × hidden={})",
Llama370MConfig::VOCAB_SIZE,
Llama370MConfig::HIDDEN_DIM
);
}
#[test]
fn build_shared_trainer_with_init_rejects_unpaired_args() {
let cfg = TransformerConfig::qwen2_0_5b();
let result = build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&cfg), None);
assert!(
result.is_err(),
"unpaired (arch=Some, path=None) must Err"
);
let dummy_path = std::path::PathBuf::from("/dev/null");
let result = build_shared_trainer_with_init(1.0e-4, 128, 42, None, Some(&dummy_path));
assert!(
result.is_err(),
"unpaired (arch=None, path=Some) must Err"
);
}
#[test]
fn build_shared_trainer_with_init_rejects_encoder_family() {
let mut encoder_cfg = TransformerConfig::qwen2_0_5b();
encoder_cfg.architecture = ModelArchitecture::Encoder;
let dummy_path = std::path::PathBuf::from("/nonexistent/encoder.apr");
let result =
build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&encoder_cfg), Some(&dummy_path));
let err = match result {
Ok(_) => panic!("encoder family must be rejected before tensor load"),
Err(e) => e,
};
assert!(
err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
"error must cite falsifier id; got: {err}"
);
}
#[test]
fn build_shared_trainer_with_init_decoder_family_proceeds_to_tensor_load() {
let cfg = TransformerConfig::qwen2_0_5b();
let dummy_path = std::path::PathBuf::from("/nonexistent/decoder.apr");
let result = build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&cfg), Some(&dummy_path));
let err = match result {
Ok(_) => panic!("missing tensor path must Err"),
Err(e) => e,
};
assert!(
err.contains("FALSIFY-APR-PRETRAIN-INIT-006"),
"decoder family proceeds to tensor load; failure cites INIT-006 not ARCH-007; got: {err}"
);
assert!(
!err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
"decoder family must NOT trigger encoder-rejection; got: {err}"
);
}
}