use super::transformer_encoder::{TransformerEncoderConfig, TransformerTextEncoder};
use crate::error::{Result, TextError};
use scirs2_core::ndarray::{Array1, Array2, Axis};
#[derive(Debug, Clone)]
pub struct BertClassifierConfig {
pub encoder_config: TransformerEncoderConfig,
pub num_classes: usize,
pub dropout: f32,
pub learning_rate: f32,
pub epochs: usize,
pub batch_size: usize,
pub seed: u64,
}
impl Default for BertClassifierConfig {
fn default() -> Self {
Self {
encoder_config: TransformerEncoderConfig::default(),
num_classes: 2,
dropout: 0.1,
learning_rate: 0.01,
epochs: 10,
batch_size: 8,
seed: 123,
}
}
}
pub struct BertClassifier {
encoder: TransformerTextEncoder,
pub classifier_weights: Array2<f32>,
pub classifier_bias: Array1<f32>,
config: BertClassifierConfig,
}
fn linear(x: &Array1<f32>, w: &Array2<f32>, b: &Array1<f32>) -> Array1<f32> {
x.dot(w) + b
}
fn softmax1(logits: &Array1<f32>) -> Array1<f32> {
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp: Array1<f32> = logits.mapv(|v| (v - max).exp());
let sum = exp.sum();
if sum > 0.0 {
exp / sum
} else {
exp
}
}
fn xavier_vec(rows: usize, cols: usize, seed: &mut u64) -> Array2<f32> {
let scale = (6.0_f32 / (rows + cols) as f32).sqrt();
Array2::from_shape_fn((rows, cols), |_| {
*seed = seed
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let v = (*seed >> 33) as f32 / (u32::MAX as f32);
(v - 0.5) * 2.0 * scale
})
}
impl BertClassifier {
pub fn new(config: BertClassifierConfig) -> Result<Self> {
if config.num_classes < 2 {
return Err(TextError::InvalidInput(
"num_classes must be ≥ 2".to_string(),
));
}
let encoder = TransformerTextEncoder::new(config.encoder_config.clone())?;
let hidden = config.encoder_config.hidden_size;
let mut seed = config.seed;
let classifier_weights = xavier_vec(hidden, config.num_classes, &mut seed);
let classifier_bias = Array1::zeros(config.num_classes);
Ok(Self {
encoder,
classifier_weights,
classifier_bias,
config,
})
}
pub fn predict_proba(&self, tokens: &[usize]) -> Result<Array1<f32>> {
let embedding = self.encoder.encode_sentence(tokens)?;
let logits = linear(&embedding, &self.classifier_weights, &self.classifier_bias);
Ok(softmax1(&logits))
}
pub fn predict(&self, tokens: &[usize]) -> Result<usize> {
let proba = self.predict_proba(tokens)?;
proba
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.ok_or_else(|| TextError::InvalidInput("Empty probability vector".to_string()))
}
pub fn fine_tune(&mut self, data: &[(Vec<usize>, usize)]) -> Result<Vec<f32>> {
if data.is_empty() {
return Err(TextError::InvalidInput(
"Training data is empty".to_string(),
));
}
let n_classes = self.config.num_classes;
let lr = self.config.learning_rate;
let mut epoch_losses = Vec::with_capacity(self.config.epochs);
for _epoch in 0..self.config.epochs {
let mut total_loss = 0.0_f32;
let mut count = 0usize;
for (tokens, label) in data.iter() {
let label = *label;
if label >= n_classes {
return Err(TextError::InvalidInput(format!(
"Label {label} out of range for {n_classes} classes"
)));
}
let emb = self.encoder.encode_sentence(tokens)?;
let logits = linear(&emb, &self.classifier_weights, &self.classifier_bias);
let proba = softmax1(&logits);
let prob_correct = proba[label].max(1e-12);
total_loss -= prob_correct.ln();
count += 1;
let mut grad_logits = proba.clone();
grad_logits[label] -= 1.0;
let hidden = emb.len();
for i in 0..hidden {
for j in 0..n_classes {
self.classifier_weights[[i, j]] -= lr * emb[i] * grad_logits[j];
}
}
for j in 0..n_classes {
self.classifier_bias[j] -= lr * grad_logits[j];
}
}
epoch_losses.push(if count > 0 {
total_loss / count as f32
} else {
0.0
});
}
Ok(epoch_losses)
}
pub fn accuracy(&self, data: &[(Vec<usize>, usize)]) -> Result<f32> {
if data.is_empty() {
return Err(TextError::InvalidInput("Test data is empty".to_string()));
}
let mut correct = 0usize;
for (tokens, label) in data.iter() {
if self.predict(tokens)? == *label {
correct += 1;
}
}
Ok(correct as f32 / data.len() as f32)
}
pub fn encoder(&self) -> &TransformerTextEncoder {
&self.encoder
}
pub fn config(&self) -> &BertClassifierConfig {
&self.config
}
pub fn encode_batch(&self, sequences: &[Vec<usize>]) -> Result<Array2<f32>> {
if sequences.is_empty() {
return Err(TextError::InvalidInput("Empty batch".to_string()));
}
let hidden = self.config.encoder_config.hidden_size;
let mut out = Array2::zeros((sequences.len(), hidden));
for (i, tokens) in sequences.iter().enumerate() {
let emb = self.encoder.encode_sentence(tokens)?;
out.row_mut(i).assign(&emb);
}
Ok(out)
}
pub fn predict_batch(&self, sequences: &[Vec<usize>]) -> Result<Vec<usize>> {
sequences
.iter()
.map(|tokens| self.predict(tokens))
.collect()
}
#[allow(dead_code)]
fn mean_pool(ctx: &Array2<f32>) -> Result<Array1<f32>> {
ctx.mean_axis(Axis(0))
.ok_or_else(|| TextError::InvalidInput("Cannot mean-pool empty context".to_string()))
}
}