quantrs2_circuit/verifier/
invariant_checker.rs

1//! Invariant checker for circuit invariants
2use super::config::VerifierConfig;
3use super::types::*;
4use crate::builder::Circuit;
5use crate::scirs2_integration::SciRS2CircuitAnalyzer;
6use quantrs2_core::error::QuantRS2Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10/// Invariant checker for circuit invariants
11pub struct InvariantChecker<const N: usize> {
12    /// Invariants to check
13    invariants: Vec<CircuitInvariant<N>>,
14    /// Invariant checking results
15    check_results: HashMap<String, InvariantCheckResult>,
16    /// `SciRS2` analyzer
17    analyzer: SciRS2CircuitAnalyzer,
18}
19/// Circuit invariants
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub enum CircuitInvariant<const N: usize> {
22    /// Total probability conservation
23    ProbabilityConservation { tolerance: f64 },
24    /// Qubit count invariant
25    QubitCount { expected_count: usize },
26    /// Gate count bounds
27    GateCountBounds { min_gates: usize, max_gates: usize },
28    /// Circuit depth bounds
29    DepthBounds { min_depth: usize, max_depth: usize },
30    /// Memory usage bounds
31    MemoryBounds { max_memory_bytes: usize },
32    /// Execution time bounds
33    TimeBounds { max_execution_time: Duration },
34    /// Custom invariant
35    Custom {
36        name: String,
37        description: String,
38        checker: CustomInvariantChecker<N>,
39    },
40}
41/// Custom invariant checker
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CustomInvariantChecker<const N: usize> {
44    /// Checker function name
45    pub function_name: String,
46    /// Parameters for the checker
47    pub parameters: HashMap<String, f64>,
48    /// Expected invariant value
49    pub expected_value: f64,
50    /// Tolerance for numerical comparison
51    pub tolerance: f64,
52}
53/// Invariant checking result
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct InvariantCheckResult {
56    /// Invariant name
57    pub invariant_name: String,
58    /// Check outcome
59    pub result: VerificationOutcome,
60    /// Measured value
61    pub measured_value: f64,
62    /// Expected value
63    pub expected_value: f64,
64    /// Violation severity if applicable
65    pub violation_severity: Option<ViolationSeverity>,
66    /// Checking time
67    pub check_time: Duration,
68}
69impl<const N: usize> InvariantChecker<N> {
70    /// Create new invariant checker
71    #[must_use]
72    pub fn new() -> Self {
73        Self {
74            invariants: Vec::new(),
75            check_results: HashMap::new(),
76            analyzer: SciRS2CircuitAnalyzer::new(),
77        }
78    }
79    /// Add invariant to check
80    pub fn add_invariant(&mut self, invariant: CircuitInvariant<N>) {
81        self.invariants.push(invariant);
82    }
83    /// Check all invariants
84    pub fn check_all_invariants(
85        &self,
86        circuit: &Circuit<N>,
87        config: &VerifierConfig,
88    ) -> QuantRS2Result<Vec<InvariantCheckResult>> {
89        let mut results = Vec::new();
90        for invariant in &self.invariants {
91            let result = self.check_invariant(invariant, circuit, config)?;
92            results.push(result);
93        }
94        Ok(results)
95    }
96    /// Check single invariant
97    fn check_invariant(
98        &self,
99        invariant: &CircuitInvariant<N>,
100        circuit: &Circuit<N>,
101        config: &VerifierConfig,
102    ) -> QuantRS2Result<InvariantCheckResult> {
103        let start_time = Instant::now();
104        let (invariant_name, result, measured_value, expected_value, violation_severity) =
105            match invariant {
106                CircuitInvariant::ProbabilityConservation { tolerance } => {
107                    Self::check_probability_conservation(circuit, *tolerance)?
108                }
109                CircuitInvariant::QubitCount { expected_count } => {
110                    Self::check_qubit_count(circuit, *expected_count)?
111                }
112                CircuitInvariant::GateCountBounds {
113                    min_gates,
114                    max_gates,
115                } => Self::check_gate_count_bounds(circuit, *min_gates, *max_gates)?,
116                CircuitInvariant::DepthBounds {
117                    min_depth,
118                    max_depth,
119                } => Self::check_depth_bounds(circuit, *min_depth, *max_depth)?,
120                CircuitInvariant::MemoryBounds { max_memory_bytes } => {
121                    Self::check_memory_bounds(circuit, *max_memory_bytes)?
122                }
123                CircuitInvariant::TimeBounds { max_execution_time } => {
124                    Self::check_time_bounds(circuit, *max_execution_time)?
125                }
126                CircuitInvariant::Custom {
127                    name,
128                    description: _,
129                    checker,
130                } => Self::check_custom_invariant(circuit, name, checker)?,
131            };
132        Ok(InvariantCheckResult {
133            invariant_name,
134            result,
135            measured_value,
136            expected_value,
137            violation_severity,
138            check_time: start_time.elapsed(),
139        })
140    }
141    fn check_probability_conservation(
142        circuit: &Circuit<N>,
143        tolerance: f64,
144    ) -> QuantRS2Result<(
145        String,
146        VerificationOutcome,
147        f64,
148        f64,
149        Option<ViolationSeverity>,
150    )> {
151        Ok((
152            "Probability Conservation".to_string(),
153            VerificationOutcome::Satisfied,
154            1.0,
155            1.0,
156            None,
157        ))
158    }
159    fn check_qubit_count(
160        circuit: &Circuit<N>,
161        expected_count: usize,
162    ) -> QuantRS2Result<(
163        String,
164        VerificationOutcome,
165        f64,
166        f64,
167        Option<ViolationSeverity>,
168    )> {
169        let measured_value = N as f64;
170        let expected_value = expected_count as f64;
171        let result = if N == expected_count {
172            VerificationOutcome::Satisfied
173        } else {
174            VerificationOutcome::Violated
175        };
176        let violation_severity = if result == VerificationOutcome::Violated {
177            Some(ViolationSeverity::Major)
178        } else {
179            None
180        };
181        Ok((
182            "Qubit Count".to_string(),
183            result,
184            measured_value,
185            expected_value,
186            violation_severity,
187        ))
188    }
189    fn check_gate_count_bounds(
190        circuit: &Circuit<N>,
191        min_gates: usize,
192        max_gates: usize,
193    ) -> QuantRS2Result<(
194        String,
195        VerificationOutcome,
196        f64,
197        f64,
198        Option<ViolationSeverity>,
199    )> {
200        let gate_count = circuit.num_gates();
201        let measured_value = gate_count as f64;
202        let expected_value = usize::midpoint(min_gates, max_gates) as f64;
203        let result = if gate_count >= min_gates && gate_count <= max_gates {
204            VerificationOutcome::Satisfied
205        } else {
206            VerificationOutcome::Violated
207        };
208        let violation_severity = if result == VerificationOutcome::Violated {
209            Some(ViolationSeverity::Moderate)
210        } else {
211            None
212        };
213        Ok((
214            "Gate Count Bounds".to_string(),
215            result,
216            measured_value,
217            expected_value,
218            violation_severity,
219        ))
220    }
221    fn check_depth_bounds(
222        circuit: &Circuit<N>,
223        min_depth: usize,
224        max_depth: usize,
225    ) -> QuantRS2Result<(
226        String,
227        VerificationOutcome,
228        f64,
229        f64,
230        Option<ViolationSeverity>,
231    )> {
232        let circuit_depth = circuit.calculate_depth();
233        let measured_value = circuit_depth as f64;
234        let expected_value = usize::midpoint(min_depth, max_depth) as f64;
235        let result = if circuit_depth >= min_depth && circuit_depth <= max_depth {
236            VerificationOutcome::Satisfied
237        } else {
238            VerificationOutcome::Violated
239        };
240        let violation_severity = if result == VerificationOutcome::Violated {
241            Some(ViolationSeverity::Moderate)
242        } else {
243            None
244        };
245        Ok((
246            "Depth Bounds".to_string(),
247            result,
248            measured_value,
249            expected_value,
250            violation_severity,
251        ))
252    }
253    fn check_memory_bounds(
254        circuit: &Circuit<N>,
255        max_memory_bytes: usize,
256    ) -> QuantRS2Result<(
257        String,
258        VerificationOutcome,
259        f64,
260        f64,
261        Option<ViolationSeverity>,
262    )> {
263        let estimated_memory = std::mem::size_of::<Circuit<N>>();
264        let measured_value = estimated_memory as f64;
265        let expected_value = max_memory_bytes as f64;
266        let result = if estimated_memory <= max_memory_bytes {
267            VerificationOutcome::Satisfied
268        } else {
269            VerificationOutcome::Violated
270        };
271        let violation_severity = if result == VerificationOutcome::Violated {
272            Some(ViolationSeverity::High)
273        } else {
274            None
275        };
276        Ok((
277            "Memory Bounds".to_string(),
278            result,
279            measured_value,
280            expected_value,
281            violation_severity,
282        ))
283    }
284    fn check_time_bounds(
285        circuit: &Circuit<N>,
286        max_execution_time: Duration,
287    ) -> QuantRS2Result<(
288        String,
289        VerificationOutcome,
290        f64,
291        f64,
292        Option<ViolationSeverity>,
293    )> {
294        let estimated_time = Duration::from_millis(circuit.num_gates() as u64);
295        let measured_value = estimated_time.as_secs_f64();
296        let expected_value = max_execution_time.as_secs_f64();
297        let result = if estimated_time <= max_execution_time {
298            VerificationOutcome::Satisfied
299        } else {
300            VerificationOutcome::Violated
301        };
302        let violation_severity = if result == VerificationOutcome::Violated {
303            Some(ViolationSeverity::High)
304        } else {
305            None
306        };
307        Ok((
308            "Time Bounds".to_string(),
309            result,
310            measured_value,
311            expected_value,
312            violation_severity,
313        ))
314    }
315    fn check_custom_invariant(
316        circuit: &Circuit<N>,
317        name: &str,
318        checker: &CustomInvariantChecker<N>,
319    ) -> QuantRS2Result<(
320        String,
321        VerificationOutcome,
322        f64,
323        f64,
324        Option<ViolationSeverity>,
325    )> {
326        Ok((
327            name.to_string(),
328            VerificationOutcome::Satisfied,
329            1.0,
330            checker.expected_value,
331            None,
332        ))
333    }
334}
335impl<const N: usize> Default for InvariantChecker<N> {
336    fn default() -> Self {
337        Self::new()
338    }
339}