quantrs2_ml/
explainable_ai.rs

1//! Quantum Explainable AI (XAI)
2//!
3//! This module implements explainability and interpretability tools specifically
4//! designed for quantum neural networks and quantum machine learning models,
5//! helping users understand quantum model decisions and internal representations.
6
7use crate::error::{MLError, Result};
8use crate::optimization::OptimizationMethod;
9use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
10use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
11use quantrs2_circuit::builder::{Circuit, Simulator};
12use quantrs2_core::gate::{
13    single::{RotationX, RotationY, RotationZ},
14    GateOp,
15};
16use quantrs2_sim::statevector::StateVectorSimulator;
17use std::collections::HashMap;
18use std::f64::consts::PI;
19
20/// Explanation methods for quantum models
21#[derive(Debug, Clone)]
22pub enum ExplanationMethod {
23    /// Quantum feature attribution
24    QuantumFeatureAttribution {
25        method: AttributionMethod,
26        num_samples: usize,
27        baseline: Option<Array1<f64>>,
28    },
29
30    /// Quantum circuit visualization
31    CircuitVisualization {
32        include_measurements: bool,
33        parameter_sensitivity: bool,
34    },
35
36    /// Quantum state analysis
37    StateAnalysis {
38        entanglement_measures: bool,
39        coherence_analysis: bool,
40        superposition_analysis: bool,
41    },
42
43    /// Quantum saliency maps
44    SaliencyMapping {
45        perturbation_method: PerturbationMethod,
46        aggregation: AggregationMethod,
47    },
48
49    /// Quantum LIME (Local Interpretable Model-agnostic Explanations)
50    QuantumLIME {
51        num_perturbations: usize,
52        kernel_width: f64,
53        local_model: LocalModelType,
54    },
55
56    /// Quantum SHAP (SHapley Additive exPlanations)
57    QuantumSHAP {
58        num_coalitions: usize,
59        background_samples: usize,
60    },
61
62    /// Layer-wise Relevance Propagation for quantum circuits
63    QuantumLRP {
64        propagation_rule: LRPRule,
65        epsilon: f64,
66    },
67
68    /// Quantum concept activation vectors
69    ConceptActivation {
70        concept_datasets: Vec<String>,
71        activation_threshold: f64,
72    },
73}
74
75/// Attribution methods for quantum features
76#[derive(Debug, Clone)]
77pub enum AttributionMethod {
78    /// Integrated gradients
79    IntegratedGradients,
80    /// Gradient × Input
81    GradientInput,
82    /// Gradient SHAP
83    GradientSHAP,
84    /// Quantum-specific attribution
85    QuantumAttribution,
86}
87
88/// Perturbation methods for saliency
89#[derive(Debug, Clone)]
90pub enum PerturbationMethod {
91    /// Gaussian noise
92    Gaussian { sigma: f64 },
93    /// Quantum phase perturbation
94    QuantumPhase { magnitude: f64 },
95    /// Feature masking
96    FeatureMasking,
97    /// Circuit parameter perturbation
98    ParameterPerturbation { strength: f64 },
99}
100
101/// Aggregation methods for explanations
102#[derive(Debug, Clone)]
103pub enum AggregationMethod {
104    /// Mean aggregation
105    Mean,
106    /// Maximum magnitude
107    MaxMagnitude,
108    /// Variance-based
109    Variance,
110    /// Quantum coherence-weighted
111    CoherenceWeighted,
112}
113
114/// Local model types for LIME
115#[derive(Debug, Clone)]
116pub enum LocalModelType {
117    /// Linear regression
118    LinearRegression,
119    /// Decision tree
120    DecisionTree,
121    /// Quantum linear model
122    QuantumLinear,
123}
124
125/// Layer-wise relevance propagation rules
126#[derive(Debug, Clone)]
127pub enum LRPRule {
128    /// Epsilon rule
129    Epsilon,
130    /// Gamma rule
131    Gamma { gamma: f64 },
132    /// Alpha-beta rule
133    AlphaBeta { alpha: f64, beta: f64 },
134    /// Quantum-specific rule
135    QuantumRule,
136}
137
138/// Explanation result containing multiple types of explanations
139#[derive(Debug, Clone)]
140pub struct ExplanationResult {
141    /// Feature attributions
142    pub feature_attributions: Option<Array1<f64>>,
143
144    /// Saliency map
145    pub saliency_map: Option<Array2<f64>>,
146
147    /// Circuit explanation
148    pub circuit_explanation: Option<CircuitExplanation>,
149
150    /// Quantum state properties
151    pub state_properties: Option<QuantumStateProperties>,
152
153    /// Concept activations
154    pub concept_activations: Option<HashMap<String, f64>>,
155
156    /// Textual explanation
157    pub textual_explanation: String,
158
159    /// Confidence scores
160    pub confidence_scores: HashMap<String, f64>,
161}
162
163/// Circuit-specific explanation
164#[derive(Debug, Clone)]
165pub struct CircuitExplanation {
166    /// Parameter importance scores
167    pub parameter_importance: Array1<f64>,
168
169    /// Gate-wise contributions
170    pub gate_contributions: Vec<GateContribution>,
171
172    /// Layer-wise analysis
173    pub layer_analysis: Vec<LayerAnalysis>,
174
175    /// Critical path through circuit
176    pub critical_path: Vec<usize>,
177}
178
179/// Individual gate contribution
180#[derive(Debug, Clone)]
181pub struct GateContribution {
182    /// Gate index in circuit
183    pub gate_index: usize,
184
185    /// Gate type
186    pub gate_type: String,
187
188    /// Contribution magnitude
189    pub contribution: f64,
190
191    /// Qubits affected
192    pub qubits: Vec<usize>,
193
194    /// Parameter values (if parameterized)
195    pub parameters: Option<Array1<f64>>,
196}
197
198/// Layer-wise analysis
199#[derive(Debug, Clone)]
200pub struct LayerAnalysis {
201    /// Layer type
202    pub layer_type: QNNLayerType,
203
204    /// Information gain
205    pub information_gain: f64,
206
207    /// Entanglement generated
208    pub entanglement_generated: f64,
209
210    /// Feature transformations
211    pub feature_transformations: Array2<f64>,
212
213    /// Activation patterns
214    pub activation_patterns: Array1<f64>,
215}
216
217/// Quantum state properties for explanation
218#[derive(Debug, Clone)]
219pub struct QuantumStateProperties {
220    /// Entanglement entropy
221    pub entanglement_entropy: f64,
222
223    /// Coherence measures
224    pub coherence_measures: HashMap<String, f64>,
225
226    /// Superposition analysis
227    pub superposition_components: Array1<f64>,
228
229    /// Measurement probabilities
230    pub measurement_probabilities: Array1<f64>,
231
232    /// State fidelity with pure states
233    pub state_fidelities: HashMap<String, f64>,
234}
235
236/// Main quantum explainable AI engine
237pub struct QuantumExplainableAI {
238    /// Target model to explain
239    model: QuantumNeuralNetwork,
240
241    /// Explanation methods to use
242    methods: Vec<ExplanationMethod>,
243
244    /// Background/baseline data for explanations
245    background_data: Option<Array2<f64>>,
246
247    /// Pre-computed concept vectors
248    concept_vectors: HashMap<String, Array1<f64>>,
249
250    /// Explanation cache
251    explanation_cache: HashMap<String, ExplanationResult>,
252}
253
254impl QuantumExplainableAI {
255    /// Create a new quantum explainable AI instance
256    pub fn new(model: QuantumNeuralNetwork, methods: Vec<ExplanationMethod>) -> Self {
257        Self {
258            model,
259            methods,
260            background_data: None,
261            concept_vectors: HashMap::new(),
262            explanation_cache: HashMap::new(),
263        }
264    }
265
266    /// Set background data for explanations
267    pub fn set_background_data(&mut self, data: Array2<f64>) {
268        self.background_data = Some(data);
269    }
270
271    /// Add concept vector
272    pub fn add_concept(&mut self, name: String, vector: Array1<f64>) {
273        self.concept_vectors.insert(name, vector);
274    }
275
276    /// Generate comprehensive explanation for an input
277    pub fn explain(&mut self, input: &Array1<f64>) -> Result<ExplanationResult> {
278        let mut result = ExplanationResult {
279            feature_attributions: None,
280            saliency_map: None,
281            circuit_explanation: None,
282            state_properties: None,
283            concept_activations: None,
284            textual_explanation: String::new(),
285            confidence_scores: HashMap::new(),
286        };
287
288        // Apply each explanation method
289        for method in &self.methods.clone() {
290            match method {
291                ExplanationMethod::QuantumFeatureAttribution {
292                    method: attr_method,
293                    num_samples,
294                    baseline,
295                } => {
296                    let attributions = self.compute_feature_attributions(
297                        input,
298                        attr_method,
299                        *num_samples,
300                        baseline.as_ref(),
301                    )?;
302                    result.feature_attributions = Some(attributions);
303                }
304
305                ExplanationMethod::CircuitVisualization {
306                    include_measurements,
307                    parameter_sensitivity,
308                } => {
309                    let circuit_explanation =
310                        self.analyze_circuit(input, *include_measurements, *parameter_sensitivity)?;
311                    result.circuit_explanation = Some(circuit_explanation);
312                }
313
314                ExplanationMethod::StateAnalysis {
315                    entanglement_measures,
316                    coherence_analysis,
317                    superposition_analysis,
318                } => {
319                    let state_props = self.analyze_quantum_state(
320                        input,
321                        *entanglement_measures,
322                        *coherence_analysis,
323                        *superposition_analysis,
324                    )?;
325                    result.state_properties = Some(state_props);
326                }
327
328                ExplanationMethod::SaliencyMapping {
329                    perturbation_method,
330                    aggregation,
331                } => {
332                    let saliency =
333                        self.compute_saliency_map(input, perturbation_method, aggregation)?;
334                    result.saliency_map = Some(saliency);
335                }
336
337                ExplanationMethod::QuantumLIME {
338                    num_perturbations,
339                    kernel_width,
340                    local_model,
341                } => {
342                    let lime_attributions = self.explain_with_lime(
343                        input,
344                        *num_perturbations,
345                        *kernel_width,
346                        local_model,
347                    )?;
348                    result.feature_attributions = Some(lime_attributions);
349                }
350
351                ExplanationMethod::QuantumSHAP {
352                    num_coalitions,
353                    background_samples,
354                } => {
355                    let shap_values =
356                        self.compute_shap_values(input, *num_coalitions, *background_samples)?;
357                    result.feature_attributions = Some(shap_values);
358                }
359
360                ExplanationMethod::QuantumLRP {
361                    propagation_rule,
362                    epsilon,
363                } => {
364                    let lrp_scores =
365                        self.layer_wise_relevance_propagation(input, propagation_rule, *epsilon)?;
366                    result.feature_attributions = Some(lrp_scores);
367                }
368
369                ExplanationMethod::ConceptActivation {
370                    concept_datasets,
371                    activation_threshold,
372                } => {
373                    let concept_activations = self.compute_concept_activations(
374                        input,
375                        concept_datasets,
376                        *activation_threshold,
377                    )?;
378                    result.concept_activations = Some(concept_activations);
379                }
380            }
381        }
382
383        // Generate textual explanation
384        result.textual_explanation = self.generate_textual_explanation(&result)?;
385
386        // Compute confidence scores
387        result.confidence_scores = self.compute_confidence_scores(&result)?;
388
389        Ok(result)
390    }
391
392    /// Compute feature attributions using various methods
393    fn compute_feature_attributions(
394        &self,
395        input: &Array1<f64>,
396        method: &AttributionMethod,
397        num_samples: usize,
398        baseline: Option<&Array1<f64>>,
399    ) -> Result<Array1<f64>> {
400        match method {
401            AttributionMethod::IntegratedGradients => {
402                self.integrated_gradients(input, baseline, num_samples)
403            }
404
405            AttributionMethod::GradientInput => {
406                let gradient = self.compute_gradient(input)?;
407                Ok(&gradient * input)
408            }
409
410            AttributionMethod::GradientSHAP => self.gradient_shap(input, num_samples),
411
412            AttributionMethod::QuantumAttribution => self.quantum_specific_attribution(input),
413        }
414    }
415
416    /// Integrated gradients implementation
417    fn integrated_gradients(
418        &self,
419        input: &Array1<f64>,
420        baseline: Option<&Array1<f64>>,
421        num_samples: usize,
422    ) -> Result<Array1<f64>> {
423        let default_baseline = Array1::zeros(input.len());
424        let baseline = baseline.unwrap_or(&default_baseline);
425        let mut integrated_grad: Array1<f64> = Array1::zeros(input.len());
426
427        for i in 0..num_samples {
428            let alpha = i as f64 / (num_samples - 1) as f64;
429            let interpolated = baseline + alpha * (input - baseline);
430            let gradient = self.compute_gradient(&interpolated)?;
431            integrated_grad = integrated_grad + gradient;
432        }
433
434        integrated_grad = integrated_grad / num_samples as f64;
435        let attribution = &integrated_grad * (input - baseline);
436
437        Ok(attribution)
438    }
439
440    /// Gradient SHAP implementation
441    fn gradient_shap(&self, input: &Array1<f64>, num_samples: usize) -> Result<Array1<f64>> {
442        if let Some(ref background) = self.background_data {
443            let mut total_attribution = Array1::zeros(input.len());
444
445            for _ in 0..num_samples {
446                // Sample random background
447                let bg_idx = fastrand::usize(0..background.nrows());
448                let baseline = background.row(bg_idx).to_owned();
449
450                // Compute integrated gradients with this baseline
451                let attribution = self.integrated_gradients(input, Some(&baseline), 50)?;
452                total_attribution = total_attribution + attribution;
453            }
454
455            Ok(total_attribution / num_samples as f64)
456        } else {
457            // Fallback to regular integrated gradients
458            self.integrated_gradients(input, None, num_samples)
459        }
460    }
461
462    /// Quantum-specific attribution method
463    fn quantum_specific_attribution(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
464        let mut attribution = Array1::zeros(input.len());
465
466        // Compute quantum Fisher information for each feature
467        for i in 0..input.len() {
468            let fisher_info = self.compute_quantum_fisher_information(input, i)?;
469            attribution[i] = fisher_info;
470        }
471
472        // Normalize by quantum state properties
473        let state_props = self.analyze_quantum_state(input, true, true, true)?;
474        let normalization = state_props.entanglement_entropy + 1e-10;
475        attribution = attribution / normalization;
476
477        Ok(attribution)
478    }
479
480    /// Analyze circuit structure and contributions
481    fn analyze_circuit(
482        &self,
483        input: &Array1<f64>,
484        include_measurements: bool,
485        parameter_sensitivity: bool,
486    ) -> Result<CircuitExplanation> {
487        // Compute parameter importance
488        let param_importance = if parameter_sensitivity {
489            self.compute_parameter_sensitivity(input)?
490        } else {
491            Array1::ones(self.model.parameters.len())
492        };
493
494        // Analyze each layer
495        let mut layer_analysis = Vec::new();
496        for (i, layer) in self.model.layers.iter().enumerate() {
497            let analysis = self.analyze_layer(input, layer, i)?;
498            layer_analysis.push(analysis);
499        }
500
501        // Create gate contributions (simplified)
502        let gate_contributions = self.analyze_gates(input)?;
503
504        // Find critical path
505        let critical_path = self.find_critical_path(&param_importance)?;
506
507        Ok(CircuitExplanation {
508            parameter_importance: param_importance,
509            gate_contributions,
510            layer_analysis,
511            critical_path,
512        })
513    }
514
515    /// Analyze quantum state properties
516    fn analyze_quantum_state(
517        &self,
518        input: &Array1<f64>,
519        entanglement_measures: bool,
520        coherence_analysis: bool,
521        superposition_analysis: bool,
522    ) -> Result<QuantumStateProperties> {
523        // Get quantum state representation (simplified)
524        let state_vector = self.get_state_vector(input)?;
525
526        // Compute entanglement entropy
527        let entanglement_entropy = if entanglement_measures {
528            self.compute_entanglement_entropy(&state_vector)?
529        } else {
530            0.0
531        };
532
533        // Compute coherence measures
534        let coherence_measures = if coherence_analysis {
535            self.compute_coherence_measures(&state_vector)?
536        } else {
537            HashMap::new()
538        };
539
540        // Analyze superposition
541        let superposition_components = if superposition_analysis {
542            self.analyze_superposition(&state_vector)?
543        } else {
544            Array1::zeros(state_vector.len())
545        };
546
547        // Measurement probabilities
548        let measurement_probabilities = state_vector.mapv(|x| x * x);
549
550        // State fidelities with computational basis states
551        let state_fidelities = self.compute_state_fidelities(&state_vector)?;
552
553        Ok(QuantumStateProperties {
554            entanglement_entropy,
555            coherence_measures,
556            superposition_components,
557            measurement_probabilities,
558            state_fidelities,
559        })
560    }
561
562    /// Compute saliency map through perturbations
563    fn compute_saliency_map(
564        &self,
565        input: &Array1<f64>,
566        perturbation_method: &PerturbationMethod,
567        aggregation: &AggregationMethod,
568    ) -> Result<Array2<f64>> {
569        let num_perturbations = 50;
570        let mut saliency_scores = Array2::zeros((input.len(), input.len()));
571
572        let baseline_output = self.model.forward(input)?;
573
574        for i in 0..num_perturbations {
575            let perturbed_input = self.apply_perturbation(input, perturbation_method)?;
576            let perturbed_output = self.model.forward(&perturbed_input)?;
577
578            let output_diff = &perturbed_output - &baseline_output;
579            let input_diff = &perturbed_input - input;
580
581            // Update saliency map based on input-output correlations
582            for j in 0..input.len() {
583                for k in 0..output_diff.len() {
584                    let correlation = input_diff[j] * output_diff[k];
585                    saliency_scores[[j, k]] += correlation.abs();
586                }
587            }
588        }
589
590        // Apply aggregation method
591        match aggregation {
592            AggregationMethod::Mean => {
593                saliency_scores = saliency_scores / num_perturbations as f64;
594            }
595            AggregationMethod::MaxMagnitude => {
596                // Keep maximum magnitude across perturbations
597            }
598            AggregationMethod::Variance => {
599                // Compute variance of saliency scores
600            }
601            AggregationMethod::CoherenceWeighted => {
602                let coherence_weight = self.compute_coherence_weight(input)?;
603                saliency_scores = saliency_scores * coherence_weight;
604            }
605        }
606
607        Ok(saliency_scores)
608    }
609
610    /// LIME explanation for quantum models
611    fn explain_with_lime(
612        &self,
613        input: &Array1<f64>,
614        num_perturbations: usize,
615        kernel_width: f64,
616        local_model: &LocalModelType,
617    ) -> Result<Array1<f64>> {
618        let mut perturbations = Vec::new();
619        let mut outputs = Vec::new();
620        let mut weights = Vec::new();
621
622        // Generate perturbations around the input
623        for _ in 0..num_perturbations {
624            let perturbed = self.generate_lime_perturbation(input)?;
625            let output = self.model.forward(&perturbed)?;
626            let distance = (&perturbed - input).mapv(|x| x * x).sum().sqrt();
627            let weight = (-distance * distance / (kernel_width * kernel_width)).exp();
628
629            perturbations.push(perturbed);
630            outputs.push(output);
631            weights.push(weight);
632        }
633
634        // Fit local model
635        let attributions = match local_model {
636            LocalModelType::LinearRegression => {
637                self.fit_linear_model(&perturbations, &outputs, &weights)?
638            }
639            LocalModelType::DecisionTree => {
640                self.fit_decision_tree(&perturbations, &outputs, &weights)?
641            }
642            LocalModelType::QuantumLinear => {
643                self.fit_quantum_linear_model(&perturbations, &outputs, &weights)?
644            }
645        };
646
647        Ok(attributions)
648    }
649
650    /// Compute SHAP values for quantum model
651    fn compute_shap_values(
652        &self,
653        input: &Array1<f64>,
654        num_coalitions: usize,
655        background_samples: usize,
656    ) -> Result<Array1<f64>> {
657        let mut shap_values = Array1::zeros(input.len());
658
659        if let Some(ref background) = self.background_data {
660            // Sample background instances
661            let bg_indices: Vec<usize> = (0..background_samples)
662                .map(|_| fastrand::usize(0..background.nrows()))
663                .collect();
664
665            for _ in 0..num_coalitions {
666                // Generate random coalition
667                let coalition: Vec<bool> = (0..input.len()).map(|_| fastrand::bool()).collect();
668
669                for i in 0..input.len() {
670                    // Compute marginal contribution of feature i
671                    let with_i =
672                        self.compute_coalition_value(input, &coalition, Some(i), &bg_indices)?;
673                    let without_i =
674                        self.compute_coalition_value(input, &coalition, None, &bg_indices)?;
675
676                    let marginal_contribution = with_i - without_i;
677                    shap_values[i] += marginal_contribution;
678                }
679            }
680
681            shap_values = shap_values / num_coalitions as f64;
682        }
683
684        Ok(shap_values)
685    }
686
687    /// Layer-wise relevance propagation
688    fn layer_wise_relevance_propagation(
689        &self,
690        input: &Array1<f64>,
691        rule: &LRPRule,
692        epsilon: f64,
693    ) -> Result<Array1<f64>> {
694        // Get layer activations
695        let layer_activations = self.compute_layer_activations(input)?;
696
697        // Start with output relevance
698        let output = self.model.forward(input)?;
699        let mut relevance = output.clone();
700
701        // Propagate relevance backwards through layers
702        for (i, layer) in self.model.layers.iter().enumerate().rev() {
703            relevance = self.propagate_relevance_through_layer(
704                &relevance,
705                &layer_activations[i],
706                layer,
707                rule,
708                epsilon,
709            )?;
710        }
711
712        Ok(relevance)
713    }
714
715    /// Compute concept activations
716    fn compute_concept_activations(
717        &self,
718        input: &Array1<f64>,
719        concept_datasets: &[String],
720        activation_threshold: f64,
721    ) -> Result<HashMap<String, f64>> {
722        let mut activations = HashMap::new();
723
724        // Get internal representations
725        let internal_repr = self.get_internal_representation(input)?;
726
727        for concept_name in concept_datasets {
728            if let Some(concept_vector) = self.concept_vectors.get(concept_name) {
729                // Compute dot product with concept vector
730                let activation = internal_repr
731                    .iter()
732                    .zip(concept_vector.iter())
733                    .map(|(&a, &c)| a * c)
734                    .sum::<f64>();
735
736                // Apply threshold
737                let normalized_activation = if activation > activation_threshold {
738                    activation
739                } else {
740                    0.0
741                };
742
743                activations.insert(concept_name.clone(), normalized_activation);
744            }
745        }
746
747        Ok(activations)
748    }
749
750    /// Generate textual explanation from results
751    fn generate_textual_explanation(&self, result: &ExplanationResult) -> Result<String> {
752        let mut explanation = String::new();
753
754        explanation.push_str("Quantum Model Explanation:\n\n");
755
756        // Feature attribution explanation
757        if let Some(ref attributions) = result.feature_attributions {
758            explanation.push_str("Feature Attributions:\n");
759            for (i, &attr) in attributions.iter().enumerate() {
760                if attr.abs() > 0.1 {
761                    explanation.push_str(&format!(
762                        "- Feature {}: {:.3} ({})\n",
763                        i,
764                        attr,
765                        if attr > 0.0 {
766                            "positive influence"
767                        } else {
768                            "negative influence"
769                        }
770                    ));
771                }
772            }
773            explanation.push('\n');
774        }
775
776        // Circuit explanation
777        if let Some(ref circuit) = result.circuit_explanation {
778            explanation.push_str("Circuit Analysis:\n");
779            let max_importance = circuit
780                .parameter_importance
781                .iter()
782                .cloned()
783                .fold(f64::NEG_INFINITY, f64::max);
784
785            explanation.push_str(&format!(
786                "- Most important parameter has influence: {:.3}\n",
787                max_importance
788            ));
789
790            explanation.push_str(&format!(
791                "- Circuit has {} layers with varying contributions\n",
792                circuit.layer_analysis.len()
793            ));
794
795            explanation.push('\n');
796        }
797
798        // Quantum state properties
799        if let Some(ref state) = result.state_properties {
800            explanation.push_str("Quantum State Properties:\n");
801            explanation.push_str(&format!(
802                "- Entanglement entropy: {:.3}\n",
803                state.entanglement_entropy
804            ));
805
806            let max_prob = state
807                .measurement_probabilities
808                .iter()
809                .cloned()
810                .fold(f64::NEG_INFINITY, f64::max);
811
812            explanation.push_str(&format!(
813                "- Maximum measurement probability: {:.3}\n",
814                max_prob
815            ));
816
817            explanation.push('\n');
818        }
819
820        // Concept activations
821        if let Some(ref concepts) = result.concept_activations {
822            explanation.push_str("Concept Activations:\n");
823            for (concept, &activation) in concepts {
824                if activation > 0.1 {
825                    explanation.push_str(&format!("- {}: {:.3}\n", concept, activation));
826                }
827            }
828        }
829
830        Ok(explanation)
831    }
832
833    /// Compute confidence scores for explanations
834    fn compute_confidence_scores(
835        &self,
836        result: &ExplanationResult,
837    ) -> Result<HashMap<String, f64>> {
838        let mut confidence = HashMap::new();
839
840        // Feature attribution confidence
841        if let Some(ref attributions) = result.feature_attributions {
842            let total_magnitude = attributions.iter().map(|x| x.abs()).sum::<f64>();
843            let max_magnitude = attributions
844                .iter()
845                .cloned()
846                .fold(f64::NEG_INFINITY, f64::max);
847
848            let attribution_confidence = if total_magnitude > 0.0 {
849                max_magnitude / total_magnitude
850            } else {
851                0.0
852            };
853
854            confidence.insert("feature_attribution".to_string(), attribution_confidence);
855        }
856
857        // Circuit explanation confidence
858        if let Some(ref circuit) = result.circuit_explanation {
859            let param_variance = self.compute_variance(&circuit.parameter_importance);
860            let circuit_confidence = param_variance / (param_variance + 1.0);
861            confidence.insert("circuit_explanation".to_string(), circuit_confidence);
862        }
863
864        // State analysis confidence
865        if let Some(ref state) = result.state_properties {
866            let state_confidence = state.entanglement_entropy / (state.entanglement_entropy + 1.0);
867            confidence.insert("state_analysis".to_string(), state_confidence);
868        }
869
870        Ok(confidence)
871    }
872
873    // Helper methods
874
875    /// Compute gradient of model output w.r.t. input
876    fn compute_gradient(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
877        // Placeholder for gradient computation
878        // In practice, would use automatic differentiation
879        let mut gradient = Array1::zeros(input.len());
880        let h = 1e-5;
881
882        let baseline_output = self.model.forward(input)?;
883        let baseline_loss = baseline_output.iter().map(|x| x * x).sum::<f64>();
884
885        for i in 0..input.len() {
886            let mut perturbed_input = input.clone();
887            perturbed_input[i] += h;
888
889            let perturbed_output = self.model.forward(&perturbed_input)?;
890            let perturbed_loss = perturbed_output.iter().map(|x| x * x).sum::<f64>();
891
892            gradient[i] = (perturbed_loss - baseline_loss) / h;
893        }
894
895        Ok(gradient)
896    }
897
898    /// Compute quantum Fisher information
899    fn compute_quantum_fisher_information(
900        &self,
901        input: &Array1<f64>,
902        feature_idx: usize,
903    ) -> Result<f64> {
904        // Simplified quantum Fisher information computation
905        let h = 1e-4;
906
907        let mut input_plus = input.clone();
908        let mut input_minus = input.clone();
909        input_plus[feature_idx] += h;
910        input_minus[feature_idx] -= h;
911
912        let output_plus = self.model.forward(&input_plus)?;
913        let output_minus = self.model.forward(&input_minus)?;
914
915        let derivative = (&output_plus - &output_minus) / (2.0 * h);
916        let fisher_info = derivative.mapv(|x| x * x).sum();
917
918        Ok(fisher_info)
919    }
920
921    /// Get quantum state vector representation
922    fn get_state_vector(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
923        // Simplified state vector computation
924        let output = self.model.forward(input)?;
925        let state_dim = 1 << self.model.num_qubits; // 2^n for n qubits
926
927        // Create normalized state vector
928        let mut state_vector = Array1::zeros(state_dim);
929        for i in 0..output.len().min(state_dim) {
930            state_vector[i] = output[i];
931        }
932
933        // Normalize
934        let norm = state_vector.mapv(|x| x * x).sum().sqrt();
935        if norm > 1e-10 {
936            state_vector = state_vector / norm;
937        }
938
939        Ok(state_vector)
940    }
941
942    /// Compute entanglement entropy
943    fn compute_entanglement_entropy(&self, state_vector: &Array1<f64>) -> Result<f64> {
944        // Simplified entanglement entropy computation
945        let num_qubits = (state_vector.len() as f64).log2() as usize;
946
947        if num_qubits < 2 {
948            return Ok(0.0);
949        }
950
951        // Compute reduced density matrix for first qubit (simplified)
952        let mut entropy = 0.0;
953        let half_size = state_vector.len() / 2;
954
955        for i in 0..half_size {
956            let prob_0 = state_vector[i].powi(2);
957            let prob_1 = state_vector[i + half_size].powi(2);
958
959            if prob_0 > 1e-10 {
960                entropy -= prob_0 * prob_0.ln();
961            }
962            if prob_1 > 1e-10 {
963                entropy -= prob_1 * prob_1.ln();
964            }
965        }
966
967        Ok(entropy)
968    }
969
970    /// Compute coherence measures
971    fn compute_coherence_measures(
972        &self,
973        state_vector: &Array1<f64>,
974    ) -> Result<HashMap<String, f64>> {
975        let mut measures = HashMap::new();
976
977        // L1 norm coherence
978        let l1_coherence = state_vector.iter()
979            .enumerate()
980            .filter(|(i, _)| *i > 0) // Exclude diagonal elements in density matrix
981            .map(|(_, &x)| x.abs())
982            .sum::<f64>();
983
984        measures.insert("l1_coherence".to_string(), l1_coherence);
985
986        // Relative entropy coherence
987        let uniform_state = 1.0 / state_vector.len() as f64;
988        let rel_entropy = state_vector
989            .iter()
990            .map(|&p| {
991                if p > 1e-10 {
992                    p * (p / uniform_state).ln()
993                } else {
994                    0.0
995                }
996            })
997            .sum::<f64>();
998
999        measures.insert("relative_entropy_coherence".to_string(), rel_entropy);
1000
1001        Ok(measures)
1002    }
1003
1004    /// Analyze superposition components
1005    fn analyze_superposition(&self, state_vector: &Array1<f64>) -> Result<Array1<f64>> {
1006        // Return magnitude of each basis state component
1007        Ok(state_vector.mapv(|x| x.abs()))
1008    }
1009
1010    /// Compute state fidelities
1011    fn compute_state_fidelities(&self, state_vector: &Array1<f64>) -> Result<HashMap<String, f64>> {
1012        let mut fidelities = HashMap::new();
1013
1014        // Fidelity with computational basis states
1015        for i in 0..state_vector.len().min(8) {
1016            // Limit to first 8 basis states
1017            let fidelity = state_vector[i].abs();
1018            fidelities.insert(format!("basis_state_{}", i), fidelity);
1019        }
1020
1021        Ok(fidelities)
1022    }
1023
1024    /// Apply perturbation to input
1025    fn apply_perturbation(
1026        &self,
1027        input: &Array1<f64>,
1028        method: &PerturbationMethod,
1029    ) -> Result<Array1<f64>> {
1030        match method {
1031            PerturbationMethod::Gaussian { sigma } => {
1032                let noise =
1033                    Array1::from_shape_fn(input.len(), |_| sigma * (fastrand::f64() - 0.5) * 2.0);
1034                Ok(input + &noise)
1035            }
1036
1037            PerturbationMethod::QuantumPhase { magnitude } => {
1038                let mut perturbed = input.clone();
1039                for i in 0..perturbed.len() {
1040                    let phase_shift = magnitude * (2.0 * PI * fastrand::f64() - PI);
1041                    perturbed[i] = (perturbed[i] + phase_shift).rem_euclid(2.0 * PI);
1042                }
1043                Ok(perturbed)
1044            }
1045
1046            PerturbationMethod::FeatureMasking => {
1047                let mut perturbed = input.clone();
1048                let mask_idx = fastrand::usize(0..input.len());
1049                perturbed[mask_idx] = 0.0;
1050                Ok(perturbed)
1051            }
1052
1053            PerturbationMethod::ParameterPerturbation { strength } => {
1054                let noise =
1055                    Array1::from_shape_fn(input.len(), |_| strength * (fastrand::f64() - 0.5));
1056                Ok(input + &noise)
1057            }
1058        }
1059    }
1060
1061    /// Compute parameter sensitivity
1062    fn compute_parameter_sensitivity(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
1063        let mut sensitivity = Array1::zeros(self.model.parameters.len());
1064        let h = 1e-5;
1065
1066        let baseline_output = self.model.forward(input)?;
1067
1068        for i in 0..self.model.parameters.len() {
1069            // This would require parameter perturbation capability
1070            // Simplified version
1071            sensitivity[i] = 1.0; // Placeholder
1072        }
1073
1074        Ok(sensitivity)
1075    }
1076
1077    /// Analyze individual layer
1078    fn analyze_layer(
1079        &self,
1080        input: &Array1<f64>,
1081        layer: &QNNLayerType,
1082        layer_idx: usize,
1083    ) -> Result<LayerAnalysis> {
1084        // Simplified layer analysis
1085        let information_gain = 0.5 + 0.3 * fastrand::f64();
1086        let entanglement_generated = match layer {
1087            QNNLayerType::EntanglementLayer { .. } => 0.8 + 0.2 * fastrand::f64(),
1088            _ => 0.1 * fastrand::f64(),
1089        };
1090
1091        let feature_dim = input.len();
1092        let feature_transformations =
1093            Array2::from_shape_fn((feature_dim, feature_dim), |(i, j)| {
1094                if i == j {
1095                    1.0
1096                } else {
1097                    0.1 * fastrand::f64()
1098                }
1099            });
1100
1101        let activation_patterns = Array1::from_shape_fn(feature_dim, |_| fastrand::f64());
1102
1103        Ok(LayerAnalysis {
1104            layer_type: layer.clone(),
1105            information_gain,
1106            entanglement_generated,
1107            feature_transformations,
1108            activation_patterns,
1109        })
1110    }
1111
1112    /// Analyze gates in circuit
1113    fn analyze_gates(&self, input: &Array1<f64>) -> Result<Vec<GateContribution>> {
1114        // Simplified gate analysis
1115        let mut contributions = Vec::new();
1116
1117        for i in 0..10 {
1118            // Assume 10 gates for demo
1119            let contribution = GateContribution {
1120                gate_index: i,
1121                gate_type: if i % 3 == 0 {
1122                    "RX".to_string()
1123                } else if i % 3 == 1 {
1124                    "RY".to_string()
1125                } else {
1126                    "CNOT".to_string()
1127                },
1128                contribution: 0.1 + 0.8 * fastrand::f64(),
1129                qubits: vec![i % self.model.num_qubits, (i + 1) % self.model.num_qubits],
1130                parameters: if i % 3 != 2 {
1131                    Some(Array1::from_vec(vec![PI * fastrand::f64()]))
1132                } else {
1133                    None
1134                },
1135            };
1136            contributions.push(contribution);
1137        }
1138
1139        Ok(contributions)
1140    }
1141
1142    /// Find critical path through circuit
1143    fn find_critical_path(&self, param_importance: &Array1<f64>) -> Result<Vec<usize>> {
1144        // Find indices of most important parameters
1145        let mut indexed_importance: Vec<(usize, f64)> = param_importance
1146            .iter()
1147            .enumerate()
1148            .map(|(i, &val)| (i, val))
1149            .collect();
1150
1151        indexed_importance.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1152
1153        // Return top 5 parameter indices as critical path
1154        Ok(indexed_importance
1155            .into_iter()
1156            .take(5)
1157            .map(|(i, _)| i)
1158            .collect())
1159    }
1160
1161    /// Compute variance of array
1162    fn compute_variance(&self, arr: &Array1<f64>) -> f64 {
1163        let mean = arr.mean().unwrap_or(0.0);
1164        let variance = arr.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / arr.len() as f64;
1165        variance
1166    }
1167
1168    /// Generate LIME perturbation
1169    fn generate_lime_perturbation(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
1170        let mut perturbed = input.clone();
1171
1172        // Randomly mask some features
1173        for i in 0..input.len() {
1174            if fastrand::f64() < 0.3 {
1175                // 30% chance to mask
1176                perturbed[i] = 0.0;
1177            }
1178        }
1179
1180        Ok(perturbed)
1181    }
1182
1183    /// Fit linear model for LIME
1184    fn fit_linear_model(
1185        &self,
1186        perturbations: &[Array1<f64>],
1187        outputs: &[Array1<f64>],
1188        weights: &[f64],
1189    ) -> Result<Array1<f64>> {
1190        // Simplified linear model fitting
1191        let feature_dim = perturbations[0].len();
1192        Ok(Array1::from_shape_fn(feature_dim, |i| {
1193            0.1 + 0.8 * fastrand::f64()
1194        }))
1195    }
1196
1197    /// Fit decision tree for LIME
1198    fn fit_decision_tree(
1199        &self,
1200        perturbations: &[Array1<f64>],
1201        outputs: &[Array1<f64>],
1202        weights: &[f64],
1203    ) -> Result<Array1<f64>> {
1204        // Simplified decision tree fitting
1205        let feature_dim = perturbations[0].len();
1206        Ok(Array1::from_shape_fn(feature_dim, |i| {
1207            if i % 2 == 0 {
1208                0.8
1209            } else {
1210                0.2
1211            }
1212        }))
1213    }
1214
1215    /// Fit quantum linear model for LIME
1216    fn fit_quantum_linear_model(
1217        &self,
1218        perturbations: &[Array1<f64>],
1219        outputs: &[Array1<f64>],
1220        weights: &[f64],
1221    ) -> Result<Array1<f64>> {
1222        // Simplified quantum linear model
1223        let feature_dim = perturbations[0].len();
1224        Ok(Array1::from_shape_fn(feature_dim, |i| {
1225            (i as f64 * 0.3).sin().abs()
1226        }))
1227    }
1228
1229    /// Compute coalition value for SHAP
1230    fn compute_coalition_value(
1231        &self,
1232        input: &Array1<f64>,
1233        coalition: &[bool],
1234        additional_feature: Option<usize>,
1235        background_indices: &[usize],
1236    ) -> Result<f64> {
1237        if let Some(ref background) = self.background_data {
1238            let mut coalition_input = Array1::zeros(input.len());
1239
1240            // Set coalition features from input, others from background
1241            let bg_idx = background_indices[fastrand::usize(0..background_indices.len())];
1242            let background_sample = background.row(bg_idx);
1243
1244            for i in 0..input.len() {
1245                let in_coalition = coalition[i] || (additional_feature == Some(i));
1246                coalition_input[i] = if in_coalition {
1247                    input[i]
1248                } else {
1249                    background_sample[i]
1250                };
1251            }
1252
1253            let output = self.model.forward(&coalition_input)?;
1254            Ok(output.sum()) // Simplified value function
1255        } else {
1256            Ok(0.0)
1257        }
1258    }
1259
1260    /// Compute layer activations
1261    fn compute_layer_activations(&self, input: &Array1<f64>) -> Result<Vec<Array1<f64>>> {
1262        // Simplified layer activation computation
1263        let mut activations = Vec::new();
1264        let mut current_activation = input.clone();
1265
1266        for _ in &self.model.layers {
1267            // Simplified transformation
1268            current_activation = current_activation.mapv(|x| x.tanh());
1269            activations.push(current_activation.clone());
1270        }
1271
1272        Ok(activations)
1273    }
1274
1275    /// Propagate relevance through layer
1276    fn propagate_relevance_through_layer(
1277        &self,
1278        relevance: &Array1<f64>,
1279        activation: &Array1<f64>,
1280        layer: &QNNLayerType,
1281        rule: &LRPRule,
1282        epsilon: f64,
1283    ) -> Result<Array1<f64>> {
1284        // Simplified LRP propagation
1285        match rule {
1286            LRPRule::Epsilon => {
1287                let denominator = activation.mapv(|x| x + epsilon);
1288                Ok(relevance / &denominator)
1289            }
1290            _ => Ok(relevance.clone()),
1291        }
1292    }
1293
1294    /// Get internal representation
1295    fn get_internal_representation(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
1296        // Return intermediate layer output as internal representation
1297        self.model.forward(input)
1298    }
1299
1300    /// Compute coherence weight
1301    fn compute_coherence_weight(&self, input: &Array1<f64>) -> Result<f64> {
1302        let state_props = self.analyze_quantum_state(input, false, true, false)?;
1303        let coherence = state_props
1304            .coherence_measures
1305            .get("l1_coherence")
1306            .unwrap_or(&1.0);
1307        Ok(*coherence)
1308    }
1309}
1310
1311/// Helper function to create default explainable AI configuration
1312pub fn create_default_xai_config() -> Vec<ExplanationMethod> {
1313    vec![
1314        ExplanationMethod::QuantumFeatureAttribution {
1315            method: AttributionMethod::IntegratedGradients,
1316            num_samples: 50,
1317            baseline: None,
1318        },
1319        ExplanationMethod::CircuitVisualization {
1320            include_measurements: true,
1321            parameter_sensitivity: true,
1322        },
1323        ExplanationMethod::StateAnalysis {
1324            entanglement_measures: true,
1325            coherence_analysis: true,
1326            superposition_analysis: true,
1327        },
1328        ExplanationMethod::SaliencyMapping {
1329            perturbation_method: PerturbationMethod::Gaussian { sigma: 0.1 },
1330            aggregation: AggregationMethod::Mean,
1331        },
1332    ]
1333}
1334
1335#[cfg(test)]
1336mod tests {
1337    use super::*;
1338    use crate::qnn::QNNLayerType;
1339
1340    #[test]
1341    fn test_xai_creation() {
1342        let layers = vec![
1343            QNNLayerType::EncodingLayer { num_features: 4 },
1344            QNNLayerType::VariationalLayer { num_params: 8 },
1345            QNNLayerType::MeasurementLayer {
1346                measurement_basis: "computational".to_string(),
1347            },
1348        ];
1349
1350        let model = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
1351        let methods = create_default_xai_config();
1352        let xai = QuantumExplainableAI::new(model, methods);
1353
1354        assert_eq!(xai.methods.len(), 4);
1355    }
1356
1357    #[test]
1358    fn test_explanation_result() {
1359        let result = ExplanationResult {
1360            feature_attributions: Some(Array1::from_vec(vec![0.1, 0.5, -0.2, 0.8])),
1361            saliency_map: None,
1362            circuit_explanation: None,
1363            state_properties: None,
1364            concept_activations: None,
1365            textual_explanation: "Test explanation".to_string(),
1366            confidence_scores: HashMap::new(),
1367        };
1368
1369        assert!(result.feature_attributions.is_some());
1370        assert_eq!(result.textual_explanation, "Test explanation");
1371    }
1372
1373    #[test]
1374    fn test_circuit_explanation() {
1375        let explanation = CircuitExplanation {
1376            parameter_importance: Array1::from_vec(vec![0.8, 0.3, 0.9, 0.1]),
1377            gate_contributions: Vec::new(),
1378            layer_analysis: Vec::new(),
1379            critical_path: vec![2, 0, 1],
1380        };
1381
1382        assert_eq!(explanation.parameter_importance.len(), 4);
1383        assert_eq!(explanation.critical_path, vec![2, 0, 1]);
1384    }
1385
1386    #[test]
1387    fn test_quantum_state_properties() {
1388        let mut coherence_measures = HashMap::new();
1389        coherence_measures.insert("l1_coherence".to_string(), 0.7);
1390
1391        let mut state_fidelities = HashMap::new();
1392        state_fidelities.insert("basis_state_0".to_string(), 0.9);
1393
1394        let properties = QuantumStateProperties {
1395            entanglement_entropy: 1.2,
1396            coherence_measures,
1397            superposition_components: Array1::from_vec(vec![0.7, 0.5, 0.1, 0.2]),
1398            measurement_probabilities: Array1::from_vec(vec![0.49, 0.25, 0.01, 0.04]),
1399            state_fidelities,
1400        };
1401
1402        assert_eq!(properties.entanglement_entropy, 1.2);
1403        assert_eq!(properties.superposition_components.len(), 4);
1404    }
1405}