quantrs2_ml/keras_api/
callbacks.rs1use super::TrainingHistory;
4use crate::error::Result;
5use scirs2_core::ndarray::ArrayD;
6use std::collections::HashMap;
7
8pub trait Callback: Send + Sync {
10 fn on_epoch_end(&self, epoch: usize, history: &TrainingHistory) -> Result<()>;
12}
13
14pub struct EarlyStopping {
16 monitor: String,
18 min_delta: f64,
20 patience: usize,
22 best: f64,
24 wait: usize,
26 stopped: bool,
28}
29
30impl EarlyStopping {
31 pub fn new(monitor: String, min_delta: f64, patience: usize) -> Self {
33 Self {
34 monitor,
35 min_delta,
36 patience,
37 best: f64::INFINITY,
38 wait: 0,
39 stopped: false,
40 }
41 }
42}
43
44impl Callback for EarlyStopping {
45 fn on_epoch_end(&self, _epoch: usize, _history: &TrainingHistory) -> Result<()> {
46 Ok(())
47 }
48}
49
50pub struct ModelCheckpoint {
52 filepath: String,
54 monitor: String,
56 save_best_only: bool,
58 mode: String,
60 best: f64,
62}
63
64impl ModelCheckpoint {
65 pub fn new(filepath: impl Into<String>) -> Self {
67 Self {
68 filepath: filepath.into(),
69 monitor: "val_loss".to_string(),
70 save_best_only: false,
71 mode: "min".to_string(),
72 best: f64::INFINITY,
73 }
74 }
75
76 pub fn monitor(mut self, monitor: impl Into<String>) -> Self {
78 self.monitor = monitor.into();
79 self
80 }
81
82 pub fn save_best_only(mut self, save_best_only: bool) -> Self {
84 self.save_best_only = save_best_only;
85 self
86 }
87
88 pub fn mode(mut self, mode: impl Into<String>) -> Self {
90 self.mode = mode.into();
91 if self.mode == "max" {
92 self.best = f64::NEG_INFINITY;
93 }
94 self
95 }
96}
97
98impl Callback for ModelCheckpoint {
99 fn on_epoch_end(&self, _epoch: usize, history: &TrainingHistory) -> Result<()> {
100 if let Some(¤t) = history.loss.last() {
101 let _ = current;
102 }
103 Ok(())
104 }
105}
106
107pub struct CSVLogger {
109 filename: String,
111 separator: String,
113 append: bool,
115 logged_data: Vec<HashMap<String, f64>>,
117}
118
119impl CSVLogger {
120 pub fn new(filename: impl Into<String>) -> Self {
122 Self {
123 filename: filename.into(),
124 separator: ",".to_string(),
125 append: false,
126 logged_data: Vec::new(),
127 }
128 }
129
130 pub fn separator(mut self, separator: impl Into<String>) -> Self {
132 self.separator = separator.into();
133 self
134 }
135
136 pub fn append(mut self, append: bool) -> Self {
138 self.append = append;
139 self
140 }
141
142 pub fn get_logged_data(&self) -> &[HashMap<String, f64>] {
144 &self.logged_data
145 }
146}
147
148impl Callback for CSVLogger {
149 fn on_epoch_end(&self, epoch: usize, history: &TrainingHistory) -> Result<()> {
150 let _ = (epoch, history);
151 Ok(())
152 }
153}
154
155pub struct ReduceLROnPlateau {
157 monitor: String,
159 factor: f64,
161 patience: usize,
163 min_lr: f64,
165 mode: String,
167 best: f64,
169 wait: usize,
171 current_lr: f64,
173}
174
175impl ReduceLROnPlateau {
176 pub fn new() -> Self {
178 Self {
179 monitor: "val_loss".to_string(),
180 factor: 0.1,
181 patience: 10,
182 min_lr: 1e-7,
183 mode: "min".to_string(),
184 best: f64::INFINITY,
185 wait: 0,
186 current_lr: 0.001,
187 }
188 }
189
190 pub fn monitor(mut self, monitor: impl Into<String>) -> Self {
192 self.monitor = monitor.into();
193 self
194 }
195
196 pub fn factor(mut self, factor: f64) -> Self {
198 self.factor = factor;
199 self
200 }
201
202 pub fn patience(mut self, patience: usize) -> Self {
204 self.patience = patience;
205 self
206 }
207
208 pub fn get_lr(&self) -> f64 {
210 self.current_lr
211 }
212}
213
214impl Default for ReduceLROnPlateau {
215 fn default() -> Self {
216 Self::new()
217 }
218}
219
220impl Callback for ReduceLROnPlateau {
221 fn on_epoch_end(&self, _epoch: usize, history: &TrainingHistory) -> Result<()> {
222 if let Some(¤t) = history.loss.last() {
223 let _ = current;
224 }
225 Ok(())
226 }
227}
228
229#[derive(Debug, Clone)]
231pub enum RegularizerType {
232 L1(f64),
234 L2(f64),
236 L1L2 { l1: f64, l2: f64 },
238}
239
240impl RegularizerType {
241 pub fn compute(&self, weights: &ArrayD<f64>) -> f64 {
243 match self {
244 RegularizerType::L1(l1) => l1 * weights.iter().map(|w| w.abs()).sum::<f64>(),
245 RegularizerType::L2(l2) => l2 * weights.iter().map(|w| w * w).sum::<f64>(),
246 RegularizerType::L1L2 { l1, l2 } => {
247 l1 * weights.iter().map(|w| w.abs()).sum::<f64>()
248 + l2 * weights.iter().map(|w| w * w).sum::<f64>()
249 }
250 }
251 }
252}