use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrainingPhase {
Profiling,
Preprocessing,
FeatureEngineering,
DatasetPreparation,
Analysis,
Tuning,
Training,
Complete,
}
impl TrainingPhase {
pub fn name(&self) -> &'static str {
match self {
Self::Profiling => "Profiling",
Self::Preprocessing => "Preprocessing",
Self::FeatureEngineering => "Feature Engineering",
Self::DatasetPreparation => "Dataset Preparation",
Self::Analysis => "Analysis",
Self::Tuning => "Tuning",
Self::Training => "Training",
Self::Complete => "Complete",
}
}
pub fn progress_pct(&self) -> u8 {
match self {
Self::Profiling => 10,
Self::Preprocessing => 20,
Self::FeatureEngineering => 30,
Self::DatasetPreparation => 40,
Self::Analysis => 50,
Self::Tuning => 70,
Self::Training => 90,
Self::Complete => 100,
}
}
}
#[derive(Debug, Clone)]
pub struct ProgressUpdate {
pub phase: TrainingPhase,
pub progress_pct: u8,
pub elapsed: Duration,
pub message: Option<String>,
}
pub trait ProgressCallback: Send + Sync {
fn on_progress(&self, update: &ProgressUpdate);
}
pub struct ConsoleProgress {
pub detailed: bool,
}
impl ConsoleProgress {
pub fn new() -> Self {
Self { detailed: false }
}
pub fn detailed() -> Self {
Self { detailed: true }
}
}
impl Default for ConsoleProgress {
fn default() -> Self {
Self::new()
}
}
impl ProgressCallback for ConsoleProgress {
fn on_progress(&self, update: &ProgressUpdate) {
let bar_width = 30;
let filled = (update.progress_pct as usize * bar_width) / 100;
let bar: String = "█".repeat(filled) + &"░".repeat(bar_width - filled);
if self.detailed {
if let Some(ref msg) = update.message {
println!(
"[{:3}%] {} │{}│ {:?} - {}",
update.progress_pct,
update.phase.name(),
bar,
update.elapsed,
msg
);
} else {
println!(
"[{:3}%] {} │{}│ {:?}",
update.progress_pct,
update.phase.name(),
bar,
update.elapsed
);
}
} else {
println!(
"[{:3}%] {} │{}│",
update.progress_pct,
update.phase.name(),
bar
);
}
}
}
pub struct QuietProgress;
impl ProgressCallback for QuietProgress {
fn on_progress(&self, _update: &ProgressUpdate) {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_phase_progress() {
assert_eq!(TrainingPhase::Profiling.progress_pct(), 10);
assert_eq!(TrainingPhase::Complete.progress_pct(), 100);
}
#[test]
fn test_phase_name() {
assert_eq!(TrainingPhase::Tuning.name(), "Tuning");
assert_eq!(TrainingPhase::Training.name(), "Training");
}
#[test]
fn test_console_progress() {
let progress = ConsoleProgress::new();
let update = ProgressUpdate {
phase: TrainingPhase::Profiling,
progress_pct: 10,
elapsed: Duration::from_secs(5),
message: Some("Analyzing 50 columns".to_string()),
};
progress.on_progress(&update);
}
#[test]
fn test_quiet_progress() {
let progress = QuietProgress;
let update = ProgressUpdate {
phase: TrainingPhase::Training,
progress_pct: 90,
elapsed: Duration::from_secs(120),
message: None,
};
progress.on_progress(&update);
}
}