#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_distill_strategy_parse() {
assert!(matches!(
"standard".parse::<DistillStrategy>(),
Ok(DistillStrategy::Standard)
));
assert!(matches!(
"kl".parse::<DistillStrategy>(),
Ok(DistillStrategy::Standard)
));
assert!(matches!(
"progressive".parse::<DistillStrategy>(),
Ok(DistillStrategy::Progressive)
));
assert!(matches!(
"ensemble".parse::<DistillStrategy>(),
Ok(DistillStrategy::Ensemble)
));
assert!("unknown".parse::<DistillStrategy>().is_err());
}
#[test]
fn test_run_teacher_not_found() {
let result = run(
Some(Path::new("/nonexistent.apr")),
None,
None,
Some(Path::new("/tmp/out.apr")),
"standard",
3.0,
0.7,
3,
false,
None,
None,
false,
);
assert!(result.is_err());
assert!(matches!(result, Err(CliError::FileNotFound(_))));
}
#[test]
fn test_run_invalid_temperature() {
let input = NamedTempFile::with_suffix(".apr").expect("create input");
let result = run(
Some(input.path()),
None,
None,
Some(Path::new("/tmp/out.apr")),
"standard",
0.0,
0.7,
3,
false,
None,
None,
false,
);
assert!(result.is_err());
match result {
Err(CliError::ValidationFailed(msg)) => assert!(msg.contains("Temperature")),
_ => panic!("Expected ValidationFailed"),
}
}
#[test]
fn test_run_invalid_alpha() {
let input = NamedTempFile::with_suffix(".apr").expect("create input");
let result = run(
Some(input.path()),
None,
None,
Some(Path::new("/tmp/out.apr")),
"standard",
3.0,
1.5,
3,
false,
None,
None,
false,
);
assert!(result.is_err());
match result {
Err(CliError::ValidationFailed(msg)) => assert!(msg.contains("Alpha")),
_ => panic!("Expected ValidationFailed"),
}
}
#[test]
fn test_run_no_student() {
let mut input = NamedTempFile::with_suffix(".apr").expect("create input");
input.write_all(&[0u8; 512]).expect("write");
let result = run(
Some(input.path()),
None,
None,
Some(Path::new("/tmp/out.apr")),
"standard",
3.0,
0.7,
3,
false,
None,
None,
false,
);
assert!(result.is_err());
match result {
Err(CliError::ValidationFailed(msg)) => assert!(msg.contains("Student")),
_ => panic!("Expected ValidationFailed"),
}
}
#[test]
fn test_run_no_output() {
let mut teacher = NamedTempFile::with_suffix(".apr").expect("create teacher");
teacher.write_all(&[0u8; 512]).expect("write");
let mut student = NamedTempFile::with_suffix(".apr").expect("create student");
student.write_all(&[0u8; 256]).expect("write");
let result = run(
Some(teacher.path()),
Some(student.path()),
None,
None,
"standard",
3.0,
0.7,
3,
false,
None,
None,
false,
);
assert!(result.is_err());
match result {
Err(CliError::ValidationFailed(msg)) => assert!(msg.contains("Output")),
_ => panic!("Expected ValidationFailed"),
}
}
fn make_test_model() -> NamedTempFile {
let mut writer = aprender::serialization::apr::AprWriter::new();
writer.set_metadata("model_type", serde_json::json!("test"));
let w0: Vec<f32> = (0..64).map(|i| (i as f32) * 0.01).collect();
writer.add_tensor_f32("model.layers.0.self_attn.q_proj.weight", vec![8, 8], &w0);
let w1: Vec<f32> = (0..64).map(|i| (i as f32) * 0.02).collect();
writer.add_tensor_f32("model.layers.1.self_attn.q_proj.weight", vec![8, 8], &w1);
writer.add_tensor_f32("model.norm.weight", vec![8], &vec![1.0; 8]);
writer.add_tensor_f32("model.embed_tokens.weight", vec![10, 8], &vec![0.1; 80]);
let file = NamedTempFile::with_suffix(".apr").expect("create model");
let bytes = writer.to_bytes().expect("serialize");
std::fs::write(file.path(), bytes).expect("write");
file
}
#[test]
fn test_run_valid() {
let teacher = make_test_model();
let student = make_test_model();
let output = NamedTempFile::with_suffix(".apr").expect("create output");
let result = run(
Some(teacher.path()),
Some(student.path()),
None,
Some(output.path()),
"standard",
3.0,
0.7,
3,
false,
None,
None,
true,
);
assert!(result.is_ok(), "Distill should succeed: {result:?}");
let reader = aprender::serialization::apr::AprReader::open(output.path())
.expect("output should be valid APR");
assert!(!reader.tensors.is_empty(), "Output should have tensors");
assert!(reader.get_metadata("distillation_teacher").is_some());
}
#[test]
fn test_plan_mode() {
let teacher = make_test_model();
let result = run(
Some(teacher.path()),
None,
None,
None,
"standard",
3.0,
0.7,
3,
true,
None,
None,
false,
);
assert!(result.is_ok());
}
#[test]
fn test_plan_json() {
let teacher = make_test_model();
let result = run(
Some(teacher.path()),
None,
None,
None,
"progressive",
4.0,
0.5,
5,
true,
None,
None,
true,
);
assert!(result.is_ok());
}
#[test]
fn test_progressive_no_student() {
let teacher = make_test_model();
let output = NamedTempFile::with_suffix(".apr").expect("create output");
let result = run(
Some(teacher.path()),
None,
None,
Some(output.path()),
"progressive",
3.0,
0.7,
3,
false,
None,
None,
true,
);
assert!(result.is_ok(), "Progressive should succeed: {result:?}");
let reader = aprender::serialization::apr::AprReader::open(output.path())
.expect("output should be valid APR");
let layer_names: Vec<_> = reader
.tensors
.iter()
.filter(|t| t.name.contains("layers.1."))
.collect();
assert!(
layer_names.is_empty(),
"Layer 1 should be dropped by progressive distillation"
);
let layer0_names: Vec<_> = reader
.tensors
.iter()
.filter(|t| t.name.contains("layers.0."))
.collect();
assert!(!layer0_names.is_empty(), "Layer 0 should be kept");
}
#[test]
fn test_extract_layer_number() {
assert_eq!(
extract_layer_number("model.layers.5.self_attn.q_proj.weight"),
Some(5)
);
assert_eq!(extract_layer_number("blk.0.attn_q.weight"), Some(0));
assert_eq!(extract_layer_number("model.norm.weight"), None);
assert_eq!(extract_layer_number("lm_head.weight"), None);
}
#[test]
fn test_create_student_progressive() {
let mut tensors = std::collections::BTreeMap::new();
tensors.insert(
"model.layers.0.weight".to_string(),
(vec![1.0; 4], vec![2, 2]),
);
tensors.insert(
"model.layers.1.weight".to_string(),
(vec![2.0; 4], vec![2, 2]),
);
tensors.insert(
"model.layers.2.weight".to_string(),
(vec![3.0; 4], vec![2, 2]),
);
tensors.insert(
"model.layers.3.weight".to_string(),
(vec![4.0; 4], vec![2, 2]),
);
tensors.insert("model.norm.weight".to_string(), (vec![1.0; 2], vec![2]));
let student = create_student_from_teacher(&tensors, DistillStrategy::Progressive);
assert_eq!(student.len(), 3);
assert!(student.contains_key("model.layers.0.weight"));
assert!(!student.contains_key("model.layers.1.weight"));
assert!(student.contains_key("model.layers.2.weight"));
assert!(!student.contains_key("model.layers.3.weight"));
assert!(student.contains_key("model.norm.weight"));
}
#[test]
fn test_create_student_standard() {
let mut tensors = std::collections::BTreeMap::new();
tensors.insert("a".to_string(), (vec![1.0], vec![1]));
tensors.insert("b".to_string(), (vec![2.0], vec![1]));
let student = create_student_from_teacher(&tensors, DistillStrategy::Standard);
assert_eq!(student.len(), 2, "Standard copies all tensors");
}
}