quantrs2_ml/keras_api/
callbacks.rs

1//! Callbacks for Keras-like API
2
3use super::TrainingHistory;
4use crate::error::Result;
5use scirs2_core::ndarray::ArrayD;
6use std::collections::HashMap;
7
8/// Callback trait for training
9pub trait Callback: Send + Sync {
10    /// Called at the end of each epoch
11    fn on_epoch_end(&self, epoch: usize, history: &TrainingHistory) -> Result<()>;
12}
13
14/// Early stopping callback
15pub struct EarlyStopping {
16    /// Metric to monitor
17    monitor: String,
18    /// Minimum change to qualify as improvement
19    min_delta: f64,
20    /// Number of epochs with no improvement to wait
21    patience: usize,
22    /// Best value seen so far
23    best: f64,
24    /// Number of epochs without improvement
25    wait: usize,
26    /// Whether to stop training
27    stopped: bool,
28}
29
30impl EarlyStopping {
31    /// Create new early stopping callback
32    pub fn new(monitor: String, min_delta: f64, patience: usize) -> Self {
33        Self {
34            monitor,
35            min_delta,
36            patience,
37            best: f64::INFINITY,
38            wait: 0,
39            stopped: false,
40        }
41    }
42}
43
44impl Callback for EarlyStopping {
45    fn on_epoch_end(&self, _epoch: usize, _history: &TrainingHistory) -> Result<()> {
46        Ok(())
47    }
48}
49
50/// Model checkpoint callback
51pub struct ModelCheckpoint {
52    /// File path pattern
53    filepath: String,
54    /// Metric to monitor
55    monitor: String,
56    /// Save best only
57    save_best_only: bool,
58    /// Mode (min or max)
59    mode: String,
60    /// Best value seen
61    best: f64,
62}
63
64impl ModelCheckpoint {
65    /// Create new model checkpoint callback
66    pub fn new(filepath: impl Into<String>) -> Self {
67        Self {
68            filepath: filepath.into(),
69            monitor: "val_loss".to_string(),
70            save_best_only: false,
71            mode: "min".to_string(),
72            best: f64::INFINITY,
73        }
74    }
75
76    /// Set metric to monitor
77    pub fn monitor(mut self, monitor: impl Into<String>) -> Self {
78        self.monitor = monitor.into();
79        self
80    }
81
82    /// Save best only
83    pub fn save_best_only(mut self, save_best_only: bool) -> Self {
84        self.save_best_only = save_best_only;
85        self
86    }
87
88    /// Set mode
89    pub fn mode(mut self, mode: impl Into<String>) -> Self {
90        self.mode = mode.into();
91        if self.mode == "max" {
92            self.best = f64::NEG_INFINITY;
93        }
94        self
95    }
96}
97
98impl Callback for ModelCheckpoint {
99    fn on_epoch_end(&self, _epoch: usize, history: &TrainingHistory) -> Result<()> {
100        if let Some(&current) = history.loss.last() {
101            let _ = current;
102        }
103        Ok(())
104    }
105}
106
107/// CSV logger callback
108pub struct CSVLogger {
109    /// File path
110    filename: String,
111    /// Separator
112    separator: String,
113    /// Append mode
114    append: bool,
115    /// Logged data
116    logged_data: Vec<HashMap<String, f64>>,
117}
118
119impl CSVLogger {
120    /// Create new CSV logger
121    pub fn new(filename: impl Into<String>) -> Self {
122        Self {
123            filename: filename.into(),
124            separator: ",".to_string(),
125            append: false,
126            logged_data: Vec::new(),
127        }
128    }
129
130    /// Set separator
131    pub fn separator(mut self, separator: impl Into<String>) -> Self {
132        self.separator = separator.into();
133        self
134    }
135
136    /// Set append mode
137    pub fn append(mut self, append: bool) -> Self {
138        self.append = append;
139        self
140    }
141
142    /// Get logged data
143    pub fn get_logged_data(&self) -> &[HashMap<String, f64>] {
144        &self.logged_data
145    }
146}
147
148impl Callback for CSVLogger {
149    fn on_epoch_end(&self, epoch: usize, history: &TrainingHistory) -> Result<()> {
150        let _ = (epoch, history);
151        Ok(())
152    }
153}
154
155/// Reduce learning rate on plateau callback
156pub struct ReduceLROnPlateau {
157    /// Metric to monitor
158    monitor: String,
159    /// Factor to reduce LR by
160    factor: f64,
161    /// Patience epochs
162    patience: usize,
163    /// Minimum learning rate
164    min_lr: f64,
165    /// Mode (min or max)
166    mode: String,
167    /// Best value seen
168    best: f64,
169    /// Wait counter
170    wait: usize,
171    /// Current learning rate
172    current_lr: f64,
173}
174
175impl ReduceLROnPlateau {
176    /// Create new ReduceLROnPlateau callback
177    pub fn new() -> Self {
178        Self {
179            monitor: "val_loss".to_string(),
180            factor: 0.1,
181            patience: 10,
182            min_lr: 1e-7,
183            mode: "min".to_string(),
184            best: f64::INFINITY,
185            wait: 0,
186            current_lr: 0.001,
187        }
188    }
189
190    /// Set metric to monitor
191    pub fn monitor(mut self, monitor: impl Into<String>) -> Self {
192        self.monitor = monitor.into();
193        self
194    }
195
196    /// Set reduction factor
197    pub fn factor(mut self, factor: f64) -> Self {
198        self.factor = factor;
199        self
200    }
201
202    /// Set patience
203    pub fn patience(mut self, patience: usize) -> Self {
204        self.patience = patience;
205        self
206    }
207
208    /// Get current learning rate
209    pub fn get_lr(&self) -> f64 {
210        self.current_lr
211    }
212}
213
214impl Default for ReduceLROnPlateau {
215    fn default() -> Self {
216        Self::new()
217    }
218}
219
220impl Callback for ReduceLROnPlateau {
221    fn on_epoch_end(&self, _epoch: usize, history: &TrainingHistory) -> Result<()> {
222        if let Some(&current) = history.loss.last() {
223            let _ = current;
224        }
225        Ok(())
226    }
227}
228
229/// Regularizer types
230#[derive(Debug, Clone)]
231pub enum RegularizerType {
232    /// L1 regularization
233    L1(f64),
234    /// L2 regularization
235    L2(f64),
236    /// L1L2 regularization
237    L1L2 { l1: f64, l2: f64 },
238}
239
240impl RegularizerType {
241    /// Compute regularization loss
242    pub fn compute(&self, weights: &ArrayD<f64>) -> f64 {
243        match self {
244            RegularizerType::L1(l1) => l1 * weights.iter().map(|w| w.abs()).sum::<f64>(),
245            RegularizerType::L2(l2) => l2 * weights.iter().map(|w| w * w).sum::<f64>(),
246            RegularizerType::L1L2 { l1, l2 } => {
247                l1 * weights.iter().map(|w| w.abs()).sum::<f64>()
248                    + l2 * weights.iter().map(|w| w * w).sum::<f64>()
249            }
250        }
251    }
252}