1use crate::builder::Circuit;
6use crate::optimization::gate_properties::{get_gate_properties, GateCost, GateError};
7use quantrs2_core::gate::GateOp;
8use std::collections::HashMap;
9
10pub trait CostModel: Send + Sync {
12 fn gate_cost(&self, gate: &dyn GateOp) -> f64;
14
15 fn circuit_cost_from_gates(&self, gates: &[Box<dyn GateOp>]) -> f64;
17
18 fn gates_cost(&self, gates: &[Box<dyn GateOp>]) -> f64 {
20 self.circuit_cost_from_gates(gates)
21 }
22
23 fn weights(&self) -> CostWeights;
25
26 fn is_native(&self, gate: &dyn GateOp) -> bool;
28}
29
30pub trait CircuitCostExt<const N: usize> {
32 fn circuit_cost(&self, circuit: &Circuit<N>) -> f64;
33}
34
35impl<T: CostModel + ?Sized, const N: usize> CircuitCostExt<N> for T {
36 fn circuit_cost(&self, circuit: &Circuit<N>) -> f64 {
37 let mut total_cost = 0.0;
38
39 for gate in circuit.gates() {
41 total_cost += self.gate_cost(gate.as_ref());
42 }
43
44 let depth = circuit.calculate_depth();
46 total_cost += depth as f64 * 0.1; let two_qubit_gates = circuit.count_two_qubit_gates();
50 total_cost += two_qubit_gates as f64 * 5.0; total_cost
53 }
54}
55
56#[derive(Debug, Clone, Copy)]
58pub struct CostWeights {
59 pub gate_count: f64,
60 pub execution_time: f64,
61 pub error_rate: f64,
62 pub circuit_depth: f64,
63}
64
65impl Default for CostWeights {
66 fn default() -> Self {
67 Self {
68 gate_count: 1.0,
69 execution_time: 1.0,
70 error_rate: 10.0,
71 circuit_depth: 0.5,
72 }
73 }
74}
75
76pub struct AbstractCostModel {
78 weights: CostWeights,
79 native_gates: Vec<String>,
80}
81
82impl AbstractCostModel {
83 #[must_use]
85 pub fn new(weights: CostWeights) -> Self {
86 Self {
87 weights,
88 native_gates: vec!["H", "X", "Y", "Z", "S", "T", "RX", "RY", "RZ", "CNOT", "CZ"]
89 .into_iter()
90 .map(std::string::ToString::to_string)
91 .collect(),
92 }
93 }
94}
95
96impl Default for AbstractCostModel {
97 fn default() -> Self {
98 Self::new(CostWeights::default())
99 }
100}
101
102impl CostModel for AbstractCostModel {
103 fn gate_cost(&self, gate: &dyn GateOp) -> f64 {
104 let props = get_gate_properties(gate);
105
106 props.cost.total_cost(
107 self.weights.execution_time,
108 self.weights.gate_count,
109 self.weights.error_rate,
110 )
111 }
112
113 fn circuit_cost_from_gates(&self, gates: &[Box<dyn GateOp>]) -> f64 {
114 gates.iter().map(|g| self.gate_cost(g.as_ref())).sum()
115 }
116
117 fn weights(&self) -> CostWeights {
118 self.weights
119 }
120
121 fn is_native(&self, gate: &dyn GateOp) -> bool {
122 self.native_gates.contains(&gate.name().to_string())
123 }
124}
125
126pub struct HardwareCostModel {
128 backend_name: String,
129 weights: CostWeights,
130 gate_costs: HashMap<String, GateCost>,
131 gate_errors: HashMap<String, GateError>,
132 native_gates: Vec<String>,
133}
134
135impl HardwareCostModel {
136 #[must_use]
138 pub fn for_backend(backend: &str) -> Self {
139 let (weights, gate_costs, gate_errors, native_gates) = match backend {
140 "ibm" => Self::ibm_config(),
141 "google" => Self::google_config(),
142 "aws" => Self::aws_config(),
143 _ => Self::default_config(),
144 };
145
146 Self {
147 backend_name: backend.to_string(),
148 weights,
149 gate_costs,
150 gate_errors,
151 native_gates,
152 }
153 }
154
155 fn ibm_config() -> (
156 CostWeights,
157 HashMap<String, GateCost>,
158 HashMap<String, GateError>,
159 Vec<String>,
160 ) {
161 let weights = CostWeights {
162 gate_count: 0.5,
163 execution_time: 1.5,
164 error_rate: 20.0,
165 circuit_depth: 1.0,
166 };
167
168 let mut gate_costs = HashMap::new();
169 gate_costs.insert("X".to_string(), GateCost::new(35.0, 1, 1.0));
170 gate_costs.insert("Y".to_string(), GateCost::new(35.0, 1, 1.0));
171 gate_costs.insert("Z".to_string(), GateCost::new(0.0, 0, 0.0)); gate_costs.insert("H".to_string(), GateCost::new(35.0, 1, 1.0));
173 gate_costs.insert("S".to_string(), GateCost::new(35.0, 1, 1.0));
174 gate_costs.insert("T".to_string(), GateCost::new(35.0, 1, 1.0));
175 gate_costs.insert("RZ".to_string(), GateCost::new(0.0, 0, 0.0)); gate_costs.insert("CNOT".to_string(), GateCost::new(300.0, 1, 3.0));
177 gate_costs.insert("CZ".to_string(), GateCost::new(300.0, 1, 3.0));
178
179 let mut gate_errors = HashMap::new();
180 gate_errors.insert("X".to_string(), GateError::new(0.99975, 0.00025, 0.00002));
181 gate_errors.insert("CNOT".to_string(), GateError::new(0.9985, 0.0015, 0.0001));
182
183 let native_gates = vec!["X", "Y", "Z", "H", "S", "T", "RZ", "CNOT", "CZ"]
184 .into_iter()
185 .map(std::string::ToString::to_string)
186 .collect();
187
188 (weights, gate_costs, gate_errors, native_gates)
189 }
190
191 fn google_config() -> (
192 CostWeights,
193 HashMap<String, GateCost>,
194 HashMap<String, GateError>,
195 Vec<String>,
196 ) {
197 let weights = CostWeights {
198 gate_count: 0.8,
199 execution_time: 1.0,
200 error_rate: 15.0,
201 circuit_depth: 0.8,
202 };
203
204 let mut gate_costs = HashMap::new();
205 gate_costs.insert("X".to_string(), GateCost::new(25.0, 1, 1.0));
206 gate_costs.insert("Y".to_string(), GateCost::new(25.0, 1, 1.0));
207 gate_costs.insert("Z".to_string(), GateCost::new(0.0, 0, 0.0)); gate_costs.insert("H".to_string(), GateCost::new(25.0, 1, 1.0));
209 gate_costs.insert("RZ".to_string(), GateCost::new(0.0, 0, 0.0)); gate_costs.insert("SQRT_X".to_string(), GateCost::new(25.0, 1, 1.0));
211 gate_costs.insert("CZ".to_string(), GateCost::new(30.0, 1, 2.0));
212
213 let mut gate_errors = HashMap::new();
214 gate_errors.insert("X".to_string(), GateError::new(0.9998, 0.0002, 0.00001));
215 gate_errors.insert("CZ".to_string(), GateError::new(0.994, 0.006, 0.0003));
216
217 let native_gates = vec!["X", "Y", "Z", "H", "RZ", "SQRT_X", "CZ"]
218 .into_iter()
219 .map(std::string::ToString::to_string)
220 .collect();
221
222 (weights, gate_costs, gate_errors, native_gates)
223 }
224
225 fn aws_config() -> (
226 CostWeights,
227 HashMap<String, GateCost>,
228 HashMap<String, GateError>,
229 Vec<String>,
230 ) {
231 let weights = CostWeights {
232 gate_count: 1.0,
233 execution_time: 1.2,
234 error_rate: 10.0,
235 circuit_depth: 0.6,
236 };
237
238 let mut gate_costs = HashMap::new();
239 gate_costs.insert("X".to_string(), GateCost::new(50.0, 1, 1.0));
240 gate_costs.insert("Y".to_string(), GateCost::new(50.0, 1, 1.0));
241 gate_costs.insert("Z".to_string(), GateCost::new(50.0, 1, 1.0));
242 gate_costs.insert("H".to_string(), GateCost::new(50.0, 1, 1.0));
243 gate_costs.insert("RX".to_string(), GateCost::new(50.0, 1, 1.2));
244 gate_costs.insert("RY".to_string(), GateCost::new(50.0, 1, 1.2));
245 gate_costs.insert("RZ".to_string(), GateCost::new(50.0, 1, 1.2));
246 gate_costs.insert("CNOT".to_string(), GateCost::new(500.0, 1, 4.0));
247
248 let mut gate_errors = HashMap::new();
249 gate_errors.insert("X".to_string(), GateError::new(0.9997, 0.0003, 0.00002));
250 gate_errors.insert("CNOT".to_string(), GateError::new(0.997, 0.003, 0.0002));
251
252 let native_gates = vec!["X", "Y", "Z", "H", "RX", "RY", "RZ", "CNOT", "CZ"]
253 .into_iter()
254 .map(std::string::ToString::to_string)
255 .collect();
256
257 (weights, gate_costs, gate_errors, native_gates)
258 }
259
260 fn default_config() -> (
261 CostWeights,
262 HashMap<String, GateCost>,
263 HashMap<String, GateError>,
264 Vec<String>,
265 ) {
266 let weights = CostWeights::default();
267 let gate_costs = HashMap::new();
268 let gate_errors = HashMap::new();
269 let native_gates = vec!["H", "X", "Y", "Z", "S", "T", "RX", "RY", "RZ", "CNOT", "CZ"]
270 .into_iter()
271 .map(std::string::ToString::to_string)
272 .collect();
273
274 (weights, gate_costs, gate_errors, native_gates)
275 }
276}
277
278impl CostModel for HardwareCostModel {
279 fn gate_cost(&self, gate: &dyn GateOp) -> f64 {
280 let gate_name = gate.name().to_string();
281
282 if let Some(cost) = self.gate_costs.get(&gate_name) {
284 let error_cost = if let Some(error) = self.gate_errors.get(&gate_name) {
285 error.total_error() * self.weights.error_rate
286 } else {
287 0.0
288 };
289
290 cost.total_cost(
291 self.weights.execution_time,
292 self.weights.gate_count,
293 0.0, ) + error_cost
295 } else {
296 let props = get_gate_properties(gate);
298 props.cost.total_cost(
299 self.weights.execution_time,
300 self.weights.gate_count,
301 self.weights.error_rate,
302 )
303 }
304 }
305
306 fn circuit_cost_from_gates(&self, gates: &[Box<dyn GateOp>]) -> f64 {
307 gates.iter().map(|g| self.gate_cost(g.as_ref())).sum()
308 }
309
310 fn weights(&self) -> CostWeights {
311 self.weights
312 }
313
314 fn is_native(&self, gate: &dyn GateOp) -> bool {
315 self.native_gates.contains(&gate.name().to_string())
316 }
317}