Skip to main content

entrenar/finetune/
classification.rs

1//! Classification head and corpus loader for fine-tuning
2//!
3//! Provides a classifier head that attaches to a transformer's hidden states
4//! (mean pooling → linear) and a JSONL corpus loader for safety labels.
5//!
6//! # Contract
7//!
8//! See `aprender/contracts/classification-finetune-v1.yaml`:
9//! - F-CLASS-001: Logit shape == num_classes
10//! - F-CLASS-002: Label index < num_classes
11//! - F-CLASS-004: Weight shape == hidden_size * num_classes
12//!
13//! # Architecture
14//!
15//! ```text
16//! hidden_states [seq_len, hidden_size]
17//!   → mean pool → [hidden_size]
18//!   → linear    → [num_classes]
19//!   → softmax   → class probabilities
20//! ```
21
22use crate::autograd::{matmul, Tensor};
23use crate::transformer::ModelArchitecture;
24use serde::Deserialize;
25use std::path::Path;
26
27/// Pooling strategy for extracting a fixed-size vector from sequence hidden states.
28///
29/// Decoders use last-token pooling (autoregressive — last position sees all context).
30/// Encoders use CLS pooling (bidirectional — first [CLS] token sees all context).
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum PoolingStrategy {
33    /// Mean pool across all positions (default, architecture-agnostic)
34    Mean,
35    /// Use last token's hidden state (decoder convention)
36    LastToken,
37    /// Use first token's [CLS] hidden state (encoder convention: BERT/RoBERTa)
38    Cls,
39}
40
41impl PoolingStrategy {
42    /// Select pooling strategy based on model architecture.
43    pub fn from_architecture(arch: ModelArchitecture) -> Self {
44        match arch {
45            ModelArchitecture::Encoder => Self::Cls,
46            ModelArchitecture::Decoder => Self::Mean, // keep backward compat
47        }
48    }
49}
50
51/// Classification head: mean pool + linear projection.
52///
53/// Maps transformer hidden states to class logits.
54/// Weight shape: [hidden_size * num_classes] (flattened row-major).
55/// Bias shape: [num_classes].
56pub struct ClassificationHead {
57    /// Linear weight [hidden_size, num_classes] flattened
58    pub weight: Tensor,
59    /// Bias [num_classes]
60    pub bias: Tensor,
61    /// Input dimension (model hidden_size)
62    hidden_size: usize,
63    /// Output dimension (number of classes)
64    num_classes: usize,
65}
66
67impl ClassificationHead {
68    /// Create a new classification head with Xavier-initialized weights.
69    ///
70    /// # Arguments
71    /// * `hidden_size` - Transformer hidden dimension (e.g., 896 for Qwen2-0.5B)
72    /// * `num_classes` - Number of output classes (e.g., 5 for shell safety)
73    ///
74    /// # Contract (F-CLASS-004)
75    /// Validates hidden_size > 0 and num_classes >= 2.
76    pub fn new(hidden_size: usize, num_classes: usize) -> Self {
77        assert!(hidden_size > 0, "F-CLASS-004: hidden_size must be > 0");
78        assert!(num_classes >= 2, "F-CLASS-004: num_classes must be >= 2");
79
80        // Xavier uniform initialization: U(-sqrt(6/(fan_in+fan_out)), sqrt(6/(fan_in+fan_out)))
81        let scale = (6.0 / (hidden_size + num_classes) as f32).sqrt();
82        let mut rng_state: u64 = 42;
83        let weight_data: Vec<f32> = (0..hidden_size * num_classes)
84            .map(|_| {
85                // Simple LCG for deterministic init
86                rng_state = rng_state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
87                let u = (rng_state >> 33) as f32 / (1u64 << 31) as f32;
88                (2.0 * u - 1.0) * scale
89            })
90            .collect();
91
92        let weight = Tensor::from_vec(weight_data, true);
93        let bias = Tensor::zeros(num_classes, true);
94
95        Self { weight, bias, hidden_size, num_classes }
96    }
97
98    /// Forward pass: hidden_states → mean pool → linear → logits.
99    ///
100    /// # Arguments
101    /// * `hidden_states` - Transformer output [seq_len * hidden_size] flattened
102    /// * `seq_len` - Sequence length
103    ///
104    /// # Returns
105    /// Logits tensor [num_classes]
106    ///
107    /// # Contract (F-CLASS-001)
108    /// Output always has exactly num_classes elements.
109    pub fn forward(&self, hidden_states: &Tensor, seq_len: usize) -> Tensor {
110        // Mean pool across sequence dimension
111        let pooled = self.mean_pool(hidden_states, seq_len);
112
113        // Linear: pooled [1, hidden_size] @ weight [hidden_size, num_classes] = [1, num_classes]
114        let logits = matmul(&pooled, &self.weight, 1, self.hidden_size, self.num_classes);
115
116        // Add bias
117        let logits_data: Vec<f32> = logits
118            .data()
119            .as_slice()
120            .expect("contiguous logits data")
121            .iter()
122            .zip(self.bias.data().as_slice().expect("contiguous bias data").iter())
123            .map(|(&l, &b)| l + b)
124            .collect();
125
126        Tensor::from_vec(logits_data, logits.requires_grad())
127    }
128
129    /// Mean pool hidden states across sequence length.
130    ///
131    /// hidden_states: [seq_len * hidden_size] → [hidden_size]
132    pub fn mean_pool(&self, hidden_states: &Tensor, seq_len: usize) -> Tensor {
133        let data = hidden_states.data();
134        let slice = data.as_slice().expect("contiguous hidden states");
135        let h = self.hidden_size;
136
137        let mut pooled = vec![0.0f32; h];
138        for pos in 0..seq_len {
139            let start = pos * h;
140            for j in 0..h {
141                pooled[j] += slice[start + j];
142            }
143        }
144        let inv_len = 1.0 / seq_len as f32;
145        for v in &mut pooled {
146            *v *= inv_len;
147        }
148
149        Tensor::from_vec(pooled, hidden_states.requires_grad())
150    }
151
152    /// CLS pooling: extract the first token's hidden state.
153    ///
154    /// In BERT-family models, position 0 is the [CLS] token which attends
155    /// bidirectionally to all other tokens, making it suitable for classification.
156    ///
157    /// # Contract (ENC-007)
158    /// Output has exactly hidden_size elements (first position of input).
159    pub fn cls_pool(&self, hidden_states: &Tensor) -> Tensor {
160        let data = hidden_states.data();
161        let slice = data.as_slice().expect("contiguous hidden states");
162        let h = self.hidden_size;
163        Tensor::from_vec(slice[..h].to_vec(), hidden_states.requires_grad())
164    }
165
166    /// Last-token pooling: extract the last token's hidden state.
167    ///
168    /// In decoder models, the last token has seen all prior context through
169    /// causal attention, making it suitable for classification.
170    pub fn last_token_pool(&self, hidden_states: &Tensor, seq_len: usize) -> Tensor {
171        let data = hidden_states.data();
172        let slice = data.as_slice().expect("contiguous hidden states");
173        let h = self.hidden_size;
174        let start = (seq_len - 1) * h;
175        Tensor::from_vec(slice[start..start + h].to_vec(), hidden_states.requires_grad())
176    }
177
178    /// Pool hidden states using the specified strategy.
179    pub fn pool(
180        &self,
181        hidden_states: &Tensor,
182        seq_len: usize,
183        strategy: PoolingStrategy,
184    ) -> Tensor {
185        match strategy {
186            PoolingStrategy::Mean => self.mean_pool(hidden_states, seq_len),
187            PoolingStrategy::Cls => self.cls_pool(hidden_states),
188            PoolingStrategy::LastToken => self.last_token_pool(hidden_states, seq_len),
189        }
190    }
191
192    /// Forward pass with configurable pooling strategy.
193    pub fn forward_with_pooling(
194        &self,
195        hidden_states: &Tensor,
196        seq_len: usize,
197        strategy: PoolingStrategy,
198    ) -> Tensor {
199        let pooled = self.pool(hidden_states, seq_len, strategy);
200
201        let logits = matmul(&pooled, &self.weight, 1, self.hidden_size, self.num_classes);
202
203        let logits_data: Vec<f32> = logits
204            .data()
205            .as_slice()
206            .expect("contiguous logits data")
207            .iter()
208            .zip(self.bias.data().as_slice().expect("contiguous bias data").iter())
209            .map(|(&l, &b)| l + b)
210            .collect();
211
212        Tensor::from_vec(logits_data, logits.requires_grad())
213    }
214
215    /// Get trainable parameters (weight + bias).
216    pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
217        vec![&mut self.weight, &mut self.bias]
218    }
219
220    /// Get parameters (immutable).
221    pub fn parameters(&self) -> Vec<&Tensor> {
222        vec![&self.weight, &self.bias]
223    }
224
225    /// Number of classes.
226    #[must_use]
227    pub fn num_classes(&self) -> usize {
228        self.num_classes
229    }
230
231    /// Hidden size.
232    #[must_use]
233    pub fn hidden_size(&self) -> usize {
234        self.hidden_size
235    }
236
237    /// Total trainable parameter count.
238    #[must_use]
239    pub fn num_parameters(&self) -> usize {
240        self.hidden_size * self.num_classes + self.num_classes
241    }
242}
243
244// =============================================================================
245// CORPUS LOADER
246// =============================================================================
247
248/// A single shell safety corpus sample (single-label).
249#[derive(Debug, Clone, Deserialize)]
250pub struct SafetySample {
251    /// Shell script content
252    pub input: String,
253    /// Safety class index (0-4)
254    pub label: usize,
255}
256
257impl SafetySample {
258    /// Convert the input text to token IDs using byte-level encoding.
259    ///
260    /// Each byte of the UTF-8 representation is mapped to a `u32` token ID.
261    /// This provides a simple, deterministic tokenization suitable for the
262    /// classification pipeline. For production use with large vocabularies,
263    /// an external tokenizer (BPE, SentencePiece) should be used before
264    /// calling `train_step` directly.
265    #[must_use]
266    pub fn input_ids(&self) -> Vec<u32> {
267        self.input.bytes().map(u32::from).collect()
268    }
269}
270
271/// A pre-tokenized training sample (KAIZEN-028).
272///
273/// Token IDs are computed once at dataset construction time and reused
274/// across all epochs, eliminating redundant BPE encoding.
275///
276/// # Contract (C-PRETOK-001)
277///
278/// - **Precondition**: `token_ids` produced by the same tokenizer used for inference
279/// - **Postcondition**: `token_ids.len() <= max_seq_len` (truncated at construction)
280/// - **Invariant**: `token_ids` is non-empty (at least one token)
281#[derive(Debug, Clone)]
282pub struct TokenizedSample {
283    /// Pre-computed token IDs (truncated to max_seq_len)
284    pub token_ids: Vec<u32>,
285    /// Safety class index
286    pub label: usize,
287}
288
289/// A multi-label shell safety corpus sample.
290///
291/// A script can have multiple active labels (e.g., both non-deterministic AND needs-quoting).
292/// Labels are a multi-hot vector: `[0.0, 1.0, 1.0, 0.0, 0.0]` means classes 1 and 2 are active.
293#[derive(Debug, Clone, Deserialize)]
294pub struct MultiLabelSafetySample {
295    /// Shell script content
296    pub input: String,
297    /// Multi-hot label vector (length == num_classes)
298    pub labels: Vec<f32>,
299}
300
301impl MultiLabelSafetySample {
302    /// Create from a single-label sample by converting to multi-hot encoding.
303    pub fn from_single_label(sample: &SafetySample, num_classes: usize) -> Self {
304        let mut labels = vec![0.0f32; num_classes];
305        if sample.label < num_classes {
306            labels[sample.label] = 1.0;
307        }
308        Self { input: sample.input.clone(), labels }
309    }
310
311    /// Active class indices (where labels[i] > 0.5).
312    pub fn active_classes(&self) -> Vec<usize> {
313        self.labels.iter().enumerate().filter(|(_, &v)| v > 0.5).map(|(i, _)| i).collect()
314    }
315}
316
317/// Corpus statistics.
318#[derive(Debug, Clone)]
319pub struct SafetyCorpusStats {
320    /// Total samples
321    pub total: usize,
322    /// Samples per class
323    pub class_counts: Vec<usize>,
324    /// Average input length (chars)
325    pub avg_input_len: usize,
326}
327
328/// Load shell safety corpus from JSONL file.
329///
330/// Each line is `{"input": "...", "label": N}` where N is 0-4.
331///
332/// # Contract (F-CLASS-002)
333/// All labels must be < num_classes.
334///
335/// # Errors
336/// Returns error if file cannot be read or contains invalid labels.
337pub fn load_safety_corpus(path: &Path, num_classes: usize) -> crate::Result<Vec<SafetySample>> {
338    let content = std::fs::read_to_string(path)
339        .map_err(|e| crate::Error::Io(format!("Corpus file not found: {}: {e}", path.display())))?;
340
341    let mut samples = Vec::new();
342    for (line_num, line) in content.lines().enumerate() {
343        let line = line.trim();
344        if line.is_empty() {
345            continue;
346        }
347        let sample: SafetySample = serde_json::from_str(line).map_err(|e| {
348            crate::Error::ConfigError(format!("Invalid JSONL at line {}: {e}", line_num + 1))
349        })?;
350
351        // F-CLASS-002: label bounds check
352        if sample.label >= num_classes {
353            return Err(crate::Error::ConfigError(format!(
354                "F-CLASS-002: label {} at line {} out of range (num_classes={num_classes})",
355                sample.label,
356                line_num + 1,
357            )));
358        }
359
360        samples.push(sample);
361    }
362
363    Ok(samples)
364}
365
366/// Compute corpus statistics.
367pub fn corpus_stats(samples: &[SafetySample], num_classes: usize) -> SafetyCorpusStats {
368    let mut class_counts = vec![0usize; num_classes];
369    let mut total_len = 0usize;
370
371    for s in samples {
372        if s.label < num_classes {
373            class_counts[s.label] += 1;
374        }
375        total_len += s.input.len();
376    }
377
378    SafetyCorpusStats {
379        total: samples.len(),
380        class_counts,
381        avg_input_len: if samples.is_empty() { 0 } else { total_len / samples.len() },
382    }
383}
384
385/// Load multi-label corpus from JSONL file.
386///
387/// Supports two formats:
388/// - Single-label: `{"input": "...", "label": N}` → converts to multi-hot
389/// - Multi-label: `{"input": "...", "labels": [0.0, 1.0, 1.0, 0.0, 0.0]}`
390///
391/// # Errors
392/// Returns error if file cannot be read or labels are invalid.
393pub fn load_multi_label_corpus(
394    path: &Path,
395    num_classes: usize,
396) -> crate::Result<Vec<MultiLabelSafetySample>> {
397    let content = std::fs::read_to_string(path)
398        .map_err(|e| crate::Error::Io(format!("Corpus file not found: {}: {e}", path.display())))?;
399
400    let mut samples = Vec::new();
401    for (line_num, line) in content.lines().enumerate() {
402        let line = line.trim();
403        if line.is_empty() {
404            continue;
405        }
406        samples.push(parse_multi_label_line(line, line_num, num_classes)?);
407    }
408
409    Ok(samples)
410}
411
412/// Parse a single JSONL line as multi-label or single-label sample.
413fn parse_multi_label_line(
414    line: &str,
415    line_num: usize,
416    num_classes: usize,
417) -> crate::Result<MultiLabelSafetySample> {
418    // Try multi-label format first
419    if let Ok(sample) = serde_json::from_str::<MultiLabelSafetySample>(line) {
420        if sample.labels.len() != num_classes {
421            return Err(crate::Error::ConfigError(format!(
422                "F-CLASS-001: labels length {} at line {} != num_classes {num_classes}",
423                sample.labels.len(),
424                line_num + 1,
425            )));
426        }
427        return Ok(sample);
428    }
429
430    if let Ok(single) = serde_json::from_str::<SafetySample>(line) {
431        if single.label >= num_classes {
432            return Err(crate::Error::ConfigError(format!(
433                "F-CLASS-002: label {} at line {} out of range (num_classes={num_classes})",
434                single.label,
435                line_num + 1,
436            )));
437        }
438        return Ok(MultiLabelSafetySample::from_single_label(&single, num_classes));
439    }
440
441    Err(crate::Error::ConfigError(format!(
442        "Invalid JSONL at line {}: unrecognized format",
443        line_num + 1,
444    )))
445}
446
447/// BCE with logits loss for multi-label classification.
448///
449/// Per-element: `L_i = max(x_i, 0) - x_i * t_i + log(1 + exp(-|x_i|))`
450/// Total: `L = mean(L_i)`
451///
452/// # Contract (F-CLASS-005)
453/// Output is finite (no NaN/Inf).
454pub fn bce_with_logits_loss(logits: &Tensor, targets: &[f32], num_classes: usize) -> Tensor {
455    let data = logits.data();
456    let slice = data.as_slice().expect("contiguous logits");
457    assert_eq!(slice.len(), num_classes, "F-CLASS-001: logit shape mismatch");
458    assert_eq!(targets.len(), num_classes, "F-CLASS-001: target shape mismatch");
459
460    let total_loss: f32 = slice
461        .iter()
462        .zip(targets.iter())
463        .map(|(&x, &t)| {
464            let relu = x.max(0.0);
465            relu - x * t + (1.0 + (-x.abs()).exp()).ln()
466        })
467        .sum::<f32>()
468        / num_classes as f32;
469
470    // F-CLASS-005: finite check
471    let total_loss = if total_loss.is_finite() { total_loss } else { 100.0 };
472
473    Tensor::from_vec(vec![total_loss], logits.requires_grad())
474}
475
476/// Class weight computation strategy for imbalanced datasets.
477#[derive(Debug, Clone, Copy, PartialEq, Eq)]
478pub enum ClassWeightStrategy {
479    /// All classes weighted equally: w_c = 1.0
480    Uniform,
481    /// Inverse frequency: w_c = N / (K * n_c)
482    InverseFreq,
483    /// Square root of inverse frequency: w_c = sqrt(N / (K * n_c))
484    SqrtInverse,
485}
486
487impl std::str::FromStr for ClassWeightStrategy {
488    type Err = String;
489
490    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
491        match s.to_lowercase().as_str() {
492            "uniform" => Ok(Self::Uniform),
493            "inverse_freq" | "inverse" => Ok(Self::InverseFreq),
494            "sqrt_inverse" | "sqrt" => Ok(Self::SqrtInverse),
495            _ => Err(format!(
496                "Unknown class weight strategy: {s}. Use: uniform, inverse_freq, sqrt_inverse"
497            )),
498        }
499    }
500}
501
502impl std::fmt::Display for ClassWeightStrategy {
503    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504        match self {
505            Self::Uniform => write!(f, "uniform"),
506            Self::InverseFreq => write!(f, "inverse_freq"),
507            Self::SqrtInverse => write!(f, "sqrt_inverse"),
508        }
509    }
510}
511
512/// Compute class weights from corpus statistics.
513///
514/// Weights are normalized so they sum to `num_classes`, preserving
515/// the overall loss scale while rebalancing class contributions.
516///
517/// # Contract (F-TUNE-005)
518/// `abs(sum(weights) - num_classes) < 1e-5`
519///
520/// # Panics
521/// Panics if `stats.class_counts.len() != num_classes` or any class has zero samples.
522pub fn compute_class_weights(
523    stats: &SafetyCorpusStats,
524    strategy: ClassWeightStrategy,
525    num_classes: usize,
526) -> Vec<f32> {
527    assert_eq!(
528        stats.class_counts.len(),
529        num_classes,
530        "F-TUNE-005: class_counts.len() != num_classes"
531    );
532
533    let n = stats.total as f32;
534    let k = num_classes as f32;
535
536    let raw_weights: Vec<f32> = match strategy {
537        ClassWeightStrategy::Uniform => vec![1.0; num_classes],
538        ClassWeightStrategy::InverseFreq => stats
539            .class_counts
540            .iter()
541            .map(|&count| {
542                let count = count.max(1) as f32; // avoid division by zero
543                n / (k * count)
544            })
545            .collect(),
546        ClassWeightStrategy::SqrtInverse => stats
547            .class_counts
548            .iter()
549            .map(|&count| {
550                let count = count.max(1) as f32;
551                (n / (k * count)).sqrt()
552            })
553            .collect(),
554    };
555
556    // Normalize so weights sum to num_classes
557    let sum: f32 = raw_weights.iter().sum();
558    if sum < 1e-10 {
559        return vec![1.0; num_classes];
560    }
561    let scale = k / sum;
562    raw_weights.iter().map(|&w| w * scale).collect()
563}
564
565/// Cross-entropy loss for classification.
566///
567/// # Arguments
568/// * `logits` - Raw logits [num_classes]
569/// * `target` - Target class index
570/// * `num_classes` - Number of classes
571///
572/// # Returns
573/// Scalar loss value (as single-element Tensor)
574///
575/// # Contract (F-CLASS-005)
576/// Output is finite (no NaN/Inf).
577pub fn cross_entropy_loss(logits: &Tensor, target: usize, num_classes: usize) -> Tensor {
578    let data = logits.data();
579    let slice = data.as_slice().expect("contiguous logits");
580    assert_eq!(slice.len(), num_classes, "F-CLASS-001: logit shape mismatch");
581    assert!(target < num_classes, "F-CLASS-002: label out of range");
582
583    // Numerically stable log-softmax: log(softmax(x_i)) = x_i - max - log(sum(exp(x_j - max)))
584    let max_val = slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
585    let log_sum_exp: f32 = slice.iter().map(|&v| (v - max_val).exp()).sum::<f32>().ln() + max_val;
586    let loss = -(slice[target] - log_sum_exp);
587
588    // F-CLASS-005: finite check
589    let loss = if loss.is_finite() { loss } else { 100.0 };
590
591    Tensor::from_vec(vec![loss], logits.requires_grad())
592}
593
594#[cfg(test)]
595#[allow(clippy::unwrap_used)]
596#[path = "classification_tests.rs"]
597mod tests;