quantrs2_tytan/sampler/hardware/
amazon_braket.rs1use scirs2_core::ndarray::{Array, Ix2};
7use scirs2_core::random::{thread_rng, Rng, RngExt};
8use std::collections::HashMap;
9
10use quantrs2_anneal::QuboModel;
11
12use super::super::{SampleResult, Sampler, SamplerError, SamplerResult};
13
14#[derive(Debug, Clone)]
16#[non_exhaustive]
17pub enum BraketDevice {
18 LocalSimulator,
20 StateVectorSimulator,
22 TensorNetworkSimulator,
24 IonQDevice,
26 RigettiDevice(String),
28 OQCDevice,
30 DWaveAdvantage,
32 DWave2000Q,
34}
35
36#[derive(Debug, Clone)]
38pub struct AmazonBraketConfig {
39 pub region: String,
41 pub s3_bucket: String,
43 pub s3_prefix: String,
45 pub device: BraketDevice,
47 pub max_parallel: usize,
49 pub poll_interval: u64,
51}
52
53impl Default for AmazonBraketConfig {
54 fn default() -> Self {
55 Self {
56 region: "us-east-1".to_string(),
57 s3_bucket: String::new(),
58 s3_prefix: "braket-results".to_string(),
59 device: BraketDevice::LocalSimulator,
60 max_parallel: 10,
61 poll_interval: 5,
62 }
63 }
64}
65
66pub struct AmazonBraketSampler {
71 config: AmazonBraketConfig,
72}
73
74impl AmazonBraketSampler {
75 #[must_use]
81 pub const fn new(config: AmazonBraketConfig) -> Self {
82 Self { config }
83 }
84
85 #[must_use]
92 pub fn with_s3(s3_bucket: &str, region: &str) -> Self {
93 Self {
94 config: AmazonBraketConfig {
95 s3_bucket: s3_bucket.to_string(),
96 region: region.to_string(),
97 ..Default::default()
98 },
99 }
100 }
101
102 #[must_use]
104 pub fn with_device(mut self, device: BraketDevice) -> Self {
105 self.config.device = device;
106 self
107 }
108
109 #[must_use]
111 pub const fn with_max_parallel(mut self, max_parallel: usize) -> Self {
112 self.config.max_parallel = max_parallel;
113 self
114 }
115
116 #[must_use]
118 pub const fn with_poll_interval(mut self, interval: u64) -> Self {
119 self.config.poll_interval = interval;
120 self
121 }
122}
123
124impl Sampler for AmazonBraketSampler {
125 fn run_qubo(
126 &self,
127 qubo: &(Array<f64, Ix2>, HashMap<String, usize>),
128 shots: usize,
129 ) -> SamplerResult<Vec<SampleResult>> {
130 let (matrix, var_map) = qubo;
132
133 let n_vars = var_map.len();
135
136 match &self.config.device {
138 BraketDevice::LocalSimulator | BraketDevice::StateVectorSimulator => {
139 if n_vars > 34 {
140 return Err(SamplerError::InvalidParameter(
141 "State vector simulators support up to 34 qubits".to_string(),
142 ));
143 }
144 }
145 BraketDevice::TensorNetworkSimulator => {
146 if n_vars > 50 {
147 return Err(SamplerError::InvalidParameter(
148 "Tensor network simulator supports up to 50 qubits".to_string(),
149 ));
150 }
151 }
152 BraketDevice::IonQDevice => {
153 if n_vars > 29 {
154 return Err(SamplerError::InvalidParameter(
155 "IonQ device supports up to 29 qubits".to_string(),
156 ));
157 }
158 }
159 BraketDevice::RigettiDevice(_) => {
160 if n_vars > 40 {
161 return Err(SamplerError::InvalidParameter(
162 "Rigetti devices support up to 40 qubits".to_string(),
163 ));
164 }
165 }
166 BraketDevice::OQCDevice => {
167 if n_vars > 8 {
168 return Err(SamplerError::InvalidParameter(
169 "OQC device supports up to 8 qubits".to_string(),
170 ));
171 }
172 }
173 BraketDevice::DWaveAdvantage => {
174 if n_vars > 5000 {
175 return Err(SamplerError::InvalidParameter(
176 "D-Wave Advantage supports up to 5000 variables".to_string(),
177 ));
178 }
179 }
180 BraketDevice::DWave2000Q => {
181 if n_vars > 2000 {
182 return Err(SamplerError::InvalidParameter(
183 "D-Wave 2000Q supports up to 2000 variables".to_string(),
184 ));
185 }
186 }
187 }
188
189 let idx_to_var: HashMap<usize, String> = var_map
191 .iter()
192 .map(|(var, &idx)| (idx, var.clone()))
193 .collect();
194
195 let mut qubo_model = QuboModel::new(n_vars);
197
198 for i in 0..n_vars {
200 if matrix[[i, i]] != 0.0 {
201 qubo_model.set_linear(i, matrix[[i, i]])?;
202 }
203
204 for j in (i + 1)..n_vars {
205 if matrix[[i, j]] != 0.0 {
206 qubo_model.set_quadratic(i, j, matrix[[i, j]])?;
207 }
208 }
209 }
210
211 #[cfg(feature = "amazon_braket")]
213 {
214 if self.config.s3_bucket.is_empty() {
216 return Err(SamplerError::ApiError(
217 "Amazon Braket S3 bucket not configured. Call with_s3() to set credentials."
218 .to_string(),
219 ));
220 }
221
222 let linear_terms: serde_json::Value = (0..n_vars)
224 .filter_map(|i| {
225 let v = matrix[[i, i]];
226 if v != 0.0 {
227 Some((i.to_string(), v))
228 } else {
229 None
230 }
231 })
232 .map(|(k, v)| (k, serde_json::Value::from(v)))
233 .collect::<serde_json::Map<_, _>>()
234 .into();
235
236 let mut quadratic_map = serde_json::Map::new();
237 for i in 0..n_vars {
238 for j in (i + 1)..n_vars {
239 let v = matrix[[i, j]];
240 if v != 0.0 {
241 quadratic_map.insert(format!("{i},{j}"), serde_json::json!(v));
242 }
243 }
244 }
245
246 let device_arn = match &self.config.device {
247 BraketDevice::LocalSimulator => {
248 "arn:aws:braket:::device/quantum-simulator/amazon/sv1"
249 }
250 BraketDevice::StateVectorSimulator => {
251 "arn:aws:braket:::device/quantum-simulator/amazon/sv1"
252 }
253 BraketDevice::TensorNetworkSimulator => {
254 "arn:aws:braket:::device/quantum-simulator/amazon/tn1"
255 }
256 BraketDevice::IonQDevice => "arn:aws:braket:us-east-1::device/qpu/ionq/ionQdevice",
257 BraketDevice::RigettiDevice(name) => name.as_str(),
258 BraketDevice::OQCDevice => "arn:aws:braket:eu-west-2::device/qpu/oqc/Lucy",
259 BraketDevice::DWaveAdvantage => {
260 "arn:aws:braket:::device/qpu/d-wave/Advantage_system4"
261 }
262 BraketDevice::DWave2000Q => "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6",
263 };
264
265 let payload = serde_json::json!({
266 "deviceArn": device_arn,
267 "outputS3Bucket": self.config.s3_bucket,
268 "outputS3KeyPrefix": self.config.s3_prefix,
269 "shots": shots,
270 "action": {
271 "actionType": "OPENQASM",
272 "problem": {
273 "type": "QUBO",
274 "linear": linear_terms,
275 "quadratic": quadratic_map
276 }
277 }
278 });
279
280 let endpoint = format!(
281 "https://braket.{}.amazonaws.com/quantum-task",
282 self.config.region
283 );
284
285 let client = reqwest::blocking::Client::builder()
287 .timeout(std::time::Duration::from_secs(30))
288 .build()
289 .map_err(|e| SamplerError::ApiError(format!("Failed to build HTTP client: {e}")))?;
290
291 let response = client
292 .post(&endpoint)
293 .header("Content-Type", "application/json")
294 .json(&payload)
295 .send()
296 .map_err(|e| {
297 SamplerError::ApiError(format!(
298 "Failed to submit Amazon Braket task (endpoint: {endpoint}): {e}. \
299 Check AWS credentials and network connectivity."
300 ))
301 })?;
302
303 if !response.status().is_success() {
304 let status = response.status();
305 let body = response
306 .text()
307 .unwrap_or_else(|_| "<unreadable>".to_string());
308 return Err(SamplerError::ApiError(format!(
309 "Amazon Braket task submission failed (HTTP {status}): {body}"
310 )));
311 }
312
313 let task_response: serde_json::Value = response.json().map_err(|e| {
314 SamplerError::ApiError(format!("Failed to parse Braket response: {e}"))
315 })?;
316
317 let task_arn = task_response["quantumTaskArn"]
318 .as_str()
319 .ok_or_else(|| {
320 SamplerError::ApiError("Missing quantumTaskArn in response".to_string())
321 })?
322 .to_string();
323
324 let status_endpoint = format!(
326 "https://braket.{}.amazonaws.com/quantum-task/{task_arn}",
327 self.config.region
328 );
329 let max_polls = 360u64; let mut poll_count = 0u64;
331 loop {
332 if poll_count >= max_polls {
333 return Err(SamplerError::ApiError(format!(
334 "Amazon Braket task {task_arn} timed out after {} polls",
335 max_polls
336 )));
337 }
338 poll_count += 1;
339 std::thread::sleep(std::time::Duration::from_secs(self.config.poll_interval));
340
341 let status_resp = client.get(&status_endpoint).send().map_err(|e| {
342 SamplerError::ApiError(format!("Failed to poll task status: {e}"))
343 })?;
344
345 let status_json: serde_json::Value = status_resp.json().map_err(|e| {
346 SamplerError::ApiError(format!("Failed to parse status response: {e}"))
347 })?;
348
349 match status_json["status"].as_str() {
350 Some("COMPLETED") => break,
351 Some("FAILED") => {
352 let reason = status_json["failureReason"]
353 .as_str()
354 .unwrap_or("unknown reason");
355 return Err(SamplerError::ApiError(format!(
356 "Amazon Braket task failed: {reason}"
357 )));
358 }
359 Some("CANCELLED") => {
360 return Err(SamplerError::ApiError(
361 "Amazon Braket task was cancelled".to_string(),
362 ));
363 }
364 _ => continue, }
366 }
367
368 let result_s3_uri = task_response["outputS3Directory"]
370 .as_str()
371 .unwrap_or("")
372 .to_string();
373
374 let _ = result_s3_uri; }
379
380 let mut results = Vec::new();
382 let mut rng = thread_rng();
383
384 let unique_solutions = match &self.config.device {
386 BraketDevice::DWaveAdvantage | BraketDevice::DWave2000Q => {
387 shots.min(1000)
389 }
390 BraketDevice::LocalSimulator | BraketDevice::StateVectorSimulator => {
391 shots.min(500)
393 }
394 BraketDevice::TensorNetworkSimulator => shots.min(300),
395 _ => {
396 shots.min(100)
398 }
399 };
400
401 for _ in 0..unique_solutions {
402 let assignments: HashMap<String, bool> = idx_to_var
403 .values()
404 .map(|name| (name.clone(), rng.random::<bool>()))
405 .collect();
406
407 let mut energy = 0.0;
409 for (var_name, &val) in &assignments {
410 let i = var_map[var_name];
411 if val {
412 energy += matrix[[i, i]];
413 for (other_var, &other_val) in &assignments {
414 let j = var_map[other_var];
415 if i < j && other_val {
416 energy += matrix[[i, j]];
417 }
418 }
419 }
420 }
421
422 let occurrences = match &self.config.device {
424 BraketDevice::DWaveAdvantage | BraketDevice::DWave2000Q => {
425 rng.random_range(1..=(shots / unique_solutions + 20))
427 }
428 _ => {
429 rng.random_range(1..=(shots / unique_solutions + 5))
431 }
432 };
433
434 results.push(SampleResult {
435 assignments,
436 energy,
437 occurrences,
438 });
439 }
440
441 results.sort_by(|a, b| {
443 a.energy
444 .partial_cmp(&b.energy)
445 .unwrap_or(std::cmp::Ordering::Equal)
446 });
447
448 results.truncate(shots.min(100));
450
451 Ok(results)
452 }
453
454 fn run_hobo(
455 &self,
456 hobo: &(
457 Array<f64, scirs2_core::ndarray::IxDyn>,
458 HashMap<String, usize>,
459 ),
460 shots: usize,
461 ) -> SamplerResult<Vec<SampleResult>> {
462 use scirs2_core::ndarray::Ix2;
463
464 if hobo.0.ndim() <= 2 {
466 let qubo_matrix = hobo.0.clone().into_dimensionality::<Ix2>().map_err(|e| {
468 SamplerError::InvalidParameter(format!(
469 "Failed to convert HOBO to QUBO dimensionality: {e}"
470 ))
471 })?;
472 let qubo = (qubo_matrix, hobo.1.clone());
473 self.run_qubo(&qubo, shots)
474 } else {
475 Err(SamplerError::InvalidParameter(
477 "Amazon Braket doesn't support HOBO problems directly. Use a quadratization technique first.".to_string()
478 ))
479 }
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn test_amazon_braket_config() {
489 let config = AmazonBraketConfig::default();
490 assert_eq!(config.region, "us-east-1");
491 assert_eq!(config.s3_prefix, "braket-results");
492 assert_eq!(config.max_parallel, 10);
493 assert!(matches!(config.device, BraketDevice::LocalSimulator));
494 }
495
496 #[test]
497 fn test_amazon_braket_sampler_creation() {
498 let sampler = AmazonBraketSampler::with_s3("my-bucket", "us-west-2")
499 .with_device(BraketDevice::IonQDevice)
500 .with_max_parallel(20)
501 .with_poll_interval(10);
502
503 assert_eq!(sampler.config.s3_bucket, "my-bucket");
504 assert_eq!(sampler.config.region, "us-west-2");
505 assert_eq!(sampler.config.max_parallel, 20);
506 assert_eq!(sampler.config.poll_interval, 10);
507 assert!(matches!(sampler.config.device, BraketDevice::IonQDevice));
508 }
509
510 #[test]
511 fn test_braket_device_types() {
512 let devices = [
513 BraketDevice::LocalSimulator,
514 BraketDevice::StateVectorSimulator,
515 BraketDevice::TensorNetworkSimulator,
516 BraketDevice::IonQDevice,
517 BraketDevice::RigettiDevice("Aspen-M-3".to_string()),
518 BraketDevice::OQCDevice,
519 BraketDevice::DWaveAdvantage,
520 BraketDevice::DWave2000Q,
521 ];
522
523 assert_eq!(devices.len(), 8);
524 }
525
526 #[test]
527 fn test_braket_device_limits() {
528 let sv_device = BraketDevice::StateVectorSimulator;
530 let tn_device = BraketDevice::TensorNetworkSimulator;
531 let dwave_device = BraketDevice::DWaveAdvantage;
532
533 assert!(matches!(sv_device, BraketDevice::StateVectorSimulator));
535 assert!(matches!(tn_device, BraketDevice::TensorNetworkSimulator));
536 assert!(matches!(dwave_device, BraketDevice::DWaveAdvantage));
537 }
538}