quantrs2_ml/pytorch_api/
mod.rs

1//! PyTorch-like API for quantum machine learning models
2//!
3//! This module provides a familiar PyTorch-style interface for building,
4//! training, and deploying quantum ML models, making it easier for classical
5//! ML practitioners to adopt quantum algorithms.
6
7mod conv;
8mod data;
9mod layers;
10mod loss;
11mod rnn;
12mod schedulers;
13mod transformer;
14
15pub use conv::*;
16pub use data::*;
17pub use layers::*;
18pub use loss::*;
19pub use rnn::*;
20pub use schedulers::*;
21pub use transformer::*;
22
23use crate::error::{MLError, Result};
24use crate::scirs2_integration::{SciRS2Array, SciRS2Optimizer};
25use scirs2_core::ndarray::{ArrayD, IxDyn};
26
27/// Base trait for all quantum ML modules
28pub trait QuantumModule: Send + Sync {
29    /// Forward pass
30    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array>;
31
32    /// Get all parameters
33    fn parameters(&self) -> Vec<Parameter>;
34
35    /// Set training mode
36    fn train(&mut self, mode: bool);
37
38    /// Check if module is in training mode
39    fn training(&self) -> bool;
40
41    /// Zero gradients of all parameters
42    fn zero_grad(&mut self);
43
44    /// Module name for debugging
45    fn name(&self) -> &str;
46}
47
48/// Quantum parameter wrapper
49#[derive(Debug, Clone)]
50pub struct Parameter {
51    /// Parameter data
52    pub data: SciRS2Array,
53    /// Parameter name
54    pub name: String,
55    /// Whether parameter requires gradient
56    pub requires_grad: bool,
57}
58
59impl Parameter {
60    /// Create new parameter
61    pub fn new(data: SciRS2Array, name: impl Into<String>) -> Self {
62        Self {
63            data,
64            name: name.into(),
65            requires_grad: true,
66        }
67    }
68
69    /// Create parameter without gradients
70    pub fn no_grad(data: SciRS2Array, name: impl Into<String>) -> Self {
71        Self {
72            data,
73            name: name.into(),
74            requires_grad: false,
75        }
76    }
77
78    /// Get parameter shape
79    pub fn shape(&self) -> &[usize] {
80        self.data.data.shape()
81    }
82
83    /// Get parameter size
84    pub fn numel(&self) -> usize {
85        self.data.data.len()
86    }
87}
88
89/// Sequential container for quantum modules
90pub struct QuantumSequential {
91    /// Ordered modules
92    modules: Vec<Box<dyn QuantumModule>>,
93    /// Training mode
94    training: bool,
95}
96
97impl QuantumSequential {
98    /// Create new sequential container
99    pub fn new() -> Self {
100        Self {
101            modules: Vec::new(),
102            training: true,
103        }
104    }
105
106    /// Add module to sequence
107    pub fn add(mut self, module: Box<dyn QuantumModule>) -> Self {
108        self.modules.push(module);
109        self
110    }
111
112    /// Get number of modules
113    pub fn len(&self) -> usize {
114        self.modules.len()
115    }
116
117    /// Check if empty
118    pub fn is_empty(&self) -> bool {
119        self.modules.is_empty()
120    }
121}
122
123impl Default for QuantumSequential {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129impl QuantumModule for QuantumSequential {
130    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
131        let mut output = input.clone();
132
133        for module in &mut self.modules {
134            output = module.forward(&output)?;
135        }
136
137        Ok(output)
138    }
139
140    fn parameters(&self) -> Vec<Parameter> {
141        let mut all_params = Vec::new();
142
143        for module in &self.modules {
144            all_params.extend(module.parameters());
145        }
146
147        all_params
148    }
149
150    fn train(&mut self, mode: bool) {
151        self.training = mode;
152        for module in &mut self.modules {
153            module.train(mode);
154        }
155    }
156
157    fn training(&self) -> bool {
158        self.training
159    }
160
161    fn zero_grad(&mut self) {
162        for module in &mut self.modules {
163            module.zero_grad();
164        }
165    }
166
167    fn name(&self) -> &str {
168        "QuantumSequential"
169    }
170}
171
172/// Training history
173#[derive(Debug, Clone)]
174pub struct TrainingHistory {
175    /// Loss values per epoch
176    pub losses: Vec<f64>,
177    /// Accuracy values per epoch (if applicable)
178    pub accuracies: Vec<f64>,
179    /// Validation losses
180    pub val_losses: Vec<f64>,
181    /// Validation accuracies
182    pub val_accuracies: Vec<f64>,
183}
184
185impl TrainingHistory {
186    /// Create new training history
187    pub fn new() -> Self {
188        Self {
189            losses: Vec::new(),
190            accuracies: Vec::new(),
191            val_losses: Vec::new(),
192            val_accuracies: Vec::new(),
193        }
194    }
195
196    /// Add training metrics
197    pub fn add_training(&mut self, loss: f64, accuracy: Option<f64>) {
198        self.losses.push(loss);
199        if let Some(acc) = accuracy {
200            self.accuracies.push(acc);
201        }
202    }
203
204    /// Add validation metrics
205    pub fn add_validation(&mut self, loss: f64, accuracy: Option<f64>) {
206        self.val_losses.push(loss);
207        if let Some(acc) = accuracy {
208            self.val_accuracies.push(acc);
209        }
210    }
211}
212
213impl Default for TrainingHistory {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219/// Training utilities
220pub struct QuantumTrainer {
221    /// Model to train
222    model: Box<dyn QuantumModule>,
223    /// Optimizer
224    optimizer: SciRS2Optimizer,
225    /// Loss function
226    loss_fn: Box<dyn QuantumLoss>,
227    /// Training history
228    history: TrainingHistory,
229}
230
231impl QuantumTrainer {
232    /// Create new trainer
233    pub fn new(
234        model: Box<dyn QuantumModule>,
235        optimizer: SciRS2Optimizer,
236        loss_fn: Box<dyn QuantumLoss>,
237    ) -> Self {
238        Self {
239            model,
240            optimizer,
241            loss_fn,
242            history: TrainingHistory::new(),
243        }
244    }
245
246    /// Train for one epoch
247    pub fn train_epoch<D: DataLoader>(&mut self, dataloader: &mut D) -> Result<f64> {
248        self.model.train(true);
249
250        let mut epoch_loss = 0.0;
251        let mut batches = 0;
252
253        while let Some((inputs, targets)) = dataloader.next_batch()? {
254            // Zero gradients
255            self.model.zero_grad();
256
257            // Forward pass
258            let predictions = self.model.forward(&inputs)?;
259
260            // Compute loss
261            let loss = self.loss_fn.forward(&predictions, &targets)?;
262            let loss_val = loss.data.iter().next().copied().unwrap_or(0.0);
263
264            epoch_loss += loss_val;
265            batches += 1;
266        }
267
268        let avg_loss = if batches > 0 {
269            epoch_loss / batches as f64
270        } else {
271            0.0
272        };
273        self.history.add_training(avg_loss, None);
274
275        Ok(avg_loss)
276    }
277
278    /// Evaluate model
279    pub fn evaluate<D: DataLoader>(&mut self, dataloader: &mut D) -> Result<f64> {
280        self.model.train(false);
281
282        let mut total_loss = 0.0;
283        let mut batches = 0;
284
285        while let Some((inputs, targets)) = dataloader.next_batch()? {
286            let predictions = self.model.forward(&inputs)?;
287            let loss = self.loss_fn.forward(&predictions, &targets)?;
288            let loss_val = loss.data.iter().next().copied().unwrap_or(0.0);
289
290            total_loss += loss_val;
291            batches += 1;
292        }
293
294        let avg_loss = if batches > 0 {
295            total_loss / batches as f64
296        } else {
297            0.0
298        };
299
300        Ok(avg_loss)
301    }
302
303    /// Get training history
304    pub fn history(&self) -> &TrainingHistory {
305        &self.history
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_quantum_linear() {
315        let linear = QuantumLinear::new(4, 2).expect("QuantumLinear creation should succeed");
316        assert_eq!(linear.in_features, 4);
317        assert_eq!(linear.out_features, 2);
318        assert_eq!(linear.parameters().len(), 1); // weights only
319
320        let _linear_with_bias = linear.with_bias().expect("Adding bias should succeed");
321        // Would have 2 parameters: weights and bias
322    }
323
324    #[test]
325    fn test_quantum_sequential() {
326        let model = QuantumSequential::new()
327            .add(Box::new(
328                QuantumLinear::new(4, 8).expect("QuantumLinear creation should succeed"),
329            ))
330            .add(Box::new(QuantumActivation::relu()))
331            .add(Box::new(
332                QuantumLinear::new(8, 2).expect("QuantumLinear creation should succeed"),
333            ));
334
335        assert_eq!(model.len(), 3);
336        assert!(!model.is_empty());
337    }
338
339    #[test]
340    fn test_quantum_activation() {
341        let mut relu = QuantumActivation::relu();
342        let input_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![-1.0, 1.0])
343            .expect("Valid shape for input data");
344        let input = SciRS2Array::new(input_data, false);
345
346        let output = relu.forward(&input).expect("Forward pass should succeed");
347        assert_eq!(output.data[[0]], 0.0); // ReLU(-1) = 0
348        assert_eq!(output.data[[1]], 1.0); // ReLU(1) = 1
349    }
350
351    #[test]
352    #[ignore]
353    fn test_quantum_loss() {
354        let mse_loss = QuantumMSELoss;
355
356        let pred_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0])
357            .expect("Valid shape for predictions");
358        let target_data =
359            ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.5, 1.8]).expect("Valid shape for targets");
360
361        let predictions = SciRS2Array::new(pred_data, false);
362        let targets = SciRS2Array::new(target_data, false);
363
364        let loss = mse_loss
365            .forward(&predictions, &targets)
366            .expect("Loss computation should succeed");
367        assert!(loss.data[[0]] > 0.0); // Should have positive loss
368    }
369
370    #[test]
371    fn test_parameter() {
372        let data = ArrayD::from_shape_vec(IxDyn(&[2, 3]), vec![1.0; 6])
373            .expect("Valid shape for parameter data");
374        let param = Parameter::new(SciRS2Array::new(data, true), "test_param");
375
376        assert_eq!(param.name, "test_param");
377        assert!(param.requires_grad);
378        assert_eq!(param.shape(), &[2, 3]);
379        assert_eq!(param.numel(), 6);
380    }
381
382    #[test]
383    fn test_training_history() {
384        let mut history = TrainingHistory::new();
385        history.add_training(0.5, Some(0.8));
386        history.add_validation(0.6, Some(0.7));
387
388        assert_eq!(history.losses.len(), 1);
389        assert_eq!(history.accuracies.len(), 1);
390        assert_eq!(history.val_losses.len(), 1);
391        assert_eq!(history.val_accuracies.len(), 1);
392    }
393}