1mod conv;
8mod data;
9mod layers;
10mod loss;
11mod rnn;
12mod schedulers;
13mod transformer;
14
15pub use conv::*;
16pub use data::*;
17pub use layers::*;
18pub use loss::*;
19pub use rnn::*;
20pub use schedulers::*;
21pub use transformer::*;
22
23use crate::error::{MLError, Result};
24use crate::scirs2_integration::{SciRS2Array, SciRS2Optimizer};
25use scirs2_core::ndarray::{ArrayD, IxDyn};
26
27pub trait QuantumModule: Send + Sync {
29 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array>;
31
32 fn parameters(&self) -> Vec<Parameter>;
34
35 fn train(&mut self, mode: bool);
37
38 fn training(&self) -> bool;
40
41 fn zero_grad(&mut self);
43
44 fn name(&self) -> &str;
46}
47
48#[derive(Debug, Clone)]
50pub struct Parameter {
51 pub data: SciRS2Array,
53 pub name: String,
55 pub requires_grad: bool,
57}
58
59impl Parameter {
60 pub fn new(data: SciRS2Array, name: impl Into<String>) -> Self {
62 Self {
63 data,
64 name: name.into(),
65 requires_grad: true,
66 }
67 }
68
69 pub fn no_grad(data: SciRS2Array, name: impl Into<String>) -> Self {
71 Self {
72 data,
73 name: name.into(),
74 requires_grad: false,
75 }
76 }
77
78 pub fn shape(&self) -> &[usize] {
80 self.data.data.shape()
81 }
82
83 pub fn numel(&self) -> usize {
85 self.data.data.len()
86 }
87}
88
89pub struct QuantumSequential {
91 modules: Vec<Box<dyn QuantumModule>>,
93 training: bool,
95}
96
97impl QuantumSequential {
98 pub fn new() -> Self {
100 Self {
101 modules: Vec::new(),
102 training: true,
103 }
104 }
105
106 pub fn add(mut self, module: Box<dyn QuantumModule>) -> Self {
108 self.modules.push(module);
109 self
110 }
111
112 pub fn len(&self) -> usize {
114 self.modules.len()
115 }
116
117 pub fn is_empty(&self) -> bool {
119 self.modules.is_empty()
120 }
121}
122
123impl Default for QuantumSequential {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl QuantumModule for QuantumSequential {
130 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
131 let mut output = input.clone();
132
133 for module in &mut self.modules {
134 output = module.forward(&output)?;
135 }
136
137 Ok(output)
138 }
139
140 fn parameters(&self) -> Vec<Parameter> {
141 let mut all_params = Vec::new();
142
143 for module in &self.modules {
144 all_params.extend(module.parameters());
145 }
146
147 all_params
148 }
149
150 fn train(&mut self, mode: bool) {
151 self.training = mode;
152 for module in &mut self.modules {
153 module.train(mode);
154 }
155 }
156
157 fn training(&self) -> bool {
158 self.training
159 }
160
161 fn zero_grad(&mut self) {
162 for module in &mut self.modules {
163 module.zero_grad();
164 }
165 }
166
167 fn name(&self) -> &str {
168 "QuantumSequential"
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct TrainingHistory {
175 pub losses: Vec<f64>,
177 pub accuracies: Vec<f64>,
179 pub val_losses: Vec<f64>,
181 pub val_accuracies: Vec<f64>,
183}
184
185impl TrainingHistory {
186 pub fn new() -> Self {
188 Self {
189 losses: Vec::new(),
190 accuracies: Vec::new(),
191 val_losses: Vec::new(),
192 val_accuracies: Vec::new(),
193 }
194 }
195
196 pub fn add_training(&mut self, loss: f64, accuracy: Option<f64>) {
198 self.losses.push(loss);
199 if let Some(acc) = accuracy {
200 self.accuracies.push(acc);
201 }
202 }
203
204 pub fn add_validation(&mut self, loss: f64, accuracy: Option<f64>) {
206 self.val_losses.push(loss);
207 if let Some(acc) = accuracy {
208 self.val_accuracies.push(acc);
209 }
210 }
211}
212
213impl Default for TrainingHistory {
214 fn default() -> Self {
215 Self::new()
216 }
217}
218
219pub struct QuantumTrainer {
221 model: Box<dyn QuantumModule>,
223 optimizer: SciRS2Optimizer,
225 loss_fn: Box<dyn QuantumLoss>,
227 history: TrainingHistory,
229}
230
231impl QuantumTrainer {
232 pub fn new(
234 model: Box<dyn QuantumModule>,
235 optimizer: SciRS2Optimizer,
236 loss_fn: Box<dyn QuantumLoss>,
237 ) -> Self {
238 Self {
239 model,
240 optimizer,
241 loss_fn,
242 history: TrainingHistory::new(),
243 }
244 }
245
246 pub fn train_epoch<D: DataLoader>(&mut self, dataloader: &mut D) -> Result<f64> {
248 self.model.train(true);
249
250 let mut epoch_loss = 0.0;
251 let mut batches = 0;
252
253 while let Some((inputs, targets)) = dataloader.next_batch()? {
254 self.model.zero_grad();
256
257 let predictions = self.model.forward(&inputs)?;
259
260 let loss = self.loss_fn.forward(&predictions, &targets)?;
262 let loss_val = loss.data.iter().next().copied().unwrap_or(0.0);
263
264 epoch_loss += loss_val;
265 batches += 1;
266 }
267
268 let avg_loss = if batches > 0 {
269 epoch_loss / batches as f64
270 } else {
271 0.0
272 };
273 self.history.add_training(avg_loss, None);
274
275 Ok(avg_loss)
276 }
277
278 pub fn evaluate<D: DataLoader>(&mut self, dataloader: &mut D) -> Result<f64> {
280 self.model.train(false);
281
282 let mut total_loss = 0.0;
283 let mut batches = 0;
284
285 while let Some((inputs, targets)) = dataloader.next_batch()? {
286 let predictions = self.model.forward(&inputs)?;
287 let loss = self.loss_fn.forward(&predictions, &targets)?;
288 let loss_val = loss.data.iter().next().copied().unwrap_or(0.0);
289
290 total_loss += loss_val;
291 batches += 1;
292 }
293
294 let avg_loss = if batches > 0 {
295 total_loss / batches as f64
296 } else {
297 0.0
298 };
299
300 Ok(avg_loss)
301 }
302
303 pub fn history(&self) -> &TrainingHistory {
305 &self.history
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_quantum_linear() {
315 let linear = QuantumLinear::new(4, 2).expect("QuantumLinear creation should succeed");
316 assert_eq!(linear.in_features, 4);
317 assert_eq!(linear.out_features, 2);
318 assert_eq!(linear.parameters().len(), 1); let _linear_with_bias = linear.with_bias().expect("Adding bias should succeed");
321 }
323
324 #[test]
325 fn test_quantum_sequential() {
326 let model = QuantumSequential::new()
327 .add(Box::new(
328 QuantumLinear::new(4, 8).expect("QuantumLinear creation should succeed"),
329 ))
330 .add(Box::new(QuantumActivation::relu()))
331 .add(Box::new(
332 QuantumLinear::new(8, 2).expect("QuantumLinear creation should succeed"),
333 ));
334
335 assert_eq!(model.len(), 3);
336 assert!(!model.is_empty());
337 }
338
339 #[test]
340 fn test_quantum_activation() {
341 let mut relu = QuantumActivation::relu();
342 let input_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![-1.0, 1.0])
343 .expect("Valid shape for input data");
344 let input = SciRS2Array::new(input_data, false);
345
346 let output = relu.forward(&input).expect("Forward pass should succeed");
347 assert_eq!(output.data[[0]], 0.0); assert_eq!(output.data[[1]], 1.0); }
350
351 #[test]
352 #[ignore]
353 fn test_quantum_loss() {
354 let mse_loss = QuantumMSELoss;
355
356 let pred_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0])
357 .expect("Valid shape for predictions");
358 let target_data =
359 ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.5, 1.8]).expect("Valid shape for targets");
360
361 let predictions = SciRS2Array::new(pred_data, false);
362 let targets = SciRS2Array::new(target_data, false);
363
364 let loss = mse_loss
365 .forward(&predictions, &targets)
366 .expect("Loss computation should succeed");
367 assert!(loss.data[[0]] > 0.0); }
369
370 #[test]
371 fn test_parameter() {
372 let data = ArrayD::from_shape_vec(IxDyn(&[2, 3]), vec![1.0; 6])
373 .expect("Valid shape for parameter data");
374 let param = Parameter::new(SciRS2Array::new(data, true), "test_param");
375
376 assert_eq!(param.name, "test_param");
377 assert!(param.requires_grad);
378 assert_eq!(param.shape(), &[2, 3]);
379 assert_eq!(param.numel(), 6);
380 }
381
382 #[test]
383 fn test_training_history() {
384 let mut history = TrainingHistory::new();
385 history.add_training(0.5, Some(0.8));
386 history.add_validation(0.6, Some(0.7));
387
388 assert_eq!(history.losses.len(), 1);
389 assert_eq!(history.accuracies.len(), 1);
390 assert_eq!(history.val_losses.len(), 1);
391 assert_eq!(history.val_accuracies.len(), 1);
392 }
393}