bashrs 6.66.0

Rust-to-Shell transpiler for deterministic bootstrap scripts
//! Training configuration export for entrenar (SSC v11 S9, CLF-001).
//!
//! Generates an entrenar-compatible training configuration from live
//! corpus statistics. Includes model architecture, LoRA parameters,
//! optimizer settings, and computed class weights.

use crate::corpus::baselines::corpus_baseline_entries;
use serde::Serialize;

/// Entrenar training configuration for SSC classifier.
#[derive(Debug, Clone, Serialize)]
pub struct TrainingConfig {
    pub model: ModelConfig,
    pub training: TrainingParams,
    pub data: DataConfig,
    pub evaluation: EvalConfig,
}

/// Model architecture configuration.
#[derive(Debug, Clone, Serialize)]
pub struct ModelConfig {
    pub architecture: String,
    pub base_model: String,
    pub num_classes: u32,
    pub hidden_size: u32,
    pub num_layers: u32,
    pub pooling: String,
    pub lora: Option<LoraConfig>,
}

/// LoRA adapter configuration.
#[derive(Debug, Clone, Serialize)]
pub struct LoraConfig {
    pub rank: u32,
    pub alpha: f64,
    pub targets: Vec<String>,
}

/// Training hyperparameters.
#[derive(Debug, Clone, Serialize)]
pub struct TrainingParams {
    pub epochs: u32,
    pub batch_size: u32,
    pub learning_rate: f64,
    pub optimizer: String,
    pub scheduler: String,
    pub warmup_steps: u32,
    pub weight_decay: f64,
    pub max_seq_length: u32,
    pub class_weights: Vec<f64>,
}

/// Data configuration.
#[derive(Debug, Clone, Serialize)]
pub struct DataConfig {
    pub total_entries: usize,
    pub safe_count: usize,
    pub unsafe_count: usize,
    pub split_ratio: String,
    pub split_method: String,
    pub preamble_stripped: bool,
}

/// Evaluation configuration.
#[derive(Debug, Clone, Serialize)]
pub struct EvalConfig {
    pub primary_metric: String,
    pub accuracy_target: f64,
    pub mcc_ci_lower_target: f64,
    pub generalization_target: f64,
    pub generalization_scripts: u32,
}

/// Generate a training configuration from live corpus data.
pub fn generate_training_config() -> TrainingConfig {
    let owned = corpus_baseline_entries();
    let total = owned.len();
    let safe_count = owned.iter().filter(|(_, l)| *l == 0).count();
    let unsafe_count = owned.iter().filter(|(_, l)| *l == 1).count();

    let w_safe = compute_sqrt_inverse_weight(safe_count, total);
    let w_unsafe = compute_sqrt_inverse_weight(unsafe_count, total);

    TrainingConfig {
        model: ModelConfig {
            architecture: "encoder".to_string(),
            base_model: "microsoft/codebert-base".to_string(),
            num_classes: 2,
            hidden_size: 768,
            num_layers: 12,
            pooling: "cls".to_string(),
            lora: None,
        },
        training: TrainingParams {
            epochs: 3,
            batch_size: 32,
            learning_rate: 2e-4,
            optimizer: "AdamW".to_string(),
            scheduler: "linear_warmup".to_string(),
            warmup_steps: 100,
            weight_decay: 0.01,
            max_seq_length: 512,
            class_weights: vec![w_safe, w_unsafe],
        },
        data: DataConfig {
            total_entries: total,
            safe_count,
            unsafe_count,
            split_ratio: "80/10/10".to_string(),
            split_method: "FNV-1a hash deterministic".to_string(),
            preamble_stripped: true,
        },
        evaluation: EvalConfig {
            primary_metric: "MCC".to_string(),
            accuracy_target: 0.935,
            mcc_ci_lower_target: 0.2,
            generalization_target: 0.50,
            generalization_scripts: 50,
        },
    }
}

/// Compute sqrt-inverse class weight.
fn compute_sqrt_inverse_weight(class_count: usize, total: usize) -> f64 {
    if class_count == 0 || total == 0 {
        return 1.0;
    }
    let freq = class_count as f64 / total as f64;
    (1.0 / freq).sqrt()
}

/// Format training config as YAML (hand-formatted, no serde_yaml dependency).
pub fn format_yaml(config: &TrainingConfig) -> String {
    use std::fmt::Write as _;
    let mut out = String::with_capacity(2048);

    let _ = writeln!(
        out,
        "# SSC v11 Training Configuration (entrenar-compatible)"
    );
    let _ = writeln!(out, "# Generated by bashrs corpus training-config");
    let _ = writeln!(out);

    // Model
    let _ = writeln!(out, "model:");
    let _ = writeln!(out, "  architecture: {}", config.model.architecture);
    let _ = writeln!(out, "  base_model: {}", config.model.base_model);
    let _ = writeln!(out, "  num_classes: {}", config.model.num_classes);
    let _ = writeln!(out, "  hidden_size: {}", config.model.hidden_size);
    let _ = writeln!(out, "  num_layers: {}", config.model.num_layers);
    let _ = writeln!(out, "  pooling: {}", config.model.pooling);
    if let Some(ref lora) = config.model.lora {
        let _ = writeln!(out, "  lora:");
        let _ = writeln!(out, "    rank: {}", lora.rank);
        let _ = writeln!(out, "    alpha: {}", lora.alpha);
        let _ = writeln!(out, "    targets:");
        for t in &lora.targets {
            let _ = writeln!(out, "    - {t}");
        }
    }
    let _ = writeln!(out);

    // Training
    let _ = writeln!(out, "training:");
    let _ = writeln!(out, "  epochs: {}", config.training.epochs);
    let _ = writeln!(out, "  batch_size: {}", config.training.batch_size);
    let _ = writeln!(out, "  learning_rate: {}", config.training.learning_rate);
    let _ = writeln!(out, "  optimizer: {}", config.training.optimizer);
    let _ = writeln!(out, "  scheduler: {}", config.training.scheduler);
    let _ = writeln!(out, "  warmup_steps: {}", config.training.warmup_steps);
    let _ = writeln!(out, "  weight_decay: {}", config.training.weight_decay);
    let _ = writeln!(out, "  max_seq_length: {}", config.training.max_seq_length);
    let _ = writeln!(out, "  class_weights:");
    for (i, w) in config.training.class_weights.iter().enumerate() {
        let _ = writeln!(out, "  - {w:.3}  # class {i}");
    }
    let _ = writeln!(out);

    // Data
    let _ = writeln!(out, "data:");
    let _ = writeln!(out, "  total_entries: {}", config.data.total_entries);
    let _ = writeln!(out, "  safe_count: {}", config.data.safe_count);
    let _ = writeln!(out, "  unsafe_count: {}", config.data.unsafe_count);
    let _ = writeln!(out, "  split_ratio: \"{}\"", config.data.split_ratio);
    let _ = writeln!(out, "  split_method: \"{}\"", config.data.split_method);
    let _ = writeln!(
        out,
        "  preamble_stripped: {}",
        config.data.preamble_stripped
    );
    let _ = writeln!(out);

    // Evaluation
    let _ = writeln!(out, "evaluation:");
    let _ = writeln!(
        out,
        "  primary_metric: {}",
        config.evaluation.primary_metric
    );
    let _ = writeln!(
        out,
        "  accuracy_target: {}",
        config.evaluation.accuracy_target
    );
    let _ = writeln!(
        out,
        "  mcc_ci_lower_target: {}",
        config.evaluation.mcc_ci_lower_target
    );
    let _ = writeln!(
        out,
        "  generalization_target: {}",
        config.evaluation.generalization_target
    );
    let _ = writeln!(
        out,
        "  generalization_scripts: {}",
        config.evaluation.generalization_scripts
    );

    out
}

/// Format training config as JSON.
pub fn format_json(config: &TrainingConfig) -> String {
    serde_json::to_string_pretty(config).unwrap_or_else(|_| format!("{config:#?}"))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_generate_training_config_structure() {
        let config = generate_training_config();
        assert_eq!(config.model.architecture, "encoder");
        assert_eq!(config.model.num_classes, 2);
        assert_eq!(config.model.hidden_size, 768);
        assert_eq!(config.training.epochs, 3);
        assert_eq!(config.evaluation.primary_metric, "MCC");
    }

    #[test]
    #[ignore = "requires runtime corpus data (externalized from builtin)"]
    fn test_training_config_has_class_weights() {
        let config = generate_training_config();
        assert_eq!(config.training.class_weights.len(), 2);
        // Unsafe is minority → higher weight
        assert!(
            config.training.class_weights[1] > config.training.class_weights[0],
            "Unsafe weight should be higher than safe weight"
        );
    }

    #[test]
    #[ignore = "requires runtime corpus data (externalized from builtin)"]
    fn test_training_config_corpus_data() {
        let config = generate_training_config();
        assert!(config.data.total_entries > 100, "Must have corpus data");
        assert!(config.data.safe_count > 0);
        assert!(config.data.unsafe_count > 0);
        assert!(config.data.preamble_stripped);
    }

    #[test]
    fn test_format_yaml_produces_yaml() {
        let config = generate_training_config();
        let yaml = format_yaml(&config);
        assert!(yaml.contains("architecture: encoder"), "Must produce YAML");
        assert!(yaml.contains("codebert"), "Must reference CodeBERT");
        assert!(yaml.contains("class_weights:"), "Must have class weights");
    }

    #[test]
    fn test_format_json_produces_json() {
        let config = generate_training_config();
        let json = format_json(&config);
        assert!(json.contains("\"architecture\""), "Must produce JSON");
        // Should be valid JSON
        let parsed: Result<serde_json::Value, _> = serde_json::from_str(&json);
        assert!(parsed.is_ok(), "Must produce valid JSON");
    }

    #[test]
    fn test_compute_sqrt_inverse_weight_balanced() {
        let w = compute_sqrt_inverse_weight(50, 100);
        assert!((w - std::f64::consts::SQRT_2).abs() < 0.001);
    }

    #[test]
    fn test_compute_sqrt_inverse_weight_guards() {
        assert!((compute_sqrt_inverse_weight(0, 100) - 1.0).abs() < 1e-9);
        assert!((compute_sqrt_inverse_weight(50, 0) - 1.0).abs() < 1e-9);
    }
}