1use super::config::VerifierConfig;
3use super::types::*;
4use crate::builder::Circuit;
5use crate::scirs2_integration::SciRS2CircuitAnalyzer;
6use quantrs2_core::error::QuantRS2Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10pub struct InvariantChecker<const N: usize> {
12 invariants: Vec<CircuitInvariant<N>>,
14 check_results: HashMap<String, InvariantCheckResult>,
16 analyzer: SciRS2CircuitAnalyzer,
18}
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub enum CircuitInvariant<const N: usize> {
22 ProbabilityConservation { tolerance: f64 },
24 QubitCount { expected_count: usize },
26 GateCountBounds { min_gates: usize, max_gates: usize },
28 DepthBounds { min_depth: usize, max_depth: usize },
30 MemoryBounds { max_memory_bytes: usize },
32 TimeBounds { max_execution_time: Duration },
34 Custom {
36 name: String,
37 description: String,
38 checker: CustomInvariantChecker<N>,
39 },
40}
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CustomInvariantChecker<const N: usize> {
44 pub function_name: String,
46 pub parameters: HashMap<String, f64>,
48 pub expected_value: f64,
50 pub tolerance: f64,
52}
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct InvariantCheckResult {
56 pub invariant_name: String,
58 pub result: VerificationOutcome,
60 pub measured_value: f64,
62 pub expected_value: f64,
64 pub violation_severity: Option<ViolationSeverity>,
66 pub check_time: Duration,
68}
69impl<const N: usize> InvariantChecker<N> {
70 #[must_use]
72 pub fn new() -> Self {
73 Self {
74 invariants: Vec::new(),
75 check_results: HashMap::new(),
76 analyzer: SciRS2CircuitAnalyzer::new(),
77 }
78 }
79 pub fn add_invariant(&mut self, invariant: CircuitInvariant<N>) {
81 self.invariants.push(invariant);
82 }
83 pub fn check_all_invariants(
85 &self,
86 circuit: &Circuit<N>,
87 config: &VerifierConfig,
88 ) -> QuantRS2Result<Vec<InvariantCheckResult>> {
89 let mut results = Vec::new();
90 for invariant in &self.invariants {
91 let result = self.check_invariant(invariant, circuit, config)?;
92 results.push(result);
93 }
94 Ok(results)
95 }
96 fn check_invariant(
98 &self,
99 invariant: &CircuitInvariant<N>,
100 circuit: &Circuit<N>,
101 config: &VerifierConfig,
102 ) -> QuantRS2Result<InvariantCheckResult> {
103 let start_time = Instant::now();
104 let (invariant_name, result, measured_value, expected_value, violation_severity) =
105 match invariant {
106 CircuitInvariant::ProbabilityConservation { tolerance } => {
107 Self::check_probability_conservation(circuit, *tolerance)?
108 }
109 CircuitInvariant::QubitCount { expected_count } => {
110 Self::check_qubit_count(circuit, *expected_count)?
111 }
112 CircuitInvariant::GateCountBounds {
113 min_gates,
114 max_gates,
115 } => Self::check_gate_count_bounds(circuit, *min_gates, *max_gates)?,
116 CircuitInvariant::DepthBounds {
117 min_depth,
118 max_depth,
119 } => Self::check_depth_bounds(circuit, *min_depth, *max_depth)?,
120 CircuitInvariant::MemoryBounds { max_memory_bytes } => {
121 Self::check_memory_bounds(circuit, *max_memory_bytes)?
122 }
123 CircuitInvariant::TimeBounds { max_execution_time } => {
124 Self::check_time_bounds(circuit, *max_execution_time)?
125 }
126 CircuitInvariant::Custom {
127 name,
128 description: _,
129 checker,
130 } => Self::check_custom_invariant(circuit, name, checker)?,
131 };
132 Ok(InvariantCheckResult {
133 invariant_name,
134 result,
135 measured_value,
136 expected_value,
137 violation_severity,
138 check_time: start_time.elapsed(),
139 })
140 }
141 fn check_probability_conservation(
142 circuit: &Circuit<N>,
143 tolerance: f64,
144 ) -> QuantRS2Result<(
145 String,
146 VerificationOutcome,
147 f64,
148 f64,
149 Option<ViolationSeverity>,
150 )> {
151 Ok((
152 "Probability Conservation".to_string(),
153 VerificationOutcome::Satisfied,
154 1.0,
155 1.0,
156 None,
157 ))
158 }
159 fn check_qubit_count(
160 circuit: &Circuit<N>,
161 expected_count: usize,
162 ) -> QuantRS2Result<(
163 String,
164 VerificationOutcome,
165 f64,
166 f64,
167 Option<ViolationSeverity>,
168 )> {
169 let measured_value = N as f64;
170 let expected_value = expected_count as f64;
171 let result = if N == expected_count {
172 VerificationOutcome::Satisfied
173 } else {
174 VerificationOutcome::Violated
175 };
176 let violation_severity = if result == VerificationOutcome::Violated {
177 Some(ViolationSeverity::Major)
178 } else {
179 None
180 };
181 Ok((
182 "Qubit Count".to_string(),
183 result,
184 measured_value,
185 expected_value,
186 violation_severity,
187 ))
188 }
189 fn check_gate_count_bounds(
190 circuit: &Circuit<N>,
191 min_gates: usize,
192 max_gates: usize,
193 ) -> QuantRS2Result<(
194 String,
195 VerificationOutcome,
196 f64,
197 f64,
198 Option<ViolationSeverity>,
199 )> {
200 let gate_count = circuit.num_gates();
201 let measured_value = gate_count as f64;
202 let expected_value = usize::midpoint(min_gates, max_gates) as f64;
203 let result = if gate_count >= min_gates && gate_count <= max_gates {
204 VerificationOutcome::Satisfied
205 } else {
206 VerificationOutcome::Violated
207 };
208 let violation_severity = if result == VerificationOutcome::Violated {
209 Some(ViolationSeverity::Moderate)
210 } else {
211 None
212 };
213 Ok((
214 "Gate Count Bounds".to_string(),
215 result,
216 measured_value,
217 expected_value,
218 violation_severity,
219 ))
220 }
221 fn check_depth_bounds(
222 circuit: &Circuit<N>,
223 min_depth: usize,
224 max_depth: usize,
225 ) -> QuantRS2Result<(
226 String,
227 VerificationOutcome,
228 f64,
229 f64,
230 Option<ViolationSeverity>,
231 )> {
232 let circuit_depth = circuit.calculate_depth();
233 let measured_value = circuit_depth as f64;
234 let expected_value = usize::midpoint(min_depth, max_depth) as f64;
235 let result = if circuit_depth >= min_depth && circuit_depth <= max_depth {
236 VerificationOutcome::Satisfied
237 } else {
238 VerificationOutcome::Violated
239 };
240 let violation_severity = if result == VerificationOutcome::Violated {
241 Some(ViolationSeverity::Moderate)
242 } else {
243 None
244 };
245 Ok((
246 "Depth Bounds".to_string(),
247 result,
248 measured_value,
249 expected_value,
250 violation_severity,
251 ))
252 }
253 fn check_memory_bounds(
254 circuit: &Circuit<N>,
255 max_memory_bytes: usize,
256 ) -> QuantRS2Result<(
257 String,
258 VerificationOutcome,
259 f64,
260 f64,
261 Option<ViolationSeverity>,
262 )> {
263 let estimated_memory = std::mem::size_of::<Circuit<N>>();
264 let measured_value = estimated_memory as f64;
265 let expected_value = max_memory_bytes as f64;
266 let result = if estimated_memory <= max_memory_bytes {
267 VerificationOutcome::Satisfied
268 } else {
269 VerificationOutcome::Violated
270 };
271 let violation_severity = if result == VerificationOutcome::Violated {
272 Some(ViolationSeverity::High)
273 } else {
274 None
275 };
276 Ok((
277 "Memory Bounds".to_string(),
278 result,
279 measured_value,
280 expected_value,
281 violation_severity,
282 ))
283 }
284 fn check_time_bounds(
285 circuit: &Circuit<N>,
286 max_execution_time: Duration,
287 ) -> QuantRS2Result<(
288 String,
289 VerificationOutcome,
290 f64,
291 f64,
292 Option<ViolationSeverity>,
293 )> {
294 let estimated_time = Duration::from_millis(circuit.num_gates() as u64);
295 let measured_value = estimated_time.as_secs_f64();
296 let expected_value = max_execution_time.as_secs_f64();
297 let result = if estimated_time <= max_execution_time {
298 VerificationOutcome::Satisfied
299 } else {
300 VerificationOutcome::Violated
301 };
302 let violation_severity = if result == VerificationOutcome::Violated {
303 Some(ViolationSeverity::High)
304 } else {
305 None
306 };
307 Ok((
308 "Time Bounds".to_string(),
309 result,
310 measured_value,
311 expected_value,
312 violation_severity,
313 ))
314 }
315 fn check_custom_invariant(
316 circuit: &Circuit<N>,
317 name: &str,
318 checker: &CustomInvariantChecker<N>,
319 ) -> QuantRS2Result<(
320 String,
321 VerificationOutcome,
322 f64,
323 f64,
324 Option<ViolationSeverity>,
325 )> {
326 Ok((
327 name.to_string(),
328 VerificationOutcome::Satisfied,
329 1.0,
330 checker.expected_value,
331 None,
332 ))
333 }
334}
335impl<const N: usize> Default for InvariantChecker<N> {
336 fn default() -> Self {
337 Self::new()
338 }
339}