quantrs2_ml/keras_api/
mod.rs

1//! Keras-style model building API for QuantRS2-ML
2//!
3//! This module provides a Keras-like interface for building quantum machine learning
4//! models, with both Sequential and Functional API patterns familiar to Keras users.
5
6mod attention;
7mod callbacks;
8mod conv;
9mod layers;
10mod quantum_layers;
11mod rnn;
12mod schedules;
13
14pub use attention::*;
15pub use callbacks::*;
16pub use conv::*;
17pub use layers::*;
18pub use quantum_layers::*;
19pub use rnn::*;
20pub use schedules::*;
21
22use crate::error::{MLError, Result};
23use scirs2_core::ndarray::{s, ArrayD, Axis, IxDyn};
24use std::collections::HashMap;
25
26/// Keras-style layer trait
27pub trait KerasLayer: Send + Sync {
28    /// Build the layer (called during model compilation)
29    fn build(&mut self, input_shape: &[usize]) -> Result<()>;
30
31    /// Forward pass through the layer
32    fn call(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>>;
33
34    /// Compute output shape given input shape
35    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize>;
36
37    /// Get layer name
38    fn name(&self) -> &str;
39
40    /// Get trainable parameters
41    fn get_weights(&self) -> Vec<ArrayD<f64>>;
42
43    /// Set trainable parameters
44    fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()>;
45
46    /// Get number of parameters
47    fn count_params(&self) -> usize {
48        self.get_weights().iter().map(|w| w.len()).sum()
49    }
50
51    /// Check if layer is built
52    fn built(&self) -> bool;
53}
54
55/// Activation function types
56#[derive(Debug, Clone)]
57pub enum ActivationFunction {
58    /// Linear activation (identity)
59    Linear,
60    /// ReLU activation
61    ReLU,
62    /// Sigmoid activation
63    Sigmoid,
64    /// Tanh activation
65    Tanh,
66    /// Softmax activation
67    Softmax,
68    /// Leaky ReLU with alpha
69    LeakyReLU(f64),
70    /// ELU with alpha
71    ELU(f64),
72}
73
74/// Weight initializer types
75#[derive(Debug, Clone)]
76pub enum InitializerType {
77    /// All zeros
78    Zeros,
79    /// All ones
80    Ones,
81    /// Glorot uniform (Xavier uniform)
82    GlorotUniform,
83    /// Glorot normal (Xavier normal)
84    GlorotNormal,
85    /// He uniform
86    HeUniform,
87}
88
89/// Sequential model
90pub struct Sequential {
91    /// Layers in the model
92    layers: Vec<Box<dyn KerasLayer>>,
93    /// Model name
94    name: String,
95    /// Built flag
96    built: bool,
97    /// Compiled flag
98    compiled: bool,
99    /// Input shape
100    input_shape: Option<Vec<usize>>,
101    /// Loss function
102    loss: Option<LossFunction>,
103    /// Optimizer
104    optimizer: Option<OptimizerType>,
105    /// Metrics
106    metrics: Vec<MetricType>,
107}
108
109impl Sequential {
110    /// Create new sequential model
111    pub fn new() -> Self {
112        Self {
113            layers: Vec::new(),
114            name: format!("sequential_{}", fastrand::u32(..)),
115            built: false,
116            compiled: false,
117            input_shape: None,
118            loss: None,
119            optimizer: None,
120            metrics: Vec::new(),
121        }
122    }
123
124    /// Set model name
125    pub fn name(mut self, name: impl Into<String>) -> Self {
126        self.name = name.into();
127        self
128    }
129
130    /// Add layer to model
131    pub fn add(&mut self, layer: Box<dyn KerasLayer>) {
132        self.layers.push(layer);
133        self.built = false;
134    }
135
136    /// Build the model with given input shape
137    pub fn build(&mut self, input_shape: Vec<usize>) -> Result<()> {
138        self.input_shape = Some(input_shape.clone());
139        let mut current_shape = input_shape;
140
141        for layer in &mut self.layers {
142            layer.build(&current_shape)?;
143            current_shape = layer.compute_output_shape(&current_shape);
144        }
145
146        self.built = true;
147        Ok(())
148    }
149
150    /// Compile the model
151    pub fn compile(
152        mut self,
153        loss: LossFunction,
154        optimizer: OptimizerType,
155        metrics: Vec<MetricType>,
156    ) -> Self {
157        self.loss = Some(loss);
158        self.optimizer = Some(optimizer);
159        self.metrics = metrics;
160        self.compiled = true;
161        self
162    }
163
164    /// Get model summary
165    pub fn summary(&self) -> ModelSummary {
166        let mut layers_info = Vec::new();
167        let mut total_params = 0;
168        let mut trainable_params = 0;
169
170        let mut current_shape = self.input_shape.clone().unwrap_or_default();
171
172        for layer in &self.layers {
173            let output_shape = layer.compute_output_shape(&current_shape);
174            let params = layer.count_params();
175
176            layers_info.push(LayerInfo {
177                name: layer.name().to_string(),
178                layer_type: "Layer".to_string(),
179                output_shape: output_shape.clone(),
180                param_count: params,
181            });
182
183            total_params += params;
184            trainable_params += params;
185            current_shape = output_shape;
186        }
187
188        ModelSummary {
189            layers: layers_info,
190            total_params,
191            trainable_params,
192            non_trainable_params: 0,
193        }
194    }
195
196    /// Forward pass (predict)
197    pub fn predict(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>> {
198        if !self.built {
199            return Err(MLError::InvalidConfiguration(
200                "Model must be built before prediction".to_string(),
201            ));
202        }
203
204        let mut current = inputs.clone();
205
206        for layer in &self.layers {
207            current = layer.call(&current)?;
208        }
209
210        Ok(current)
211    }
212
213    /// Train the model
214    #[allow(non_snake_case)]
215    pub fn fit(
216        &mut self,
217        X: &ArrayD<f64>,
218        y: &ArrayD<f64>,
219        epochs: usize,
220        batch_size: Option<usize>,
221        validation_data: Option<(&ArrayD<f64>, &ArrayD<f64>)>,
222        callbacks: Vec<Box<dyn Callback>>,
223    ) -> Result<TrainingHistory> {
224        if !self.compiled {
225            return Err(MLError::InvalidConfiguration(
226                "Model must be compiled before training".to_string(),
227            ));
228        }
229
230        let batch_size = batch_size.unwrap_or(32);
231        let n_samples = X.shape()[0];
232        let n_batches = (n_samples + batch_size - 1) / batch_size;
233
234        let mut history = TrainingHistory::new();
235
236        for epoch in 0..epochs {
237            let mut epoch_loss = 0.0;
238            let mut epoch_metrics: HashMap<String, f64> = HashMap::new();
239
240            for metric in &self.metrics {
241                epoch_metrics.insert(metric.name(), 0.0);
242            }
243
244            for batch_idx in 0..n_batches {
245                let start_idx = batch_idx * batch_size;
246                let end_idx = ((batch_idx + 1) * batch_size).min(n_samples);
247
248                let X_batch = X.slice(s![start_idx..end_idx, ..]);
249                let y_batch = y.slice(s![start_idx..end_idx, ..]);
250
251                let predictions = self.predict(&X_batch.to_owned().into_dyn())?;
252
253                let loss = self.compute_loss(&predictions, &y_batch.to_owned().into_dyn())?;
254                epoch_loss += loss;
255
256                self.backward_pass(&predictions, &y_batch.to_owned().into_dyn())?;
257
258                for metric in &self.metrics {
259                    let metric_value =
260                        metric.compute(&predictions, &y_batch.to_owned().into_dyn())?;
261                    *epoch_metrics.entry(metric.name()).or_insert(0.0) += metric_value;
262                }
263            }
264
265            epoch_loss /= n_batches as f64;
266            for value in epoch_metrics.values_mut() {
267                *value /= n_batches as f64;
268            }
269
270            let (val_loss, val_metrics) = if let Some((X_val, y_val)) = validation_data {
271                let val_predictions = self.predict(X_val)?;
272                let val_loss = self.compute_loss(&val_predictions, y_val)?;
273
274                let mut val_metrics = HashMap::new();
275                for metric in &self.metrics {
276                    let metric_value = metric.compute(&val_predictions, y_val)?;
277                    val_metrics.insert(format!("val_{}", metric.name()), metric_value);
278                }
279
280                (Some(val_loss), val_metrics)
281            } else {
282                (None, HashMap::new())
283            };
284
285            history.add_epoch(epoch_loss, epoch_metrics, val_loss, val_metrics);
286
287            for callback in &callbacks {
288                callback.on_epoch_end(epoch, &history)?;
289            }
290
291            println!("Epoch {}/{} - loss: {:.4}", epoch + 1, epochs, epoch_loss);
292        }
293
294        Ok(history)
295    }
296
297    /// Evaluate the model
298    #[allow(non_snake_case)]
299    pub fn evaluate(
300        &self,
301        X: &ArrayD<f64>,
302        y: &ArrayD<f64>,
303        _batch_size: Option<usize>,
304    ) -> Result<HashMap<String, f64>> {
305        let predictions = self.predict(X)?;
306        let loss = self.compute_loss(&predictions, y)?;
307
308        let mut results = HashMap::new();
309        results.insert("loss".to_string(), loss);
310
311        for metric in &self.metrics {
312            let metric_value = metric.compute(&predictions, y)?;
313            results.insert(metric.name(), metric_value);
314        }
315
316        Ok(results)
317    }
318
319    /// Compute loss
320    fn compute_loss(&self, predictions: &ArrayD<f64>, targets: &ArrayD<f64>) -> Result<f64> {
321        if let Some(ref loss_fn) = self.loss {
322            loss_fn.compute(predictions, targets)
323        } else {
324            Err(MLError::InvalidConfiguration(
325                "Loss function not specified".to_string(),
326            ))
327        }
328    }
329
330    /// Backward pass (placeholder)
331    fn backward_pass(&mut self, _predictions: &ArrayD<f64>, _targets: &ArrayD<f64>) -> Result<()> {
332        Ok(())
333    }
334}
335
336impl Default for Sequential {
337    fn default() -> Self {
338        Self::new()
339    }
340}
341
342/// Loss functions
343#[derive(Debug, Clone)]
344pub enum LossFunction {
345    /// Mean squared error
346    MeanSquaredError,
347    /// Binary crossentropy
348    BinaryCrossentropy,
349    /// Categorical crossentropy
350    CategoricalCrossentropy,
351    /// Sparse categorical crossentropy
352    SparseCategoricalCrossentropy,
353    /// Mean absolute error
354    MeanAbsoluteError,
355    /// Huber loss
356    Huber(f64),
357}
358
359impl LossFunction {
360    /// Compute loss
361    pub fn compute(&self, predictions: &ArrayD<f64>, targets: &ArrayD<f64>) -> Result<f64> {
362        match self {
363            LossFunction::MeanSquaredError => {
364                let diff = predictions - targets;
365                diff.mapv(|x| x * x).mean().ok_or_else(|| {
366                    MLError::ComputationError("Failed to compute mean of empty array".to_string())
367                })
368            }
369            LossFunction::BinaryCrossentropy => {
370                let epsilon = 1e-15;
371                let clipped_preds = predictions.mapv(|x| x.max(epsilon).min(1.0 - epsilon));
372                let loss = targets * clipped_preds.mapv(|x| x.ln())
373                    + (1.0 - targets) * clipped_preds.mapv(|x| (1.0 - x).ln());
374                loss.mean().map(|m| -m).ok_or_else(|| {
375                    MLError::ComputationError("Failed to compute mean of empty array".to_string())
376                })
377            }
378            LossFunction::MeanAbsoluteError => {
379                let diff = predictions - targets;
380                diff.mapv(|x| x.abs()).mean().ok_or_else(|| {
381                    MLError::ComputationError("Failed to compute mean of empty array".to_string())
382                })
383            }
384            _ => Err(MLError::InvalidConfiguration(
385                "Loss function not implemented".to_string(),
386            )),
387        }
388    }
389}
390
391/// Optimizer types
392#[derive(Debug, Clone)]
393pub enum OptimizerType {
394    /// Stochastic Gradient Descent
395    SGD { learning_rate: f64, momentum: f64 },
396    /// Adam optimizer
397    Adam {
398        learning_rate: f64,
399        beta1: f64,
400        beta2: f64,
401        epsilon: f64,
402    },
403    /// RMSprop optimizer
404    RMSprop {
405        learning_rate: f64,
406        rho: f64,
407        epsilon: f64,
408    },
409    /// AdaGrad optimizer
410    AdaGrad { learning_rate: f64, epsilon: f64 },
411}
412
413/// Metric types
414#[derive(Debug, Clone)]
415pub enum MetricType {
416    /// Accuracy
417    Accuracy,
418    /// Precision
419    Precision,
420    /// Recall
421    Recall,
422    /// F1 Score
423    F1Score,
424    /// Mean Absolute Error
425    MeanAbsoluteError,
426    /// Mean Squared Error
427    MeanSquaredError,
428}
429
430impl MetricType {
431    /// Get metric name
432    pub fn name(&self) -> String {
433        match self {
434            MetricType::Accuracy => "accuracy".to_string(),
435            MetricType::Precision => "precision".to_string(),
436            MetricType::Recall => "recall".to_string(),
437            MetricType::F1Score => "f1_score".to_string(),
438            MetricType::MeanAbsoluteError => "mean_absolute_error".to_string(),
439            MetricType::MeanSquaredError => "mean_squared_error".to_string(),
440        }
441    }
442
443    /// Compute metric
444    pub fn compute(&self, predictions: &ArrayD<f64>, targets: &ArrayD<f64>) -> Result<f64> {
445        match self {
446            MetricType::Accuracy => {
447                let pred_classes = predictions.mapv(|x| if x > 0.5 { 1.0 } else { 0.0 });
448                let correct = pred_classes
449                    .iter()
450                    .zip(targets.iter())
451                    .filter(|(&pred, &target)| (pred - target).abs() < 1e-6)
452                    .count();
453                Ok(correct as f64 / targets.len() as f64)
454            }
455            MetricType::MeanAbsoluteError => {
456                let diff = predictions - targets;
457                diff.mapv(|x| x.abs()).mean().ok_or_else(|| {
458                    MLError::ComputationError("Failed to compute mean of empty array".to_string())
459                })
460            }
461            MetricType::MeanSquaredError => {
462                let diff = predictions - targets;
463                diff.mapv(|x| x * x).mean().ok_or_else(|| {
464                    MLError::ComputationError("Failed to compute mean of empty array".to_string())
465                })
466            }
467            _ => Err(MLError::InvalidConfiguration(
468                "Metric not implemented".to_string(),
469            )),
470        }
471    }
472}
473
474/// Training history
475#[derive(Debug, Clone)]
476pub struct TrainingHistory {
477    /// Training loss for each epoch
478    pub loss: Vec<f64>,
479    /// Training metrics for each epoch
480    pub metrics: Vec<HashMap<String, f64>>,
481    /// Validation loss for each epoch
482    pub val_loss: Vec<f64>,
483    /// Validation metrics for each epoch
484    pub val_metrics: Vec<HashMap<String, f64>>,
485}
486
487impl TrainingHistory {
488    /// Create new training history
489    pub fn new() -> Self {
490        Self {
491            loss: Vec::new(),
492            metrics: Vec::new(),
493            val_loss: Vec::new(),
494            val_metrics: Vec::new(),
495        }
496    }
497
498    /// Add epoch results
499    pub fn add_epoch(
500        &mut self,
501        loss: f64,
502        metrics: HashMap<String, f64>,
503        val_loss: Option<f64>,
504        val_metrics: HashMap<String, f64>,
505    ) {
506        self.loss.push(loss);
507        self.metrics.push(metrics);
508
509        if let Some(val_loss) = val_loss {
510            self.val_loss.push(val_loss);
511        }
512        self.val_metrics.push(val_metrics);
513    }
514}
515
516impl Default for TrainingHistory {
517    fn default() -> Self {
518        Self::new()
519    }
520}
521
522/// Model summary information
523#[derive(Debug)]
524pub struct ModelSummary {
525    /// Layer information
526    pub layers: Vec<LayerInfo>,
527    /// Total number of parameters
528    pub total_params: usize,
529    /// Number of trainable parameters
530    pub trainable_params: usize,
531    /// Number of non-trainable parameters
532    pub non_trainable_params: usize,
533}
534
535/// Layer information for summary
536#[derive(Debug)]
537pub struct LayerInfo {
538    /// Layer name
539    pub name: String,
540    /// Layer type
541    pub layer_type: String,
542    /// Output shape
543    pub output_shape: Vec<usize>,
544    /// Parameter count
545    pub param_count: usize,
546}
547
548/// Model input specification
549pub struct Input {
550    /// Input shape (excluding batch dimension)
551    pub shape: Vec<usize>,
552    /// Input name
553    pub name: Option<String>,
554    /// Data type
555    pub dtype: DataType,
556}
557
558impl Input {
559    /// Create new input specification
560    pub fn new(shape: Vec<usize>) -> Self {
561        Self {
562            shape,
563            name: None,
564            dtype: DataType::Float64,
565        }
566    }
567
568    /// Set input name
569    pub fn name(mut self, name: impl Into<String>) -> Self {
570        self.name = Some(name.into());
571        self
572    }
573
574    /// Set data type
575    pub fn dtype(mut self, dtype: DataType) -> Self {
576        self.dtype = dtype;
577        self
578    }
579}
580
581/// Data types
582#[derive(Debug, Clone)]
583pub enum DataType {
584    /// 32-bit float
585    Float32,
586    /// 64-bit float
587    Float64,
588    /// 32-bit integer
589    Int32,
590    /// 64-bit integer
591    Int64,
592}
593
594/// Utility functions for building models
595pub mod utils {
596    use super::*;
597
598    /// Create a simple sequential model for classification
599    pub fn create_classification_model(
600        _input_dim: usize,
601        num_classes: usize,
602        hidden_layers: Vec<usize>,
603    ) -> Sequential {
604        let mut model = Sequential::new();
605
606        for (i, &units) in hidden_layers.iter().enumerate() {
607            model.add(Box::new(
608                Dense::new(units)
609                    .activation(ActivationFunction::ReLU)
610                    .name(format!("dense_{}", i)),
611            ));
612        }
613
614        let output_activation = if num_classes == 2 {
615            ActivationFunction::Sigmoid
616        } else {
617            ActivationFunction::Softmax
618        };
619
620        model.add(Box::new(
621            Dense::new(num_classes)
622                .activation(output_activation)
623                .name("output"),
624        ));
625
626        model
627    }
628
629    /// Create a quantum neural network model
630    pub fn create_quantum_model(
631        num_qubits: usize,
632        num_classes: usize,
633        num_layers: usize,
634    ) -> Sequential {
635        let mut model = Sequential::new();
636
637        model.add(Box::new(
638            QuantumDense::new(num_qubits, num_classes)
639                .num_layers(num_layers)
640                .ansatz_type(QuantumAnsatzType::HardwareEfficient)
641                .name("quantum_layer"),
642        ));
643
644        if num_classes > 1 {
645            model.add(Box::new(
646                Activation::new(ActivationFunction::Softmax).name("softmax"),
647            ));
648        }
649
650        model
651    }
652
653    /// Create a hybrid quantum-classical model
654    pub fn create_hybrid_model(
655        _input_dim: usize,
656        num_qubits: usize,
657        num_classes: usize,
658        classical_hidden: Vec<usize>,
659    ) -> Sequential {
660        let mut model = Sequential::new();
661
662        for (i, &units) in classical_hidden.iter().enumerate() {
663            model.add(Box::new(
664                Dense::new(units)
665                    .activation(ActivationFunction::ReLU)
666                    .name(format!("classical_{}", i)),
667            ));
668        }
669
670        model.add(Box::new(
671            QuantumDense::new(num_qubits, 64)
672                .num_layers(2)
673                .ansatz_type(QuantumAnsatzType::HardwareEfficient)
674                .name("quantum_layer"),
675        ));
676
677        model.add(Box::new(
678            Dense::new(num_classes)
679                .activation(if num_classes == 2 {
680                    ActivationFunction::Sigmoid
681                } else {
682                    ActivationFunction::Softmax
683                })
684                .name("output"),
685        ));
686
687        model
688    }
689}
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694    use scirs2_core::ndarray::Array;
695
696    #[test]
697    fn test_dense_layer() {
698        let mut dense = Dense::new(10)
699            .activation(ActivationFunction::ReLU)
700            .name("test_dense");
701
702        assert!(!dense.built());
703
704        dense.build(&[5]).expect("Should build successfully");
705
706        assert!(dense.built());
707        assert_eq!(dense.compute_output_shape(&[32, 5]), vec![32, 10]);
708    }
709
710    #[test]
711    fn test_sequential_model() {
712        let mut model = Sequential::new();
713        model.add(Box::new(Dense::new(10)));
714        model.add(Box::new(Activation::new(ActivationFunction::ReLU)));
715        model.add(Box::new(Dense::new(5)));
716
717        model
718            .build(vec![32, 20])
719            .expect("Should build successfully");
720
721        let summary = model.summary();
722        assert_eq!(summary.layers.len(), 3);
723    }
724
725    #[test]
726    fn test_activation_functions() {
727        let relu = ActivationFunction::ReLU;
728        let sigmoid = ActivationFunction::Sigmoid;
729        let _tanh = ActivationFunction::Tanh;
730
731        let mut act_relu = Activation::new(relu);
732        act_relu.build(&[10]).expect("Should build");
733
734        let mut act_sigmoid = Activation::new(sigmoid);
735        act_sigmoid.build(&[10]).expect("Should build");
736    }
737}