use super::*;
use crate::lm::tuner::optimizers::sgd::SGD;
struct FakeDataset {
samples: Vec<Example>,
}
impl FakeDataset {
fn new(n: usize, len: usize) -> Self {
let samples = (0..n)
.map(|i| ((0..len).map(|k| ((i + k) as u32) % 32).collect(), 0_usize))
.collect();
Self { samples }
}
}
impl Dataset for FakeDataset {
fn len(&self) -> usize {
self.samples.len()
}
fn get(&self, _idx: usize) -> Result<&serde_json::Value> {
Err(Error::InvariantViolation(InvariantViolationPayload::new(
"FakeDataset::get",
"is not used by the trainer iterator",
)))
}
fn process(&self, idx: usize) -> Result<Example> {
Ok((self.samples[idx].0.clone(), self.samples[idx].1))
}
}
struct FakeModel;
impl Model for FakeModel {
fn forward(&self, tokens: &Array, _cache: &mut [Box<dyn KvCache>]) -> Result<Array> {
let shape = tokens.shape();
let (b, s) = (shape[0], shape[1]);
let vocab = 8;
let n = b * s * vocab;
let buf = vec![0.1_f32; n];
Array::from_slice::<f32>(&buf, &(b, s, vocab))
}
}
#[test]
fn training_args_default_matches_python() {
let a = TrainingArgs::default();
assert_eq!(a.batch_size(), 4);
assert_eq!(a.iters(), 100);
assert_eq!(a.val_batches(), Some(25));
assert_eq!(a.steps_per_report(), 10);
assert_eq!(a.steps_per_eval(), 200);
assert_eq!(a.steps_per_save(), 100);
assert_eq!(a.max_seq_length(), 2048);
assert!(!a.grad_checkpoint());
assert_eq!(a.grad_accumulation_steps(), 1);
assert!(!a.acknowledge_no_real_gradients());
}
#[test]
fn default_loss_matches_masked_cross_entropy() -> Result<()> {
let model = FakeModel;
let batch = Array::from_slice::<i32>(&[1, 2, 3], &(1, 3))?;
let lengths = Array::from_slice::<i32>(&[1, 3], &(1, 2))?;
let (mut loss, mut ntoks) = default_loss(&model, &batch, &lengths)?;
let loss_v = loss.item::<f32>()?;
let ntoks_v = ntoks.item::<f32>()?;
assert!((loss_v - 8.0_f32.ln()).abs() < 1e-4, "got loss {loss_v}");
assert!((ntoks_v - 2.0).abs() < 1e-6, "got ntoks {ntoks_v}");
Ok(())
}
#[test]
fn default_loss_excludes_padded_target_at_length_boundary() -> Result<()> {
let model = FakeModel;
let batch = Array::from_slice::<i32>(&[1, 2, 0, 0], &(1, 4))?;
let lengths = Array::from_slice::<i32>(&[0, 2], &(1, 2))?;
let (mut loss, mut ntoks) = default_loss(&model, &batch, &lengths)?;
let loss_v = loss.item::<f32>()?;
let ntoks_v = ntoks.item::<f32>()?;
assert!(
(ntoks_v - 1.0).abs() < 1e-6,
"expected ntoks=1 (boundary pad excluded by `<` upper bound), got {ntoks_v}",
);
assert!(
(loss_v - 8.0_f32.ln()).abs() < 1e-4,
"expected loss=log(8) for single supervised token, got {loss_v}",
);
Ok(())
}
#[test]
fn iterate_batches_emits_expected_shape_for_known_dataset_size() -> Result<()> {
let dataset = FakeDataset::new(8, 4); let iter = iterate_batches(&dataset, 4, 64, false, None)?;
let mut count = 0;
for b in iter {
let b = b?;
assert_eq!(b.tokens_ref().shape()[0], 4);
assert_eq!(b.lengths_ref().shape(), &[4, 2]);
count += 1;
}
assert_eq!(count, 2, "8/4=2 batches expected");
Ok(())
}
#[test]
fn iterate_batches_rejects_too_small_dataset() {
let dataset = FakeDataset::new(2, 4);
let res = iterate_batches(&dataset, 4, 64, false, None);
assert!(res.is_err());
}
#[test]
fn iterate_batches_loop_forever_yields_more_batches_than_dataset_size() -> Result<()> {
let dataset = FakeDataset::new(4, 4); let mut iter = iterate_batches(&dataset, 4, 64, true, Some(0xCAFE))?;
for _ in 0..5 {
assert!(iter.next().is_some());
}
Ok(())
}
#[test]
fn evaluate_returns_correct_loss_for_known_eval_set() -> Result<()> {
let dataset = FakeDataset::new(4, 6); let model = FakeModel;
let loss = evaluate(&model, &dataset, 4, Some(1), 64, |m, b, l| {
default_loss(m, b, l)
})?;
assert!((loss - 8.0_f32.ln()).abs() < 1e-4, "got {loss}");
Ok(())
}
struct CountingCallback {
train_reports: usize,
val_reports: usize,
saves: usize,
}
impl TrainingCallback for CountingCallback {
fn on_train_loss_report(&mut self, _info: &TrainInfo) {
self.train_reports += 1;
}
fn on_val_loss_report(&mut self, _info: &ValInfo) {
self.val_reports += 1;
}
fn on_save(&mut self, _it: usize, _adapter_file: &str) -> Result<()> {
self.saves += 1;
Ok(())
}
}
#[test]
fn train_completes_n_iters_with_progress_callback() -> Result<()> {
let dataset = FakeDataset::new(4, 6); let model = FakeModel;
let mut params: Weights = HashMap::new();
params.insert("w".into(), Array::full::<f32>(&[0i32; 0], 1.0)?);
let mut sgd = SGD::vanilla(0.01)?;
let mut cb = CountingCallback {
train_reports: 0,
val_reports: 0,
saves: 0,
};
let args = TrainingArgs::new()
.with_iters(6)
.with_steps_per_report(2)
.with_steps_per_eval(4)
.with_steps_per_save(3)
.with_batch_size(4)
.with_max_seq_length(64)
.with_val_batches(Some(1))
.with_acknowledge_no_real_gradients(true);
train(
&model,
&mut sgd,
&mut params,
&dataset,
Some(&dataset),
&args,
default_loss,
&mut cb,
)?;
assert_eq!(cb.train_reports, 3);
assert_eq!(cb.val_reports, 3);
assert_eq!(cb.saves, 3);
Ok(())
}
#[test]
fn grad_checkpoint_wraps_layer_without_changing_output() -> Result<()> {
let plain = |xs: &[Array]| Ok(vec![crate::ops::arithmetic::square(&xs[0])?]);
let wrapped = grad_checkpoint(plain)?;
let x = Array::full::<f32>(&[0i32; 0], 3.0)?;
let mut out = wrapped(&[x])?;
assert_eq!(out[0].item::<f32>()?, 9.0);
Ok(())
}
#[test]
fn train_rejects_when_acknowledge_no_real_gradients_is_false() -> Result<()> {
let dataset = FakeDataset::new(4, 6);
let model = FakeModel;
let mut params: Weights = HashMap::new();
params.insert("w".into(), Array::full::<f32>(&[0i32; 0], 1.0)?);
let mut sgd = SGD::vanilla(0.01)?;
let mut cb = NoopCallback;
let args = TrainingArgs::new()
.with_iters(1)
.with_batch_size(4)
.with_max_seq_length(64)
.with_val_batches(Some(1));
assert!(!args.acknowledge_no_real_gradients());
let res = train(
&model,
&mut sgd,
&mut params,
&dataset,
None,
&args,
default_loss,
&mut cb,
);
match res {
Err(Error::InvariantViolation(payload)) => {
assert_eq!(
payload.context(),
"train: TrainingArgs::acknowledge_no_real_gradients"
);
assert_eq!(
payload.requirement(),
"must be set to `true` to run the v1 mechanics-only training path"
);
}
other => panic!("expected Err(InvariantViolation), got {other:?}"),
}
Ok(())
}
#[test]
fn train_runs_when_acknowledge_no_real_gradients_is_true() -> Result<()> {
let dataset = FakeDataset::new(4, 6);
let model = FakeModel;
let mut params: Weights = HashMap::new();
params.insert("w".into(), Array::full::<f32>(&[0i32; 0], 1.0)?);
let mut sgd = SGD::vanilla(0.01)?;
let mut cb = NoopCallback;
let args = TrainingArgs::new()
.with_iters(1)
.with_batch_size(4)
.with_max_seq_length(64)
.with_val_batches(Some(1))
.with_acknowledge_no_real_gradients(true);
let res = train(
&model,
&mut sgd,
&mut params,
&dataset,
None,
&args,
default_loss,
&mut cb,
);
assert!(
res.is_ok(),
"train should run when opt-in is set; got {res:?}"
);
Ok(())
}
fn args_for_zero_interval_tests() -> TrainingArgs {
TrainingArgs::new()
.with_iters(1)
.with_batch_size(4)
.with_max_seq_length(64)
.with_val_batches(Some(1))
.with_acknowledge_no_real_gradients(true)
}
fn run_train_with_args(args: &TrainingArgs) -> crate::Result<()> {
let dataset = FakeDataset::new(4, 6);
let model = FakeModel;
let mut params: Weights = HashMap::new();
params.insert("w".into(), Array::full::<f32>(&[0i32; 0], 1.0)?);
let mut sgd = SGD::vanilla(0.01)?;
let mut cb = NoopCallback;
train(
&model,
&mut sgd,
&mut params,
&dataset,
None,
args,
default_loss,
&mut cb,
)
}
#[test]
fn train_rejects_zero_steps_per_report() {
let args = args_for_zero_interval_tests().with_steps_per_report(0);
let res = run_train_with_args(&args);
match res {
Err(Error::InvariantViolation(payload)) => {
assert_eq!(payload.context(), "train: steps_per_report");
assert_eq!(payload.requirement(), "must be >= 1");
}
other => panic!("expected Err(InvariantViolation) for steps_per_report=0; got {other:?}"),
}
}
#[test]
fn train_rejects_zero_steps_per_eval() {
let args = args_for_zero_interval_tests().with_steps_per_eval(0);
let res = run_train_with_args(&args);
match res {
Err(Error::InvariantViolation(payload)) => {
assert_eq!(payload.context(), "train: steps_per_eval");
assert_eq!(payload.requirement(), "must be >= 1");
}
other => panic!("expected Err(InvariantViolation) for steps_per_eval=0; got {other:?}"),
}
}
#[test]
fn train_rejects_zero_steps_per_save() {
let args = args_for_zero_interval_tests().with_steps_per_save(0);
let res = run_train_with_args(&args);
match res {
Err(Error::InvariantViolation(payload)) => {
assert_eq!(payload.context(), "train: steps_per_save");
assert_eq!(payload.requirement(), "must be >= 1");
}
other => panic!("expected Err(InvariantViolation) for steps_per_save=0; got {other:?}"),
}
}
#[test]
fn train_rejects_zero_grad_accumulation_steps() {
let args = args_for_zero_interval_tests().with_grad_accumulation_steps(0);
let res = run_train_with_args(&args);
match res {
Err(Error::InvariantViolation(payload)) => {
assert_eq!(payload.context(), "train: grad_accumulation_steps");
assert_eq!(payload.requirement(), "must be >= 1");
}
other => {
panic!("expected Err(InvariantViolation) for grad_accumulation_steps=0; got {other:?}")
}
}
}
struct CountingOptimizer {
apply_calls: usize,
step_count: usize,
lr: f32,
}
impl crate::lm::tuner::optimizers::Optimizer for CountingOptimizer {
fn init(&mut self, _params: &Weights) -> Result<()> {
Ok(())
}
fn apply_gradients(&mut self, _gradients: &Weights, _params: &mut Weights) -> Result<()> {
self.apply_calls += 1;
self.step_count += 1;
Ok(())
}
fn step(&self) -> usize {
self.step_count
}
fn learning_rate(&self) -> f32 {
self.lr
}
}
fn build_train_fixture() -> Result<(
FakeModel,
FakeDataset,
Weights,
NoopCallback,
CountingOptimizer,
)> {
let dataset = FakeDataset::new(4, 6);
let model = FakeModel;
let mut params: Weights = HashMap::new();
params.insert("w".into(), Array::full::<f32>(&[0i32; 0], 1.0)?);
let cb = NoopCallback;
let opt = CountingOptimizer {
apply_calls: 0,
step_count: 0,
lr: 0.0,
};
Ok((model, dataset, params, cb, opt))
}
#[test]
fn grad_accumulation_steps_2_calls_optimizer_every_other_iter() -> Result<()> {
let (model, dataset, mut params, mut cb, mut opt) = build_train_fixture()?;
let args = TrainingArgs::new()
.with_iters(10)
.with_grad_accumulation_steps(2)
.with_steps_per_report(100)
.with_steps_per_eval(100)
.with_steps_per_save(100)
.with_batch_size(4)
.with_max_seq_length(64)
.with_val_batches(Some(1))
.with_acknowledge_no_real_gradients(true);
train(
&model,
&mut opt,
&mut params,
&dataset,
None,
&args,
default_loss,
&mut cb,
)?;
assert_eq!(
opt.apply_calls, 5,
"iters=10 + grad_accumulation_steps=2 must produce 5 optimizer steps; got {}",
opt.apply_calls,
);
Ok(())
}
#[test]
fn grad_accumulation_steps_partial_window_at_end_drops() -> Result<()> {
let (model, dataset, mut params, mut cb, mut opt) = build_train_fixture()?;
let args = TrainingArgs::new()
.with_iters(11)
.with_grad_accumulation_steps(4)
.with_steps_per_report(100)
.with_steps_per_eval(100)
.with_steps_per_save(100)
.with_batch_size(4)
.with_max_seq_length(64)
.with_val_batches(Some(1))
.with_acknowledge_no_real_gradients(true);
train(
&model,
&mut opt,
&mut params,
&dataset,
None,
&args,
default_loss,
&mut cb,
)?;
assert_eq!(
opt.apply_calls, 2,
"iters=11 + grad_accumulation_steps=4 must drop the final partial \
window of 3 microbatches; expected 2 optimizer calls, got {}",
opt.apply_calls,
);
Ok(())
}
#[test]
fn grad_accumulation_steps_1_is_identity_to_microbatch_count() -> Result<()> {
let (model, dataset, mut params, mut cb, mut opt) = build_train_fixture()?;
let args = TrainingArgs::new()
.with_iters(7)
.with_grad_accumulation_steps(1)
.with_steps_per_report(100)
.with_steps_per_eval(100)
.with_steps_per_save(100)
.with_batch_size(4)
.with_max_seq_length(64)
.with_val_batches(Some(1))
.with_acknowledge_no_real_gradients(true);
train(
&model,
&mut opt,
&mut params,
&dataset,
None,
&args,
default_loss,
&mut cb,
)?;
assert_eq!(opt.apply_calls, 7);
Ok(())
}
struct LossRecordingCallback {
losses: Vec<f32>,
}
impl TrainingCallback for LossRecordingCallback {
fn on_train_loss_report(&mut self, info: &TrainInfo) {
self.losses.push(info.train_loss());
}
}
#[test]
fn grad_accumulation_steps_4_reports_constant_loss_at_2_not_8() -> Result<()> {
let dataset = FakeDataset::new(4, 6);
let model = FakeModel;
let mut params: Weights = HashMap::new();
params.insert("w".into(), Array::full::<f32>(&[0i32; 0], 1.0)?);
let mut opt = CountingOptimizer {
apply_calls: 0,
step_count: 0,
lr: 0.0,
};
let mut cb = LossRecordingCallback { losses: Vec::new() };
let const_loss_fn =
|_m: &FakeModel, _batch: &Array, _lengths: &Array| -> Result<(Array, Array)> {
let loss = Array::full::<f32>(&[0i32; 0], 2.0)?;
let ntoks = Array::full::<f32>(&[0i32; 0], 1.0)?;
Ok((loss, ntoks))
};
let args = TrainingArgs::new()
.with_iters(12)
.with_grad_accumulation_steps(4)
.with_steps_per_report(1)
.with_steps_per_eval(100)
.with_steps_per_save(100)
.with_batch_size(4)
.with_max_seq_length(64)
.with_val_batches(Some(1))
.with_acknowledge_no_real_gradients(true);
train(
&model,
&mut opt,
&mut params,
&dataset,
None,
&args,
const_loss_fn,
&mut cb,
)?;
assert_eq!(
cb.losses.len(),
3,
"iters=12 + grad_accumulation_steps=4 + steps_per_report=1 must fire 3 train-loss reports; got {}",
cb.losses.len(),
);
for (i, &loss) in cb.losses.iter().enumerate() {
assert!(
(loss - 2.0).abs() < 1e-6,
"report #{i} train_loss = {loss}, expected 2.0 (per-microbatch loss); dividing \
`window_loss / window_steps` (4×constant-2.0 by 1 optimizer-step) would wrongly \
report 8.0",
);
}
Ok(())
}
#[test]
fn default_loss_rejects_zero_token_batch_after_mask() -> Result<()> {
let model = FakeModel;
let batch = Array::from_slice::<i32>(&[0, 0, 0, 0], &(2, 2))?;
let lengths = Array::from_slice::<i32>(&[0, 1, 0, 1], &(2, 2))?;
let err = default_loss(&model, &batch, &lengths)
.expect_err("expected default_loss to reject zero-token batch");
match err {
Error::EmptyInput(p) => {
assert!(
p.context().contains("0 supervised tokens"),
"expected context to mention '0 supervised tokens', got: {}",
p.context(),
);
}
other => panic!("expected Error::EmptyInput, got: {other:?}"),
}
Ok(())
}
#[test]
fn default_loss_rejects_lengths_with_extra_batch_row() -> Result<()> {
let model = FakeModel;
let batch = Array::from_slice::<i32>(&[1, 2, 3, 4], &(2, 2))?;
let lengths = Array::from_slice::<i32>(&[0, 2, 0, 2, 0, 2], &(3, 2))?;
let err = default_loss(&model, &batch, &lengths)
.expect_err("expected ShapePairMismatch for extra length row");
match err {
Error::ShapePairMismatch(p) => {
assert_eq!(p.expected(), &[2_usize, 2_usize][..]);
assert_eq!(p.actual(), &[3_usize, 2_usize][..]);
}
other => panic!("expected Error::ShapePairMismatch, got: {other:?}"),
}
Ok(())
}
#[test]
fn default_loss_rejects_lengths_with_missing_batch_row() -> Result<()> {
let model = FakeModel;
let batch = Array::from_slice::<i32>(&[1, 2, 3, 4], &(2, 2))?;
let lengths = Array::from_slice::<i32>(&[0, 2], &(1, 2))?;
let err = default_loss(&model, &batch, &lengths)
.expect_err("expected ShapePairMismatch for missing length row");
match err {
Error::ShapePairMismatch(p) => {
assert_eq!(p.expected(), &[2_usize, 2_usize][..]);
assert_eq!(p.actual(), &[1_usize, 2_usize][..]);
}
other => panic!("expected Error::ShapePairMismatch, got: {other:?}"),
}
Ok(())
}
#[test]
fn training_args_clear_cache_threshold_getter_and_builder() {
let a = TrainingArgs::new();
assert_eq!(a.clear_cache_threshold(), 0);
let a = a.with_clear_cache_threshold(4096);
assert_eq!(a.clear_cache_threshold(), 4096);
}
#[test]
fn training_args_adapter_file_builder_accepts_str_and_string() {
let a = TrainingArgs::new();
assert_eq!(a.adapter_file(), "adapters.safetensors");
let a = a.with_adapter_file("custom/path.safetensors");
assert_eq!(a.adapter_file(), "custom/path.safetensors");
let owned = String::from("owned/adapters.safetensors");
let a = a.with_adapter_file(owned);
assert_eq!(a.adapter_file(), "owned/adapters.safetensors");
}
#[test]
fn training_args_grad_checkpoint_builder_flips_flag() {
let a = TrainingArgs::new();
assert!(!a.grad_checkpoint());
let a = a.with_grad_checkpoint(true);
assert!(a.grad_checkpoint());
}
#[test]
fn default_loss_rejects_non_rank_2_batch() -> Result<()> {
let model = FakeModel;
let batch = Array::from_slice::<i32>(&[1, 2, 3], &(3usize,))?;
let lengths = Array::from_slice::<i32>(&[0, 3], &(1, 2))?;
let err =
default_loss(&model, &batch, &lengths).expect_err("expected RankMismatch for rank-1 batch");
match err {
Error::RankMismatch(p) => {
assert_eq!(p.context(), "default_loss: batch must be rank-2 [B, S]");
assert_eq!(p.actual(), 1);
assert_eq!(p.actual_shape(), &[3_usize][..]);
}
other => panic!("expected Error::RankMismatch, got: {other:?}"),
}
Ok(())
}
#[test]
fn default_loss_rejects_rank_3_batch() -> Result<()> {
let model = FakeModel;
let batch = Array::from_slice::<i32>(&[1, 2], &(1usize, 1usize, 2usize))?;
let lengths = Array::from_slice::<i32>(&[0, 2], &(1, 2))?;
let err =
default_loss(&model, &batch, &lengths).expect_err("expected RankMismatch for rank-3 batch");
match err {
Error::RankMismatch(p) => {
assert_eq!(p.actual(), 3);
assert_eq!(p.actual_shape(), &[1_usize, 1, 2][..]);
}
other => panic!("expected Error::RankMismatch, got: {other:?}"),
}
Ok(())
}
#[test]
fn default_loss_rejects_seq_len_below_2() -> Result<()> {
let model = FakeModel;
let batch = Array::from_slice::<i32>(&[5], &(1usize, 1usize))?;
let lengths = Array::from_slice::<i32>(&[0, 1], &(1, 2))?;
let err = default_loss(&model, &batch, &lengths).expect_err("expected OutOfRange for S < 2");
match err {
Error::OutOfRange(p) => {
assert_eq!(p.context(), "default_loss: batch S");
assert_eq!(p.requirement(), "must be >= 2 for next-token prediction");
assert_eq!(p.value(), "1");
}
other => panic!("expected Error::OutOfRange, got: {other:?}"),
}
Ok(())
}
#[test]
fn train_info_accessors_return_constructor_inputs() {
let info = TrainInfo::new(7, 1.5, 0.001, 12.0, 480.0, 3_200);
assert_eq!(info.iteration(), 7);
assert_eq!(info.train_loss(), 1.5);
assert_eq!(info.learning_rate(), 0.001);
assert_eq!(info.iterations_per_second(), 12.0);
assert_eq!(info.tokens_per_second(), 480.0);
assert_eq!(info.trained_tokens(), 3_200);
}
#[test]
fn val_info_accessors_return_constructor_inputs() {
let info = ValInfo::new(42, 2.25, 0.75);
assert_eq!(info.iteration(), 42);
assert_eq!(info.val_loss(), 2.25);
assert_eq!(info.val_time(), 0.75);
}
#[test]
fn noop_callback_default_methods_are_no_ops() -> Result<()> {
let mut cb = NoopCallback;
cb.on_train_loss_report(&TrainInfo::new(1, 0.0, 0.0, 0.0, 0.0, 0));
cb.on_val_loss_report(&ValInfo::new(0, 0.0, 0.0));
cb.on_save(3, "adapters.safetensors")?;
Ok(())
}
#[test]
fn build_batch_clamps_padded_length_to_max_seq_length() -> Result<()> {
let dataset = FakeDataset::new(4, 4);
let iter = iterate_batches(&dataset, 4, 8, false, None)?;
let mut saw = false;
for b in iter {
let b = b?;
assert_eq!(
b.tokens_ref().shape(),
&[4, 8],
"padded width must be clamped to max_seq_length=8 (un-clamped would be 33)",
);
saw = true;
}
assert!(saw, "expected at least one batch");
Ok(())
}
#[test]
fn fisher_yates_shuffle_is_a_deterministic_permutation() {
let mut a: Vec<usize> = (0..16).collect();
let mut b: Vec<usize> = (0..16).collect();
fisher_yates_shuffle(&mut a, 0xDEAD_BEEF);
fisher_yates_shuffle(&mut b, 0xDEAD_BEEF);
assert_eq!(a, b, "same seed must produce the same permutation");
let mut sorted = a.clone();
sorted.sort_unstable();
assert_eq!(
sorted,
(0..16).collect::<Vec<_>>(),
"shuffle must be a permutation (no lost/duplicated elements)",
);
}
#[test]
fn fisher_yates_shuffle_different_seeds_can_differ() {
let mut a: Vec<usize> = (0..16).collect();
let mut b: Vec<usize> = (0..16).collect();
fisher_yates_shuffle(&mut a, 1);
fisher_yates_shuffle(&mut b, 999_999);
assert_ne!(a, b, "distinct seeds should yield distinct permutations");
}
#[test]
fn iterate_batches_shuffle_over_multiple_batches_runs_shuffle_body() -> Result<()> {
let dataset = FakeDataset::new(8, 4);
let mut iter = iterate_batches(&dataset, 4, 64, true, Some(0x1234))?;
for _ in 0..4 {
let b = iter.next().expect("loop_forever must not exhaust")?;
assert_eq!(b.tokens_ref().shape()[0], 4);
}
Ok(())
}
#[test]
fn batch_iter_with_empty_batch_idx_yields_none() {
let dataset = FakeDataset::new(4, 4);
let mut iter = BatchIter {
dataset: &dataset,
batch_idx: Vec::new(),
max_seq_length: 64,
cursor: 0,
order: Vec::new(),
loop_forever: true,
shuffle_seed: None,
rng_state: None,
first_pass: true,
};
assert!(
iter.next().is_none(),
"empty batch_idx must yield None even with loop_forever=true",
);
}
#[test]
fn evaluate_stops_after_num_batches_cap() -> Result<()> {
let dataset = FakeDataset::new(8, 6);
let model = FakeModel;
let loss = evaluate(&model, &dataset, 4, Some(1), 64, |m, b, l| {
default_loss(m, b, l)
})?;
assert!(
(loss - 8.0_f32.ln()).abs() < 1e-4,
"capped eval over 1 batch must report log(8); got {loss}",
);
Ok(())
}
#[test]
fn evaluate_rejects_eval_set_that_produces_no_tokens() {
let dataset = FakeDataset::new(4, 6);
let model = FakeModel;
let zero_tok = |_m: &FakeModel, _b: &Array, _l: &Array| -> Result<(Array, Array)> {
let loss = Array::full::<f32>(&[0i32; 0], 1.0)?;
let ntoks = Array::full::<f32>(&[0i32; 0], 0.0)?;
Ok((loss, ntoks))
};
let err = evaluate(&model, &dataset, 4, Some(1), 64, zero_tok)
.expect_err("expected EmptyInput when eval produces no tokens");
match err {
Error::EmptyInput(p) => {
assert!(
p.context().contains("produced no batches with tokens"),
"unexpected context: {}",
p.context(),
);
}
other => panic!("expected Error::EmptyInput, got: {other:?}"),
}
}
#[test]
fn train_with_zero_iters_returns_ok_without_firing_callbacks() -> Result<()> {
let dataset = FakeDataset::new(4, 6);
let model = FakeModel;
let mut params: Weights = HashMap::new();
params.insert("w".into(), Array::full::<f32>(&[0i32; 0], 1.0)?);
let mut sgd = SGD::vanilla(0.01)?;
let mut cb = CountingCallback {
train_reports: 0,
val_reports: 0,
saves: 0,
};
let args = TrainingArgs::new()
.with_iters(0)
.with_batch_size(4)
.with_max_seq_length(64)
.with_val_batches(Some(1))
.with_acknowledge_no_real_gradients(true);
train(
&model,
&mut sgd,
&mut params,
&dataset,
Some(&dataset),
&args,
default_loss,
&mut cb,
)?;
assert_eq!(cb.train_reports, 0, "no train reports for iters=0");
assert_eq!(cb.val_reports, 0, "no eval for iters=0");
assert_eq!(
cb.saves, 0,
"iters=0 returns before the final save hook, so no save fires",
);
Ok(())
}
#[test]
fn add_weights_rejects_mismatched_key_counts() -> Result<()> {
let mut a: Weights = HashMap::new();
a.insert("x".into(), Array::full::<f32>(&[0i32; 0], 1.0)?);
a.insert("y".into(), Array::full::<f32>(&[0i32; 0], 2.0)?);
let mut b: Weights = HashMap::new();
b.insert("x".into(), Array::full::<f32>(&[0i32; 0], 3.0)?);
let err = add_weights(&a, &b).expect_err("expected LengthMismatch for key-count skew");
match err {
Error::LengthMismatch(p) => {
assert_eq!(p.context(), "trainer::add_weights: lhs vs rhs key counts");
assert_eq!(p.expected(), 2);
assert_eq!(p.actual(), 1);
}
other => panic!("expected Error::LengthMismatch, got: {other:?}"),
}
Ok(())
}
#[test]
fn add_weights_rejects_key_present_in_lhs_but_missing_from_rhs() -> Result<()> {
let mut a: Weights = HashMap::new();
a.insert("x".into(), Array::full::<f32>(&[0i32; 0], 1.0)?);
let mut b: Weights = HashMap::new();
b.insert("y".into(), Array::full::<f32>(&[0i32; 0], 2.0)?);
let err = add_weights(&a, &b).expect_err("expected MissingKey for disjoint key sets");
match err {
Error::MissingKey(p) => {
assert_eq!(p.context(), "trainer::add_weights: key missing from rhs");
assert_eq!(p.key(), "x");
}
other => panic!("expected Error::MissingKey, got: {other:?}"),
}
Ok(())
}