1mod attention;
7mod callbacks;
8mod conv;
9mod layers;
10mod quantum_layers;
11mod rnn;
12mod schedules;
13
14pub use attention::*;
15pub use callbacks::*;
16pub use conv::*;
17pub use layers::*;
18pub use quantum_layers::*;
19pub use rnn::*;
20pub use schedules::*;
21
22use crate::error::{MLError, Result};
23use scirs2_core::ndarray::{s, ArrayD, Axis, IxDyn};
24use std::collections::HashMap;
25
26pub trait KerasLayer: Send + Sync {
28 fn build(&mut self, input_shape: &[usize]) -> Result<()>;
30
31 fn call(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>>;
33
34 fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize>;
36
37 fn name(&self) -> &str;
39
40 fn get_weights(&self) -> Vec<ArrayD<f64>>;
42
43 fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()>;
45
46 fn count_params(&self) -> usize {
48 self.get_weights().iter().map(|w| w.len()).sum()
49 }
50
51 fn built(&self) -> bool;
53}
54
55#[derive(Debug, Clone)]
57pub enum ActivationFunction {
58 Linear,
60 ReLU,
62 Sigmoid,
64 Tanh,
66 Softmax,
68 LeakyReLU(f64),
70 ELU(f64),
72}
73
74#[derive(Debug, Clone)]
76pub enum InitializerType {
77 Zeros,
79 Ones,
81 GlorotUniform,
83 GlorotNormal,
85 HeUniform,
87}
88
89pub struct Sequential {
91 layers: Vec<Box<dyn KerasLayer>>,
93 name: String,
95 built: bool,
97 compiled: bool,
99 input_shape: Option<Vec<usize>>,
101 loss: Option<LossFunction>,
103 optimizer: Option<OptimizerType>,
105 metrics: Vec<MetricType>,
107}
108
109impl Sequential {
110 pub fn new() -> Self {
112 Self {
113 layers: Vec::new(),
114 name: format!("sequential_{}", fastrand::u32(..)),
115 built: false,
116 compiled: false,
117 input_shape: None,
118 loss: None,
119 optimizer: None,
120 metrics: Vec::new(),
121 }
122 }
123
124 pub fn name(mut self, name: impl Into<String>) -> Self {
126 self.name = name.into();
127 self
128 }
129
130 pub fn add(&mut self, layer: Box<dyn KerasLayer>) {
132 self.layers.push(layer);
133 self.built = false;
134 }
135
136 pub fn build(&mut self, input_shape: Vec<usize>) -> Result<()> {
138 self.input_shape = Some(input_shape.clone());
139 let mut current_shape = input_shape;
140
141 for layer in &mut self.layers {
142 layer.build(¤t_shape)?;
143 current_shape = layer.compute_output_shape(¤t_shape);
144 }
145
146 self.built = true;
147 Ok(())
148 }
149
150 pub fn compile(
152 mut self,
153 loss: LossFunction,
154 optimizer: OptimizerType,
155 metrics: Vec<MetricType>,
156 ) -> Self {
157 self.loss = Some(loss);
158 self.optimizer = Some(optimizer);
159 self.metrics = metrics;
160 self.compiled = true;
161 self
162 }
163
164 pub fn summary(&self) -> ModelSummary {
166 let mut layers_info = Vec::new();
167 let mut total_params = 0;
168 let mut trainable_params = 0;
169
170 let mut current_shape = self.input_shape.clone().unwrap_or_default();
171
172 for layer in &self.layers {
173 let output_shape = layer.compute_output_shape(¤t_shape);
174 let params = layer.count_params();
175
176 layers_info.push(LayerInfo {
177 name: layer.name().to_string(),
178 layer_type: "Layer".to_string(),
179 output_shape: output_shape.clone(),
180 param_count: params,
181 });
182
183 total_params += params;
184 trainable_params += params;
185 current_shape = output_shape;
186 }
187
188 ModelSummary {
189 layers: layers_info,
190 total_params,
191 trainable_params,
192 non_trainable_params: 0,
193 }
194 }
195
196 pub fn predict(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>> {
198 if !self.built {
199 return Err(MLError::InvalidConfiguration(
200 "Model must be built before prediction".to_string(),
201 ));
202 }
203
204 let mut current = inputs.clone();
205
206 for layer in &self.layers {
207 current = layer.call(¤t)?;
208 }
209
210 Ok(current)
211 }
212
213 #[allow(non_snake_case)]
215 pub fn fit(
216 &mut self,
217 X: &ArrayD<f64>,
218 y: &ArrayD<f64>,
219 epochs: usize,
220 batch_size: Option<usize>,
221 validation_data: Option<(&ArrayD<f64>, &ArrayD<f64>)>,
222 callbacks: Vec<Box<dyn Callback>>,
223 ) -> Result<TrainingHistory> {
224 if !self.compiled {
225 return Err(MLError::InvalidConfiguration(
226 "Model must be compiled before training".to_string(),
227 ));
228 }
229
230 let batch_size = batch_size.unwrap_or(32);
231 let n_samples = X.shape()[0];
232 let n_batches = (n_samples + batch_size - 1) / batch_size;
233
234 let mut history = TrainingHistory::new();
235
236 for epoch in 0..epochs {
237 let mut epoch_loss = 0.0;
238 let mut epoch_metrics: HashMap<String, f64> = HashMap::new();
239
240 for metric in &self.metrics {
241 epoch_metrics.insert(metric.name(), 0.0);
242 }
243
244 for batch_idx in 0..n_batches {
245 let start_idx = batch_idx * batch_size;
246 let end_idx = ((batch_idx + 1) * batch_size).min(n_samples);
247
248 let X_batch = X.slice(s![start_idx..end_idx, ..]);
249 let y_batch = y.slice(s![start_idx..end_idx, ..]);
250
251 let predictions = self.predict(&X_batch.to_owned().into_dyn())?;
252
253 let loss = self.compute_loss(&predictions, &y_batch.to_owned().into_dyn())?;
254 epoch_loss += loss;
255
256 self.backward_pass(&predictions, &y_batch.to_owned().into_dyn())?;
257
258 for metric in &self.metrics {
259 let metric_value =
260 metric.compute(&predictions, &y_batch.to_owned().into_dyn())?;
261 *epoch_metrics.entry(metric.name()).or_insert(0.0) += metric_value;
262 }
263 }
264
265 epoch_loss /= n_batches as f64;
266 for value in epoch_metrics.values_mut() {
267 *value /= n_batches as f64;
268 }
269
270 let (val_loss, val_metrics) = if let Some((X_val, y_val)) = validation_data {
271 let val_predictions = self.predict(X_val)?;
272 let val_loss = self.compute_loss(&val_predictions, y_val)?;
273
274 let mut val_metrics = HashMap::new();
275 for metric in &self.metrics {
276 let metric_value = metric.compute(&val_predictions, y_val)?;
277 val_metrics.insert(format!("val_{}", metric.name()), metric_value);
278 }
279
280 (Some(val_loss), val_metrics)
281 } else {
282 (None, HashMap::new())
283 };
284
285 history.add_epoch(epoch_loss, epoch_metrics, val_loss, val_metrics);
286
287 for callback in &callbacks {
288 callback.on_epoch_end(epoch, &history)?;
289 }
290
291 println!("Epoch {}/{} - loss: {:.4}", epoch + 1, epochs, epoch_loss);
292 }
293
294 Ok(history)
295 }
296
297 #[allow(non_snake_case)]
299 pub fn evaluate(
300 &self,
301 X: &ArrayD<f64>,
302 y: &ArrayD<f64>,
303 _batch_size: Option<usize>,
304 ) -> Result<HashMap<String, f64>> {
305 let predictions = self.predict(X)?;
306 let loss = self.compute_loss(&predictions, y)?;
307
308 let mut results = HashMap::new();
309 results.insert("loss".to_string(), loss);
310
311 for metric in &self.metrics {
312 let metric_value = metric.compute(&predictions, y)?;
313 results.insert(metric.name(), metric_value);
314 }
315
316 Ok(results)
317 }
318
319 fn compute_loss(&self, predictions: &ArrayD<f64>, targets: &ArrayD<f64>) -> Result<f64> {
321 if let Some(ref loss_fn) = self.loss {
322 loss_fn.compute(predictions, targets)
323 } else {
324 Err(MLError::InvalidConfiguration(
325 "Loss function not specified".to_string(),
326 ))
327 }
328 }
329
330 fn backward_pass(&mut self, _predictions: &ArrayD<f64>, _targets: &ArrayD<f64>) -> Result<()> {
332 Ok(())
333 }
334}
335
336impl Default for Sequential {
337 fn default() -> Self {
338 Self::new()
339 }
340}
341
342#[derive(Debug, Clone)]
344pub enum LossFunction {
345 MeanSquaredError,
347 BinaryCrossentropy,
349 CategoricalCrossentropy,
351 SparseCategoricalCrossentropy,
353 MeanAbsoluteError,
355 Huber(f64),
357}
358
359impl LossFunction {
360 pub fn compute(&self, predictions: &ArrayD<f64>, targets: &ArrayD<f64>) -> Result<f64> {
362 match self {
363 LossFunction::MeanSquaredError => {
364 let diff = predictions - targets;
365 diff.mapv(|x| x * x).mean().ok_or_else(|| {
366 MLError::ComputationError("Failed to compute mean of empty array".to_string())
367 })
368 }
369 LossFunction::BinaryCrossentropy => {
370 let epsilon = 1e-15;
371 let clipped_preds = predictions.mapv(|x| x.max(epsilon).min(1.0 - epsilon));
372 let loss = targets * clipped_preds.mapv(|x| x.ln())
373 + (1.0 - targets) * clipped_preds.mapv(|x| (1.0 - x).ln());
374 loss.mean().map(|m| -m).ok_or_else(|| {
375 MLError::ComputationError("Failed to compute mean of empty array".to_string())
376 })
377 }
378 LossFunction::MeanAbsoluteError => {
379 let diff = predictions - targets;
380 diff.mapv(|x| x.abs()).mean().ok_or_else(|| {
381 MLError::ComputationError("Failed to compute mean of empty array".to_string())
382 })
383 }
384 _ => Err(MLError::InvalidConfiguration(
385 "Loss function not implemented".to_string(),
386 )),
387 }
388 }
389}
390
391#[derive(Debug, Clone)]
393pub enum OptimizerType {
394 SGD { learning_rate: f64, momentum: f64 },
396 Adam {
398 learning_rate: f64,
399 beta1: f64,
400 beta2: f64,
401 epsilon: f64,
402 },
403 RMSprop {
405 learning_rate: f64,
406 rho: f64,
407 epsilon: f64,
408 },
409 AdaGrad { learning_rate: f64, epsilon: f64 },
411}
412
413#[derive(Debug, Clone)]
415pub enum MetricType {
416 Accuracy,
418 Precision,
420 Recall,
422 F1Score,
424 MeanAbsoluteError,
426 MeanSquaredError,
428}
429
430impl MetricType {
431 pub fn name(&self) -> String {
433 match self {
434 MetricType::Accuracy => "accuracy".to_string(),
435 MetricType::Precision => "precision".to_string(),
436 MetricType::Recall => "recall".to_string(),
437 MetricType::F1Score => "f1_score".to_string(),
438 MetricType::MeanAbsoluteError => "mean_absolute_error".to_string(),
439 MetricType::MeanSquaredError => "mean_squared_error".to_string(),
440 }
441 }
442
443 pub fn compute(&self, predictions: &ArrayD<f64>, targets: &ArrayD<f64>) -> Result<f64> {
445 match self {
446 MetricType::Accuracy => {
447 let pred_classes = predictions.mapv(|x| if x > 0.5 { 1.0 } else { 0.0 });
448 let correct = pred_classes
449 .iter()
450 .zip(targets.iter())
451 .filter(|(&pred, &target)| (pred - target).abs() < 1e-6)
452 .count();
453 Ok(correct as f64 / targets.len() as f64)
454 }
455 MetricType::MeanAbsoluteError => {
456 let diff = predictions - targets;
457 diff.mapv(|x| x.abs()).mean().ok_or_else(|| {
458 MLError::ComputationError("Failed to compute mean of empty array".to_string())
459 })
460 }
461 MetricType::MeanSquaredError => {
462 let diff = predictions - targets;
463 diff.mapv(|x| x * x).mean().ok_or_else(|| {
464 MLError::ComputationError("Failed to compute mean of empty array".to_string())
465 })
466 }
467 _ => Err(MLError::InvalidConfiguration(
468 "Metric not implemented".to_string(),
469 )),
470 }
471 }
472}
473
474#[derive(Debug, Clone)]
476pub struct TrainingHistory {
477 pub loss: Vec<f64>,
479 pub metrics: Vec<HashMap<String, f64>>,
481 pub val_loss: Vec<f64>,
483 pub val_metrics: Vec<HashMap<String, f64>>,
485}
486
487impl TrainingHistory {
488 pub fn new() -> Self {
490 Self {
491 loss: Vec::new(),
492 metrics: Vec::new(),
493 val_loss: Vec::new(),
494 val_metrics: Vec::new(),
495 }
496 }
497
498 pub fn add_epoch(
500 &mut self,
501 loss: f64,
502 metrics: HashMap<String, f64>,
503 val_loss: Option<f64>,
504 val_metrics: HashMap<String, f64>,
505 ) {
506 self.loss.push(loss);
507 self.metrics.push(metrics);
508
509 if let Some(val_loss) = val_loss {
510 self.val_loss.push(val_loss);
511 }
512 self.val_metrics.push(val_metrics);
513 }
514}
515
516impl Default for TrainingHistory {
517 fn default() -> Self {
518 Self::new()
519 }
520}
521
522#[derive(Debug)]
524pub struct ModelSummary {
525 pub layers: Vec<LayerInfo>,
527 pub total_params: usize,
529 pub trainable_params: usize,
531 pub non_trainable_params: usize,
533}
534
535#[derive(Debug)]
537pub struct LayerInfo {
538 pub name: String,
540 pub layer_type: String,
542 pub output_shape: Vec<usize>,
544 pub param_count: usize,
546}
547
548pub struct Input {
550 pub shape: Vec<usize>,
552 pub name: Option<String>,
554 pub dtype: DataType,
556}
557
558impl Input {
559 pub fn new(shape: Vec<usize>) -> Self {
561 Self {
562 shape,
563 name: None,
564 dtype: DataType::Float64,
565 }
566 }
567
568 pub fn name(mut self, name: impl Into<String>) -> Self {
570 self.name = Some(name.into());
571 self
572 }
573
574 pub fn dtype(mut self, dtype: DataType) -> Self {
576 self.dtype = dtype;
577 self
578 }
579}
580
581#[derive(Debug, Clone)]
583pub enum DataType {
584 Float32,
586 Float64,
588 Int32,
590 Int64,
592}
593
594pub mod utils {
596 use super::*;
597
598 pub fn create_classification_model(
600 _input_dim: usize,
601 num_classes: usize,
602 hidden_layers: Vec<usize>,
603 ) -> Sequential {
604 let mut model = Sequential::new();
605
606 for (i, &units) in hidden_layers.iter().enumerate() {
607 model.add(Box::new(
608 Dense::new(units)
609 .activation(ActivationFunction::ReLU)
610 .name(format!("dense_{}", i)),
611 ));
612 }
613
614 let output_activation = if num_classes == 2 {
615 ActivationFunction::Sigmoid
616 } else {
617 ActivationFunction::Softmax
618 };
619
620 model.add(Box::new(
621 Dense::new(num_classes)
622 .activation(output_activation)
623 .name("output"),
624 ));
625
626 model
627 }
628
629 pub fn create_quantum_model(
631 num_qubits: usize,
632 num_classes: usize,
633 num_layers: usize,
634 ) -> Sequential {
635 let mut model = Sequential::new();
636
637 model.add(Box::new(
638 QuantumDense::new(num_qubits, num_classes)
639 .num_layers(num_layers)
640 .ansatz_type(QuantumAnsatzType::HardwareEfficient)
641 .name("quantum_layer"),
642 ));
643
644 if num_classes > 1 {
645 model.add(Box::new(
646 Activation::new(ActivationFunction::Softmax).name("softmax"),
647 ));
648 }
649
650 model
651 }
652
653 pub fn create_hybrid_model(
655 _input_dim: usize,
656 num_qubits: usize,
657 num_classes: usize,
658 classical_hidden: Vec<usize>,
659 ) -> Sequential {
660 let mut model = Sequential::new();
661
662 for (i, &units) in classical_hidden.iter().enumerate() {
663 model.add(Box::new(
664 Dense::new(units)
665 .activation(ActivationFunction::ReLU)
666 .name(format!("classical_{}", i)),
667 ));
668 }
669
670 model.add(Box::new(
671 QuantumDense::new(num_qubits, 64)
672 .num_layers(2)
673 .ansatz_type(QuantumAnsatzType::HardwareEfficient)
674 .name("quantum_layer"),
675 ));
676
677 model.add(Box::new(
678 Dense::new(num_classes)
679 .activation(if num_classes == 2 {
680 ActivationFunction::Sigmoid
681 } else {
682 ActivationFunction::Softmax
683 })
684 .name("output"),
685 ));
686
687 model
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694 use scirs2_core::ndarray::Array;
695
696 #[test]
697 fn test_dense_layer() {
698 let mut dense = Dense::new(10)
699 .activation(ActivationFunction::ReLU)
700 .name("test_dense");
701
702 assert!(!dense.built());
703
704 dense.build(&[5]).expect("Should build successfully");
705
706 assert!(dense.built());
707 assert_eq!(dense.compute_output_shape(&[32, 5]), vec![32, 10]);
708 }
709
710 #[test]
711 fn test_sequential_model() {
712 let mut model = Sequential::new();
713 model.add(Box::new(Dense::new(10)));
714 model.add(Box::new(Activation::new(ActivationFunction::ReLU)));
715 model.add(Box::new(Dense::new(5)));
716
717 model
718 .build(vec![32, 20])
719 .expect("Should build successfully");
720
721 let summary = model.summary();
722 assert_eq!(summary.layers.len(), 3);
723 }
724
725 #[test]
726 fn test_activation_functions() {
727 let relu = ActivationFunction::ReLU;
728 let sigmoid = ActivationFunction::Sigmoid;
729 let _tanh = ActivationFunction::Tanh;
730
731 let mut act_relu = Activation::new(relu);
732 act_relu.build(&[10]).expect("Should build");
733
734 let mut act_sigmoid = Activation::new(sigmoid);
735 act_sigmoid.build(&[10]).expect("Should build");
736 }
737}