use crate::autograd::{matmul, Tensor};
use crate::transformer::ModelArchitecture;
use serde::Deserialize;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PoolingStrategy {
Mean,
LastToken,
Cls,
}
impl PoolingStrategy {
pub fn from_architecture(arch: ModelArchitecture) -> Self {
match arch {
ModelArchitecture::Encoder => Self::Cls,
ModelArchitecture::Decoder => Self::Mean, }
}
}
pub struct ClassificationHead {
pub weight: Tensor,
pub bias: Tensor,
hidden_size: usize,
num_classes: usize,
}
impl ClassificationHead {
pub fn new(hidden_size: usize, num_classes: usize) -> Self {
assert!(hidden_size > 0, "F-CLASS-004: hidden_size must be > 0");
assert!(num_classes >= 2, "F-CLASS-004: num_classes must be >= 2");
let scale = (6.0 / (hidden_size + num_classes) as f32).sqrt();
let mut rng_state: u64 = 42;
let weight_data: Vec<f32> = (0..hidden_size * num_classes)
.map(|_| {
rng_state = rng_state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
let u = (rng_state >> 33) as f32 / (1u64 << 31) as f32;
(2.0 * u - 1.0) * scale
})
.collect();
let weight = Tensor::from_vec(weight_data, true);
let bias = Tensor::zeros(num_classes, true);
Self { weight, bias, hidden_size, num_classes }
}
pub fn forward(&self, hidden_states: &Tensor, seq_len: usize) -> Tensor {
let pooled = self.mean_pool(hidden_states, seq_len);
let logits = matmul(&pooled, &self.weight, 1, self.hidden_size, self.num_classes);
let logits_data: Vec<f32> = logits
.data()
.as_slice()
.expect("contiguous logits data")
.iter()
.zip(self.bias.data().as_slice().expect("contiguous bias data").iter())
.map(|(&l, &b)| l + b)
.collect();
Tensor::from_vec(logits_data, logits.requires_grad())
}
pub fn mean_pool(&self, hidden_states: &Tensor, seq_len: usize) -> Tensor {
let data = hidden_states.data();
let slice = data.as_slice().expect("contiguous hidden states");
let h = self.hidden_size;
let mut pooled = vec![0.0f32; h];
for pos in 0..seq_len {
let start = pos * h;
for j in 0..h {
pooled[j] += slice[start + j];
}
}
let inv_len = 1.0 / seq_len as f32;
for v in &mut pooled {
*v *= inv_len;
}
Tensor::from_vec(pooled, hidden_states.requires_grad())
}
pub fn cls_pool(&self, hidden_states: &Tensor) -> Tensor {
let data = hidden_states.data();
let slice = data.as_slice().expect("contiguous hidden states");
let h = self.hidden_size;
Tensor::from_vec(slice[..h].to_vec(), hidden_states.requires_grad())
}
pub fn last_token_pool(&self, hidden_states: &Tensor, seq_len: usize) -> Tensor {
let data = hidden_states.data();
let slice = data.as_slice().expect("contiguous hidden states");
let h = self.hidden_size;
let start = (seq_len - 1) * h;
Tensor::from_vec(slice[start..start + h].to_vec(), hidden_states.requires_grad())
}
pub fn pool(
&self,
hidden_states: &Tensor,
seq_len: usize,
strategy: PoolingStrategy,
) -> Tensor {
match strategy {
PoolingStrategy::Mean => self.mean_pool(hidden_states, seq_len),
PoolingStrategy::Cls => self.cls_pool(hidden_states),
PoolingStrategy::LastToken => self.last_token_pool(hidden_states, seq_len),
}
}
pub fn forward_with_pooling(
&self,
hidden_states: &Tensor,
seq_len: usize,
strategy: PoolingStrategy,
) -> Tensor {
let pooled = self.pool(hidden_states, seq_len, strategy);
let logits = matmul(&pooled, &self.weight, 1, self.hidden_size, self.num_classes);
let logits_data: Vec<f32> = logits
.data()
.as_slice()
.expect("contiguous logits data")
.iter()
.zip(self.bias.data().as_slice().expect("contiguous bias data").iter())
.map(|(&l, &b)| l + b)
.collect();
Tensor::from_vec(logits_data, logits.requires_grad())
}
pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
vec![&mut self.weight, &mut self.bias]
}
pub fn parameters(&self) -> Vec<&Tensor> {
vec![&self.weight, &self.bias]
}
#[must_use]
pub fn num_classes(&self) -> usize {
self.num_classes
}
#[must_use]
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
#[must_use]
pub fn num_parameters(&self) -> usize {
self.hidden_size * self.num_classes + self.num_classes
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct SafetySample {
pub input: String,
pub label: usize,
}
impl SafetySample {
#[must_use]
pub fn input_ids(&self) -> Vec<u32> {
self.input.bytes().map(u32::from).collect()
}
}
#[derive(Debug, Clone)]
pub struct TokenizedSample {
pub token_ids: Vec<u32>,
pub label: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MultiLabelSafetySample {
pub input: String,
pub labels: Vec<f32>,
}
impl MultiLabelSafetySample {
pub fn from_single_label(sample: &SafetySample, num_classes: usize) -> Self {
let mut labels = vec![0.0f32; num_classes];
if sample.label < num_classes {
labels[sample.label] = 1.0;
}
Self { input: sample.input.clone(), labels }
}
pub fn active_classes(&self) -> Vec<usize> {
self.labels.iter().enumerate().filter(|(_, &v)| v > 0.5).map(|(i, _)| i).collect()
}
}
#[derive(Debug, Clone)]
pub struct SafetyCorpusStats {
pub total: usize,
pub class_counts: Vec<usize>,
pub avg_input_len: usize,
}
pub fn load_safety_corpus(path: &Path, num_classes: usize) -> crate::Result<Vec<SafetySample>> {
let content = std::fs::read_to_string(path)
.map_err(|e| crate::Error::Io(format!("Corpus file not found: {}: {e}", path.display())))?;
let mut samples = Vec::new();
for (line_num, line) in content.lines().enumerate() {
let line = line.trim();
if line.is_empty() {
continue;
}
let sample: SafetySample = serde_json::from_str(line).map_err(|e| {
crate::Error::ConfigError(format!("Invalid JSONL at line {}: {e}", line_num + 1))
})?;
if sample.label >= num_classes {
return Err(crate::Error::ConfigError(format!(
"F-CLASS-002: label {} at line {} out of range (num_classes={num_classes})",
sample.label,
line_num + 1,
)));
}
samples.push(sample);
}
Ok(samples)
}
pub fn corpus_stats(samples: &[SafetySample], num_classes: usize) -> SafetyCorpusStats {
let mut class_counts = vec![0usize; num_classes];
let mut total_len = 0usize;
for s in samples {
if s.label < num_classes {
class_counts[s.label] += 1;
}
total_len += s.input.len();
}
SafetyCorpusStats {
total: samples.len(),
class_counts,
avg_input_len: if samples.is_empty() { 0 } else { total_len / samples.len() },
}
}
pub fn load_multi_label_corpus(
path: &Path,
num_classes: usize,
) -> crate::Result<Vec<MultiLabelSafetySample>> {
let content = std::fs::read_to_string(path)
.map_err(|e| crate::Error::Io(format!("Corpus file not found: {}: {e}", path.display())))?;
let mut samples = Vec::new();
for (line_num, line) in content.lines().enumerate() {
let line = line.trim();
if line.is_empty() {
continue;
}
samples.push(parse_multi_label_line(line, line_num, num_classes)?);
}
Ok(samples)
}
fn parse_multi_label_line(
line: &str,
line_num: usize,
num_classes: usize,
) -> crate::Result<MultiLabelSafetySample> {
if let Ok(sample) = serde_json::from_str::<MultiLabelSafetySample>(line) {
if sample.labels.len() != num_classes {
return Err(crate::Error::ConfigError(format!(
"F-CLASS-001: labels length {} at line {} != num_classes {num_classes}",
sample.labels.len(),
line_num + 1,
)));
}
return Ok(sample);
}
if let Ok(single) = serde_json::from_str::<SafetySample>(line) {
if single.label >= num_classes {
return Err(crate::Error::ConfigError(format!(
"F-CLASS-002: label {} at line {} out of range (num_classes={num_classes})",
single.label,
line_num + 1,
)));
}
return Ok(MultiLabelSafetySample::from_single_label(&single, num_classes));
}
Err(crate::Error::ConfigError(format!(
"Invalid JSONL at line {}: unrecognized format",
line_num + 1,
)))
}
pub fn bce_with_logits_loss(logits: &Tensor, targets: &[f32], num_classes: usize) -> Tensor {
let data = logits.data();
let slice = data.as_slice().expect("contiguous logits");
assert_eq!(slice.len(), num_classes, "F-CLASS-001: logit shape mismatch");
assert_eq!(targets.len(), num_classes, "F-CLASS-001: target shape mismatch");
let total_loss: f32 = slice
.iter()
.zip(targets.iter())
.map(|(&x, &t)| {
let relu = x.max(0.0);
relu - x * t + (1.0 + (-x.abs()).exp()).ln()
})
.sum::<f32>()
/ num_classes as f32;
let total_loss = if total_loss.is_finite() { total_loss } else { 100.0 };
Tensor::from_vec(vec![total_loss], logits.requires_grad())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ClassWeightStrategy {
Uniform,
InverseFreq,
SqrtInverse,
}
impl std::str::FromStr for ClassWeightStrategy {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"uniform" => Ok(Self::Uniform),
"inverse_freq" | "inverse" => Ok(Self::InverseFreq),
"sqrt_inverse" | "sqrt" => Ok(Self::SqrtInverse),
_ => Err(format!(
"Unknown class weight strategy: {s}. Use: uniform, inverse_freq, sqrt_inverse"
)),
}
}
}
impl std::fmt::Display for ClassWeightStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Uniform => write!(f, "uniform"),
Self::InverseFreq => write!(f, "inverse_freq"),
Self::SqrtInverse => write!(f, "sqrt_inverse"),
}
}
}
pub fn compute_class_weights(
stats: &SafetyCorpusStats,
strategy: ClassWeightStrategy,
num_classes: usize,
) -> Vec<f32> {
assert_eq!(
stats.class_counts.len(),
num_classes,
"F-TUNE-005: class_counts.len() != num_classes"
);
let n = stats.total as f32;
let k = num_classes as f32;
let raw_weights: Vec<f32> = match strategy {
ClassWeightStrategy::Uniform => vec![1.0; num_classes],
ClassWeightStrategy::InverseFreq => stats
.class_counts
.iter()
.map(|&count| {
let count = count.max(1) as f32; n / (k * count)
})
.collect(),
ClassWeightStrategy::SqrtInverse => stats
.class_counts
.iter()
.map(|&count| {
let count = count.max(1) as f32;
(n / (k * count)).sqrt()
})
.collect(),
};
let sum: f32 = raw_weights.iter().sum();
if sum < 1e-10 {
return vec![1.0; num_classes];
}
let scale = k / sum;
raw_weights.iter().map(|&w| w * scale).collect()
}
pub fn cross_entropy_loss(logits: &Tensor, target: usize, num_classes: usize) -> Tensor {
let data = logits.data();
let slice = data.as_slice().expect("contiguous logits");
assert_eq!(slice.len(), num_classes, "F-CLASS-001: logit shape mismatch");
assert!(target < num_classes, "F-CLASS-002: label out of range");
let max_val = slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let log_sum_exp: f32 = slice.iter().map(|&v| (v - max_val).exp()).sum::<f32>().ln() + max_val;
let loss = -(slice[target] - log_sum_exp);
let loss = if loss.is_finite() { loss } else { 100.0 };
Tensor::from_vec(vec![loss], logits.requires_grad())
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
#[path = "classification_tests.rs"]
mod tests;