use std::collections::HashMap;
use std::path::PathBuf;
use crate::error::Result;
pub struct CallbackContext {
pub epoch: usize,
pub n_epochs: usize,
pub batch: usize,
pub n_batches: usize,
pub lr: f64,
pub train_loss: Option<f32>,
pub valid_loss: Option<f32>,
pub metrics: HashMap<String, f32>,
pub stop_training: bool,
pub skip_batch: bool,
}
impl CallbackContext {
pub fn new(n_epochs: usize, n_batches: usize) -> Self {
Self {
epoch: 0,
n_epochs,
batch: 0,
n_batches,
lr: 0.0,
train_loss: None,
valid_loss: None,
metrics: HashMap::new(),
stop_training: false,
skip_batch: false,
}
}
pub fn progress(&self) -> f32 {
let total_batches = self.n_epochs * self.n_batches;
let current = self.epoch * self.n_batches + self.batch;
current as f32 / total_batches as f32
}
}
pub trait Callback: Send + Sync {
fn before_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn before_epoch(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn after_epoch(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn before_batch(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn after_batch(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn before_validate(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn after_validate(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
std::any::type_name::<Self>()
}
}
#[derive(Default)]
pub struct CallbackList {
callbacks: Vec<Box<dyn Callback>>,
}
impl CallbackList {
pub fn new() -> Self {
Self {
callbacks: Vec::new(),
}
}
pub fn add<C: Callback + 'static>(&mut self, callback: C) {
self.callbacks.push(Box::new(callback));
}
pub fn before_fit(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.before_fit(ctx)?;
}
Ok(())
}
pub fn after_fit(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.after_fit(ctx)?;
}
Ok(())
}
pub fn before_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.before_epoch(ctx)?;
}
Ok(())
}
pub fn after_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.after_epoch(ctx)?;
}
Ok(())
}
pub fn before_batch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.before_batch(ctx)?;
}
Ok(())
}
pub fn after_batch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.after_batch(ctx)?;
}
Ok(())
}
pub fn before_validate(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.before_validate(ctx)?;
}
Ok(())
}
pub fn after_validate(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.after_validate(ctx)?;
}
Ok(())
}
}
pub struct ProgressCallback {
#[allow(dead_code)]
show_batch: bool,
}
impl ProgressCallback {
pub fn new(show_batch: bool) -> Self {
Self { show_batch }
}
}
impl Callback for ProgressCallback {
fn before_fit(&mut self, ctx: &mut CallbackContext) -> Result<()> {
tracing::info!("Starting training for {} epochs", ctx.n_epochs);
Ok(())
}
fn after_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
let train_loss = ctx.train_loss.map(|l| format!("{:.4}", l)).unwrap_or_default();
let valid_loss = ctx.valid_loss.map(|l| format!("{:.4}", l)).unwrap_or_default();
tracing::info!(
"Epoch {}/{}: train_loss={}, valid_loss={}, lr={:.6}",
ctx.epoch + 1,
ctx.n_epochs,
train_loss,
valid_loss,
ctx.lr
);
for (name, value) in &ctx.metrics {
tracing::info!(" {}: {:.4}", name, value);
}
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
tracing::info!("Training completed");
Ok(())
}
fn name(&self) -> &str {
"ProgressCallback"
}
}
pub struct EarlyStoppingCallback {
patience: usize,
min_delta: f32,
best_loss: f32,
counter: usize,
mode: EarlyStoppingMode,
}
pub enum EarlyStoppingMode {
Min,
Max,
}
impl EarlyStoppingCallback {
pub fn new(patience: usize, min_delta: f32, mode: EarlyStoppingMode) -> Self {
let best_loss = match mode {
EarlyStoppingMode::Min => f32::INFINITY,
EarlyStoppingMode::Max => f32::NEG_INFINITY,
};
Self {
patience,
min_delta,
best_loss,
counter: 0,
mode,
}
}
}
impl Callback for EarlyStoppingCallback {
fn after_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
let current = ctx.valid_loss.unwrap_or(f32::INFINITY);
let improved = match self.mode {
EarlyStoppingMode::Min => current < self.best_loss - self.min_delta,
EarlyStoppingMode::Max => current > self.best_loss + self.min_delta,
};
if improved {
self.best_loss = current;
self.counter = 0;
} else {
self.counter += 1;
if self.counter >= self.patience {
tracing::info!(
"Early stopping triggered after {} epochs without improvement",
self.patience
);
ctx.stop_training = true;
}
}
Ok(())
}
fn name(&self) -> &str {
"EarlyStoppingCallback"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SaveModelMode {
Min,
Max,
Every,
}
pub struct SaveModelCallback {
save_dir: PathBuf,
mode: SaveModelMode,
best_value: f32,
metric_name: Option<String>,
save_best_only: bool,
filename_prefix: String,
best_epoch: usize,
}
impl SaveModelCallback {
pub fn new<P: Into<PathBuf>>(save_dir: P, mode: SaveModelMode) -> Self {
let best_value = match mode {
SaveModelMode::Min => f32::INFINITY,
SaveModelMode::Max => f32::NEG_INFINITY,
SaveModelMode::Every => 0.0,
};
Self {
save_dir: save_dir.into(),
mode,
best_value,
metric_name: None,
save_best_only: true,
filename_prefix: "checkpoint".to_string(),
best_epoch: 0,
}
}
#[must_use]
pub fn with_metric(mut self, name: &str) -> Self {
self.metric_name = Some(name.to_string());
self
}
#[must_use]
pub fn save_best_only(mut self, value: bool) -> Self {
self.save_best_only = value;
self
}
#[must_use]
pub fn with_prefix(mut self, prefix: &str) -> Self {
self.filename_prefix = prefix.to_string();
self
}
pub fn best_checkpoint_path(&self) -> PathBuf {
self.save_dir.join(format!("{}_best.json", self.filename_prefix))
}
pub fn epoch_checkpoint_path(&self, epoch: usize) -> PathBuf {
self.save_dir
.join(format!("{}_epoch_{}.json", self.filename_prefix, epoch))
}
pub fn best_epoch(&self) -> usize {
self.best_epoch
}
pub fn best_value(&self) -> f32 {
self.best_value
}
fn get_current_value(&self, ctx: &CallbackContext) -> Option<f32> {
if let Some(ref metric_name) = self.metric_name {
ctx.metrics.get(metric_name).copied()
} else {
ctx.valid_loss
}
}
fn should_save(&self, current: f32) -> bool {
match self.mode {
SaveModelMode::Min => current < self.best_value,
SaveModelMode::Max => current > self.best_value,
SaveModelMode::Every => true,
}
}
fn save_checkpoint(&self, ctx: &CallbackContext, is_best: bool) -> Result<()> {
std::fs::create_dir_all(&self.save_dir).map_err(|e| {
crate::error::TrainError::CheckpointError(format!(
"Failed to create checkpoint directory: {}",
e
))
})?;
let checkpoint = CheckpointMetadata {
epoch: ctx.epoch,
train_loss: ctx.train_loss,
valid_loss: ctx.valid_loss,
metrics: ctx.metrics.clone(),
is_best,
};
let epoch_path = self.epoch_checkpoint_path(ctx.epoch);
let json = serde_json::to_string_pretty(&checkpoint).map_err(|e| {
crate::error::TrainError::SerializationError(format!(
"Failed to serialize checkpoint: {}",
e
))
})?;
std::fs::write(&epoch_path, json).map_err(|e| {
crate::error::TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
})?;
if is_best {
let best_path = self.best_checkpoint_path();
std::fs::copy(&epoch_path, &best_path).map_err(|e| {
crate::error::TrainError::CheckpointError(format!(
"Failed to copy best checkpoint: {}",
e
))
})?;
}
Ok(())
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CheckpointMetadata {
pub epoch: usize,
pub train_loss: Option<f32>,
pub valid_loss: Option<f32>,
pub metrics: HashMap<String, f32>,
pub is_best: bool,
}
impl Callback for SaveModelCallback {
fn before_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
std::fs::create_dir_all(&self.save_dir).map_err(|e| {
crate::error::TrainError::CheckpointError(format!(
"Failed to create checkpoint directory: {}",
e
))
})?;
tracing::info!("Checkpoints will be saved to: {:?}", self.save_dir);
Ok(())
}
fn after_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
let Some(current) = self.get_current_value(ctx) else {
return Ok(());
};
let is_best = self.should_save(current);
if is_best {
self.best_value = current;
self.best_epoch = ctx.epoch;
}
if !self.save_best_only || is_best {
self.save_checkpoint(ctx, is_best)?;
if is_best {
let metric_display = self
.metric_name
.as_deref()
.unwrap_or("valid_loss");
tracing::info!(
"Epoch {}: {} improved to {:.4}, saving checkpoint",
ctx.epoch + 1,
metric_display,
current
);
}
}
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
tracing::info!(
"Best model from epoch {} with value {:.4}",
self.best_epoch + 1,
self.best_value
);
Ok(())
}
fn name(&self) -> &str {
"SaveModelCallback"
}
}
#[derive(Debug, Clone, Copy)]
pub enum GradientClipMode {
Value(f32),
Norm(f32),
}
pub struct GradientClipCallback {
mode: GradientClipMode,
clip_count: usize,
total_batches: usize,
}
impl GradientClipCallback {
pub fn new(mode: GradientClipMode) -> Self {
Self {
mode,
clip_count: 0,
total_batches: 0,
}
}
pub fn by_norm(max_norm: f32) -> Self {
Self::new(GradientClipMode::Norm(max_norm))
}
pub fn by_value(max_value: f32) -> Self {
Self::new(GradientClipMode::Value(max_value))
}
pub fn mode(&self) -> GradientClipMode {
self.mode
}
pub fn clip_value(&self) -> f32 {
match self.mode {
GradientClipMode::Value(v) => v,
GradientClipMode::Norm(n) => n,
}
}
}
impl Callback for GradientClipCallback {
fn before_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
self.clip_count = 0;
self.total_batches = 0;
Ok(())
}
fn after_batch(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
self.total_batches += 1;
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
if self.total_batches > 0 {
let clip_rate = self.clip_count as f32 / self.total_batches as f32 * 100.0;
if clip_rate > 0.0 {
tracing::info!(
"Gradient clipping was applied in {:.1}% of batches",
clip_rate
);
}
}
Ok(())
}
fn name(&self) -> &str {
"GradientClipCallback"
}
}
#[derive(Default)]
pub struct HistoryCallback {
train_losses: Vec<f32>,
valid_losses: Vec<f32>,
learning_rates: Vec<f64>,
metrics_history: HashMap<String, Vec<f32>>,
}
impl HistoryCallback {
pub fn new() -> Self {
Self::default()
}
pub fn train_losses(&self) -> &[f32] {
&self.train_losses
}
pub fn valid_losses(&self) -> &[f32] {
&self.valid_losses
}
pub fn learning_rates(&self) -> &[f64] {
&self.learning_rates
}
pub fn metric_history(&self, name: &str) -> Option<&[f32]> {
self.metrics_history.get(name).map(|v| v.as_slice())
}
pub fn metric_names(&self) -> Vec<&str> {
self.metrics_history.keys().map(|s| s.as_str()).collect()
}
pub fn best_epoch(&self) -> Option<usize> {
self.valid_losses
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
}
}
impl Callback for HistoryCallback {
fn after_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
if let Some(loss) = ctx.train_loss {
self.train_losses.push(loss);
}
if let Some(loss) = ctx.valid_loss {
self.valid_losses.push(loss);
}
self.learning_rates.push(ctx.lr);
for (name, &value) in &ctx.metrics {
self.metrics_history
.entry(name.clone())
.or_default()
.push(value);
}
Ok(())
}
fn name(&self) -> &str {
"HistoryCallback"
}
}
pub struct MixedPrecisionCallback {
initial_scale: f32,
current_scale: f32,
growth_factor: f32,
backoff_factor: f32,
growth_interval: usize,
batches_since_rescale: usize,
overflow_count: usize,
}
impl MixedPrecisionCallback {
pub fn new(initial_scale: f32) -> Self {
Self {
initial_scale,
current_scale: initial_scale,
growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
batches_since_rescale: 0,
overflow_count: 0,
}
}
pub fn current_scale(&self) -> f32 {
self.current_scale
}
pub fn report_overflow(&mut self) {
self.overflow_count += 1;
self.current_scale *= self.backoff_factor;
self.batches_since_rescale = 0;
}
}
impl Callback for MixedPrecisionCallback {
fn before_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
self.current_scale = self.initial_scale;
self.batches_since_rescale = 0;
self.overflow_count = 0;
Ok(())
}
fn after_batch(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
self.batches_since_rescale += 1;
if self.batches_since_rescale >= self.growth_interval {
self.current_scale *= self.growth_factor;
self.batches_since_rescale = 0;
}
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
if self.overflow_count > 0 {
tracing::info!(
"Mixed precision: {} overflow events, final scale = {:.0}",
self.overflow_count,
self.current_scale
);
}
Ok(())
}
fn name(&self) -> &str {
"MixedPrecisionCallback"
}
}
pub struct TerminateOnNanCallback {
nan_count: usize,
}
impl TerminateOnNanCallback {
pub fn new() -> Self {
Self { nan_count: 0 }
}
}
impl Default for TerminateOnNanCallback {
fn default() -> Self {
Self::new()
}
}
impl Callback for TerminateOnNanCallback {
fn after_batch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
if let Some(loss) = ctx.train_loss {
if loss.is_nan() || loss.is_infinite() {
self.nan_count += 1;
tracing::error!("NaN/Inf detected in training loss at batch {}", ctx.batch);
ctx.stop_training = true;
}
}
Ok(())
}
fn name(&self) -> &str {
"TerminateOnNanCallback"
}
}
pub struct ShowGraphCallback {
train_losses: Vec<f32>,
valid_losses: Vec<f32>,
metrics_history: HashMap<String, Vec<f32>>,
metric_names: Vec<String>,
width: usize,
height: usize,
show_per_epoch: bool,
}
impl Default for ShowGraphCallback {
fn default() -> Self {
Self {
train_losses: Vec::new(),
valid_losses: Vec::new(),
metrics_history: HashMap::new(),
metric_names: Vec::new(),
width: 50,
height: 10,
show_per_epoch: true,
}
}
}
impl ShowGraphCallback {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_width(mut self, width: usize) -> Self {
self.width = width.max(20);
self
}
#[must_use]
pub fn with_height(mut self, height: usize) -> Self {
self.height = height.max(5);
self
}
#[must_use]
pub fn with_metrics(mut self, names: Vec<&str>) -> Self {
self.metric_names = names.into_iter().map(|s| s.to_string()).collect();
self
}
#[must_use]
pub fn show_per_epoch(mut self, show: bool) -> Self {
self.show_per_epoch = show;
self
}
fn render_graph(&self, label: &str, values: &[f32], color_start: &str, color_end: &str) -> String {
if values.is_empty() {
return String::new();
}
let mut output = String::new();
let min_val = values.iter().cloned().fold(f32::INFINITY, f32::min);
let max_val = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let range = (max_val - min_val).max(1e-6);
output.push_str(&format!("┌─ {} ", label));
let header_remaining = self.width.saturating_sub(label.len() + 4);
output.push_str(&"─".repeat(header_remaining));
output.push_str("┐\n");
let mut grid = vec![vec![' '; self.width]; self.height];
let step = values.len() as f32 / self.width as f32;
for col in 0..self.width {
let idx = (col as f32 * step) as usize;
if idx < values.len() {
let val = values[idx];
let normalized = (val - min_val) / range;
let row = ((1.0 - normalized) * (self.height - 1) as f32) as usize;
let row = row.min(self.height - 1);
grid[row][col] = '█';
}
}
for (i, row) in grid.iter().enumerate() {
if i == 0 {
output.push_str(&format!("│{:>6.3} ", max_val));
} else if i == self.height - 1 {
output.push_str(&format!("│{:>6.3} ", min_val));
} else {
output.push_str("│ ");
}
output.push_str(color_start);
for &ch in row {
output.push(ch);
}
output.push_str(color_end);
output.push_str("│\n");
}
output.push_str("└───────");
output.push_str(&"─".repeat(self.width));
output.push_str("┘\n");
output.push_str(&format!(" Epochs: 1 → {}\n", values.len()));
output
}
fn display_graphs(&self) {
let mut output = String::new();
output.push_str("\n╔══════════════════════════════════════════════════════════════╗\n");
output.push_str("║ Training Progress ║\n");
output.push_str("╚══════════════════════════════════════════════════════════════╝\n\n");
if !self.train_losses.is_empty() {
output.push_str(&self.render_graph("Train Loss", &self.train_losses, "\x1b[33m", "\x1b[0m"));
output.push('\n');
}
if !self.valid_losses.is_empty() {
output.push_str(&self.render_graph("Valid Loss", &self.valid_losses, "\x1b[36m", "\x1b[0m"));
output.push('\n');
}
for name in &self.metric_names {
if let Some(values) = self.metrics_history.get(name) {
if !values.is_empty() {
output.push_str(&self.render_graph(name, values, "\x1b[32m", "\x1b[0m"));
output.push('\n');
}
}
}
output.push_str("Current Values:\n");
if let Some(train) = self.train_losses.last() {
output.push_str(&format!(" Train Loss: {:.4}\n", train));
}
if let Some(valid) = self.valid_losses.last() {
output.push_str(&format!(" Valid Loss: {:.4}\n", valid));
}
for name in &self.metric_names {
if let Some(values) = self.metrics_history.get(name) {
if let Some(val) = values.last() {
output.push_str(&format!(" {}: {:.4}\n", name, val));
}
}
}
print!("{}", output);
}
}
impl Callback for ShowGraphCallback {
fn before_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
self.train_losses.clear();
self.valid_losses.clear();
self.metrics_history.clear();
tracing::info!("ShowGraph enabled - will display training curves");
Ok(())
}
fn after_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
if let Some(loss) = ctx.train_loss {
self.train_losses.push(loss);
}
if let Some(loss) = ctx.valid_loss {
self.valid_losses.push(loss);
}
for name in &self.metric_names {
if let Some(&value) = ctx.metrics.get(name) {
self.metrics_history
.entry(name.clone())
.or_default()
.push(value);
}
}
if self.show_per_epoch {
self.display_graphs();
}
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
if !self.show_per_epoch && (!self.train_losses.is_empty() || !self.valid_losses.is_empty()) {
self.display_graphs();
}
Ok(())
}
fn name(&self) -> &str {
"ShowGraphCallback"
}
}
#[derive(Debug, Clone)]
pub enum TransformSchedule {
Constant(f32),
LinearWarmup {
max_p: f32,
warmup_epochs: usize,
},
LinearCooldown {
max_p: f32,
cooldown_epochs: usize,
},
CosineAnnealing {
min_p: f32,
max_p: f32,
},
Step {
schedule: Vec<(usize, f32)>,
},
DelayedStart {
p: f32,
start_epoch: usize,
},
}
pub struct TransformSchedulerCallback {
transform_name: String,
schedule: TransformSchedule,
current_p: f32,
#[allow(dead_code)]
is_active: bool,
}
impl TransformSchedulerCallback {
pub fn new(transform_name: &str, schedule: TransformSchedule) -> Self {
let initial_p = match &schedule {
TransformSchedule::Constant(p) => *p,
TransformSchedule::LinearWarmup { .. } => 0.0,
TransformSchedule::LinearCooldown { max_p, .. } => *max_p,
TransformSchedule::CosineAnnealing { min_p, max_p } => (*min_p + *max_p) / 2.0,
TransformSchedule::Step { schedule } => schedule.first().map(|(_, p)| *p).unwrap_or(0.5),
TransformSchedule::DelayedStart { .. } => 0.0,
};
Self {
transform_name: transform_name.to_string(),
schedule,
current_p: initial_p,
is_active: true,
}
}
pub fn current_probability(&self) -> f32 {
self.current_p
}
pub fn transform_name(&self) -> &str {
&self.transform_name
}
fn compute_probability(&self, epoch: usize, n_epochs: usize) -> f32 {
match &self.schedule {
TransformSchedule::Constant(p) => *p,
TransformSchedule::LinearWarmup { max_p, warmup_epochs } => {
if epoch >= *warmup_epochs {
*max_p
} else {
*max_p * (epoch as f32 / *warmup_epochs as f32)
}
}
TransformSchedule::LinearCooldown { max_p, cooldown_epochs } => {
let start_cooldown = n_epochs.saturating_sub(*cooldown_epochs);
if epoch < start_cooldown {
*max_p
} else {
let progress = (epoch - start_cooldown) as f32 / *cooldown_epochs as f32;
*max_p * (1.0 - progress)
}
}
TransformSchedule::CosineAnnealing { min_p, max_p } => {
if n_epochs <= 1 {
(*min_p + *max_p) / 2.0
} else {
let progress = epoch as f32 / (n_epochs - 1) as f32;
let cosine = (1.0 + (progress * std::f32::consts::PI).cos()) / 2.0;
*min_p + (*max_p - *min_p) * cosine
}
}
TransformSchedule::Step { schedule } => {
let mut current_p = schedule.first().map(|(_, p)| *p).unwrap_or(0.5);
for &(step_epoch, p) in schedule {
if epoch >= step_epoch {
current_p = p;
} else {
break;
}
}
current_p
}
TransformSchedule::DelayedStart { p, start_epoch } => {
if epoch >= *start_epoch {
*p
} else {
0.0
}
}
}
}
}
impl Callback for TransformSchedulerCallback {
fn before_fit(&mut self, ctx: &mut CallbackContext) -> Result<()> {
self.current_p = self.compute_probability(0, ctx.n_epochs);
tracing::info!(
"TransformScheduler: {} starting with p={:.3}",
self.transform_name,
self.current_p
);
Ok(())
}
fn before_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
self.current_p = self.compute_probability(ctx.epoch, ctx.n_epochs);
tracing::debug!(
"TransformScheduler: {} epoch {} p={:.3}",
self.transform_name,
ctx.epoch + 1,
self.current_p
);
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
tracing::info!(
"TransformScheduler: {} finished with final p={:.3}",
self.transform_name,
self.current_p
);
Ok(())
}
fn name(&self) -> &str {
"TransformSchedulerCallback"
}
}
#[derive(Debug, Clone)]
pub enum WeightStrategy {
Uniform,
InverseFrequency {
class_counts: Option<Vec<usize>>,
},
EffectiveNumber {
beta: f64,
class_counts: Option<Vec<usize>>,
},
Custom(Vec<f32>),
Curriculum {
sample_losses: Vec<f32>,
easy_first: bool,
},
}
pub struct WeightedPerSampleLossCallback {
strategy: WeightStrategy,
weights: Vec<f32>,
total_samples: usize,
num_classes: Option<usize>,
}
impl WeightedPerSampleLossCallback {
pub fn new(strategy: WeightStrategy) -> Self {
Self {
strategy,
weights: Vec::new(),
total_samples: 0,
num_classes: None,
}
}
pub fn inverse_frequency(class_counts: Vec<usize>) -> Self {
Self::new(WeightStrategy::InverseFrequency {
class_counts: Some(class_counts),
})
}
pub fn effective_number(beta: f64, class_counts: Vec<usize>) -> Self {
Self::new(WeightStrategy::EffectiveNumber {
beta,
class_counts: Some(class_counts),
})
}
pub fn custom(weights: Vec<f32>) -> Self {
Self::new(WeightStrategy::Custom(weights))
}
#[must_use]
pub fn with_num_classes(mut self, num_classes: usize) -> Self {
self.num_classes = Some(num_classes);
self
}
pub fn weights(&self) -> &[f32] {
&self.weights
}
pub fn get_weight(&self, sample_idx: usize) -> f32 {
self.weights.get(sample_idx).copied().unwrap_or(1.0)
}
pub fn get_batch_weights(&self, indices: &[usize]) -> Vec<f32> {
indices.iter().map(|&i| self.get_weight(i)).collect()
}
fn compute_inverse_frequency_weights(class_counts: &[usize]) -> Vec<f32> {
let total: usize = class_counts.iter().sum();
let n_classes = class_counts.len();
class_counts
.iter()
.map(|&count| {
if count > 0 {
total as f32 / (n_classes as f32 * count as f32)
} else {
1.0
}
})
.collect()
}
fn compute_effective_number_weights(beta: f64, class_counts: &[usize]) -> Vec<f32> {
let effective_nums: Vec<f64> = class_counts
.iter()
.map(|&n| {
if n == 0 {
1.0
} else {
(1.0 - beta.powi(n as i32)) / (1.0 - beta)
}
})
.collect();
let total: f64 = effective_nums.iter().sum();
let n_classes = class_counts.len();
effective_nums
.iter()
.map(|&eff| (total / (n_classes as f64 * eff)) as f32)
.collect()
}
fn initialize_weights(&mut self, n_samples: usize) {
self.total_samples = n_samples;
self.weights = match &self.strategy {
WeightStrategy::Uniform => vec![1.0; n_samples],
WeightStrategy::InverseFrequency { class_counts } => {
if let Some(counts) = class_counts {
Self::compute_inverse_frequency_weights(counts)
} else {
vec![1.0; n_samples]
}
}
WeightStrategy::EffectiveNumber { beta, class_counts } => {
if let Some(counts) = class_counts {
Self::compute_effective_number_weights(*beta, counts)
} else {
vec![1.0; n_samples]
}
}
WeightStrategy::Custom(w) => w.clone(),
WeightStrategy::Curriculum { sample_losses, easy_first } => {
let mut indexed: Vec<(usize, f32)> = sample_losses
.iter()
.enumerate()
.map(|(i, &l)| (i, l))
.collect();
if *easy_first {
indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
} else {
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
}
let mut weights = vec![0.0; sample_losses.len()];
for (rank, (orig_idx, _)) in indexed.iter().enumerate() {
let w = 1.0 - 0.9 * (rank as f32 / sample_losses.len() as f32);
weights[*orig_idx] = w;
}
weights
}
};
if !self.weights.is_empty() {
let mean: f32 = self.weights.iter().sum::<f32>() / self.weights.len() as f32;
if mean > 0.0 {
for w in &mut self.weights {
*w /= mean;
}
}
}
}
pub fn update_curriculum_weights(&mut self, sample_losses: Vec<f32>, easy_first: bool) {
self.strategy = WeightStrategy::Curriculum {
sample_losses: sample_losses.clone(),
easy_first,
};
self.initialize_weights(sample_losses.len());
}
}
impl Callback for WeightedPerSampleLossCallback {
fn before_fit(&mut self, ctx: &mut CallbackContext) -> Result<()> {
let approx_samples = ctx.n_batches * 32; if self.weights.is_empty() {
self.initialize_weights(approx_samples);
}
tracing::info!(
"WeightedPerSampleLoss: initialized with {} weights",
self.weights.len()
);
Ok(())
}
fn name(&self) -> &str {
"WeightedPerSampleLossCallback"
}
}
#[derive(Debug, Clone)]
pub enum SubsampleStrategy {
Random {
fraction: f32,
},
HardExamples {
fraction: f32,
},
Curriculum {
start_fraction: f32,
end_fraction: f32,
},
Stratified {
fraction: f32,
},
}
impl Default for SubsampleStrategy {
fn default() -> Self {
Self::Random { fraction: 0.5 }
}
}
pub struct BatchSubsamplerCallback {
strategy: SubsampleStrategy,
current_fraction: f32,
samples_kept: usize,
samples_total: usize,
batch_losses: Vec<f32>,
seed: u64,
}
impl BatchSubsamplerCallback {
pub fn new(strategy: SubsampleStrategy) -> Self {
let initial_fraction = match &strategy {
SubsampleStrategy::Random { fraction } => *fraction,
SubsampleStrategy::HardExamples { fraction } => *fraction,
SubsampleStrategy::Curriculum { start_fraction, .. } => *start_fraction,
SubsampleStrategy::Stratified { fraction } => *fraction,
};
Self {
strategy,
current_fraction: initial_fraction,
samples_kept: 0,
samples_total: 0,
batch_losses: Vec::new(),
seed: 42,
}
}
pub fn random(fraction: f32) -> Self {
Self::new(SubsampleStrategy::Random { fraction })
}
pub fn hard_examples(fraction: f32) -> Self {
Self::new(SubsampleStrategy::HardExamples { fraction })
}
pub fn curriculum(start_fraction: f32, end_fraction: f32) -> Self {
Self::new(SubsampleStrategy::Curriculum {
start_fraction,
end_fraction,
})
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn current_fraction(&self) -> f32 {
self.current_fraction
}
pub fn get_subsample_indices(&mut self, batch_size: usize, losses: Option<&[f32]>) -> Vec<usize> {
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
let keep_count = ((batch_size as f32 * self.current_fraction).round() as usize).max(1);
match &self.strategy {
SubsampleStrategy::Random { .. } => {
let mut rng = ChaCha8Rng::seed_from_u64(self.seed);
self.seed = self.seed.wrapping_add(1);
let mut indices: Vec<usize> = (0..batch_size).collect();
indices.shuffle(&mut rng);
indices.truncate(keep_count);
indices.sort_unstable();
indices
}
SubsampleStrategy::HardExamples { .. } => {
if let Some(loss_vals) = losses {
let mut indexed: Vec<(usize, f32)> = loss_vals
.iter()
.enumerate()
.map(|(i, &l)| (i, l))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut indices: Vec<usize> = indexed.iter().take(keep_count).map(|(i, _)| *i).collect();
indices.sort_unstable();
indices
} else {
(0..keep_count).collect()
}
}
SubsampleStrategy::Curriculum { .. } => {
if let Some(loss_vals) = losses {
let mut indexed: Vec<(usize, f32)> = loss_vals
.iter()
.enumerate()
.map(|(i, &l)| (i, l))
.collect();
indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut indices: Vec<usize> = indexed.iter().take(keep_count).map(|(i, _)| *i).collect();
indices.sort_unstable();
indices
} else {
(0..keep_count).collect()
}
}
SubsampleStrategy::Stratified { .. } => {
let mut rng = ChaCha8Rng::seed_from_u64(self.seed);
self.seed = self.seed.wrapping_add(1);
let mut indices: Vec<usize> = (0..batch_size).collect();
indices.shuffle(&mut rng);
indices.truncate(keep_count);
indices.sort_unstable();
indices
}
}
}
pub fn record_batch_losses(&mut self, losses: Vec<f32>) {
self.batch_losses = losses;
}
fn update_curriculum(&mut self, progress: f32) {
if let SubsampleStrategy::Curriculum {
start_fraction,
end_fraction,
} = &self.strategy
{
self.current_fraction = start_fraction + (end_fraction - start_fraction) * progress;
}
}
}
impl Callback for BatchSubsamplerCallback {
fn before_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
self.samples_kept = 0;
self.samples_total = 0;
self.current_fraction = match &self.strategy {
SubsampleStrategy::Random { fraction } => *fraction,
SubsampleStrategy::HardExamples { fraction } => *fraction,
SubsampleStrategy::Curriculum { start_fraction, .. } => *start_fraction,
SubsampleStrategy::Stratified { fraction } => *fraction,
};
tracing::info!(
"BatchSubsampler: starting with {:.1}% subsampling",
self.current_fraction * 100.0
);
Ok(())
}
fn before_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
let progress = ctx.epoch as f32 / ctx.n_epochs.max(1) as f32;
self.update_curriculum(progress);
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
if self.samples_total > 0 {
let keep_rate = self.samples_kept as f32 / self.samples_total as f32 * 100.0;
tracing::info!(
"BatchSubsampler: kept {:.1}% of samples ({}/{})",
keep_rate,
self.samples_kept,
self.samples_total
);
}
Ok(())
}
fn name(&self) -> &str {
"BatchSubsamplerCallback"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PredictionTrackingMode {
All,
ChangesOnly,
Subset(usize),
}
impl Default for PredictionTrackingMode {
fn default() -> Self {
Self::All
}
}
#[derive(Debug, Clone, Default)]
pub struct SamplePredictionHistory {
pub sample_idx: usize,
pub true_label: i64,
pub predictions: Vec<i64>,
pub confidences: Vec<f32>,
pub losses: Vec<f32>,
}
impl SamplePredictionHistory {
pub fn new(sample_idx: usize, true_label: i64) -> Self {
Self {
sample_idx,
true_label,
predictions: Vec::new(),
confidences: Vec::new(),
losses: Vec::new(),
}
}
pub fn add_epoch(&mut self, prediction: i64, confidence: f32, loss: f32) {
self.predictions.push(prediction);
self.confidences.push(confidence);
self.losses.push(loss);
}
pub fn ever_correct(&self) -> bool {
self.predictions.iter().any(|&p| p == self.true_label)
}
pub fn always_correct(&self) -> bool {
!self.predictions.is_empty() && self.predictions.iter().all(|&p| p == self.true_label)
}
pub fn flip_count(&self) -> usize {
if self.predictions.len() < 2 {
return 0;
}
self.predictions
.windows(2)
.filter(|w| w[0] != w[1])
.count()
}
pub fn first_correct_epoch(&self) -> Option<usize> {
self.predictions
.iter()
.enumerate()
.find(|(_, &p)| p == self.true_label)
.map(|(i, _)| i)
}
pub fn last_incorrect_epoch(&self) -> Option<usize> {
self.predictions
.iter()
.enumerate()
.rev()
.find(|(_, &p)| p != self.true_label)
.map(|(i, _)| i)
}
pub fn has_regression(&self) -> bool {
let mut was_correct = false;
for &p in &self.predictions {
if p == self.true_label {
was_correct = true;
} else if was_correct {
return true;
}
}
false
}
pub fn stability(&self) -> f32 {
if self.predictions.len() < 2 {
return 1.0;
}
let flips = self.flip_count();
1.0 - (flips as f32 / (self.predictions.len() - 1) as f32)
}
}
pub struct PredictionDynamicsCallback {
#[allow(dead_code)]
mode: PredictionTrackingMode,
histories: HashMap<usize, SamplePredictionHistory>,
true_labels: HashMap<usize, i64>,
current_predictions: Vec<(usize, i64, f32, f32)>, epochs_tracked: usize,
log_per_epoch: bool,
}
impl PredictionDynamicsCallback {
pub fn new(mode: PredictionTrackingMode) -> Self {
Self {
mode,
histories: HashMap::new(),
true_labels: HashMap::new(),
current_predictions: Vec::new(),
epochs_tracked: 0,
log_per_epoch: false,
}
}
#[must_use]
pub fn with_logging(mut self, enabled: bool) -> Self {
self.log_per_epoch = enabled;
self
}
pub fn set_true_labels(&mut self, labels: HashMap<usize, i64>) {
self.true_labels = labels;
}
pub fn record_predictions(
&mut self,
sample_indices: &[usize],
predictions: &[i64],
confidences: &[f32],
losses: &[f32],
) {
for i in 0..sample_indices.len() {
let idx = sample_indices[i];
let pred = predictions.get(i).copied().unwrap_or(0);
let conf = confidences.get(i).copied().unwrap_or(0.0);
let loss = losses.get(i).copied().unwrap_or(0.0);
self.current_predictions.push((idx, pred, conf, loss));
}
}
fn commit_epoch(&mut self) {
for (idx, pred, conf, loss) in self.current_predictions.drain(..) {
let true_label = self.true_labels.get(&idx).copied().unwrap_or(-1);
let history = self.histories.entry(idx).or_insert_with(|| {
SamplePredictionHistory::new(idx, true_label)
});
if history.true_label == -1 && true_label != -1 {
history.true_label = true_label;
}
history.add_epoch(pred, conf, loss);
}
self.epochs_tracked += 1;
}
pub fn get_never_correct_samples(&self) -> Vec<usize> {
self.histories
.values()
.filter(|h| !h.predictions.is_empty() && !h.ever_correct())
.map(|h| h.sample_idx)
.collect()
}
pub fn get_unstable_samples(&self, stability_threshold: f32) -> Vec<(usize, f32)> {
self.histories
.values()
.filter(|h| h.stability() < stability_threshold)
.map(|h| (h.sample_idx, h.stability()))
.collect()
}
pub fn get_regression_samples(&self) -> Vec<usize> {
self.histories
.values()
.filter(|h| h.has_regression())
.map(|h| h.sample_idx)
.collect()
}
pub fn get_always_correct_samples(&self) -> Vec<usize> {
self.histories
.values()
.filter(|h| h.always_correct())
.map(|h| h.sample_idx)
.collect()
}
pub fn get_sample_history(&self, sample_idx: usize) -> Option<&SamplePredictionHistory> {
self.histories.get(&sample_idx)
}
pub fn get_average_confidences(&self) -> HashMap<usize, f32> {
self.histories
.iter()
.map(|(idx, h)| {
let avg = if h.confidences.is_empty() {
0.0
} else {
h.confidences.iter().sum::<f32>() / h.confidences.len() as f32
};
(*idx, avg)
})
.collect()
}
pub fn get_difficulty_scores(&self) -> HashMap<usize, f32> {
let max_epochs = self.epochs_tracked as f32;
self.histories
.iter()
.map(|(idx, h)| {
let first_correct_score = h
.first_correct_epoch()
.map(|e| e as f32 / max_epochs.max(1.0))
.unwrap_or(1.0);
let stability_score = 1.0 - h.stability();
let final_correct = h
.predictions
.last()
.map(|&p| if p == h.true_label { 0.0 } else { 0.5 })
.unwrap_or(0.5);
let score = (first_correct_score + stability_score + final_correct) / 3.0;
(*idx, score)
})
.collect()
}
pub fn summary(&self) -> PredictionDynamicsSummary {
let total = self.histories.len();
let never_correct = self.get_never_correct_samples().len();
let always_correct = self.get_always_correct_samples().len();
let regression = self.get_regression_samples().len();
let unstable = self.get_unstable_samples(0.7).len();
let avg_stability = if total > 0 {
self.histories.values().map(|h| h.stability()).sum::<f32>() / total as f32
} else {
0.0
};
PredictionDynamicsSummary {
total_samples: total,
never_correct,
always_correct,
regression_count: regression,
unstable_count: unstable,
average_stability: avg_stability,
epochs_tracked: self.epochs_tracked,
}
}
}
#[derive(Debug, Clone)]
pub struct PredictionDynamicsSummary {
pub total_samples: usize,
pub never_correct: usize,
pub always_correct: usize,
pub regression_count: usize,
pub unstable_count: usize,
pub average_stability: f32,
pub epochs_tracked: usize,
}
impl std::fmt::Display for PredictionDynamicsSummary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "=== Prediction Dynamics Summary ===")?;
writeln!(f, "Epochs tracked: {}", self.epochs_tracked)?;
writeln!(f, "Total samples: {}", self.total_samples)?;
writeln!(f, "Always correct: {} ({:.1}%)",
self.always_correct,
100.0 * self.always_correct as f32 / self.total_samples.max(1) as f32
)?;
writeln!(f, "Never correct: {} ({:.1}%)",
self.never_correct,
100.0 * self.never_correct as f32 / self.total_samples.max(1) as f32
)?;
writeln!(f, "Regressions: {} ({:.1}%)",
self.regression_count,
100.0 * self.regression_count as f32 / self.total_samples.max(1) as f32
)?;
writeln!(f, "Unstable: {} ({:.1}%)",
self.unstable_count,
100.0 * self.unstable_count as f32 / self.total_samples.max(1) as f32
)?;
writeln!(f, "Average stability: {:.3}", self.average_stability)?;
Ok(())
}
}
impl Callback for PredictionDynamicsCallback {
fn before_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
self.histories.clear();
self.current_predictions.clear();
self.epochs_tracked = 0;
tracing::info!("PredictionDynamics: tracking enabled");
Ok(())
}
fn after_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
self.commit_epoch();
if self.log_per_epoch && !self.histories.is_empty() {
let summary = self.summary();
tracing::info!(
"Epoch {}: {} samples tracked, {:.1}% always correct, {:.1}% never correct",
ctx.epoch + 1,
summary.total_samples,
100.0 * summary.always_correct as f32 / summary.total_samples.max(1) as f32,
100.0 * summary.never_correct as f32 / summary.total_samples.max(1) as f32
);
}
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
let summary = self.summary();
tracing::info!("\n{}", summary);
Ok(())
}
fn name(&self) -> &str {
"PredictionDynamicsCallback"
}
}
#[derive(Debug, Clone)]
pub enum PseudoLabelFilter {
All,
ConfidenceThreshold(f32),
TopK {
fraction: f32,
},
ClassBalancedTopK {
samples_per_class: usize,
},
}
impl Default for PseudoLabelFilter {
fn default() -> Self {
Self::ConfidenceThreshold(0.9)
}
}
#[derive(Debug, Clone)]
pub enum NoiseInjection {
None,
InputDropout(f32),
StochasticDepth(f32),
RandAugment {
n_ops: usize,
magnitude: f32,
},
GaussianNoise(f32),
Combined(Vec<NoiseInjection>),
}
impl Default for NoiseInjection {
fn default() -> Self {
Self::InputDropout(0.5)
}
}
#[derive(Debug, Clone)]
pub struct PseudoLabel {
pub sample_idx: usize,
pub label: i64,
pub confidence: f32,
pub teacher_iteration: usize,
}
pub struct NoisyStudentCallback {
filter: PseudoLabelFilter,
noise: NoiseInjection,
n_iterations: usize,
current_iteration: usize,
pseudo_labels: Vec<PseudoLabel>,
n_labeled: usize,
n_unlabeled: usize,
is_student_phase: bool,
temperature: f32,
min_confidence: f32,
stats: NoisyStudentStats,
}
#[derive(Debug, Clone, Default)]
pub struct NoisyStudentStats {
pub labels_per_iteration: Vec<usize>,
pub avg_confidence_per_iteration: Vec<f32>,
pub validation_accuracy: Vec<f32>,
pub class_distribution: HashMap<i64, usize>,
}
impl NoisyStudentCallback {
pub fn new() -> Self {
Self {
filter: PseudoLabelFilter::default(),
noise: NoiseInjection::default(),
n_iterations: 3,
current_iteration: 0,
pseudo_labels: Vec::new(),
n_labeled: 0,
n_unlabeled: 0,
is_student_phase: false,
temperature: 1.0,
min_confidence: 0.5,
stats: NoisyStudentStats::default(),
}
}
#[must_use]
pub fn with_filter(mut self, filter: PseudoLabelFilter) -> Self {
self.filter = filter;
self
}
#[must_use]
pub fn with_noise(mut self, noise: NoiseInjection) -> Self {
self.noise = noise;
self
}
#[must_use]
pub fn with_iterations(mut self, n: usize) -> Self {
self.n_iterations = n;
self
}
#[must_use]
pub fn with_temperature(mut self, t: f32) -> Self {
self.temperature = t.max(0.1);
self
}
#[must_use]
pub fn with_min_confidence(mut self, conf: f32) -> Self {
self.min_confidence = conf.clamp(0.0, 1.0);
self
}
pub fn set_data_counts(&mut self, n_labeled: usize, n_unlabeled: usize) {
self.n_labeled = n_labeled;
self.n_unlabeled = n_unlabeled;
}
pub fn generate_pseudo_labels(
&mut self,
sample_indices: &[usize],
predictions: &[i64],
confidences: &[f32],
) {
let mut candidates: Vec<PseudoLabel> = sample_indices
.iter()
.zip(predictions.iter())
.zip(confidences.iter())
.map(|((&idx, &label), &conf)| PseudoLabel {
sample_idx: idx,
label,
confidence: conf,
teacher_iteration: self.current_iteration,
})
.filter(|pl| pl.confidence >= self.min_confidence)
.collect();
self.pseudo_labels = match &self.filter {
PseudoLabelFilter::All => candidates,
PseudoLabelFilter::ConfidenceThreshold(thresh) => {
candidates.retain(|pl| pl.confidence >= *thresh);
candidates
}
PseudoLabelFilter::TopK { fraction } => {
candidates.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
let keep = ((candidates.len() as f32 * fraction).round() as usize).max(1);
candidates.truncate(keep);
candidates
}
PseudoLabelFilter::ClassBalancedTopK { samples_per_class } => {
let mut by_class: HashMap<i64, Vec<PseudoLabel>> = HashMap::new();
for pl in candidates {
by_class.entry(pl.label).or_default().push(pl);
}
let mut result = Vec::new();
for (_, mut class_samples) in by_class {
class_samples.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
class_samples.truncate(*samples_per_class);
result.extend(class_samples);
}
result
}
};
self.stats.labels_per_iteration.push(self.pseudo_labels.len());
let avg_conf = if self.pseudo_labels.is_empty() {
0.0
} else {
self.pseudo_labels.iter().map(|pl| pl.confidence).sum::<f32>()
/ self.pseudo_labels.len() as f32
};
self.stats.avg_confidence_per_iteration.push(avg_conf);
self.stats.class_distribution.clear();
for pl in &self.pseudo_labels {
*self.stats.class_distribution.entry(pl.label).or_default() += 1;
}
tracing::info!(
"NoisyStudent: generated {} pseudo-labels (avg conf: {:.3})",
self.pseudo_labels.len(),
avg_conf
);
}
pub fn get_pseudo_labels(&self) -> &[PseudoLabel] {
&self.pseudo_labels
}
pub fn get_sample_label(&self, sample_idx: usize) -> Option<&PseudoLabel> {
self.pseudo_labels.iter().find(|pl| pl.sample_idx == sample_idx)
}
pub fn should_apply_noise(&self) -> bool {
self.is_student_phase
}
pub fn noise_config(&self) -> &NoiseInjection {
&self.noise
}
pub fn next_iteration(&mut self) {
self.current_iteration += 1;
self.is_student_phase = true;
tracing::info!(
"NoisyStudent: starting iteration {} of {}",
self.current_iteration + 1,
self.n_iterations
);
}
pub fn is_complete(&self) -> bool {
self.current_iteration >= self.n_iterations
}
pub fn current_iteration(&self) -> usize {
self.current_iteration
}
pub fn stats(&self) -> &NoisyStudentStats {
&self.stats
}
pub fn record_validation_accuracy(&mut self, accuracy: f32) {
self.stats.validation_accuracy.push(accuracy);
}
pub fn get_sample_weight(&self, sample_idx: usize, is_labeled: bool) -> f32 {
if is_labeled {
1.0
} else if let Some(pl) = self.get_sample_label(sample_idx) {
pl.confidence
} else {
0.0 }
}
pub fn get_combined_indices(&self, labeled_indices: &[usize]) -> Vec<(usize, bool)> {
let mut combined: Vec<(usize, bool)> = labeled_indices
.iter()
.map(|&idx| (idx, true))
.collect();
for pl in &self.pseudo_labels {
combined.push((pl.sample_idx, false));
}
combined
}
pub fn summary(&self) -> String {
let mut output = String::new();
output.push_str("=== Noisy Student Summary ===\n");
output.push_str(&format!("Iteration: {} / {}\n", self.current_iteration + 1, self.n_iterations));
output.push_str(&format!("Labeled samples: {}\n", self.n_labeled));
output.push_str(&format!("Pseudo-labeled: {}\n", self.pseudo_labels.len()));
output.push_str(&format!("Student phase: {}\n", self.is_student_phase));
if !self.stats.class_distribution.is_empty() {
output.push_str("\nClass distribution:\n");
let mut sorted: Vec<_> = self.stats.class_distribution.iter().collect();
sorted.sort_by_key(|(k, _)| *k);
for (class, count) in sorted {
output.push_str(&format!(" Class {}: {}\n", class, count));
}
}
output
}
}
impl Default for NoisyStudentCallback {
fn default() -> Self {
Self::new()
}
}
impl Callback for NoisyStudentCallback {
fn before_fit(&mut self, ctx: &mut CallbackContext) -> Result<()> {
tracing::info!(
"NoisyStudent: starting iteration {} with {} epochs",
self.current_iteration + 1,
ctx.n_epochs
);
if self.is_student_phase {
tracing::info!(
"NoisyStudent: student phase with {} pseudo-labels",
self.pseudo_labels.len()
);
}
Ok(())
}
fn after_fit(&mut self, ctx: &mut CallbackContext) -> Result<()> {
if let Some(valid_loss) = ctx.valid_loss {
tracing::info!(
"NoisyStudent iteration {} complete: valid_loss = {:.4}",
self.current_iteration + 1,
valid_loss
);
}
if let Some(&acc) = ctx.metrics.get("accuracy") {
self.record_validation_accuracy(acc);
}
Ok(())
}
fn before_batch(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"NoisyStudentCallback"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_callback_context() {
let ctx = CallbackContext::new(10, 100);
assert_eq!(ctx.epoch, 0);
assert_eq!(ctx.n_epochs, 10);
assert_eq!(ctx.progress(), 0.0);
}
#[test]
fn test_callback_list() {
let mut list = CallbackList::new();
list.add(ProgressCallback::new(false));
}
#[test]
fn test_gradient_clip_callback() {
let clip = GradientClipCallback::by_norm(1.0);
assert_eq!(clip.clip_value(), 1.0);
let clip = GradientClipCallback::by_value(0.5);
assert_eq!(clip.clip_value(), 0.5);
}
#[test]
fn test_history_callback() {
let mut history = HistoryCallback::new();
let mut ctx = CallbackContext::new(10, 100);
ctx.train_loss = Some(0.5);
ctx.valid_loss = Some(0.4);
ctx.lr = 0.001;
history.after_epoch(&mut ctx).unwrap();
assert_eq!(history.train_losses(), &[0.5]);
assert_eq!(history.valid_losses(), &[0.4]);
assert_eq!(history.learning_rates(), &[0.001]);
}
#[test]
fn test_show_graph_callback_config() {
let callback = ShowGraphCallback::new()
.with_width(60)
.with_height(15)
.with_metrics(vec!["accuracy", "f1"]);
assert_eq!(callback.width, 60);
assert_eq!(callback.height, 15);
assert_eq!(callback.metric_names, vec!["accuracy", "f1"]);
}
#[test]
fn test_show_graph_callback_render() {
let mut callback = ShowGraphCallback::new()
.with_width(30)
.with_height(5)
.show_per_epoch(false);
let mut ctx = CallbackContext::new(10, 100);
for i in 0..5 {
ctx.train_loss = Some(1.0 - i as f32 * 0.1);
ctx.valid_loss = Some(0.9 - i as f32 * 0.08);
callback.after_epoch(&mut ctx).unwrap();
}
assert_eq!(callback.train_losses.len(), 5);
assert_eq!(callback.valid_losses.len(), 5);
}
#[test]
fn test_transform_scheduler_constant() {
let callback = TransformSchedulerCallback::new(
"TestTransform",
TransformSchedule::Constant(0.8),
);
assert_eq!(callback.current_probability(), 0.8);
}
#[test]
fn test_transform_scheduler_linear_warmup() {
let mut callback = TransformSchedulerCallback::new(
"TestTransform",
TransformSchedule::LinearWarmup {
max_p: 1.0,
warmup_epochs: 5,
},
);
let mut ctx = CallbackContext::new(10, 100);
callback.before_fit(&mut ctx).unwrap();
assert_eq!(callback.current_probability(), 0.0);
ctx.epoch = 2;
callback.before_epoch(&mut ctx).unwrap();
assert!((callback.current_probability() - 0.4).abs() < 0.01);
ctx.epoch = 5;
callback.before_epoch(&mut ctx).unwrap();
assert_eq!(callback.current_probability(), 1.0);
}
#[test]
fn test_transform_scheduler_delayed_start() {
let mut callback = TransformSchedulerCallback::new(
"TestTransform",
TransformSchedule::DelayedStart {
p: 0.7,
start_epoch: 5,
},
);
let mut ctx = CallbackContext::new(10, 100);
ctx.epoch = 3;
callback.before_epoch(&mut ctx).unwrap();
assert_eq!(callback.current_probability(), 0.0);
ctx.epoch = 5;
callback.before_epoch(&mut ctx).unwrap();
assert_eq!(callback.current_probability(), 0.7);
}
}