quantrs2_ml/
quantum_pinns.rs

1//! Quantum Physics-Informed Neural Networks (QPINNs)
2//!
3//! This module implements Quantum Physics-Informed Neural Networks, which incorporate
4//! physical laws and constraints directly into quantum neural network architectures.
5//! QPINNs can solve partial differential equations (PDEs) and enforce physical
6//! conservation laws using quantum computing advantages.
7
8use crate::error::Result;
9use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayD, IxDyn};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Configuration for Quantum Physics-Informed Neural Networks
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct QPINNConfig {
16    /// Number of qubits for the quantum neural network
17    pub num_qubits: usize,
18    /// Number of quantum layers
19    pub num_layers: usize,
20    /// Spatial domain boundaries
21    pub domain_bounds: Vec<(f64, f64)>,
22    /// Temporal domain bounds
23    pub time_bounds: (f64, f64),
24    /// Physical equation type
25    pub equation_type: PhysicsEquationType,
26    /// Boundary condition types
27    pub boundary_conditions: Vec<BoundaryCondition>,
28    /// Initial conditions
29    pub initial_conditions: Vec<InitialCondition>,
30    /// Loss function weights
31    pub loss_weights: LossWeights,
32    /// Quantum ansatz configuration
33    pub ansatz_config: AnsatzConfig,
34    /// Training configuration
35    pub training_config: TrainingConfig,
36    /// Physics constraints
37    pub physics_constraints: PhysicsConstraints,
38}
39
40impl Default for QPINNConfig {
41    fn default() -> Self {
42        Self {
43            num_qubits: 6,
44            num_layers: 4,
45            domain_bounds: vec![(-1.0, 1.0), (-1.0, 1.0)],
46            time_bounds: (0.0, 1.0),
47            equation_type: PhysicsEquationType::Poisson,
48            boundary_conditions: vec![],
49            initial_conditions: vec![],
50            loss_weights: LossWeights::default(),
51            ansatz_config: AnsatzConfig::default(),
52            training_config: TrainingConfig::default(),
53            physics_constraints: PhysicsConstraints::default(),
54        }
55    }
56}
57
58/// Types of physics equations that can be solved
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub enum PhysicsEquationType {
61    /// Poisson equation: ∇²u = f
62    Poisson,
63    /// Heat equation: ∂u/∂t = α∇²u
64    Heat,
65    /// Wave equation: ∂²u/∂t² = c²∇²u
66    Wave,
67    /// Schrödinger equation: iℏ∂ψ/∂t = Ĥψ
68    Schrodinger,
69    /// Navier-Stokes equations
70    NavierStokes,
71    /// Maxwell equations
72    Maxwell,
73    /// Klein-Gordon equation
74    KleinGordon,
75    /// Burgers equation: ∂u/∂t + u∇u = ν∇²u
76    Burgers,
77    /// Custom PDE with user-defined operator
78    Custom {
79        differential_operator: String,
80        equation_form: String,
81    },
82}
83
84/// Boundary condition specifications
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct BoundaryCondition {
87    /// Boundary location specification
88    pub boundary: BoundaryLocation,
89    /// Type of boundary condition
90    pub condition_type: BoundaryType,
91    /// Boundary value function
92    pub value_function: String, // Mathematical expression
93}
94
95/// Boundary location in the domain
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub enum BoundaryLocation {
98    /// Left boundary (x = x_min)
99    Left,
100    /// Right boundary (x = x_max)
101    Right,
102    /// Bottom boundary (y = y_min)
103    Bottom,
104    /// Top boundary (y = y_max)
105    Top,
106    /// Custom boundary defined by equation
107    Custom(String),
108}
109
110/// Types of boundary conditions
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum BoundaryType {
113    /// Dirichlet boundary condition (fixed value)
114    Dirichlet,
115    /// Neumann boundary condition (fixed derivative)
116    Neumann,
117    /// Robin boundary condition (mixed)
118    Robin { alpha: f64, beta: f64 },
119    /// Periodic boundary condition
120    Periodic,
121}
122
123/// Initial condition specifications
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct InitialCondition {
126    /// Initial value function
127    pub value_function: String,
128    /// Derivative initial condition (for second-order equations)
129    pub derivative_function: Option<String>,
130}
131
132/// Loss function weights for different terms
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct LossWeights {
135    /// Weight for PDE residual loss
136    pub pde_loss_weight: f64,
137    /// Weight for boundary condition loss
138    pub boundary_loss_weight: f64,
139    /// Weight for initial condition loss
140    pub initial_loss_weight: f64,
141    /// Weight for physics constraint loss
142    pub physics_constraint_weight: f64,
143    /// Weight for data fitting loss (if available)
144    pub data_loss_weight: f64,
145}
146
147impl Default for LossWeights {
148    fn default() -> Self {
149        Self {
150            pde_loss_weight: 1.0,
151            boundary_loss_weight: 10.0,
152            initial_loss_weight: 10.0,
153            physics_constraint_weight: 1.0,
154            data_loss_weight: 1.0,
155        }
156    }
157}
158
159/// Quantum ansatz configuration
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct AnsatzConfig {
162    /// Type of quantum ansatz
163    pub ansatz_type: QuantumAnsatzType,
164    /// Entanglement pattern
165    pub entanglement_pattern: EntanglementPattern,
166    /// Number of repetitions
167    pub repetitions: usize,
168    /// Parameter initialization strategy
169    pub parameter_init: ParameterInitialization,
170}
171
172impl Default for AnsatzConfig {
173    fn default() -> Self {
174        Self {
175            ansatz_type: QuantumAnsatzType::EfficientSU2,
176            entanglement_pattern: EntanglementPattern::Linear,
177            repetitions: 3,
178            parameter_init: ParameterInitialization::Random,
179        }
180    }
181}
182
183/// Types of quantum ansatz circuits
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub enum QuantumAnsatzType {
186    /// Efficient SU(2) ansatz
187    EfficientSU2,
188    /// Two-local ansatz
189    TwoLocal,
190    /// Alternating operator ansatz
191    AlternatingOperator,
192    /// Hardware-efficient ansatz
193    HardwareEfficient,
194    /// Physics-informed ansatz (problem-specific)
195    PhysicsInformed,
196}
197
198/// Entanglement patterns for quantum circuits
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub enum EntanglementPattern {
201    /// Linear entanglement (nearest neighbor)
202    Linear,
203    /// Circular entanglement
204    Circular,
205    /// Full entanglement (all-to-all)
206    Full,
207    /// Custom entanglement pattern
208    Custom(Vec<(usize, usize)>),
209}
210
211/// Parameter initialization strategies
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub enum ParameterInitialization {
214    /// Random initialization
215    Random,
216    /// Xavier/Glorot initialization
217    Xavier,
218    /// He initialization
219    He,
220    /// Physics-informed initialization
221    PhysicsInformed,
222    /// Custom initialization
223    Custom(Vec<f64>),
224}
225
226/// Training configuration
227#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct TrainingConfig {
229    /// Number of training epochs
230    pub epochs: usize,
231    /// Learning rate
232    pub learning_rate: f64,
233    /// Optimizer type
234    pub optimizer: OptimizerType,
235    /// Batch size for collocation points
236    pub batch_size: usize,
237    /// Number of collocation points
238    pub num_collocation_points: usize,
239    /// Adaptive sampling strategy
240    pub adaptive_sampling: bool,
241}
242
243impl Default for TrainingConfig {
244    fn default() -> Self {
245        Self {
246            epochs: 1000,
247            learning_rate: 0.001,
248            optimizer: OptimizerType::Adam,
249            batch_size: 128,
250            num_collocation_points: 1000,
251            adaptive_sampling: true,
252        }
253    }
254}
255
256/// Optimizer types
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub enum OptimizerType {
259    Adam,
260    LBFGS,
261    SGD,
262    QuantumNaturalGradient,
263    ParameterShift,
264}
265
266/// Physics constraints for the problem
267#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct PhysicsConstraints {
269    /// Conservation laws to enforce
270    pub conservation_laws: Vec<ConservationLaw>,
271    /// Symmetries to preserve
272    pub symmetries: Vec<Symmetry>,
273    /// Physical bounds on the solution
274    pub solution_bounds: Option<(f64, f64)>,
275    /// Energy constraints
276    pub energy_constraints: Vec<EnergyConstraint>,
277}
278
279impl Default for PhysicsConstraints {
280    fn default() -> Self {
281        Self {
282            conservation_laws: vec![],
283            symmetries: vec![],
284            solution_bounds: None,
285            energy_constraints: vec![],
286        }
287    }
288}
289
290/// Conservation laws
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub enum ConservationLaw {
293    /// Conservation of mass
294    Mass,
295    /// Conservation of momentum
296    Momentum,
297    /// Conservation of energy
298    Energy,
299    /// Conservation of charge
300    Charge,
301    /// Custom conservation law
302    Custom(String),
303}
304
305/// Symmetries in the problem
306#[derive(Debug, Clone, Serialize, Deserialize)]
307pub enum Symmetry {
308    /// Translational symmetry
309    Translational,
310    /// Rotational symmetry
311    Rotational,
312    /// Reflection symmetry
313    Reflection,
314    /// Time reversal symmetry
315    TimeReversal,
316    /// Custom symmetry
317    Custom(String),
318}
319
320/// Energy constraints
321#[derive(Debug, Clone, Serialize, Deserialize)]
322pub struct EnergyConstraint {
323    /// Type of energy constraint
324    pub constraint_type: EnergyConstraintType,
325    /// Target energy value
326    pub target_value: f64,
327    /// Tolerance for constraint satisfaction
328    pub tolerance: f64,
329}
330
331/// Types of energy constraints
332#[derive(Debug, Clone, Serialize, Deserialize)]
333pub enum EnergyConstraintType {
334    /// Total energy constraint
335    Total,
336    /// Kinetic energy constraint
337    Kinetic,
338    /// Potential energy constraint
339    Potential,
340    /// Custom energy functional
341    Custom(String),
342}
343
344/// Main Quantum Physics-Informed Neural Network
345#[derive(Debug, Clone)]
346pub struct QuantumPINN {
347    config: QPINNConfig,
348    quantum_circuit: QuantumCircuit,
349    parameters: Array1<f64>,
350    collocation_points: Array2<f64>,
351    training_history: Vec<TrainingMetrics>,
352    physics_evaluator: PhysicsEvaluator,
353}
354
355/// Quantum circuit for the PINN
356#[derive(Debug, Clone)]
357pub struct QuantumCircuit {
358    gates: Vec<QuantumGate>,
359    num_qubits: usize,
360    parameter_map: HashMap<usize, usize>, // Gate index to parameter index
361}
362
363/// Individual quantum gates
364#[derive(Debug, Clone)]
365pub struct QuantumGate {
366    gate_type: GateType,
367    qubits: Vec<usize>,
368    parameters: Vec<usize>, // Parameter indices
369    is_parametric: bool,
370}
371
372/// Gate types for quantum circuits
373#[derive(Debug, Clone)]
374pub enum GateType {
375    RX,
376    RY,
377    RZ,
378    CNOT,
379    CZ,
380    CY,
381    Hadamard,
382    S,
383    T,
384    Custom(String),
385}
386
387/// Physics evaluator for computing PDE residuals
388#[derive(Debug, Clone)]
389pub struct PhysicsEvaluator {
390    equation_type: PhysicsEquationType,
391    differential_operators: HashMap<String, DifferentialOperator>,
392}
393
394/// Differential operators for computing derivatives
395#[derive(Debug, Clone)]
396pub struct DifferentialOperator {
397    operator_type: OperatorType,
398    order: usize,
399    direction: Vec<usize>, // Spatial directions for mixed derivatives
400}
401
402/// Types of differential operators
403#[derive(Debug, Clone)]
404pub enum OperatorType {
405    Gradient,
406    Laplacian,
407    Divergence,
408    Curl,
409    TimeDerivative,
410    Mixed,
411}
412
413/// Training metrics for QPINNs
414#[derive(Debug, Clone)]
415pub struct TrainingMetrics {
416    epoch: usize,
417    total_loss: f64,
418    pde_loss: f64,
419    boundary_loss: f64,
420    initial_loss: f64,
421    physics_constraint_loss: f64,
422    quantum_fidelity: f64,
423    solution_energy: f64,
424}
425
426impl QuantumPINN {
427    /// Create a new Quantum Physics-Informed Neural Network
428    pub fn new(config: QPINNConfig) -> Result<Self> {
429        let quantum_circuit = Self::build_quantum_circuit(&config)?;
430        let num_parameters = Self::count_parameters(&quantum_circuit);
431        let parameters = Self::initialize_parameters(&config, num_parameters)?;
432        let collocation_points = Self::generate_collocation_points(&config)?;
433        let physics_evaluator = PhysicsEvaluator::new(&config.equation_type)?;
434
435        Ok(Self {
436            config,
437            quantum_circuit,
438            parameters,
439            collocation_points,
440            training_history: Vec::new(),
441            physics_evaluator,
442        })
443    }
444
445    /// Build the quantum circuit based on configuration
446    fn build_quantum_circuit(config: &QPINNConfig) -> Result<QuantumCircuit> {
447        let mut gates = Vec::new();
448        let mut parameter_map = HashMap::new();
449        let mut param_index = 0;
450
451        match config.ansatz_config.ansatz_type {
452            QuantumAnsatzType::EfficientSU2 => {
453                for rep in 0..config.ansatz_config.repetitions {
454                    // Single-qubit rotations
455                    for qubit in 0..config.num_qubits {
456                        // RY gate
457                        gates.push(QuantumGate {
458                            gate_type: GateType::RY,
459                            qubits: vec![qubit],
460                            parameters: vec![param_index],
461                            is_parametric: true,
462                        });
463                        parameter_map.insert(gates.len() - 1, param_index);
464                        param_index += 1;
465
466                        // RZ gate
467                        gates.push(QuantumGate {
468                            gate_type: GateType::RZ,
469                            qubits: vec![qubit],
470                            parameters: vec![param_index],
471                            is_parametric: true,
472                        });
473                        parameter_map.insert(gates.len() - 1, param_index);
474                        param_index += 1;
475                    }
476
477                    // Entangling gates
478                    match config.ansatz_config.entanglement_pattern {
479                        EntanglementPattern::Linear => {
480                            for qubit in 0..config.num_qubits - 1 {
481                                gates.push(QuantumGate {
482                                    gate_type: GateType::CNOT,
483                                    qubits: vec![qubit, qubit + 1],
484                                    parameters: vec![],
485                                    is_parametric: false,
486                                });
487                            }
488                        }
489                        EntanglementPattern::Circular => {
490                            for qubit in 0..config.num_qubits {
491                                gates.push(QuantumGate {
492                                    gate_type: GateType::CNOT,
493                                    qubits: vec![qubit, (qubit + 1) % config.num_qubits],
494                                    parameters: vec![],
495                                    is_parametric: false,
496                                });
497                            }
498                        }
499                        EntanglementPattern::Full => {
500                            for i in 0..config.num_qubits {
501                                for j in i + 1..config.num_qubits {
502                                    gates.push(QuantumGate {
503                                        gate_type: GateType::CNOT,
504                                        qubits: vec![i, j],
505                                        parameters: vec![],
506                                        is_parametric: false,
507                                    });
508                                }
509                            }
510                        }
511                        _ => {
512                            return Err(crate::error::MLError::InvalidConfiguration(
513                                "Unsupported entanglement pattern".to_string(),
514                            ));
515                        }
516                    }
517                }
518            }
519            QuantumAnsatzType::PhysicsInformed => {
520                // Build physics-informed ansatz based on the equation type
521                gates = Self::build_physics_informed_ansatz(
522                    config,
523                    &mut param_index,
524                    &mut parameter_map,
525                )?;
526            }
527            _ => {
528                return Err(crate::error::MLError::InvalidConfiguration(
529                    "Ansatz type not implemented".to_string(),
530                ));
531            }
532        }
533
534        Ok(QuantumCircuit {
535            gates,
536            num_qubits: config.num_qubits,
537            parameter_map,
538        })
539    }
540
541    /// Build physics-informed ansatz specific to the equation type
542    fn build_physics_informed_ansatz(
543        config: &QPINNConfig,
544        param_index: &mut usize,
545        parameter_map: &mut HashMap<usize, usize>,
546    ) -> Result<Vec<QuantumGate>> {
547        let mut gates = Vec::new();
548
549        match config.equation_type {
550            PhysicsEquationType::Schrodinger => {
551                // Use time-evolution inspired ansatz
552                for layer in 0..config.num_layers {
553                    // Kinetic energy terms (hopping)
554                    for qubit in 0..config.num_qubits - 1 {
555                        gates.push(QuantumGate {
556                            gate_type: GateType::RX,
557                            qubits: vec![qubit],
558                            parameters: vec![*param_index],
559                            is_parametric: true,
560                        });
561                        parameter_map.insert(gates.len() - 1, *param_index);
562                        *param_index += 1;
563
564                        gates.push(QuantumGate {
565                            gate_type: GateType::CNOT,
566                            qubits: vec![qubit, qubit + 1],
567                            parameters: vec![],
568                            is_parametric: false,
569                        });
570
571                        gates.push(QuantumGate {
572                            gate_type: GateType::RZ,
573                            qubits: vec![qubit + 1],
574                            parameters: vec![*param_index],
575                            is_parametric: true,
576                        });
577                        parameter_map.insert(gates.len() - 1, *param_index);
578                        *param_index += 1;
579
580                        gates.push(QuantumGate {
581                            gate_type: GateType::CNOT,
582                            qubits: vec![qubit, qubit + 1],
583                            parameters: vec![],
584                            is_parametric: false,
585                        });
586                    }
587
588                    // Potential energy terms
589                    for qubit in 0..config.num_qubits {
590                        gates.push(QuantumGate {
591                            gate_type: GateType::RZ,
592                            qubits: vec![qubit],
593                            parameters: vec![*param_index],
594                            is_parametric: true,
595                        });
596                        parameter_map.insert(gates.len() - 1, *param_index);
597                        *param_index += 1;
598                    }
599                }
600            }
601            PhysicsEquationType::Heat => {
602                // Diffusion-inspired ansatz
603                for layer in 0..config.num_layers {
604                    for qubit in 0..config.num_qubits {
605                        gates.push(QuantumGate {
606                            gate_type: GateType::RY,
607                            qubits: vec![qubit],
608                            parameters: vec![*param_index],
609                            is_parametric: true,
610                        });
611                        parameter_map.insert(gates.len() - 1, *param_index);
612                        *param_index += 1;
613                    }
614
615                    // Nearest-neighbor interactions for diffusion
616                    for qubit in 0..config.num_qubits - 1 {
617                        gates.push(QuantumGate {
618                            gate_type: GateType::CZ,
619                            qubits: vec![qubit, qubit + 1],
620                            parameters: vec![],
621                            is_parametric: false,
622                        });
623                    }
624                }
625            }
626            _ => {
627                // Default to efficient SU(2) for other equation types
628                for qubit in 0..config.num_qubits {
629                    gates.push(QuantumGate {
630                        gate_type: GateType::RY,
631                        qubits: vec![qubit],
632                        parameters: vec![*param_index],
633                        is_parametric: true,
634                    });
635                    parameter_map.insert(gates.len() - 1, *param_index);
636                    *param_index += 1;
637                }
638            }
639        }
640
641        Ok(gates)
642    }
643
644    /// Count parameters in the quantum circuit
645    fn count_parameters(circuit: &QuantumCircuit) -> usize {
646        circuit
647            .gates
648            .iter()
649            .filter(|gate| gate.is_parametric)
650            .map(|gate| gate.parameters.len())
651            .sum()
652    }
653
654    /// Initialize parameters based on configuration
655    fn initialize_parameters(config: &QPINNConfig, num_params: usize) -> Result<Array1<f64>> {
656        match &config.ansatz_config.parameter_init {
657            ParameterInitialization::Random => Ok(Array1::from_shape_fn(num_params, |_| {
658                fastrand::f64() * 2.0 * std::f64::consts::PI
659            })),
660            ParameterInitialization::Xavier => {
661                let limit = (6.0 / num_params as f64).sqrt();
662                Ok(Array1::from_shape_fn(num_params, |_| {
663                    (fastrand::f64() - 0.5) * 2.0 * limit
664                }))
665            }
666            ParameterInitialization::PhysicsInformed => {
667                // Initialize based on physical intuition
668                match config.equation_type {
669                    PhysicsEquationType::Schrodinger => {
670                        // Small random values for quantum evolution
671                        Ok(Array1::from_shape_fn(num_params, |_| {
672                            (fastrand::f64() - 0.5) * 0.1
673                        }))
674                    }
675                    PhysicsEquationType::Heat => {
676                        // Initialize for diffusive behavior
677                        Ok(Array1::from_shape_fn(num_params, |i| {
678                            0.1 * (i as f64 / num_params as f64)
679                        }))
680                    }
681                    _ => {
682                        // Default random initialization
683                        Ok(Array1::from_shape_fn(num_params, |_| {
684                            fastrand::f64() * std::f64::consts::PI
685                        }))
686                    }
687                }
688            }
689            ParameterInitialization::Custom(values) => {
690                if values.len() != num_params {
691                    return Err(crate::error::MLError::InvalidConfiguration(
692                        "Custom parameter length mismatch".to_string(),
693                    ));
694                }
695                Ok(Array1::from_vec(values.clone()))
696            }
697            _ => Ok(Array1::zeros(num_params)),
698        }
699    }
700
701    /// Generate collocation points for training
702    fn generate_collocation_points(config: &QPINNConfig) -> Result<Array2<f64>> {
703        let num_points = config.training_config.num_collocation_points;
704        let num_dims = config.domain_bounds.len() + 1; // spatial + time
705        let mut points = Array2::zeros((num_points, num_dims));
706
707        for i in 0..num_points {
708            // Spatial coordinates
709            for (j, &(min_val, max_val)) in config.domain_bounds.iter().enumerate() {
710                points[[i, j]] = min_val + fastrand::f64() * (max_val - min_val);
711            }
712
713            // Temporal coordinate
714            let (t_min, t_max) = config.time_bounds;
715            points[[i, config.domain_bounds.len()]] = t_min + fastrand::f64() * (t_max - t_min);
716        }
717
718        Ok(points)
719    }
720
721    /// Forward pass through the quantum network
722    pub fn forward(&self, input_points: &Array2<f64>) -> Result<Array2<f64>> {
723        let batch_size = input_points.nrows();
724        let num_outputs = 1; // Single output for scalar PDE solutions
725        let mut outputs = Array2::zeros((batch_size, num_outputs));
726
727        for i in 0..batch_size {
728            let point = input_points.row(i);
729            let quantum_state = self.encode_input(&point.to_owned())?;
730            let evolved_state = self.apply_quantum_circuit(&quantum_state)?;
731            let output = self.decode_output(&evolved_state)?;
732            outputs[[i, 0]] = output;
733        }
734
735        Ok(outputs)
736    }
737
738    /// Encode input coordinates into quantum state
739    fn encode_input(&self, point: &Array1<f64>) -> Result<Array1<f64>> {
740        let num_amplitudes = 1 << self.config.num_qubits;
741        let mut quantum_state = Array1::zeros(num_amplitudes);
742
743        // Amplitude encoding of coordinates
744        let norm = point.iter().map(|x| x * x).sum::<f64>().sqrt();
745        if norm > 1e-10 {
746            for (i, &coord) in point.iter().enumerate() {
747                if i < num_amplitudes {
748                    quantum_state[i] = coord / norm;
749                }
750            }
751        } else {
752            quantum_state[0] = 1.0;
753        }
754
755        // Normalize the quantum state
756        let state_norm = quantum_state.iter().map(|x| x * x).sum::<f64>().sqrt();
757        if state_norm > 1e-10 {
758            quantum_state /= state_norm;
759        }
760
761        Ok(quantum_state)
762    }
763
764    /// Apply the parameterized quantum circuit
765    fn apply_quantum_circuit(&self, input_state: &Array1<f64>) -> Result<Array1<f64>> {
766        let mut state = input_state.clone();
767
768        for gate in &self.quantum_circuit.gates {
769            match gate.gate_type {
770                GateType::RY => {
771                    let angle = if gate.is_parametric {
772                        self.parameters[gate.parameters[0]]
773                    } else {
774                        0.0
775                    };
776                    state = self.apply_ry_gate(&state, gate.qubits[0], angle)?;
777                }
778                GateType::RZ => {
779                    let angle = if gate.is_parametric {
780                        self.parameters[gate.parameters[0]]
781                    } else {
782                        0.0
783                    };
784                    state = self.apply_rz_gate(&state, gate.qubits[0], angle)?;
785                }
786                GateType::RX => {
787                    let angle = if gate.is_parametric {
788                        self.parameters[gate.parameters[0]]
789                    } else {
790                        0.0
791                    };
792                    state = self.apply_rx_gate(&state, gate.qubits[0], angle)?;
793                }
794                GateType::CNOT => {
795                    state = self.apply_cnot_gate(&state, gate.qubits[0], gate.qubits[1])?;
796                }
797                GateType::CZ => {
798                    state = self.apply_cz_gate(&state, gate.qubits[0], gate.qubits[1])?;
799                }
800                _ => {
801                    // Other gates can be implemented as needed
802                }
803            }
804        }
805
806        Ok(state)
807    }
808
809    /// Apply RX gate
810    fn apply_rx_gate(&self, state: &Array1<f64>, qubit: usize, angle: f64) -> Result<Array1<f64>> {
811        let mut new_state = state.clone();
812        let cos_half = (angle / 2.0).cos();
813        let sin_half = (angle / 2.0).sin();
814
815        let qubit_mask = 1 << qubit;
816
817        for i in 0..state.len() {
818            if i & qubit_mask == 0 {
819                let j = i | qubit_mask;
820                if j < state.len() {
821                    let state_0 = state[i];
822                    let state_1 = state[j];
823                    new_state[i] = cos_half * state_0 - sin_half * state_1;
824                    new_state[j] = -sin_half * state_0 + cos_half * state_1;
825                }
826            }
827        }
828
829        Ok(new_state)
830    }
831
832    /// Apply RY gate
833    fn apply_ry_gate(&self, state: &Array1<f64>, qubit: usize, angle: f64) -> Result<Array1<f64>> {
834        let mut new_state = state.clone();
835        let cos_half = (angle / 2.0).cos();
836        let sin_half = (angle / 2.0).sin();
837
838        let qubit_mask = 1 << qubit;
839
840        for i in 0..state.len() {
841            if i & qubit_mask == 0 {
842                let j = i | qubit_mask;
843                if j < state.len() {
844                    let state_0 = state[i];
845                    let state_1 = state[j];
846                    new_state[i] = cos_half * state_0 - sin_half * state_1;
847                    new_state[j] = sin_half * state_0 + cos_half * state_1;
848                }
849            }
850        }
851
852        Ok(new_state)
853    }
854
855    /// Apply RZ gate
856    fn apply_rz_gate(&self, state: &Array1<f64>, qubit: usize, angle: f64) -> Result<Array1<f64>> {
857        let mut new_state = state.clone();
858        let phase_0 = (-angle / 2.0); // For real-valued implementation
859        let phase_1 = (angle / 2.0);
860
861        let qubit_mask = 1 << qubit;
862
863        for i in 0..state.len() {
864            if i & qubit_mask == 0 {
865                new_state[i] *= phase_0.cos(); // Real part only for simplification
866            } else {
867                new_state[i] *= phase_1.cos();
868            }
869        }
870
871        Ok(new_state)
872    }
873
874    /// Apply CNOT gate
875    fn apply_cnot_gate(
876        &self,
877        state: &Array1<f64>,
878        control: usize,
879        target: usize,
880    ) -> Result<Array1<f64>> {
881        let mut new_state = state.clone();
882        let control_mask = 1 << control;
883        let target_mask = 1 << target;
884
885        for i in 0..state.len() {
886            if i & control_mask != 0 {
887                let j = i ^ target_mask;
888                new_state[i] = state[j];
889            }
890        }
891
892        Ok(new_state)
893    }
894
895    /// Apply CZ gate
896    fn apply_cz_gate(
897        &self,
898        state: &Array1<f64>,
899        control: usize,
900        target: usize,
901    ) -> Result<Array1<f64>> {
902        let mut new_state = state.clone();
903        let control_mask = 1 << control;
904        let target_mask = 1 << target;
905
906        for i in 0..state.len() {
907            if (i & control_mask != 0) && (i & target_mask != 0) {
908                new_state[i] *= -1.0; // Apply phase flip
909            }
910        }
911
912        Ok(new_state)
913    }
914
915    /// Decode quantum state to classical output
916    fn decode_output(&self, quantum_state: &Array1<f64>) -> Result<f64> {
917        // Expectation value of Z operator on first qubit
918        let mut expectation = 0.0;
919
920        for (i, &amplitude) in quantum_state.iter().enumerate() {
921            if i & 1 == 0 {
922                expectation += amplitude * amplitude;
923            } else {
924                expectation -= amplitude * amplitude;
925            }
926        }
927
928        Ok(expectation)
929    }
930
931    /// Compute derivatives using automatic differentiation
932    pub fn compute_derivatives(&self, points: &Array2<f64>) -> Result<DerivativeResults> {
933        let h = 1e-5; // Finite difference step
934        let num_points = points.nrows();
935        let num_dims = points.ncols();
936
937        let mut first_derivatives = Array2::zeros((num_points, num_dims));
938        let mut second_derivatives = Array2::zeros((num_points, num_dims));
939        let mut mixed_derivatives = Array3::zeros((num_points, num_dims, num_dims));
940
941        for i in 0..num_points {
942            for j in 0..num_dims {
943                // First derivatives
944                let mut point_plus = points.row(i).to_owned();
945                let mut point_minus = points.row(i).to_owned();
946                point_plus[j] += h;
947                point_minus[j] -= h;
948
949                let output_plus =
950                    self.forward(&point_plus.insert_axis(scirs2_core::ndarray::Axis(0)))?[[0, 0]];
951                let output_minus =
952                    self.forward(&point_minus.insert_axis(scirs2_core::ndarray::Axis(0)))?[[0, 0]];
953
954                first_derivatives[[i, j]] = (output_plus - output_minus) / (2.0 * h);
955
956                // Second derivatives
957                let output_center = self.forward(
958                    &points
959                        .row(i)
960                        .insert_axis(scirs2_core::ndarray::Axis(0))
961                        .to_owned(),
962                )?[[0, 0]];
963                second_derivatives[[i, j]] =
964                    (output_plus - 2.0 * output_center + output_minus) / (h * h);
965
966                // Mixed derivatives
967                for k in j + 1..num_dims {
968                    let mut point_pp = points.row(i).to_owned();
969                    let mut point_pm = points.row(i).to_owned();
970                    let mut point_mp = points.row(i).to_owned();
971                    let mut point_mm = points.row(i).to_owned();
972
973                    point_pp[j] += h;
974                    point_pp[k] += h;
975                    point_pm[j] += h;
976                    point_pm[k] -= h;
977                    point_mp[j] -= h;
978                    point_mp[k] += h;
979                    point_mm[j] -= h;
980                    point_mm[k] -= h;
981
982                    let output_pp =
983                        self.forward(&point_pp.insert_axis(scirs2_core::ndarray::Axis(0)))?[[0, 0]];
984                    let output_pm =
985                        self.forward(&point_pm.insert_axis(scirs2_core::ndarray::Axis(0)))?[[0, 0]];
986                    let output_mp =
987                        self.forward(&point_mp.insert_axis(scirs2_core::ndarray::Axis(0)))?[[0, 0]];
988                    let output_mm =
989                        self.forward(&point_mm.insert_axis(scirs2_core::ndarray::Axis(0)))?[[0, 0]];
990
991                    let mixed_deriv =
992                        (output_pp - output_pm - output_mp + output_mm) / (4.0 * h * h);
993                    mixed_derivatives[[i, j, k]] = mixed_deriv;
994                    mixed_derivatives[[i, k, j]] = mixed_deriv; // Symmetry
995                }
996            }
997        }
998
999        Ok(DerivativeResults {
1000            first_derivatives,
1001            second_derivatives,
1002            mixed_derivatives,
1003        })
1004    }
1005
1006    /// Train the Quantum PINN
1007    pub fn train(&mut self, epochs: Option<usize>) -> Result<()> {
1008        let num_epochs = epochs.unwrap_or(self.config.training_config.epochs);
1009
1010        for epoch in 0..num_epochs {
1011            // Adaptive sampling of collocation points
1012            if self.config.training_config.adaptive_sampling && epoch % 100 == 0 {
1013                self.collocation_points =
1014                    Self::generate_adaptive_collocation_points(&self.config, epoch)?;
1015            }
1016
1017            // Compute total loss
1018            let total_loss = self.compute_total_loss()?;
1019
1020            // Compute gradients
1021            let gradients = self.compute_gradients()?;
1022
1023            // Update parameters
1024            self.update_parameters(&gradients)?;
1025
1026            // Record metrics
1027            let metrics = self.compute_training_metrics(epoch, total_loss)?;
1028            self.training_history.push(metrics);
1029
1030            if epoch % 100 == 0 {
1031                if let Some(last_metrics) = self.training_history.last() {
1032                    println!(
1033                        "Epoch {}: Total Loss = {:.6}, PDE Loss = {:.6}, Boundary Loss = {:.6}",
1034                        epoch,
1035                        last_metrics.total_loss,
1036                        last_metrics.pde_loss,
1037                        last_metrics.boundary_loss
1038                    );
1039                }
1040            }
1041        }
1042
1043        Ok(())
1044    }
1045
1046    /// Generate adaptive collocation points
1047    fn generate_adaptive_collocation_points(
1048        config: &QPINNConfig,
1049        epoch: usize,
1050    ) -> Result<Array2<f64>> {
1051        // For now, use uniform random sampling (adaptive refinement would be more complex)
1052        Self::generate_collocation_points(config)
1053    }
1054
1055    /// Compute total loss
1056    fn compute_total_loss(&self) -> Result<TotalLoss> {
1057        let pde_loss = self.compute_pde_loss()?;
1058        let boundary_loss = self.compute_boundary_loss()?;
1059        let initial_loss = self.compute_initial_loss()?;
1060        let physics_constraint_loss = self.compute_physics_constraint_loss()?;
1061
1062        let weights = &self.config.loss_weights;
1063        let total = weights.pde_loss_weight * pde_loss
1064            + weights.boundary_loss_weight * boundary_loss
1065            + weights.initial_loss_weight * initial_loss
1066            + weights.physics_constraint_weight * physics_constraint_loss;
1067
1068        Ok(TotalLoss {
1069            total,
1070            pde_loss,
1071            boundary_loss,
1072            initial_loss,
1073            physics_constraint_loss,
1074        })
1075    }
1076
1077    /// Compute PDE residual loss
1078    fn compute_pde_loss(&self) -> Result<f64> {
1079        let derivatives = self.compute_derivatives(&self.collocation_points)?;
1080        let residuals = self.physics_evaluator.compute_pde_residual(
1081            &self.collocation_points,
1082            &self.forward(&self.collocation_points)?,
1083            &derivatives,
1084        )?;
1085
1086        Ok(residuals.iter().map(|r| r * r).sum::<f64>() / residuals.len() as f64)
1087    }
1088
1089    /// Compute boundary condition loss
1090    fn compute_boundary_loss(&self) -> Result<f64> {
1091        // Generate boundary points
1092        let boundary_points = self.generate_boundary_points()?;
1093        let boundary_values = self.forward(&boundary_points)?;
1094
1095        let mut total_loss = 0.0;
1096        for (bc, points) in self
1097            .config
1098            .boundary_conditions
1099            .iter()
1100            .zip(boundary_values.rows())
1101        {
1102            let target_values = self.evaluate_boundary_condition(bc, &boundary_points)?;
1103            for (predicted, target) in points.iter().zip(target_values.iter()) {
1104                total_loss += (predicted - target).powi(2);
1105            }
1106        }
1107
1108        Ok(total_loss)
1109    }
1110
1111    /// Compute initial condition loss
1112    fn compute_initial_loss(&self) -> Result<f64> {
1113        // Generate initial time points
1114        let initial_points = self.generate_initial_points()?;
1115        let initial_values = self.forward(&initial_points)?;
1116
1117        let mut total_loss = 0.0;
1118        for (ic, points) in self
1119            .config
1120            .initial_conditions
1121            .iter()
1122            .zip(initial_values.rows())
1123        {
1124            let target_values = self.evaluate_initial_condition(ic, &initial_points)?;
1125            for (predicted, target) in points.iter().zip(target_values.iter()) {
1126                total_loss += (predicted - target).powi(2);
1127            }
1128        }
1129
1130        Ok(total_loss)
1131    }
1132
1133    /// Compute physics constraint loss
1134    fn compute_physics_constraint_loss(&self) -> Result<f64> {
1135        // Implement conservation law and symmetry constraints
1136        let mut constraint_loss = 0.0;
1137
1138        for conservation_law in &self.config.physics_constraints.conservation_laws {
1139            constraint_loss += self.evaluate_conservation_law(conservation_law)?;
1140        }
1141
1142        for symmetry in &self.config.physics_constraints.symmetries {
1143            constraint_loss += self.evaluate_symmetry_constraint(symmetry)?;
1144        }
1145
1146        Ok(constraint_loss)
1147    }
1148
1149    /// Generate boundary points
1150    fn generate_boundary_points(&self) -> Result<Array2<f64>> {
1151        // Simplified boundary point generation
1152        let num_boundary_points = 100;
1153        let num_dims = self.config.domain_bounds.len() + 1;
1154        let mut boundary_points = Array2::zeros((num_boundary_points, num_dims));
1155
1156        // Generate points on each boundary
1157        for i in 0..num_boundary_points {
1158            for (j, &(min_val, max_val)) in self.config.domain_bounds.iter().enumerate() {
1159                if i % 2 == 0 {
1160                    boundary_points[[i, j]] = min_val; // Left/bottom boundary
1161                } else {
1162                    boundary_points[[i, j]] = max_val; // Right/top boundary
1163                }
1164            }
1165
1166            // Random time coordinate
1167            let (t_min, t_max) = self.config.time_bounds;
1168            boundary_points[[i, self.config.domain_bounds.len()]] =
1169                t_min + fastrand::f64() * (t_max - t_min);
1170        }
1171
1172        Ok(boundary_points)
1173    }
1174
1175    /// Generate initial time points
1176    fn generate_initial_points(&self) -> Result<Array2<f64>> {
1177        let num_initial_points = 100;
1178        let num_dims = self.config.domain_bounds.len() + 1;
1179        let mut initial_points = Array2::zeros((num_initial_points, num_dims));
1180
1181        for i in 0..num_initial_points {
1182            // Random spatial coordinates
1183            for (j, &(min_val, max_val)) in self.config.domain_bounds.iter().enumerate() {
1184                initial_points[[i, j]] = min_val + fastrand::f64() * (max_val - min_val);
1185            }
1186
1187            // Initial time
1188            initial_points[[i, self.config.domain_bounds.len()]] = self.config.time_bounds.0;
1189        }
1190
1191        Ok(initial_points)
1192    }
1193
1194    /// Evaluate boundary condition
1195    fn evaluate_boundary_condition(
1196        &self,
1197        _bc: &BoundaryCondition,
1198        _points: &Array2<f64>,
1199    ) -> Result<Array1<f64>> {
1200        // Simplified: return zeros for Dirichlet conditions
1201        Ok(Array1::zeros(_points.nrows()))
1202    }
1203
1204    /// Evaluate initial condition
1205    fn evaluate_initial_condition(
1206        &self,
1207        _ic: &InitialCondition,
1208        _points: &Array2<f64>,
1209    ) -> Result<Array1<f64>> {
1210        // Simplified: return zeros
1211        Ok(Array1::zeros(_points.nrows()))
1212    }
1213
1214    /// Evaluate conservation law constraint
1215    fn evaluate_conservation_law(&self, _law: &ConservationLaw) -> Result<f64> {
1216        // Placeholder implementation
1217        Ok(0.0)
1218    }
1219
1220    /// Evaluate symmetry constraint
1221    fn evaluate_symmetry_constraint(&self, _symmetry: &Symmetry) -> Result<f64> {
1222        // Placeholder implementation
1223        Ok(0.0)
1224    }
1225
1226    /// Compute gradients
1227    fn compute_gradients(&self) -> Result<Array1<f64>> {
1228        let total_loss = self.compute_total_loss()?;
1229        let mut gradients = Array1::zeros(self.parameters.len());
1230        let epsilon = 1e-6;
1231
1232        for i in 0..self.parameters.len() {
1233            let mut params_plus = self.parameters.clone();
1234            params_plus[i] += epsilon;
1235
1236            let mut temp_pinn = self.clone();
1237            temp_pinn.parameters = params_plus;
1238            let loss_plus = temp_pinn.compute_total_loss()?.total;
1239
1240            gradients[i] = (loss_plus - total_loss.total) / epsilon;
1241        }
1242
1243        Ok(gradients)
1244    }
1245
1246    /// Update parameters
1247    fn update_parameters(&mut self, gradients: &Array1<f64>) -> Result<()> {
1248        let learning_rate = self.config.training_config.learning_rate;
1249
1250        for i in 0..self.parameters.len() {
1251            self.parameters[i] -= learning_rate * gradients[i];
1252        }
1253
1254        Ok(())
1255    }
1256
1257    /// Compute training metrics
1258    fn compute_training_metrics(
1259        &self,
1260        epoch: usize,
1261        total_loss: TotalLoss,
1262    ) -> Result<TrainingMetrics> {
1263        Ok(TrainingMetrics {
1264            epoch,
1265            total_loss: total_loss.total,
1266            pde_loss: total_loss.pde_loss,
1267            boundary_loss: total_loss.boundary_loss,
1268            initial_loss: total_loss.initial_loss,
1269            physics_constraint_loss: total_loss.physics_constraint_loss,
1270            quantum_fidelity: 0.9, // Placeholder
1271            solution_energy: 1.0,  // Placeholder
1272        })
1273    }
1274
1275    /// Get training history
1276    pub fn get_training_history(&self) -> &[TrainingMetrics] {
1277        &self.training_history
1278    }
1279
1280    /// Solve PDE and return solution on a grid
1281    pub fn solve_on_grid(&self, grid_points: &Array2<f64>) -> Result<Array1<f64>> {
1282        let solutions = self.forward(grid_points)?;
1283        Ok(solutions.column(0).to_owned())
1284    }
1285}
1286
1287/// Results from derivative computation
1288#[derive(Debug)]
1289pub struct DerivativeResults {
1290    pub first_derivatives: Array2<f64>,
1291    pub second_derivatives: Array2<f64>,
1292    pub mixed_derivatives: Array3<f64>,
1293}
1294
1295/// Total loss breakdown
1296#[derive(Debug)]
1297pub struct TotalLoss {
1298    pub total: f64,
1299    pub pde_loss: f64,
1300    pub boundary_loss: f64,
1301    pub initial_loss: f64,
1302    pub physics_constraint_loss: f64,
1303}
1304
1305impl PhysicsEvaluator {
1306    /// Create a new physics evaluator
1307    pub fn new(equation_type: &PhysicsEquationType) -> Result<Self> {
1308        let mut differential_operators = HashMap::new();
1309
1310        match equation_type {
1311            PhysicsEquationType::Poisson => {
1312                differential_operators.insert(
1313                    "laplacian".to_string(),
1314                    DifferentialOperator {
1315                        operator_type: OperatorType::Laplacian,
1316                        order: 2,
1317                        direction: vec![0, 1], // x and y directions
1318                    },
1319                );
1320            }
1321            PhysicsEquationType::Heat => {
1322                differential_operators.insert(
1323                    "time_derivative".to_string(),
1324                    DifferentialOperator {
1325                        operator_type: OperatorType::TimeDerivative,
1326                        order: 1,
1327                        direction: vec![2], // time direction
1328                    },
1329                );
1330                differential_operators.insert(
1331                    "laplacian".to_string(),
1332                    DifferentialOperator {
1333                        operator_type: OperatorType::Laplacian,
1334                        order: 2,
1335                        direction: vec![0, 1],
1336                    },
1337                );
1338            }
1339            PhysicsEquationType::Wave => {
1340                differential_operators.insert(
1341                    "second_time_derivative".to_string(),
1342                    DifferentialOperator {
1343                        operator_type: OperatorType::TimeDerivative,
1344                        order: 2,
1345                        direction: vec![2],
1346                    },
1347                );
1348                differential_operators.insert(
1349                    "laplacian".to_string(),
1350                    DifferentialOperator {
1351                        operator_type: OperatorType::Laplacian,
1352                        order: 2,
1353                        direction: vec![0, 1],
1354                    },
1355                );
1356            }
1357            _ => {
1358                // Add more equation types as needed
1359            }
1360        }
1361
1362        Ok(Self {
1363            equation_type: equation_type.clone(),
1364            differential_operators,
1365        })
1366    }
1367
1368    /// Compute PDE residual
1369    pub fn compute_pde_residual(
1370        &self,
1371        points: &Array2<f64>,
1372        solution: &Array2<f64>,
1373        derivatives: &DerivativeResults,
1374    ) -> Result<Array1<f64>> {
1375        let num_points = points.nrows();
1376        let mut residuals = Array1::zeros(num_points);
1377
1378        match self.equation_type {
1379            PhysicsEquationType::Poisson => {
1380                // ∇²u = f (assuming f = 0 for simplicity)
1381                for i in 0..num_points {
1382                    let laplacian = derivatives.second_derivatives[[i, 0]]
1383                        + derivatives.second_derivatives[[i, 1]];
1384                    residuals[i] = laplacian; // f = 0
1385                }
1386            }
1387            PhysicsEquationType::Heat => {
1388                // ∂u/∂t = α∇²u (assuming α = 1)
1389                for i in 0..num_points {
1390                    let time_deriv = derivatives.first_derivatives[[i, 2]]; // time direction
1391                    let laplacian = derivatives.second_derivatives[[i, 0]]
1392                        + derivatives.second_derivatives[[i, 1]];
1393                    residuals[i] = time_deriv - laplacian;
1394                }
1395            }
1396            PhysicsEquationType::Wave => {
1397                // ∂²u/∂t² = c²∇²u (assuming c = 1)
1398                for i in 0..num_points {
1399                    let second_time_deriv = derivatives.second_derivatives[[i, 2]];
1400                    let laplacian = derivatives.second_derivatives[[i, 0]]
1401                        + derivatives.second_derivatives[[i, 1]];
1402                    residuals[i] = second_time_deriv - laplacian;
1403                }
1404            }
1405            _ => {
1406                return Err(crate::error::MLError::InvalidConfiguration(
1407                    "PDE type not implemented".to_string(),
1408                ));
1409            }
1410        }
1411
1412        Ok(residuals)
1413    }
1414}
1415
1416#[cfg(test)]
1417mod tests {
1418    use super::*;
1419
1420    #[test]
1421    fn test_qpinn_creation() {
1422        let config = QPINNConfig::default();
1423        let qpinn = QuantumPINN::new(config);
1424        assert!(qpinn.is_ok());
1425    }
1426
1427    #[test]
1428    fn test_forward_pass() {
1429        let config = QPINNConfig::default();
1430        let qpinn = QuantumPINN::new(config).expect("Failed to create QPINN");
1431        let input_points = Array2::from_shape_vec(
1432            (5, 3),
1433            vec![
1434                0.1, 0.2, 0.0, 0.3, 0.4, 0.1, 0.5, 0.6, 0.2, 0.7, 0.8, 0.3, 0.9, 1.0, 0.4,
1435            ],
1436        )
1437        .expect("Failed to create input points");
1438
1439        let result = qpinn.forward(&input_points);
1440        assert!(result.is_ok());
1441        assert_eq!(result.expect("Forward pass should succeed").shape(), [5, 1]);
1442    }
1443
1444    #[test]
1445    fn test_derivative_computation() {
1446        let config = QPINNConfig::default();
1447        let qpinn = QuantumPINN::new(config).expect("Failed to create QPINN");
1448        let points =
1449            Array2::from_shape_vec((3, 3), vec![0.1, 0.2, 0.0, 0.3, 0.4, 0.1, 0.5, 0.6, 0.2])
1450                .expect("Failed to create points array");
1451
1452        let result = qpinn.compute_derivatives(&points);
1453        assert!(result.is_ok());
1454    }
1455
1456    #[test]
1457    #[ignore]
1458    fn test_training() {
1459        let mut config = QPINNConfig::default();
1460        config.training_config.epochs = 5;
1461        config.training_config.num_collocation_points = 10;
1462
1463        let mut qpinn = QuantumPINN::new(config).expect("Failed to create QPINN");
1464        let result = qpinn.train(Some(5));
1465        assert!(result.is_ok());
1466        assert!(!qpinn.get_training_history().is_empty());
1467    }
1468}