quantrs2_circuit/optimization/
cost_model.rs

1//! Cost models for circuit optimization
2//!
3//! This module defines different cost models used to evaluate and optimize quantum circuits.
4
5use crate::builder::Circuit;
6use crate::optimization::gate_properties::{get_gate_properties, GateCost, GateError};
7use quantrs2_core::gate::GateOp;
8use std::collections::HashMap;
9
10/// Trait for cost models
11pub trait CostModel: Send + Sync {
12    /// Calculate the cost of a single gate
13    fn gate_cost(&self, gate: &dyn GateOp) -> f64;
14
15    /// Calculate the total cost of a circuit (using gate list)
16    fn circuit_cost_from_gates(&self, gates: &[Box<dyn GateOp>]) -> f64;
17
18    /// Calculate the total cost of a list of gates (alias for `circuit_cost_from_gates`)
19    fn gates_cost(&self, gates: &[Box<dyn GateOp>]) -> f64 {
20        self.circuit_cost_from_gates(gates)
21    }
22
23    /// Get the weights used in cost calculation
24    fn weights(&self) -> CostWeights;
25
26    /// Check if a gate is native on the target hardware
27    fn is_native(&self, gate: &dyn GateOp) -> bool;
28}
29
30/// Extension trait for circuit cost calculation
31pub trait CircuitCostExt<const N: usize> {
32    fn circuit_cost(&self, circuit: &Circuit<N>) -> f64;
33}
34
35impl<T: CostModel + ?Sized, const N: usize> CircuitCostExt<N> for T {
36    fn circuit_cost(&self, circuit: &Circuit<N>) -> f64 {
37        let mut total_cost = 0.0;
38
39        // Calculate cost for each gate
40        for gate in circuit.gates() {
41            total_cost += self.gate_cost(gate.as_ref());
42        }
43
44        // Add depth penalty for deep circuits
45        let depth = circuit.calculate_depth();
46        total_cost += depth as f64 * 0.1; // Small penalty per depth unit
47
48        // Add two-qubit gate penalty (they're expensive)
49        let two_qubit_gates = circuit.count_two_qubit_gates();
50        total_cost += two_qubit_gates as f64 * 5.0; // Extra cost for two-qubit gates
51
52        total_cost
53    }
54}
55
56/// Weights for different cost components
57#[derive(Debug, Clone, Copy)]
58pub struct CostWeights {
59    pub gate_count: f64,
60    pub execution_time: f64,
61    pub error_rate: f64,
62    pub circuit_depth: f64,
63}
64
65impl Default for CostWeights {
66    fn default() -> Self {
67        Self {
68            gate_count: 1.0,
69            execution_time: 1.0,
70            error_rate: 10.0,
71            circuit_depth: 0.5,
72        }
73    }
74}
75
76/// Abstract cost model (hardware-agnostic)
77pub struct AbstractCostModel {
78    weights: CostWeights,
79    native_gates: Vec<String>,
80}
81
82impl AbstractCostModel {
83    /// Create a new abstract cost model
84    #[must_use]
85    pub fn new(weights: CostWeights) -> Self {
86        Self {
87            weights,
88            native_gates: vec!["H", "X", "Y", "Z", "S", "T", "RX", "RY", "RZ", "CNOT", "CZ"]
89                .into_iter()
90                .map(std::string::ToString::to_string)
91                .collect(),
92        }
93    }
94}
95
96impl Default for AbstractCostModel {
97    fn default() -> Self {
98        Self::new(CostWeights::default())
99    }
100}
101
102impl CostModel for AbstractCostModel {
103    fn gate_cost(&self, gate: &dyn GateOp) -> f64 {
104        let props = get_gate_properties(gate);
105
106        props.cost.total_cost(
107            self.weights.execution_time,
108            self.weights.gate_count,
109            self.weights.error_rate,
110        )
111    }
112
113    fn circuit_cost_from_gates(&self, gates: &[Box<dyn GateOp>]) -> f64 {
114        gates.iter().map(|g| self.gate_cost(g.as_ref())).sum()
115    }
116
117    fn weights(&self) -> CostWeights {
118        self.weights
119    }
120
121    fn is_native(&self, gate: &dyn GateOp) -> bool {
122        self.native_gates.contains(&gate.name().to_string())
123    }
124}
125
126/// Hardware-specific cost model
127pub struct HardwareCostModel {
128    backend_name: String,
129    weights: CostWeights,
130    gate_costs: HashMap<String, GateCost>,
131    gate_errors: HashMap<String, GateError>,
132    native_gates: Vec<String>,
133}
134
135impl HardwareCostModel {
136    /// Create a cost model for a specific backend
137    #[must_use]
138    pub fn for_backend(backend: &str) -> Self {
139        let (weights, gate_costs, gate_errors, native_gates) = match backend {
140            "ibm" => Self::ibm_config(),
141            "google" => Self::google_config(),
142            "aws" => Self::aws_config(),
143            _ => Self::default_config(),
144        };
145
146        Self {
147            backend_name: backend.to_string(),
148            weights,
149            gate_costs,
150            gate_errors,
151            native_gates,
152        }
153    }
154
155    fn ibm_config() -> (
156        CostWeights,
157        HashMap<String, GateCost>,
158        HashMap<String, GateError>,
159        Vec<String>,
160    ) {
161        let weights = CostWeights {
162            gate_count: 0.5,
163            execution_time: 1.5,
164            error_rate: 20.0,
165            circuit_depth: 1.0,
166        };
167
168        let mut gate_costs = HashMap::new();
169        gate_costs.insert("X".to_string(), GateCost::new(35.0, 1, 1.0));
170        gate_costs.insert("Y".to_string(), GateCost::new(35.0, 1, 1.0));
171        gate_costs.insert("Z".to_string(), GateCost::new(0.0, 0, 0.0)); // Virtual Z
172        gate_costs.insert("H".to_string(), GateCost::new(35.0, 1, 1.0));
173        gate_costs.insert("S".to_string(), GateCost::new(35.0, 1, 1.0));
174        gate_costs.insert("T".to_string(), GateCost::new(35.0, 1, 1.0));
175        gate_costs.insert("RZ".to_string(), GateCost::new(0.0, 0, 0.0)); // Virtual RZ
176        gate_costs.insert("CNOT".to_string(), GateCost::new(300.0, 1, 3.0));
177        gate_costs.insert("CZ".to_string(), GateCost::new(300.0, 1, 3.0));
178
179        let mut gate_errors = HashMap::new();
180        gate_errors.insert("X".to_string(), GateError::new(0.99975, 0.00025, 0.00002));
181        gate_errors.insert("CNOT".to_string(), GateError::new(0.9985, 0.0015, 0.0001));
182
183        let native_gates = vec!["X", "Y", "Z", "H", "S", "T", "RZ", "CNOT", "CZ"]
184            .into_iter()
185            .map(std::string::ToString::to_string)
186            .collect();
187
188        (weights, gate_costs, gate_errors, native_gates)
189    }
190
191    fn google_config() -> (
192        CostWeights,
193        HashMap<String, GateCost>,
194        HashMap<String, GateError>,
195        Vec<String>,
196    ) {
197        let weights = CostWeights {
198            gate_count: 0.8,
199            execution_time: 1.0,
200            error_rate: 15.0,
201            circuit_depth: 0.8,
202        };
203
204        let mut gate_costs = HashMap::new();
205        gate_costs.insert("X".to_string(), GateCost::new(25.0, 1, 1.0));
206        gate_costs.insert("Y".to_string(), GateCost::new(25.0, 1, 1.0));
207        gate_costs.insert("Z".to_string(), GateCost::new(0.0, 0, 0.0)); // Virtual
208        gate_costs.insert("H".to_string(), GateCost::new(25.0, 1, 1.0));
209        gate_costs.insert("RZ".to_string(), GateCost::new(0.0, 0, 0.0)); // Virtual
210        gate_costs.insert("SQRT_X".to_string(), GateCost::new(25.0, 1, 1.0));
211        gate_costs.insert("CZ".to_string(), GateCost::new(30.0, 1, 2.0));
212
213        let mut gate_errors = HashMap::new();
214        gate_errors.insert("X".to_string(), GateError::new(0.9998, 0.0002, 0.00001));
215        gate_errors.insert("CZ".to_string(), GateError::new(0.994, 0.006, 0.0003));
216
217        let native_gates = vec!["X", "Y", "Z", "H", "RZ", "SQRT_X", "CZ"]
218            .into_iter()
219            .map(std::string::ToString::to_string)
220            .collect();
221
222        (weights, gate_costs, gate_errors, native_gates)
223    }
224
225    fn aws_config() -> (
226        CostWeights,
227        HashMap<String, GateCost>,
228        HashMap<String, GateError>,
229        Vec<String>,
230    ) {
231        let weights = CostWeights {
232            gate_count: 1.0,
233            execution_time: 1.2,
234            error_rate: 10.0,
235            circuit_depth: 0.6,
236        };
237
238        let mut gate_costs = HashMap::new();
239        gate_costs.insert("X".to_string(), GateCost::new(50.0, 1, 1.0));
240        gate_costs.insert("Y".to_string(), GateCost::new(50.0, 1, 1.0));
241        gate_costs.insert("Z".to_string(), GateCost::new(50.0, 1, 1.0));
242        gate_costs.insert("H".to_string(), GateCost::new(50.0, 1, 1.0));
243        gate_costs.insert("RX".to_string(), GateCost::new(50.0, 1, 1.2));
244        gate_costs.insert("RY".to_string(), GateCost::new(50.0, 1, 1.2));
245        gate_costs.insert("RZ".to_string(), GateCost::new(50.0, 1, 1.2));
246        gate_costs.insert("CNOT".to_string(), GateCost::new(500.0, 1, 4.0));
247
248        let mut gate_errors = HashMap::new();
249        gate_errors.insert("X".to_string(), GateError::new(0.9997, 0.0003, 0.00002));
250        gate_errors.insert("CNOT".to_string(), GateError::new(0.997, 0.003, 0.0002));
251
252        let native_gates = vec!["X", "Y", "Z", "H", "RX", "RY", "RZ", "CNOT", "CZ"]
253            .into_iter()
254            .map(std::string::ToString::to_string)
255            .collect();
256
257        (weights, gate_costs, gate_errors, native_gates)
258    }
259
260    fn default_config() -> (
261        CostWeights,
262        HashMap<String, GateCost>,
263        HashMap<String, GateError>,
264        Vec<String>,
265    ) {
266        let weights = CostWeights::default();
267        let gate_costs = HashMap::new();
268        let gate_errors = HashMap::new();
269        let native_gates = vec!["H", "X", "Y", "Z", "S", "T", "RX", "RY", "RZ", "CNOT", "CZ"]
270            .into_iter()
271            .map(std::string::ToString::to_string)
272            .collect();
273
274        (weights, gate_costs, gate_errors, native_gates)
275    }
276}
277
278impl CostModel for HardwareCostModel {
279    fn gate_cost(&self, gate: &dyn GateOp) -> f64 {
280        let gate_name = gate.name().to_string();
281
282        // Use hardware-specific cost if available
283        if let Some(cost) = self.gate_costs.get(&gate_name) {
284            let error_cost = if let Some(error) = self.gate_errors.get(&gate_name) {
285                error.total_error() * self.weights.error_rate
286            } else {
287                0.0
288            };
289
290            cost.total_cost(
291                self.weights.execution_time,
292                self.weights.gate_count,
293                0.0, // Use error_cost separately
294            ) + error_cost
295        } else {
296            // Fall back to generic properties
297            let props = get_gate_properties(gate);
298            props.cost.total_cost(
299                self.weights.execution_time,
300                self.weights.gate_count,
301                self.weights.error_rate,
302            )
303        }
304    }
305
306    fn circuit_cost_from_gates(&self, gates: &[Box<dyn GateOp>]) -> f64 {
307        gates.iter().map(|g| self.gate_cost(g.as_ref())).sum()
308    }
309
310    fn weights(&self) -> CostWeights {
311        self.weights
312    }
313
314    fn is_native(&self, gate: &dyn GateOp) -> bool {
315        self.native_gates.contains(&gate.name().to_string())
316    }
317}