1use crate::autograd::{matmul, Tensor};
23use crate::transformer::ModelArchitecture;
24use serde::Deserialize;
25use std::path::Path;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum PoolingStrategy {
33 Mean,
35 LastToken,
37 Cls,
39}
40
41impl PoolingStrategy {
42 pub fn from_architecture(arch: ModelArchitecture) -> Self {
44 match arch {
45 ModelArchitecture::Encoder => Self::Cls,
46 ModelArchitecture::Decoder => Self::Mean, }
48 }
49}
50
51pub struct ClassificationHead {
57 pub weight: Tensor,
59 pub bias: Tensor,
61 hidden_size: usize,
63 num_classes: usize,
65}
66
67impl ClassificationHead {
68 pub fn new(hidden_size: usize, num_classes: usize) -> Self {
77 assert!(hidden_size > 0, "F-CLASS-004: hidden_size must be > 0");
78 assert!(num_classes >= 2, "F-CLASS-004: num_classes must be >= 2");
79
80 let scale = (6.0 / (hidden_size + num_classes) as f32).sqrt();
82 let mut rng_state: u64 = 42;
83 let weight_data: Vec<f32> = (0..hidden_size * num_classes)
84 .map(|_| {
85 rng_state = rng_state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
87 let u = (rng_state >> 33) as f32 / (1u64 << 31) as f32;
88 (2.0 * u - 1.0) * scale
89 })
90 .collect();
91
92 let weight = Tensor::from_vec(weight_data, true);
93 let bias = Tensor::zeros(num_classes, true);
94
95 Self { weight, bias, hidden_size, num_classes }
96 }
97
98 pub fn forward(&self, hidden_states: &Tensor, seq_len: usize) -> Tensor {
110 let pooled = self.mean_pool(hidden_states, seq_len);
112
113 let logits = matmul(&pooled, &self.weight, 1, self.hidden_size, self.num_classes);
115
116 let logits_data: Vec<f32> = logits
118 .data()
119 .as_slice()
120 .expect("contiguous logits data")
121 .iter()
122 .zip(self.bias.data().as_slice().expect("contiguous bias data").iter())
123 .map(|(&l, &b)| l + b)
124 .collect();
125
126 Tensor::from_vec(logits_data, logits.requires_grad())
127 }
128
129 pub fn mean_pool(&self, hidden_states: &Tensor, seq_len: usize) -> Tensor {
133 let data = hidden_states.data();
134 let slice = data.as_slice().expect("contiguous hidden states");
135 let h = self.hidden_size;
136
137 let mut pooled = vec![0.0f32; h];
138 for pos in 0..seq_len {
139 let start = pos * h;
140 for j in 0..h {
141 pooled[j] += slice[start + j];
142 }
143 }
144 let inv_len = 1.0 / seq_len as f32;
145 for v in &mut pooled {
146 *v *= inv_len;
147 }
148
149 Tensor::from_vec(pooled, hidden_states.requires_grad())
150 }
151
152 pub fn cls_pool(&self, hidden_states: &Tensor) -> Tensor {
160 let data = hidden_states.data();
161 let slice = data.as_slice().expect("contiguous hidden states");
162 let h = self.hidden_size;
163 Tensor::from_vec(slice[..h].to_vec(), hidden_states.requires_grad())
164 }
165
166 pub fn last_token_pool(&self, hidden_states: &Tensor, seq_len: usize) -> Tensor {
171 let data = hidden_states.data();
172 let slice = data.as_slice().expect("contiguous hidden states");
173 let h = self.hidden_size;
174 let start = (seq_len - 1) * h;
175 Tensor::from_vec(slice[start..start + h].to_vec(), hidden_states.requires_grad())
176 }
177
178 pub fn pool(
180 &self,
181 hidden_states: &Tensor,
182 seq_len: usize,
183 strategy: PoolingStrategy,
184 ) -> Tensor {
185 match strategy {
186 PoolingStrategy::Mean => self.mean_pool(hidden_states, seq_len),
187 PoolingStrategy::Cls => self.cls_pool(hidden_states),
188 PoolingStrategy::LastToken => self.last_token_pool(hidden_states, seq_len),
189 }
190 }
191
192 pub fn forward_with_pooling(
194 &self,
195 hidden_states: &Tensor,
196 seq_len: usize,
197 strategy: PoolingStrategy,
198 ) -> Tensor {
199 let pooled = self.pool(hidden_states, seq_len, strategy);
200
201 let logits = matmul(&pooled, &self.weight, 1, self.hidden_size, self.num_classes);
202
203 let logits_data: Vec<f32> = logits
204 .data()
205 .as_slice()
206 .expect("contiguous logits data")
207 .iter()
208 .zip(self.bias.data().as_slice().expect("contiguous bias data").iter())
209 .map(|(&l, &b)| l + b)
210 .collect();
211
212 Tensor::from_vec(logits_data, logits.requires_grad())
213 }
214
215 pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
217 vec![&mut self.weight, &mut self.bias]
218 }
219
220 pub fn parameters(&self) -> Vec<&Tensor> {
222 vec![&self.weight, &self.bias]
223 }
224
225 #[must_use]
227 pub fn num_classes(&self) -> usize {
228 self.num_classes
229 }
230
231 #[must_use]
233 pub fn hidden_size(&self) -> usize {
234 self.hidden_size
235 }
236
237 #[must_use]
239 pub fn num_parameters(&self) -> usize {
240 self.hidden_size * self.num_classes + self.num_classes
241 }
242}
243
244#[derive(Debug, Clone, Deserialize)]
250pub struct SafetySample {
251 pub input: String,
253 pub label: usize,
255}
256
257impl SafetySample {
258 #[must_use]
266 pub fn input_ids(&self) -> Vec<u32> {
267 self.input.bytes().map(u32::from).collect()
268 }
269}
270
271#[derive(Debug, Clone)]
282pub struct TokenizedSample {
283 pub token_ids: Vec<u32>,
285 pub label: usize,
287}
288
289#[derive(Debug, Clone, Deserialize)]
294pub struct MultiLabelSafetySample {
295 pub input: String,
297 pub labels: Vec<f32>,
299}
300
301impl MultiLabelSafetySample {
302 pub fn from_single_label(sample: &SafetySample, num_classes: usize) -> Self {
304 let mut labels = vec![0.0f32; num_classes];
305 if sample.label < num_classes {
306 labels[sample.label] = 1.0;
307 }
308 Self { input: sample.input.clone(), labels }
309 }
310
311 pub fn active_classes(&self) -> Vec<usize> {
313 self.labels.iter().enumerate().filter(|(_, &v)| v > 0.5).map(|(i, _)| i).collect()
314 }
315}
316
317#[derive(Debug, Clone)]
319pub struct SafetyCorpusStats {
320 pub total: usize,
322 pub class_counts: Vec<usize>,
324 pub avg_input_len: usize,
326}
327
328pub fn load_safety_corpus(path: &Path, num_classes: usize) -> crate::Result<Vec<SafetySample>> {
338 let content = std::fs::read_to_string(path)
339 .map_err(|e| crate::Error::Io(format!("Corpus file not found: {}: {e}", path.display())))?;
340
341 let mut samples = Vec::new();
342 for (line_num, line) in content.lines().enumerate() {
343 let line = line.trim();
344 if line.is_empty() {
345 continue;
346 }
347 let sample: SafetySample = serde_json::from_str(line).map_err(|e| {
348 crate::Error::ConfigError(format!("Invalid JSONL at line {}: {e}", line_num + 1))
349 })?;
350
351 if sample.label >= num_classes {
353 return Err(crate::Error::ConfigError(format!(
354 "F-CLASS-002: label {} at line {} out of range (num_classes={num_classes})",
355 sample.label,
356 line_num + 1,
357 )));
358 }
359
360 samples.push(sample);
361 }
362
363 Ok(samples)
364}
365
366pub fn corpus_stats(samples: &[SafetySample], num_classes: usize) -> SafetyCorpusStats {
368 let mut class_counts = vec![0usize; num_classes];
369 let mut total_len = 0usize;
370
371 for s in samples {
372 if s.label < num_classes {
373 class_counts[s.label] += 1;
374 }
375 total_len += s.input.len();
376 }
377
378 SafetyCorpusStats {
379 total: samples.len(),
380 class_counts,
381 avg_input_len: if samples.is_empty() { 0 } else { total_len / samples.len() },
382 }
383}
384
385pub fn load_multi_label_corpus(
394 path: &Path,
395 num_classes: usize,
396) -> crate::Result<Vec<MultiLabelSafetySample>> {
397 let content = std::fs::read_to_string(path)
398 .map_err(|e| crate::Error::Io(format!("Corpus file not found: {}: {e}", path.display())))?;
399
400 let mut samples = Vec::new();
401 for (line_num, line) in content.lines().enumerate() {
402 let line = line.trim();
403 if line.is_empty() {
404 continue;
405 }
406 samples.push(parse_multi_label_line(line, line_num, num_classes)?);
407 }
408
409 Ok(samples)
410}
411
412fn parse_multi_label_line(
414 line: &str,
415 line_num: usize,
416 num_classes: usize,
417) -> crate::Result<MultiLabelSafetySample> {
418 if let Ok(sample) = serde_json::from_str::<MultiLabelSafetySample>(line) {
420 if sample.labels.len() != num_classes {
421 return Err(crate::Error::ConfigError(format!(
422 "F-CLASS-001: labels length {} at line {} != num_classes {num_classes}",
423 sample.labels.len(),
424 line_num + 1,
425 )));
426 }
427 return Ok(sample);
428 }
429
430 if let Ok(single) = serde_json::from_str::<SafetySample>(line) {
431 if single.label >= num_classes {
432 return Err(crate::Error::ConfigError(format!(
433 "F-CLASS-002: label {} at line {} out of range (num_classes={num_classes})",
434 single.label,
435 line_num + 1,
436 )));
437 }
438 return Ok(MultiLabelSafetySample::from_single_label(&single, num_classes));
439 }
440
441 Err(crate::Error::ConfigError(format!(
442 "Invalid JSONL at line {}: unrecognized format",
443 line_num + 1,
444 )))
445}
446
447pub fn bce_with_logits_loss(logits: &Tensor, targets: &[f32], num_classes: usize) -> Tensor {
455 let data = logits.data();
456 let slice = data.as_slice().expect("contiguous logits");
457 assert_eq!(slice.len(), num_classes, "F-CLASS-001: logit shape mismatch");
458 assert_eq!(targets.len(), num_classes, "F-CLASS-001: target shape mismatch");
459
460 let total_loss: f32 = slice
461 .iter()
462 .zip(targets.iter())
463 .map(|(&x, &t)| {
464 let relu = x.max(0.0);
465 relu - x * t + (1.0 + (-x.abs()).exp()).ln()
466 })
467 .sum::<f32>()
468 / num_classes as f32;
469
470 let total_loss = if total_loss.is_finite() { total_loss } else { 100.0 };
472
473 Tensor::from_vec(vec![total_loss], logits.requires_grad())
474}
475
476#[derive(Debug, Clone, Copy, PartialEq, Eq)]
478pub enum ClassWeightStrategy {
479 Uniform,
481 InverseFreq,
483 SqrtInverse,
485}
486
487impl std::str::FromStr for ClassWeightStrategy {
488 type Err = String;
489
490 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
491 match s.to_lowercase().as_str() {
492 "uniform" => Ok(Self::Uniform),
493 "inverse_freq" | "inverse" => Ok(Self::InverseFreq),
494 "sqrt_inverse" | "sqrt" => Ok(Self::SqrtInverse),
495 _ => Err(format!(
496 "Unknown class weight strategy: {s}. Use: uniform, inverse_freq, sqrt_inverse"
497 )),
498 }
499 }
500}
501
502impl std::fmt::Display for ClassWeightStrategy {
503 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504 match self {
505 Self::Uniform => write!(f, "uniform"),
506 Self::InverseFreq => write!(f, "inverse_freq"),
507 Self::SqrtInverse => write!(f, "sqrt_inverse"),
508 }
509 }
510}
511
512pub fn compute_class_weights(
523 stats: &SafetyCorpusStats,
524 strategy: ClassWeightStrategy,
525 num_classes: usize,
526) -> Vec<f32> {
527 assert_eq!(
528 stats.class_counts.len(),
529 num_classes,
530 "F-TUNE-005: class_counts.len() != num_classes"
531 );
532
533 let n = stats.total as f32;
534 let k = num_classes as f32;
535
536 let raw_weights: Vec<f32> = match strategy {
537 ClassWeightStrategy::Uniform => vec![1.0; num_classes],
538 ClassWeightStrategy::InverseFreq => stats
539 .class_counts
540 .iter()
541 .map(|&count| {
542 let count = count.max(1) as f32; n / (k * count)
544 })
545 .collect(),
546 ClassWeightStrategy::SqrtInverse => stats
547 .class_counts
548 .iter()
549 .map(|&count| {
550 let count = count.max(1) as f32;
551 (n / (k * count)).sqrt()
552 })
553 .collect(),
554 };
555
556 let sum: f32 = raw_weights.iter().sum();
558 if sum < 1e-10 {
559 return vec![1.0; num_classes];
560 }
561 let scale = k / sum;
562 raw_weights.iter().map(|&w| w * scale).collect()
563}
564
565pub fn cross_entropy_loss(logits: &Tensor, target: usize, num_classes: usize) -> Tensor {
578 let data = logits.data();
579 let slice = data.as_slice().expect("contiguous logits");
580 assert_eq!(slice.len(), num_classes, "F-CLASS-001: logit shape mismatch");
581 assert!(target < num_classes, "F-CLASS-002: label out of range");
582
583 let max_val = slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
585 let log_sum_exp: f32 = slice.iter().map(|&v| (v - max_val).exp()).sum::<f32>().ln() + max_val;
586 let loss = -(slice[target] - log_sum_exp);
587
588 let loss = if loss.is_finite() { loss } else { 100.0 };
590
591 Tensor::from_vec(vec![loss], logits.requires_grad())
592}
593
594#[cfg(test)]
595#[allow(clippy::unwrap_used)]
596#[path = "classification_tests.rs"]
597mod tests;