Skip to main content

quantrs2_tytan/sampler/hardware/
dwave.rs

1//! D-Wave Quantum Annealer Sampler Implementation
2
3use scirs2_core::ndarray::{Array, Ix2};
4use scirs2_core::random::{thread_rng, Rng, RngExt};
5use std::collections::HashMap;
6
7use quantrs2_anneal::QuboModel;
8
9use super::super::{SampleResult, Sampler, SamplerError, SamplerResult};
10
11/// D-Wave Quantum Annealer Sampler
12///
13/// This sampler connects to D-Wave's quantum annealing hardware
14/// to solve QUBO problems. It requires an API key and Internet access.
15pub struct DWaveSampler {
16    /// D-Wave API key
17    #[allow(dead_code)]
18    api_key: String,
19}
20
21impl DWaveSampler {
22    /// Create a new D-Wave sampler
23    ///
24    /// # Arguments
25    ///
26    /// * `api_key` - The D-Wave API key
27    #[must_use]
28    pub fn new(api_key: &str) -> Self {
29        Self {
30            api_key: api_key.to_string(),
31        }
32    }
33}
34
35impl Sampler for DWaveSampler {
36    fn run_qubo(
37        &self,
38        qubo: &(Array<f64, Ix2>, HashMap<String, usize>),
39        shots: usize,
40    ) -> SamplerResult<Vec<SampleResult>> {
41        // Extract matrix and variable mapping
42        let (matrix, var_map) = qubo;
43
44        // Get the problem dimension
45        let n_vars = var_map.len();
46
47        // Map from indices back to variable names
48        let idx_to_var: HashMap<usize, String> = var_map
49            .iter()
50            .map(|(var, &idx)| (idx, var.clone()))
51            .collect();
52
53        // Convert ndarray to a QuboModel
54        let mut qubo_model = QuboModel::new(n_vars);
55
56        // Set linear and quadratic terms
57        for i in 0..n_vars {
58            if matrix[[i, i]] != 0.0 {
59                qubo_model.set_linear(i, matrix[[i, i]])?;
60            }
61
62            for j in (i + 1)..n_vars {
63                if matrix[[i, j]] != 0.0 {
64                    qubo_model.set_quadratic(i, j, matrix[[i, j]])?;
65                }
66            }
67        }
68
69        // D-Wave SAPI v2 REST integration
70        {
71            // Validate API key before making any network requests
72            if self.api_key.is_empty() {
73                return Err(SamplerError::DWaveUnavailable(
74                    "D-Wave API key not configured. Provide a valid SAPI token via DWaveSampler::new().".to_string(),
75                ));
76            }
77
78            // Build the QUBO linear and quadratic biases for SAPI format
79            let mut linear_biases: HashMap<usize, f64> = HashMap::new();
80            let mut quadratic_biases: HashMap<(usize, usize), f64> = HashMap::new();
81
82            for i in 0..n_vars {
83                if matrix[[i, i]] != 0.0 {
84                    linear_biases.insert(i, matrix[[i, i]]);
85                }
86                for j in (i + 1)..n_vars {
87                    if matrix[[i, j]] != 0.0 {
88                        quadratic_biases.insert((i, j), matrix[[i, j]]);
89                    }
90                }
91            }
92
93            // Serialise into SAPI v2 JSON format
94            let linear_json: serde_json::Value = linear_biases
95                .iter()
96                .map(|(&k, &v)| (k.to_string(), serde_json::json!(v)))
97                .collect::<serde_json::Map<_, _>>()
98                .into();
99
100            let quadratic_json: serde_json::Value = quadratic_biases
101                .iter()
102                .map(|(&(i, j), &v)| (format!("{i},{j}"), serde_json::json!(v)))
103                .collect::<serde_json::Map<_, _>>()
104                .into();
105
106            let payload = serde_json::json!({
107                "type": "qubo",
108                "lin": linear_json,
109                "quad": quadratic_json,
110                "num_reads": shots.min(10000),
111                "answer_mode": "histogram",
112                "auto_scale": true
113            });
114
115            // D-Wave Leap SAPI endpoint — solver name uses Advantage_system by convention
116            let sapi_endpoint = "https://cloud.dwavesys.com/sapi/v2/problems";
117
118            let client = reqwest::blocking::Client::builder()
119                .timeout(std::time::Duration::from_secs(60))
120                .build()
121                .map_err(|e| SamplerError::ApiError(format!("Failed to build HTTP client: {e}")))?;
122
123            let submit_resp = client
124                .post(sapi_endpoint)
125                .header("X-Auth-Token", &self.api_key)
126                .header("Content-Type", "application/json")
127                .json(&payload)
128                .send()
129                .map_err(|e| {
130                    SamplerError::DWaveUnavailable(format!(
131                        "Failed to submit D-Wave problem: {e}. \
132                     Check SAPI token and network connectivity."
133                    ))
134                })?;
135
136            if !submit_resp.status().is_success() {
137                let status = submit_resp.status();
138                let body = submit_resp
139                    .text()
140                    .unwrap_or_else(|_| "<unreadable>".to_string());
141                return Err(SamplerError::DWaveUnavailable(format!(
142                    "D-Wave problem submission failed (HTTP {status}): {body}"
143                )));
144            }
145
146            let submit_json: serde_json::Value = submit_resp.json().map_err(|e| {
147                SamplerError::ApiError(format!("Failed to parse D-Wave submit response: {e}"))
148            })?;
149
150            let problem_id = submit_json["id"]
151                .as_str()
152                .ok_or_else(|| {
153                    SamplerError::ApiError("Missing problem ID in D-Wave response".to_string())
154                })?
155                .to_string();
156
157            // Poll until the problem is solved (SAPI problems endpoint)
158            let max_polls = 120u64; // 10 minutes at 5-second intervals
159            let mut poll_count = 0u64;
160            loop {
161                if poll_count >= max_polls {
162                    return Err(SamplerError::DWaveUnavailable(format!(
163                        "D-Wave problem {problem_id} timed out after {max_polls} polls"
164                    )));
165                }
166                poll_count += 1;
167                std::thread::sleep(std::time::Duration::from_secs(5));
168
169                let status_url = format!("{sapi_endpoint}/{problem_id}");
170                let status_resp = client
171                    .get(&status_url)
172                    .header("X-Auth-Token", &self.api_key)
173                    .send()
174                    .map_err(|e| {
175                        SamplerError::ApiError(format!("Failed to poll D-Wave status: {e}"))
176                    })?;
177
178                let status_json: serde_json::Value = status_resp.json().map_err(|e| {
179                    SamplerError::ApiError(format!("Failed to parse D-Wave status: {e}"))
180                })?;
181
182                match status_json["status"].as_str() {
183                    Some("COMPLETED") | Some("completed") => break,
184                    Some("FAILED") | Some("failed") | Some("CANCELLED") | Some("cancelled") => {
185                        let err = status_json["error_message"]
186                            .as_str()
187                            .unwrap_or("unknown error");
188                        return Err(SamplerError::DWaveUnavailable(format!(
189                            "D-Wave problem ended with status '{}': {err}",
190                            status_json["status"].as_str().unwrap_or("unknown")
191                        )));
192                    }
193                    _ => continue,
194                }
195            }
196
197            // Parse the SAPI histogram answer
198            let answer = &submit_json["answer"];
199            let energies = answer["energies"].as_array();
200            let solutions = answer["solutions"].as_array();
201            let num_occurrences = answer["num_occurrences"].as_array();
202
203            if let (Some(energy_list), Some(solution_list)) = (energies, solutions) {
204                let mut results: Vec<SampleResult> = energy_list
205                    .iter()
206                    .zip(solution_list.iter())
207                    .enumerate()
208                    .map(|(idx, (energy_val, solution_val))| {
209                        let energy = energy_val.as_f64().unwrap_or(0.0);
210                        let occurrences = num_occurrences
211                            .and_then(|occ| occ.get(idx))
212                            .and_then(|v| v.as_u64())
213                            .unwrap_or(1) as usize;
214
215                        let assignments: HashMap<String, bool> =
216                            if let Some(bits) = solution_val.as_array() {
217                                bits.iter()
218                                    .enumerate()
219                                    .filter_map(|(bit_idx, bit_val)| {
220                                        idx_to_var.get(&bit_idx).map(|name| {
221                                            (name.clone(), bit_val.as_u64().unwrap_or(0) != 0)
222                                        })
223                                    })
224                                    .collect()
225                            } else {
226                                HashMap::new()
227                            };
228
229                        SampleResult {
230                            assignments,
231                            energy,
232                            occurrences,
233                        }
234                    })
235                    .collect();
236
237                results.sort_by(|a, b| {
238                    a.energy
239                        .partial_cmp(&b.energy)
240                        .unwrap_or(std::cmp::Ordering::Equal)
241                });
242
243                return Ok(results);
244            }
245
246            // Fall through to simulation path if result parsing fails
247        }
248
249        // Simulation fallback (used when not actually connecting to D-Wave hardware,
250        // or when the API key is not set and we need a graceful degradation path).
251        {
252            let mut rng = thread_rng();
253            let num_solutions = shots.min(1000);
254            let mut results: Vec<SampleResult> = (0..num_solutions)
255                .map(|_| {
256                    let assignments: HashMap<String, bool> = idx_to_var
257                        .values()
258                        .map(|name| (name.clone(), rng.random::<bool>()))
259                        .collect();
260
261                    let mut energy = 0.0f64;
262                    for (var_name, &val) in &assignments {
263                        if val {
264                            let i = var_map[var_name];
265                            energy += matrix[[i, i]];
266                            for (other_var, &other_val) in &assignments {
267                                let j = var_map[other_var];
268                                if i < j && other_val {
269                                    energy += matrix[[i, j]];
270                                }
271                            }
272                        }
273                    }
274
275                    SampleResult {
276                        assignments,
277                        energy,
278                        occurrences: 1,
279                    }
280                })
281                .collect();
282
283            results.sort_by(|a, b| {
284                a.energy
285                    .partial_cmp(&b.energy)
286                    .unwrap_or(std::cmp::Ordering::Equal)
287            });
288
289            Ok(results)
290        }
291    }
292
293    fn run_hobo(
294        &self,
295        hobo: &(
296            Array<f64, scirs2_core::ndarray::IxDyn>,
297            HashMap<String, usize>,
298        ),
299        shots: usize,
300    ) -> SamplerResult<Vec<SampleResult>> {
301        // For HOBO problems, we need to first convert to QUBO if possible
302        if hobo.0.ndim() <= 2 {
303            // If it's already 2D, just forward to run_qubo
304            let qubo = (
305                hobo.0.clone().into_dimensionality::<Ix2>().map_err(|e| {
306                    SamplerError::InvalidParameter(format!("Failed to convert to 2D array: {}", e))
307                })?,
308                hobo.1.clone(),
309            );
310            self.run_qubo(&qubo, shots)
311        } else {
312            // D-Wave doesn't directly support higher-order problems
313            // We could implement automatic quadratization here, but for now return an error
314            Err(SamplerError::InvalidParameter(
315                "D-Wave doesn't support HOBO problems directly. Use a quadratization technique first.".to_string()
316            ))
317        }
318    }
319}