use crate::corpus::baselines::corpus_baseline_entries;
use serde::Serialize;
#[derive(Debug, Clone, Serialize)]
pub struct TrainingConfig {
pub model: ModelConfig,
pub training: TrainingParams,
pub data: DataConfig,
pub evaluation: EvalConfig,
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelConfig {
pub architecture: String,
pub base_model: String,
pub num_classes: u32,
pub hidden_size: u32,
pub num_layers: u32,
pub pooling: String,
pub lora: Option<LoraConfig>,
}
#[derive(Debug, Clone, Serialize)]
pub struct LoraConfig {
pub rank: u32,
pub alpha: f64,
pub targets: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct TrainingParams {
pub epochs: u32,
pub batch_size: u32,
pub learning_rate: f64,
pub optimizer: String,
pub scheduler: String,
pub warmup_steps: u32,
pub weight_decay: f64,
pub max_seq_length: u32,
pub class_weights: Vec<f64>,
}
#[derive(Debug, Clone, Serialize)]
pub struct DataConfig {
pub total_entries: usize,
pub safe_count: usize,
pub unsafe_count: usize,
pub split_ratio: String,
pub split_method: String,
pub preamble_stripped: bool,
}
#[derive(Debug, Clone, Serialize)]
pub struct EvalConfig {
pub primary_metric: String,
pub accuracy_target: f64,
pub mcc_ci_lower_target: f64,
pub generalization_target: f64,
pub generalization_scripts: u32,
}
pub fn generate_training_config() -> TrainingConfig {
let owned = corpus_baseline_entries();
let total = owned.len();
let safe_count = owned.iter().filter(|(_, l)| *l == 0).count();
let unsafe_count = owned.iter().filter(|(_, l)| *l == 1).count();
let w_safe = compute_sqrt_inverse_weight(safe_count, total);
let w_unsafe = compute_sqrt_inverse_weight(unsafe_count, total);
TrainingConfig {
model: ModelConfig {
architecture: "encoder".to_string(),
base_model: "microsoft/codebert-base".to_string(),
num_classes: 2,
hidden_size: 768,
num_layers: 12,
pooling: "cls".to_string(),
lora: None,
},
training: TrainingParams {
epochs: 3,
batch_size: 32,
learning_rate: 2e-4,
optimizer: "AdamW".to_string(),
scheduler: "linear_warmup".to_string(),
warmup_steps: 100,
weight_decay: 0.01,
max_seq_length: 512,
class_weights: vec![w_safe, w_unsafe],
},
data: DataConfig {
total_entries: total,
safe_count,
unsafe_count,
split_ratio: "80/10/10".to_string(),
split_method: "FNV-1a hash deterministic".to_string(),
preamble_stripped: true,
},
evaluation: EvalConfig {
primary_metric: "MCC".to_string(),
accuracy_target: 0.935,
mcc_ci_lower_target: 0.2,
generalization_target: 0.50,
generalization_scripts: 50,
},
}
}
fn compute_sqrt_inverse_weight(class_count: usize, total: usize) -> f64 {
if class_count == 0 || total == 0 {
return 1.0;
}
let freq = class_count as f64 / total as f64;
(1.0 / freq).sqrt()
}
pub fn format_yaml(config: &TrainingConfig) -> String {
use std::fmt::Write as _;
let mut out = String::with_capacity(2048);
let _ = writeln!(
out,
"# SSC v11 Training Configuration (entrenar-compatible)"
);
let _ = writeln!(out, "# Generated by bashrs corpus training-config");
let _ = writeln!(out);
let _ = writeln!(out, "model:");
let _ = writeln!(out, " architecture: {}", config.model.architecture);
let _ = writeln!(out, " base_model: {}", config.model.base_model);
let _ = writeln!(out, " num_classes: {}", config.model.num_classes);
let _ = writeln!(out, " hidden_size: {}", config.model.hidden_size);
let _ = writeln!(out, " num_layers: {}", config.model.num_layers);
let _ = writeln!(out, " pooling: {}", config.model.pooling);
if let Some(ref lora) = config.model.lora {
let _ = writeln!(out, " lora:");
let _ = writeln!(out, " rank: {}", lora.rank);
let _ = writeln!(out, " alpha: {}", lora.alpha);
let _ = writeln!(out, " targets:");
for t in &lora.targets {
let _ = writeln!(out, " - {t}");
}
}
let _ = writeln!(out);
let _ = writeln!(out, "training:");
let _ = writeln!(out, " epochs: {}", config.training.epochs);
let _ = writeln!(out, " batch_size: {}", config.training.batch_size);
let _ = writeln!(out, " learning_rate: {}", config.training.learning_rate);
let _ = writeln!(out, " optimizer: {}", config.training.optimizer);
let _ = writeln!(out, " scheduler: {}", config.training.scheduler);
let _ = writeln!(out, " warmup_steps: {}", config.training.warmup_steps);
let _ = writeln!(out, " weight_decay: {}", config.training.weight_decay);
let _ = writeln!(out, " max_seq_length: {}", config.training.max_seq_length);
let _ = writeln!(out, " class_weights:");
for (i, w) in config.training.class_weights.iter().enumerate() {
let _ = writeln!(out, " - {w:.3} # class {i}");
}
let _ = writeln!(out);
let _ = writeln!(out, "data:");
let _ = writeln!(out, " total_entries: {}", config.data.total_entries);
let _ = writeln!(out, " safe_count: {}", config.data.safe_count);
let _ = writeln!(out, " unsafe_count: {}", config.data.unsafe_count);
let _ = writeln!(out, " split_ratio: \"{}\"", config.data.split_ratio);
let _ = writeln!(out, " split_method: \"{}\"", config.data.split_method);
let _ = writeln!(
out,
" preamble_stripped: {}",
config.data.preamble_stripped
);
let _ = writeln!(out);
let _ = writeln!(out, "evaluation:");
let _ = writeln!(
out,
" primary_metric: {}",
config.evaluation.primary_metric
);
let _ = writeln!(
out,
" accuracy_target: {}",
config.evaluation.accuracy_target
);
let _ = writeln!(
out,
" mcc_ci_lower_target: {}",
config.evaluation.mcc_ci_lower_target
);
let _ = writeln!(
out,
" generalization_target: {}",
config.evaluation.generalization_target
);
let _ = writeln!(
out,
" generalization_scripts: {}",
config.evaluation.generalization_scripts
);
out
}
pub fn format_json(config: &TrainingConfig) -> String {
serde_json::to_string_pretty(config).unwrap_or_else(|_| format!("{config:#?}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_training_config_structure() {
let config = generate_training_config();
assert_eq!(config.model.architecture, "encoder");
assert_eq!(config.model.num_classes, 2);
assert_eq!(config.model.hidden_size, 768);
assert_eq!(config.training.epochs, 3);
assert_eq!(config.evaluation.primary_metric, "MCC");
}
#[test]
#[ignore = "requires runtime corpus data (externalized from builtin)"]
fn test_training_config_has_class_weights() {
let config = generate_training_config();
assert_eq!(config.training.class_weights.len(), 2);
assert!(
config.training.class_weights[1] > config.training.class_weights[0],
"Unsafe weight should be higher than safe weight"
);
}
#[test]
#[ignore = "requires runtime corpus data (externalized from builtin)"]
fn test_training_config_corpus_data() {
let config = generate_training_config();
assert!(config.data.total_entries > 100, "Must have corpus data");
assert!(config.data.safe_count > 0);
assert!(config.data.unsafe_count > 0);
assert!(config.data.preamble_stripped);
}
#[test]
fn test_format_yaml_produces_yaml() {
let config = generate_training_config();
let yaml = format_yaml(&config);
assert!(yaml.contains("architecture: encoder"), "Must produce YAML");
assert!(yaml.contains("codebert"), "Must reference CodeBERT");
assert!(yaml.contains("class_weights:"), "Must have class weights");
}
#[test]
fn test_format_json_produces_json() {
let config = generate_training_config();
let json = format_json(&config);
assert!(json.contains("\"architecture\""), "Must produce JSON");
let parsed: Result<serde_json::Value, _> = serde_json::from_str(&json);
assert!(parsed.is_ok(), "Must produce valid JSON");
}
#[test]
fn test_compute_sqrt_inverse_weight_balanced() {
let w = compute_sqrt_inverse_weight(50, 100);
assert!((w - std::f64::consts::SQRT_2).abs() < 0.001);
}
#[test]
fn test_compute_sqrt_inverse_weight_guards() {
assert!((compute_sqrt_inverse_weight(0, 100) - 1.0).abs() < 1e-9);
assert!((compute_sqrt_inverse_weight(50, 0) - 1.0).abs() < 1e-9);
}
}