use crate::Float;
#[derive(Debug, Clone)]
pub struct TrainingConfig {
pub max_graph_nodes: usize,
pub auto_clear: bool,
pub warn_threshold: usize,
pub print_warnings: bool,
pub checkpoint_interval: Option<usize>,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
max_graph_nodes: 50_000,
auto_clear: true,
warn_threshold: 25_000,
print_warnings: true,
checkpoint_interval: None,
}
}
}
impl TrainingConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_graph_nodes(mut self, max: usize) -> Self {
self.max_graph_nodes = max;
self
}
pub fn with_auto_clear(mut self, auto_clear: bool) -> Self {
self.auto_clear = auto_clear;
self
}
pub fn with_warn_threshold(mut self, threshold: usize) -> Self {
self.warn_threshold = threshold;
self
}
pub fn with_print_warnings(mut self, print: bool) -> Self {
self.print_warnings = print;
self
}
pub fn with_checkpoint_interval(mut self, interval: usize) -> Self {
self.checkpoint_interval = Some(interval);
self
}
}
#[derive(Debug)]
pub struct TrainingLoop<F: Float> {
config: TrainingConfig,
step_count: usize,
total_nodes_created: usize,
peak_graph_size: usize,
_phantom: std::marker::PhantomData<F>,
}
impl<F: Float> TrainingLoop<F> {
pub fn new(config: TrainingConfig) -> Self {
Self {
config,
step_count: 0,
total_nodes_created: 0,
peak_graph_size: 0,
_phantom: std::marker::PhantomData,
}
}
pub fn step_count(&self) -> usize {
self.step_count
}
pub fn total_nodes_created(&self) -> usize {
self.total_nodes_created
}
pub fn peak_graph_size(&self) -> usize {
self.peak_graph_size
}
pub fn config(&self) -> &TrainingConfig {
&self.config
}
pub fn increment_step(&mut self) -> usize {
self.step_count += 1;
self.step_count
}
pub fn record_graph_stats(&mut self, node_count: usize) {
self.total_nodes_created += node_count;
if node_count > self.peak_graph_size {
self.peak_graph_size = node_count;
}
if self.config.print_warnings && node_count > self.config.warn_threshold {
eprintln!(
"Warning: Graph size ({}) exceeds warning threshold ({}). \
Consider using ctx.clear_graph() or restructuring your training loop.",
node_count, self.config.warn_threshold
);
}
}
pub fn should_checkpoint(&self) -> bool {
if let Some(interval) = self.config.checkpoint_interval {
self.step_count > 0 && self.step_count.is_multiple_of(interval)
} else {
false
}
}
pub fn reset(&mut self) {
self.step_count = 0;
self.total_nodes_created = 0;
self.peak_graph_size = 0;
}
pub fn stats_string(&self) -> String {
format!(
"TrainingLoop Stats:\n\
- Steps: {}\n\
- Total nodes created: {}\n\
- Peak graph size: {}\n\
- Avg nodes per step: {:.1}",
self.step_count,
self.total_nodes_created,
self.peak_graph_size,
if self.step_count > 0 {
self.total_nodes_created as f64 / self.step_count as f64
} else {
0.0
}
)
}
}
impl<F: Float> Default for TrainingLoop<F> {
fn default() -> Self {
Self::new(TrainingConfig::default())
}
}
#[allow(dead_code)]
pub struct GraphClearGuard<'a, F: Float> {
ctx: &'a mut crate::Context<'a, F>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_training_config_builder() {
let config = TrainingConfig::new()
.with_max_graph_nodes(100_000)
.with_auto_clear(false)
.with_warn_threshold(50_000)
.with_checkpoint_interval(100);
assert_eq!(config.max_graph_nodes, 100_000);
assert!(!config.auto_clear);
assert_eq!(config.warn_threshold, 50_000);
assert_eq!(config.checkpoint_interval, Some(100));
}
#[test]
fn test_training_loop_stats() {
let mut trainer: TrainingLoop<f64> = TrainingLoop::new(TrainingConfig::default());
trainer.increment_step();
trainer.record_graph_stats(1000);
trainer.increment_step();
trainer.record_graph_stats(2000);
assert_eq!(trainer.step_count(), 2);
assert_eq!(trainer.total_nodes_created(), 3000);
assert_eq!(trainer.peak_graph_size(), 2000);
}
#[test]
fn test_should_checkpoint() {
let config = TrainingConfig::new().with_checkpoint_interval(10);
let mut trainer: TrainingLoop<f64> = TrainingLoop::new(config);
for _ in 0..9 {
trainer.increment_step();
assert!(!trainer.should_checkpoint());
}
trainer.increment_step();
assert!(trainer.should_checkpoint());
}
}