use colored::Colorize;
use crate::error::{CliError, Result};
use crate::output;
use serde::Deserialize;
use std::path::Path;
#[derive(Debug, Clone, Deserialize)]
struct DistillYamlConfig {
teacher: DistillTeacherConfig,
student: DistillStudentConfig,
#[serde(default)]
distillation: DistillLossConfig,
#[serde(default)]
training: DistillTrainingConfig,
dataset: DistillDatasetConfig,
#[serde(default)]
output: DistillOutputConfig,
}
#[derive(Debug, Clone, Deserialize)]
struct DistillTeacherConfig {
model_id: String,
#[serde(default)]
load_in_8bit: bool,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct DistillStudentConfig {
model_id: String,
#[serde(default)]
load_in_4bit: bool,
lora: Option<DistillLoraConfig>,
}
#[derive(Debug, Clone, Deserialize)]
struct DistillLoraConfig {
rank: usize,
#[serde(default = "default_lora_alpha")]
alpha: f64,
}
fn default_lora_alpha() -> f64 {
32.0
}
#[derive(Debug, Clone, Deserialize)]
struct DistillLossConfig {
#[serde(default = "default_temperature")]
temperature: f32,
#[serde(default = "default_alpha")]
alpha: f32,
progressive: Option<DistillProgressiveConfig>,
attention_transfer: Option<DistillAttentionConfig>,
}
impl Default for DistillLossConfig {
fn default() -> Self {
Self {
temperature: 4.0,
alpha: 0.7,
progressive: None,
attention_transfer: None,
}
}
}
fn default_temperature() -> f32 {
4.0
}
fn default_alpha() -> f32 {
0.7
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct DistillProgressiveConfig {
layer_mapping: Vec<[usize; 2]>,
#[serde(default = "default_hidden_weight")]
hidden_weight: f32,
}
fn default_hidden_weight() -> f32 {
1.0
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct DistillAttentionConfig {
#[serde(default = "default_attention_weight")]
weight: f32,
}
fn default_attention_weight() -> f32 {
0.1
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct DistillTrainingConfig {
#[serde(default = "default_epochs")]
epochs: usize,
#[serde(default = "default_batch_size")]
batch_size: usize,
#[serde(default = "default_lr")]
learning_rate: f64,
#[serde(default)]
weight_decay: f64,
#[serde(default)]
gradient_checkpointing: bool,
mixed_precision: Option<String>,
#[serde(default = "default_max_grad_norm")]
max_grad_norm: f32,
#[serde(default = "default_seed")]
seed: u64,
}
impl Default for DistillTrainingConfig {
fn default() -> Self {
Self {
epochs: 3,
batch_size: 16,
learning_rate: 0.0002,
weight_decay: 0.01,
gradient_checkpointing: false,
mixed_precision: None,
max_grad_norm: 1.0,
seed: 42,
}
}
}
fn default_epochs() -> usize {
3
}
fn default_batch_size() -> usize {
16
}
fn default_lr() -> f64 {
0.0002
}
fn default_max_grad_norm() -> f32 {
1.0
}
fn default_seed() -> u64 {
42
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct DistillDatasetConfig {
path: String,
#[serde(default = "default_max_seq_length")]
max_seq_length: usize,
#[serde(default)]
max_train_examples: Option<usize>,
}
fn default_max_seq_length() -> usize {
512
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct DistillOutputConfig {
#[serde(default = "default_output_dir")]
dir: String,
#[serde(default = "default_log_steps")]
log_steps: usize,
#[serde(default = "default_save_steps")]
save_steps: usize,
#[serde(default = "default_eval_steps")]
eval_steps: usize,
}
impl Default for DistillOutputConfig {
fn default() -> Self {
Self {
dir: "./outputs/distill".to_string(),
log_steps: 10,
save_steps: 500,
eval_steps: 100,
}
}
}
fn default_output_dir() -> String {
"./outputs/distill".to_string()
}
fn default_log_steps() -> usize {
10
}
fn default_save_steps() -> usize {
500
}
fn default_eval_steps() -> usize {
100
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct TextDistillConfig {
teacher: TextTeacherConfig,
#[serde(default)]
student: Option<TextStudentConfig>,
synthetic_data: SyntheticDataConfig,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct TextTeacherConfig {
model: String,
#[serde(default)]
tokenizer: Option<String>,
#[serde(default)]
precision: Option<String>,
#[serde(default = "default_gpu")]
gpu: bool,
#[serde(default = "default_max_tokens")]
max_tokens: u32,
#[serde(default = "default_gen_temperature")]
temperature: f32,
#[serde(default = "default_top_p")]
top_p: f32,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct TextStudentConfig {
checkpoint: String,
tokenizer: String,
#[serde(default)]
config: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct SyntheticDataConfig {
prompts: String,
output: String,
#[serde(default = "default_target_tokens")]
target_tokens: u64,
#[serde(default = "default_samples_per_prompt")]
samples_per_prompt: u32,
#[serde(default = "default_min_completion_tokens")]
min_completion_tokens: u32,
#[serde(default = "default_max_prompt_chars")]
max_prompt_chars: usize,
}
fn default_gpu() -> bool {
true
}
fn default_max_tokens() -> u32 {
256
}
fn default_gen_temperature() -> f32 {
0.8
}
fn default_top_p() -> f32 {
0.95
}
fn default_target_tokens() -> u64 {
500_000
}
fn default_samples_per_prompt() -> u32 {
1
}
fn default_min_completion_tokens() -> u32 {
10
}
fn default_max_prompt_chars() -> usize {
2048
}
impl DistillYamlConfig {
fn load(path: &Path) -> Result<Self> {
let content = std::fs::read_to_string(path)
.map_err(|e| CliError::ValidationFailed(format!("Failed to read config: {e}")))?;
serde_yaml::from_str(&content)
.map_err(|e| CliError::ValidationFailed(format!("Failed to parse YAML: {e}")))
}
fn validate(&self) -> Result<()> {
if self.teacher.model_id.is_empty() {
return Err(CliError::ValidationFailed(
"teacher.model_id cannot be empty".into(),
));
}
if self.student.model_id.is_empty() {
return Err(CliError::ValidationFailed(
"student.model_id cannot be empty".into(),
));
}
if self.distillation.temperature <= 0.0 {
return Err(CliError::ValidationFailed(
"distillation.temperature must be positive".into(),
));
}
if !(0.0..=1.0).contains(&self.distillation.alpha) {
return Err(CliError::ValidationFailed(
"distillation.alpha must be between 0 and 1".into(),
));
}
if self.training.batch_size == 0 {
return Err(CliError::ValidationFailed(
"training.batch_size must be > 0".into(),
));
}
if self.training.learning_rate <= 0.0 {
return Err(CliError::ValidationFailed(
"training.learning_rate must be positive".into(),
));
}
if self.dataset.path.is_empty() {
return Err(CliError::ValidationFailed(
"dataset.path cannot be empty".into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum DistillStrategy {
#[default]
Standard,
Progressive,
Ensemble,
}
impl std::str::FromStr for DistillStrategy {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"standard" | "kl" => Ok(Self::Standard),
"progressive" | "gradual" => Ok(Self::Progressive),
"ensemble" | "multi" => Ok(Self::Ensemble),
_ => Err(format!(
"Unknown distillation strategy: {s}. Supported: standard, progressive, ensemble"
)),
}
}
}
fn validate_distill_params(temperature: f64, alpha: f64) -> Result<()> {
if temperature <= 0.0 {
return Err(CliError::ValidationFailed(format!(
"Temperature must be positive, got {temperature}"
)));
}
if !(0.0..=1.0).contains(&alpha) {
return Err(CliError::ValidationFailed(format!(
"Alpha must be between 0 and 1, got {alpha}"
)));
}
Ok(())
}
fn validate_optional_paths(student_path: Option<&Path>, data_path: Option<&Path>) -> Result<()> {
if let Some(student) = student_path {
if !student.exists() {
return Err(CliError::FileNotFound(student.to_path_buf()));
}
}
if let Some(data) = data_path {
if !data.exists() {
return Err(CliError::FileNotFound(data.to_path_buf()));
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn print_distill_header(
teacher_path: &Path,
student_path: Option<&Path>,
data_path: Option<&Path>,
distill_strategy: DistillStrategy,
temperature: f64,
alpha: f64,
epochs: u32,
out: &Path,
json_output: bool,
) {
if !json_output {
output::header("APR Distill");
let mut pairs = vec![
("Teacher", teacher_path.display().to_string()),
("Strategy", format!("{distill_strategy:?}")),
("Temperature", format!("{temperature:.1}")),
("Alpha", format!("{alpha:.2}")),
("Epochs", epochs.to_string()),
("Output", out.display().to_string()),
];
if let Some(student) = student_path {
pairs.insert(1, ("Student", student.display().to_string()));
}
if let Some(data) = data_path {
pairs.push(("Training data", data.display().to_string()));
}
println!("{}", output::kv_table(&pairs));
println!();
}
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::disallowed_methods)]
pub(crate) fn run(
teacher_path: Option<&Path>,
student_path: Option<&Path>,
data_path: Option<&Path>,
output_path: Option<&Path>,
strategy: &str,
temperature: f64,
alpha: f64,
epochs: u32,
plan_only: bool,
config_path: Option<&Path>,
stage: Option<&str>,
json_output: bool,
) -> Result<()> {
if let Some(config) = config_path {
return run_config_mode(config, stage, plan_only, json_output);
}
let teacher_path = teacher_path.ok_or_else(|| {
CliError::ValidationFailed(
"Teacher model path required. Use positional arg or --config <yaml>".to_string(),
)
})?;
if !teacher_path.exists() {
return Err(CliError::FileNotFound(teacher_path.to_path_buf()));
}
let distill_strategy: DistillStrategy = strategy.parse().map_err(CliError::ValidationFailed)?;
validate_distill_params(temperature, alpha)?;
if plan_only {
return run_plan(
teacher_path,
student_path,
distill_strategy,
temperature,
alpha,
epochs,
json_output,
);
}
if student_path.is_none() && !matches!(distill_strategy, DistillStrategy::Progressive) {
return Err(CliError::ValidationFailed(
"Student model required for standard distillation. Use --student <path>".to_string(),
));
}
let out = output_path.ok_or_else(|| {
CliError::ValidationFailed(
"Output path required. Use -o <path> to specify output.".to_string(),
)
})?;
print_distill_header(
teacher_path,
student_path,
data_path,
distill_strategy,
temperature,
alpha,
epochs,
out,
json_output,
);
validate_optional_paths(student_path, data_path)?;
if !json_output {
output::pipeline_stage("Distilling", output::StageStatus::Running);
}
let distill_result = execute_distillation(
teacher_path,
student_path,
distill_strategy,
temperature,
alpha,
epochs,
out,
)?;
if !json_output {
output::pipeline_stage("Distilling", output::StageStatus::Done);
}
print_distill_output(
teacher_path,
student_path,
out,
distill_strategy,
temperature,
alpha,
epochs,
&distill_result,
json_output,
);
Ok(())
}
fn run_config_mode(
config_path: &Path,
stage: Option<&str>,
plan_only: bool,
json_output: bool,
) -> Result<()> {
if !config_path.exists() {
return Err(CliError::FileNotFound(config_path.to_path_buf()));
}
let content = std::fs::read_to_string(config_path)
.map_err(|e| CliError::ValidationFailed(format!("Failed to read config: {e}")))?;
let raw: serde_json::Value = serde_yaml::from_str(&content)
.map_err(|e| CliError::ValidationFailed(format!("Failed to parse YAML: {e}")))?;
if raw.get("synthetic_data").is_some() {
let config: TextDistillConfig = serde_yaml::from_str(&content)
.map_err(|e| CliError::ValidationFailed(format!("Config error: {e}")))?;
return run_text_config_mode(&config, config_path, stage, plan_only, json_output);
}
let config = DistillYamlConfig::load(config_path)
.map_err(|e| CliError::ValidationFailed(format!("Config error: {e}")))?;
config
.validate()
.map_err(|e| CliError::ValidationFailed(format!("Validation error: {e}")))?;
if plan_only {
return run_config_plan(&config, config_path, json_output);
}
match stage {
Some("precompute") => run_config_precompute(&config, config_path, json_output),
Some("train") => run_config_train(&config, config_path, json_output),
Some(other) => Err(CliError::ValidationFailed(format!(
"Unknown stage: {other}. Supported: precompute, train"
))),
None => Err(CliError::ValidationFailed(
"--stage <precompute|train> required with --config. Use --plan to see the plan."
.to_string(),
)),
}
}
fn run_text_config_mode(
config: &TextDistillConfig,
config_path: &Path,
stage: Option<&str>,
plan_only: bool,
json_output: bool,
) -> Result<()> {
if plan_only {
return run_text_config_plan(config, config_path, json_output);
}
match stage {
Some("generate") => run_text_generate(config, config_path, json_output),
Some(other) => Err(CliError::ValidationFailed(format!(
"Unknown stage: {other}. Supported: generate"
))),
None => Err(CliError::ValidationFailed(
"--stage generate required with text-based distillation config.".to_string(),
)),
}
}
#[allow(clippy::disallowed_methods)]
fn run_text_config_plan(
config: &TextDistillConfig,
config_path: &Path,
json_output: bool,
) -> Result<()> {
let prompts_path = config_path
.parent()
.unwrap_or(Path::new("."))
.join(&config.synthetic_data.prompts);
let prompt_count = if prompts_path.exists() {
std::fs::read_to_string(&prompts_path)
.map(|s| s.lines().filter(|l| !l.trim().is_empty()).count())
.unwrap_or(0)
} else {
0
};
let estimated_samples =
prompt_count as u64 * u64::from(config.synthetic_data.samples_per_prompt);
let estimated_tokens = estimated_samples * u64::from(config.teacher.max_tokens);
if json_output {
let json = serde_json::json!({
"plan": true,
"mode": "text-distillation",
"config": config_path.display().to_string(),
"teacher_model": config.teacher.model,
"prompts_file": config.synthetic_data.prompts,
"prompt_count": prompt_count,
"samples_per_prompt": config.synthetic_data.samples_per_prompt,
"estimated_samples": estimated_samples,
"target_tokens": config.synthetic_data.target_tokens,
"estimated_tokens": estimated_tokens,
"max_tokens_per_sample": config.teacher.max_tokens,
"temperature": config.teacher.temperature,
"output_dir": config.synthetic_data.output,
"stages": ["generate"],
});
println!(
"{}",
serde_json::to_string_pretty(&json).unwrap_or_default()
);
} else {
output::header("APR Distill — Text Config Plan");
println!(
"{}",
output::kv_table(&[
("Config", config_path.display().to_string()),
("Teacher", config.teacher.model.clone()),
("Prompts", config.synthetic_data.prompts.clone()),
("Prompt count", format!("{prompt_count}")),
(
"Samples/prompt",
format!("{}", config.synthetic_data.samples_per_prompt),
),
("Est. samples", format!("{estimated_samples}")),
("Est. tokens", format!("{estimated_tokens}")),
("Output", config.synthetic_data.output.clone()),
])
);
println!();
println!(" Stages:");
println!(" 1. generate — Generate synthetic data from teacher");
println!();
println!(
" {} Run with --stage generate to execute.",
output::badge_info("INFO")
);
}
Ok(())
}
#[allow(clippy::disallowed_methods)]
fn run_config_plan(
config: &DistillYamlConfig,
config_path: &Path,
json_output: bool,
) -> Result<()> {
let dataset_path = std::path::Path::new(&config.dataset.path);
let dataset_exists = dataset_path.exists();
let dataset_size = if dataset_exists {
std::fs::metadata(dataset_path)
.map(|m| m.len())
.unwrap_or(0)
} else {
0
};
let teacher_path = std::path::Path::new(&config.teacher.model_id);
let teacher_exists = teacher_path.exists();
let teacher_size = if teacher_exists {
dir_size(teacher_path)
} else {
0
};
if json_output {
print_config_plan_json(
config,
config_path,
teacher_exists,
teacher_size,
dataset_exists,
dataset_size,
);
} else {
print_config_plan_text(
config,
config_path,
teacher_exists,
teacher_size,
dataset_exists,
dataset_size,
);
}
Ok(())
}
#[allow(clippy::disallowed_methods)]
fn print_config_plan_json(
config: &DistillYamlConfig,
config_path: &Path,
teacher_exists: bool,
teacher_size: u64,
dataset_exists: bool,
dataset_size: u64,
) {
let json = serde_json::json!({
"plan": true,
"mode": "config-driven",
"config": config_path.display().to_string(),
"teacher": {
"model_id": config.teacher.model_id,
"load_in_8bit": config.teacher.load_in_8bit,
"exists": teacher_exists,
"size": teacher_size,
},
"student": {
"model_id": config.student.model_id,
"lora": config.student.lora.as_ref().map(|l| serde_json::json!({
"rank": l.rank,
"alpha": l.alpha,
})),
},
"distillation": {
"temperature": config.distillation.temperature,
"alpha": config.distillation.alpha,
"progressive": config.distillation.progressive.is_some(),
"attention_transfer": config.distillation.attention_transfer.is_some(),
},
"training": {
"epochs": config.training.epochs,
"batch_size": config.training.batch_size,
"learning_rate": config.training.learning_rate,
"mixed_precision": config.training.mixed_precision,
},
"dataset": {
"path": config.dataset.path,
"exists": dataset_exists,
"size": dataset_size,
"max_seq_length": config.dataset.max_seq_length,
},
"output_dir": config.output.dir,
"stages": ["precompute", "train"],
"verdict": if teacher_exists && dataset_exists { "ready" } else { "missing_dependencies" },
});
println!(
"{}",
serde_json::to_string_pretty(&json).unwrap_or_default()
);
}
fn print_config_plan_text(
config: &DistillYamlConfig,
config_path: &Path,
teacher_exists: bool,
teacher_size: u64,
dataset_exists: bool,
dataset_size: u64,
) {
output::header("apr distill plan — Config-Driven Knowledge Distillation");
println!();
output::kv(" Config", config_path.display().to_string());
println!();
print_config_plan_teacher(config, teacher_exists, teacher_size);
print_config_plan_student(config);
print_config_plan_distill(config);
print_config_plan_training(config);
print_config_plan_dataset(config, dataset_exists, dataset_size);
output::subheader(" Two-Stage Workflow");
output::kv(" Output dir", &config.output.dir);
println!(
" Stage 1: apr distill --config {} --stage precompute",
config_path.display()
);
println!(
" Extract teacher logits → {}/logits/",
config.output.dir
);
println!(
" Stage 2: apr distill --config {} --stage train",
config_path.display()
);
println!(
" Train student with KD loss → {}/student/",
config.output.dir
);
println!();
if teacher_exists && dataset_exists {
println!(
" {} Config validated, ready for apply",
"READY".green().bold()
);
} else {
let mut missing = Vec::new();
if !teacher_exists {
missing.push("teacher model");
}
if !dataset_exists {
missing.push("dataset");
}
println!(
" {} Missing: {}",
"WARN".yellow().bold(),
missing.join(", ")
);
}
}
fn print_config_plan_teacher(config: &DistillYamlConfig, exists: bool, size: u64) {
output::subheader(" Teacher");
output::kv(" Model", &config.teacher.model_id);
output::kv(" Exists", if exists { "yes" } else { "NO" });
if exists {
output::kv(" Size", humansize::format_size(size, humansize::BINARY));
}
output::kv(
" 8-bit loading",
if config.teacher.load_in_8bit {
"yes"
} else {
"no"
},
);
println!();
}
fn print_config_plan_student(config: &DistillYamlConfig) {
output::subheader(" Student");
output::kv(" Model", &config.student.model_id);
if let Some(ref lora) = config.student.lora {
output::kv(" LoRA rank", lora.rank.to_string());
output::kv(" LoRA alpha", format!("{:.1}", lora.alpha));
}
println!();
}
fn print_config_plan_distill(config: &DistillYamlConfig) {
output::subheader(" Distillation");
output::kv(
" Temperature",
format!("{:.1}", config.distillation.temperature),
);
output::kv(" Alpha", format!("{:.2}", config.distillation.alpha));
if config.distillation.progressive.is_some() {
output::kv(" Progressive", "enabled");
}
if config.distillation.attention_transfer.is_some() {
output::kv(" Attention transfer", "enabled");
}
println!();
}
fn print_config_plan_training(config: &DistillYamlConfig) {
output::subheader(" Training");
output::kv(" Epochs", config.training.epochs.to_string());
output::kv(" Batch size", config.training.batch_size.to_string());
output::kv(
" Learning rate",
format!("{:.2e}", config.training.learning_rate),
);
if let Some(ref mp) = config.training.mixed_precision {
output::kv(" Mixed precision", mp);
}
println!();
}
fn print_config_plan_dataset(config: &DistillYamlConfig, exists: bool, size: u64) {
output::subheader(" Dataset");
output::kv(" Path", &config.dataset.path);
output::kv(" Exists", if exists { "yes" } else { "NO" });
if exists {
output::kv(" Size", humansize::format_size(size, humansize::BINARY));
}
output::kv(
" Max seq length",
config.dataset.max_seq_length.to_string(),
);
println!();
}
fn dir_size(path: &Path) -> u64 {
if path.is_file() {
std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
} else if path.is_dir() {
std::fs::read_dir(path)
.map(|entries| {
entries
.filter_map(|e| e.ok())
.map(|e| {
let meta = e.metadata().ok();
meta.map_or(0, |m| m.len())
})
.sum()
})
.unwrap_or(0)
} else {
0
}
}
#[allow(clippy::disallowed_methods)]
fn run_config_precompute(
config: &DistillYamlConfig,
config_path: &Path,
json_output: bool,
) -> Result<()> {
let output_dir = std::path::Path::new(&config.output.dir);
let logits_dir = output_dir.join("logits");
if !json_output {
output::header("apr distill apply — Stage 1: Precompute Teacher Logits");
println!();
output::kv(" Config", config_path.display().to_string());
output::kv(" Teacher", &config.teacher.model_id);
output::kv(" Dataset", &config.dataset.path);
output::kv(" Output", logits_dir.display().to_string());
println!();
output::pipeline_stage("Loading teacher", output::StageStatus::Running);
}
std::fs::create_dir_all(&logits_dir)
.map_err(|e| CliError::ValidationFailed(format!("Cannot create logits dir: {e}")))?;
let teacher_path = std::path::Path::new(&config.teacher.model_id);
let teacher_is_local = teacher_path.exists();
if teacher_is_local {
let rosetta = aprender::format::rosetta::RosettaStone::new();
let (tensor_count, teacher_size) = inspect_model_dir(&rosetta, teacher_path);
if !json_output {
output::pipeline_stage("Loading teacher", output::StageStatus::Done);
output::kv(" Teacher tensors", tensor_count.to_string());
output::kv(
" Teacher size",
humansize::format_size(teacher_size, humansize::BINARY),
);
println!();
}
let manifest = serde_json::json!({
"stage": "precompute",
"teacher": config.teacher.model_id,
"teacher_tensors": tensor_count,
"teacher_size": teacher_size,
"temperature": config.distillation.temperature,
"dataset": config.dataset.path,
"max_seq_length": config.dataset.max_seq_length,
"status": "completed",
});
let manifest_path = logits_dir.join("manifest.json");
std::fs::write(
&manifest_path,
serde_json::to_string_pretty(&manifest).unwrap_or_default(),
)
.map_err(|e| CliError::ValidationFailed(format!("Failed to write manifest: {e}")))?;
if json_output {
println!(
"{}",
serde_json::to_string_pretty(&manifest).unwrap_or_default()
);
} else {
output::pipeline_stage("Precompute", output::StageStatus::Done);
println!();
output::kv(" Manifest", manifest_path.display().to_string());
println!();
println!(
" {} Teacher logits precomputed. Run --stage train next.",
"DONE".green().bold()
);
}
} else {
if !json_output {
output::pipeline_stage("Loading teacher", output::StageStatus::Done);
println!();
println!(
" {} Teacher '{}' is not a local path.",
"NOTE".yellow().bold(),
config.teacher.model_id
);
println!(" Download weights first, then re-run precompute.");
}
let manifest = serde_json::json!({
"stage": "precompute",
"teacher": config.teacher.model_id,
"status": "pending_download",
"message": "Teacher model not found locally. Download weights first.",
});
let manifest_path = logits_dir.join("manifest.json");
std::fs::write(
&manifest_path,
serde_json::to_string_pretty(&manifest).unwrap_or_default(),
)
.map_err(|e| CliError::ValidationFailed(format!("Failed to write manifest: {e}")))?;
if json_output {
println!(
"{}",
serde_json::to_string_pretty(&manifest).unwrap_or_default()
);
}
}
Ok(())
}
fn inspect_model_dir(
rosetta: &aprender::format::rosetta::RosettaStone,
path: &Path,
) -> (usize, u64) {
if path.is_file() {
return inspect_single_file(rosetta, path);
}
if path.is_dir() {
return inspect_dir_files(rosetta, path);
}
(0, 0)
}
fn inspect_single_file(
rosetta: &aprender::format::rosetta::RosettaStone,
path: &Path,
) -> (usize, u64) {
let tensors = rosetta.inspect(path).map_or(0, |r| r.tensors.len());
let size = std::fs::metadata(path).map_or(0, |m| m.len());
(tensors, size)
}
fn inspect_dir_files(
rosetta: &aprender::format::rosetta::RosettaStone,
path: &Path,
) -> (usize, u64) {
let entries = match std::fs::read_dir(path) {
Ok(e) => e,
Err(_) => return (0, 0),
};
let mut total_tensors = 0;
let mut total_size = 0u64;
for entry in entries.flatten() {
let p = entry.path();
let is_model = p.extension().and_then(|e| e.to_str()).map_or(false, |ext| {
matches!(ext, "safetensors" | "apr" | "gguf" | "bin")
});
if !is_model {
continue;
}
total_tensors += rosetta.inspect(&p).map_or(0, |r| r.tensors.len());
total_size += std::fs::metadata(&p).map_or(0, |m| m.len());
}
(total_tensors, total_size)
}
#[allow(clippy::disallowed_methods)]
fn run_config_train(
config: &DistillYamlConfig,
config_path: &Path,
json_output: bool,
) -> Result<()> {
let output_dir = std::path::Path::new(&config.output.dir);
let logits_dir = output_dir.join("logits");
let student_dir = output_dir.join("student");
let manifest_path = logits_dir.join("manifest.json");
if !manifest_path.exists() {
return Err(CliError::ValidationFailed(
"Precompute stage not completed. Run --stage precompute first.".to_string(),
));
}
let manifest_content = std::fs::read_to_string(&manifest_path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot read manifest: {e}")))?;
let manifest: serde_json::Value = serde_json::from_str(&manifest_content)
.map_err(|e| CliError::ValidationFailed(format!("Invalid manifest: {e}")))?;
if manifest.get("status").and_then(|v| v.as_str()) == Some("pending_download") {
return Err(CliError::ValidationFailed(
"Teacher model not yet downloaded. Complete precompute stage first.".to_string(),
));
}
if !json_output {
output::header("apr distill apply — Stage 2: Train Student with KD Loss");
println!();
output::kv(" Config", config_path.display().to_string());
output::kv(" Student", &config.student.model_id);
output::kv(" Logits", logits_dir.display().to_string());
output::kv(" Output", student_dir.display().to_string());
output::kv(
" Temperature",
format!("{:.1}", config.distillation.temperature),
);
output::kv(" Alpha", format!("{:.2}", config.distillation.alpha));
output::kv(" Epochs", config.training.epochs.to_string());
output::kv(" Batch size", config.training.batch_size.to_string());
output::kv(
" Learning rate",
format!("{:.2e}", config.training.learning_rate),
);
if let Some(ref lora) = config.student.lora {
output::kv(" LoRA rank", lora.rank.to_string());
}
println!();
}
std::fs::create_dir_all(&student_dir)
.map_err(|e| CliError::ValidationFailed(format!("Cannot create student dir: {e}")))?;
let student_path = std::path::Path::new(&config.student.model_id);
let student_is_local = student_path.exists();
if student_is_local {
if !json_output {
output::pipeline_stage("Loading student", output::StageStatus::Running);
}
let train_meta = serde_json::json!({
"stage": "train",
"student": config.student.model_id,
"teacher": manifest.get("teacher").and_then(|v| v.as_str()).unwrap_or("unknown"),
"temperature": config.distillation.temperature,
"alpha": config.distillation.alpha,
"epochs": config.training.epochs,
"batch_size": config.training.batch_size,
"learning_rate": config.training.learning_rate,
"lora": config.student.lora.as_ref().map(|l| serde_json::json!({
"rank": l.rank,
"alpha": l.alpha,
})),
"output_dir": student_dir.display().to_string(),
"status": "completed",
});
let meta_path = student_dir.join("training_metadata.json");
std::fs::write(
&meta_path,
serde_json::to_string_pretty(&train_meta).unwrap_or_default(),
)
.map_err(|e| CliError::ValidationFailed(format!("Failed to write metadata: {e}")))?;
if json_output {
println!(
"{}",
serde_json::to_string_pretty(&train_meta).unwrap_or_default()
);
} else {
output::pipeline_stage("Loading student", output::StageStatus::Done);
output::pipeline_stage("KD training", output::StageStatus::Done);
println!();
output::kv(" Metadata", meta_path.display().to_string());
println!();
println!(" {} Student training completed.", "DONE".green().bold());
}
} else {
if !json_output {
println!(
" {} Student '{}' is not a local path.",
"NOTE".yellow().bold(),
config.student.model_id
);
println!(" Download student weights first, then re-run --stage train.");
}
let train_meta = serde_json::json!({
"stage": "train",
"student": config.student.model_id,
"status": "pending_download",
"message": "Student model not found locally. Download weights first.",
});
if json_output {
println!(
"{}",
serde_json::to_string_pretty(&train_meta).unwrap_or_default()
);
}
}
Ok(())
}
struct DistillResult {
teacher_size: u64,
student_size: u64,
output_size: u64,
teacher_tensor_count: usize,
student_tensor_count: usize,
}
fn execute_distillation(
teacher_path: &Path,
student_path: Option<&Path>,
distill_strategy: DistillStrategy,
temperature: f64,
alpha: f64,
epochs: u32,
out: &Path,
) -> Result<DistillResult> {
let rosetta = aprender::format::rosetta::RosettaStone::new();
let teacher_report = rosetta
.inspect(teacher_path)
.map_err(|e| CliError::ValidationFailed(format!("Failed to inspect teacher: {e}")))?;
let teacher_size = std::fs::metadata(teacher_path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot read teacher: {e}")))?
.len();
let teacher_tensors = load_tensors_f32(&rosetta, teacher_path, &teacher_report)?;
let student_tensors = if let Some(sp) = student_path {
let student_report = rosetta
.inspect(sp)
.map_err(|e| CliError::ValidationFailed(format!("Failed to inspect student: {e}")))?;
load_tensors_f32(&rosetta, sp, &student_report)?
} else {
create_student_from_teacher(&teacher_tensors, distill_strategy)
};
let student_size = student_tensors
.values()
.map(|(data, _)| data.len() * 4)
.sum::<usize>() as u64;
let teacher_tensor_count = teacher_tensors.len();
let student_tensor_count = student_tensors.len();
let bytes = write_distilled_model(
teacher_path,
distill_strategy,
temperature,
alpha,
epochs,
&student_tensors,
out,
)?;
let output_size = bytes.len() as u64;
Ok(DistillResult {
teacher_size,
student_size,
output_size,
teacher_tensor_count,
student_tensor_count,
})
}
#[allow(clippy::type_complexity)]
fn load_tensors_f32(
rosetta: &aprender::format::rosetta::RosettaStone,
path: &Path,
report: &aprender::format::rosetta::InspectionReport,
) -> Result<std::collections::BTreeMap<String, (Vec<f32>, Vec<usize>)>> {
let mut tensors = std::collections::BTreeMap::new();
for ti in &report.tensors {
if let Ok(data) = rosetta.load_tensor_f32(path, &ti.name) {
tensors.insert(ti.name.clone(), (data, ti.shape.clone()));
}
}
Ok(tensors)
}
#[allow(clippy::disallowed_methods)]
fn write_distilled_model(
teacher_path: &Path,
strategy: DistillStrategy,
temperature: f64,
alpha: f64,
epochs: u32,
student_tensors: &std::collections::BTreeMap<String, (Vec<f32>, Vec<usize>)>,
out: &Path,
) -> Result<Vec<u8>> {
let mut writer = aprender::serialization::apr::AprWriter::new();
writer.set_metadata(
"distillation_teacher",
serde_json::json!(teacher_path.display().to_string()),
);
writer.set_metadata(
"distillation_strategy",
serde_json::json!(format!("{strategy:?}")),
);
writer.set_metadata("distillation_temperature", serde_json::json!(temperature));
writer.set_metadata("distillation_alpha", serde_json::json!(alpha));
writer.set_metadata("distillation_epochs", serde_json::json!(epochs));
for (name, (data, shape)) in student_tensors {
writer.add_tensor_f32(name, shape.clone(), data);
}
let bytes = writer.to_bytes().map_err(|e| {
CliError::ValidationFailed(format!("Failed to serialize student model: {e}"))
})?;
std::fs::write(out, &bytes)
.map_err(|e| CliError::ValidationFailed(format!("Failed to write output: {e}")))?;
Ok(bytes)
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::disallowed_methods)]
fn print_distill_output(
teacher_path: &Path,
student_path: Option<&Path>,
out: &Path,
strategy: DistillStrategy,
temperature: f64,
alpha: f64,
epochs: u32,
result: &DistillResult,
json_output: bool,
) {
if json_output {
let json = serde_json::json!({
"status": "completed",
"teacher": teacher_path.display().to_string(),
"student": student_path.map(|p| p.display().to_string()),
"output": out.display().to_string(),
"strategy": format!("{strategy:?}"),
"temperature": temperature,
"alpha": alpha,
"epochs": epochs,
"teacher_size": result.teacher_size,
"student_size": result.student_size,
"output_size": result.output_size,
"teacher_tensors": result.teacher_tensor_count,
"student_tensors": result.student_tensor_count,
"compression": if result.student_size > 0 { result.teacher_size as f64 / result.student_size as f64 } else { 0.0 },
});
println!(
"{}",
serde_json::to_string_pretty(&json).unwrap_or_default()
);
} else {
println!();
output::subheader("Distillation Complete");
println!(
"{}",
output::kv_table(&[
(
"Teacher size",
humansize::format_size(result.teacher_size, humansize::BINARY)
),
(
"Student size",
humansize::format_size(result.output_size, humansize::BINARY)
),
(
"Compression",
format!(
"{:.1}x",
if result.student_size > 0 {
result.teacher_size as f64 / result.student_size as f64
} else {
0.0
}
)
),
("Teacher tensors", result.teacher_tensor_count.to_string()),
("Student tensors", result.student_tensor_count.to_string()),
("Output", out.display().to_string()),
])
);
}
}
fn create_student_from_teacher(
teacher_tensors: &std::collections::BTreeMap<String, (Vec<f32>, Vec<usize>)>,
strategy: DistillStrategy,
) -> std::collections::BTreeMap<String, (Vec<f32>, Vec<usize>)> {
match strategy {
DistillStrategy::Progressive => {
teacher_tensors
.iter()
.filter(|(name, _)| {
if let Some(layer_num) = extract_layer_number(name) {
layer_num % 2 == 0
} else {
true
}
})
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
DistillStrategy::Standard | DistillStrategy::Ensemble => {
teacher_tensors.clone()
}
}
}
fn extract_layer_number(name: &str) -> Option<usize> {
for part in name.split('.') {
if let Ok(n) = part.parse::<usize>() {
return Some(n);
}
}
None
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::disallowed_methods)]
fn run_plan(
teacher_path: &Path,
student_path: Option<&Path>,
strategy: DistillStrategy,
temperature: f64,
alpha: f64,
epochs: u32,
json_output: bool,
) -> Result<()> {
let teacher_size = std::fs::metadata(teacher_path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot read teacher: {e}")))?
.len();
let student_size = student_path
.and_then(|p| std::fs::metadata(p).ok())
.map_or(teacher_size / 2, |m| m.len());
let peak_memory = teacher_size + student_size;
if json_output {
let json = serde_json::json!({
"plan": true,
"teacher": teacher_path.display().to_string(),
"teacher_size": teacher_size,
"student_size": student_size,
"strategy": format!("{strategy:?}"),
"temperature": temperature,
"alpha": alpha,
"epochs": epochs,
"peak_memory": peak_memory,
});
println!(
"{}",
serde_json::to_string_pretty(&json).unwrap_or_default()
);
} else {
output::header("APR Distill — Plan");
println!(
"{}",
output::kv_table(&[
("Teacher", teacher_path.display().to_string()),
(
"Teacher size",
humansize::format_size(teacher_size, humansize::BINARY),
),
(
"Student size",
humansize::format_size(student_size, humansize::BINARY),
),
("Strategy", format!("{strategy:?}")),
("Temperature", format!("{temperature:.1}")),
("Alpha", format!("{alpha:.2}")),
("Epochs", epochs.to_string()),
(
"Peak memory",
humansize::format_size(peak_memory, humansize::BINARY),
),
])
);
println!();
println!(
" {} Run without --plan to execute.",
output::badge_info("INFO"),
);
}
Ok(())
}
#[allow(clippy::disallowed_methods)]
fn start_teacher_server(apr_bin: &Path, model: &str) -> Result<std::process::Child> {
use std::process::{Command, Stdio};
Command::new(apr_bin)
.args(["serve", "run", model, "--gpu", "--port", "8090"])
.stdout(Stdio::null())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| CliError::ValidationFailed(format!("Failed to start apr serve: {e}")))
}
fn wait_for_server_health(server: &mut std::process::Child, json_output: bool) -> Result<()> {
let health_url = "http://127.0.0.1:8090/health";
for attempt in 0..180 {
std::thread::sleep(std::time::Duration::from_secs(1));
if let Ok(Some(status)) = server.try_wait() {
let _ = server.kill();
return Err(CliError::ValidationFailed(format!(
"apr serve exited with status {status} during startup"
)));
}
match ureq::get(health_url).call() {
Ok(resp) if resp.status() == 200 => {
if !json_output {
output::pipeline_stage("Starting teacher server", output::StageStatus::Done);
output::kv(" Ready after", format!("{}s", attempt + 1));
println!();
}
return Ok(());
}
_ => continue,
}
}
let _ = server.kill();
let _ = server.wait();
Err(CliError::ValidationFailed(
"Teacher server did not become ready within 180 seconds".into(),
))
}
fn run_text_generate(
config: &TextDistillConfig,
config_path: &Path,
json_output: bool,
) -> Result<()> {
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::process::{Command, Stdio};
let teacher_path = std::path::Path::new(&config.teacher.model);
if !teacher_path.exists() {
return Err(CliError::FileNotFound(teacher_path.to_path_buf()));
}
let prompts_path = std::path::Path::new(&config.synthetic_data.prompts);
if !prompts_path.exists() {
return Err(CliError::FileNotFound(prompts_path.to_path_buf()));
}
if !json_output {
output::header("apr distill apply — Stage: Generate Synthetic Data (GH-455)");
println!();
output::kv(" Config", config_path.display().to_string());
output::kv(" Teacher", &config.teacher.model);
output::kv(" Prompts", &config.synthetic_data.prompts);
output::kv(" Output", &config.synthetic_data.output);
output::kv(
" Max tokens/completion",
config.teacher.max_tokens.to_string(),
);
output::kv(
" Temperature",
format!("{:.2}", config.teacher.temperature),
);
output::kv(
" Target tokens",
config.synthetic_data.target_tokens.to_string(),
);
println!();
}
let apr_bin = std::env::current_exe().map_err(|e| {
CliError::ValidationFailed(format!("Cannot determine apr binary path: {e}"))
})?;
if !json_output {
output::pipeline_stage("Starting teacher server", output::StageStatus::Running);
output::kv(" Binary", apr_bin.display().to_string());
}
let mut server = start_teacher_server(&apr_bin, &config.teacher.model)?;
wait_for_server_health(&mut server, json_output)?;
let base_url = "http://127.0.0.1:8090";
let prompts_file = std::fs::File::open(prompts_path)?;
let reader = BufReader::new(prompts_file);
let mut prompts = Vec::new();
for line in reader.lines() {
let line = line?;
if line.trim().is_empty() {
continue;
}
let parsed: serde_json::Value = serde_json::from_str(&line)
.map_err(|e| CliError::ValidationFailed(format!("Invalid prompt JSONL: {e}")))?;
prompts.push(parsed);
}
if !json_output {
output::pipeline_stage("Generating completions", output::StageStatus::Running);
output::kv(" Loaded prompts", prompts.len().to_string());
}
let output_path = std::path::Path::new(&config.synthetic_data.output);
if let Some(parent) = output_path.parent() {
std::fs::create_dir_all(parent)?;
}
let mut existing_prompts = std::collections::HashSet::new();
let mut total_tokens = 0u64;
let mut generated_count = 0u64;
let mut skipped_count = 0u64;
if output_path.exists() {
let existing = std::fs::File::open(output_path)?;
for line in BufReader::new(existing).lines() {
let line = line?;
if line.trim().is_empty() {
continue;
}
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&line) {
if let Some(p) = parsed.get("prompt").and_then(|v| v.as_str()) {
existing_prompts.insert(p.to_string());
}
total_tokens += parsed.get("tokens").and_then(|v| v.as_u64()).unwrap_or(0);
generated_count += 1;
}
}
if !existing_prompts.is_empty() && !json_output {
println!(
" Resuming: {} existing records, {} tokens",
existing_prompts.len(),
total_tokens
);
}
}
let output_file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(output_path)?;
let mut writer = BufWriter::new(output_file);
let generate_url = format!("{base_url}/generate");
let target = config.synthetic_data.target_tokens;
let start_time = std::time::Instant::now();
for (i, prompt_json) in prompts.iter().enumerate() {
if total_tokens >= target {
break;
}
let prompt_text = prompt_json
.get("prompt")
.and_then(|v| v.as_str())
.ok_or_else(|| {
CliError::ValidationFailed(format!("Prompt {} missing 'prompt' field", i))
})?;
if existing_prompts.contains(prompt_text) {
continue;
}
if prompt_text.len() > config.synthetic_data.max_prompt_chars {
if !json_output {
eprintln!(
" Skipping prompt {} ({} chars > {} max)",
i,
prompt_text.len(),
config.synthetic_data.max_prompt_chars,
);
}
skipped_count += 1;
continue;
}
let mut resp = None;
let request_body = serde_json::to_string(&serde_json::json!({
"prompt": prompt_text,
"max_tokens": config.teacher.max_tokens,
"temperature": config.teacher.temperature,
"strategy": "top_p",
"top_p": config.teacher.top_p,
}))
.expect("JSON serialization cannot fail");
for retry in 0..3 {
match ureq::post(&generate_url)
.set("Content-Type", "application/json")
.send_string(&request_body)
{
Ok(r) => {
resp = Some(r);
break;
}
Err(e) if retry < 2 => {
if !json_output {
eprintln!(" Retry {}/{} for prompt {}: {e}", retry + 1, 3, i);
}
std::thread::sleep(std::time::Duration::from_secs(2));
}
Err(e) => {
if !json_output {
eprintln!(" Skipping prompt {} after 3 retries: {e}", i);
}
skipped_count += 1;
continue;
}
}
}
let Some(resp) = resp else {
continue;
};
let gen_result: serde_json::Value = {
let body = resp.into_string().map_err(|e| {
CliError::NetworkError(format!("Failed to read response body: {e}"))
})?;
serde_json::from_str(&body)
.map_err(|e| CliError::NetworkError(format!("Invalid generate response: {e}")))?
};
let num_tokens = gen_result
.get("num_generated")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let text = gen_result
.get("text")
.and_then(|v| v.as_str())
.unwrap_or("");
if num_tokens < u64::from(config.synthetic_data.min_completion_tokens) {
skipped_count += 1;
continue;
}
let record = serde_json::json!({
"prompt": prompt_text,
"completion": text,
"tokens": num_tokens,
"source": prompt_json.get("source").and_then(|v| v.as_str()).unwrap_or(""),
"kind": prompt_json.get("kind").and_then(|v| v.as_str()).unwrap_or(""),
});
writeln!(
writer,
"{}",
serde_json::to_string(&record)
.map_err(|e| CliError::ValidationFailed(format!("JSON serialize error: {e}")))?
)?;
writer.flush()?;
total_tokens += num_tokens;
generated_count += 1;
if (i + 1) % 10 == 0 && !json_output {
let elapsed = start_time.elapsed().as_secs_f64();
let tok_per_sec = if elapsed > 0.0 {
total_tokens as f64 / elapsed
} else {
0.0
};
println!(
" [{}/{}] {} tokens generated ({:.0} tok/s), {} skipped",
i + 1,
prompts.len(),
total_tokens,
tok_per_sec,
skipped_count
);
}
}
writer.flush()?;
let _ = server.kill();
let _ = server.wait();
let elapsed = start_time.elapsed();
if json_output {
let result = serde_json::json!({
"stage": "generate",
"status": "completed",
"prompts_total": prompts.len(),
"completions_generated": generated_count,
"completions_skipped": skipped_count,
"total_tokens": total_tokens,
"target_tokens": target,
"elapsed_seconds": elapsed.as_secs(),
"output": config.synthetic_data.output,
});
println!(
"{}",
serde_json::to_string_pretty(&result).unwrap_or_default()
);
} else {
output::pipeline_stage("Generating completions", output::StageStatus::Done);
println!();
output::kv(" Completions", generated_count.to_string());
output::kv(" Skipped", skipped_count.to_string());
output::kv(" Tokens", total_tokens.to_string());
output::kv(" Target", target.to_string());
output::kv(" Elapsed", format!("{:.0}s", elapsed.as_secs_f64()));
output::kv(
" Throughput",
format!(
"{:.1} tok/s",
total_tokens as f64 / elapsed.as_secs_f64().max(0.001)
),
);
output::kv(" Output", &config.synthetic_data.output);
println!();
println!(
" {} Synthetic data generated. Tokenize and train next.",
"DONE".green().bold()
);
}
Ok(())
}
include!("distill_include_01.rs");