quantrs2_tytan/sampler/hardware/
amazon_braket.rs1use scirs2_core::ndarray::{Array, Ix2};
7use scirs2_core::random::{thread_rng, Rng, RngExt};
8use std::collections::HashMap;
9
10use quantrs2_anneal::QuboModel;
11
12use super::super::{SampleResult, Sampler, SamplerError, SamplerResult};
13
14#[derive(Debug, Clone)]
16pub enum BraketDevice {
17 LocalSimulator,
19 StateVectorSimulator,
21 TensorNetworkSimulator,
23 IonQDevice,
25 RigettiDevice(String),
27 OQCDevice,
29 DWaveAdvantage,
31 DWave2000Q,
33}
34
35#[derive(Debug, Clone)]
37pub struct AmazonBraketConfig {
38 pub region: String,
40 pub s3_bucket: String,
42 pub s3_prefix: String,
44 pub device: BraketDevice,
46 pub max_parallel: usize,
48 pub poll_interval: u64,
50}
51
52impl Default for AmazonBraketConfig {
53 fn default() -> Self {
54 Self {
55 region: "us-east-1".to_string(),
56 s3_bucket: String::new(),
57 s3_prefix: "braket-results".to_string(),
58 device: BraketDevice::LocalSimulator,
59 max_parallel: 10,
60 poll_interval: 5,
61 }
62 }
63}
64
65pub struct AmazonBraketSampler {
70 config: AmazonBraketConfig,
71}
72
73impl AmazonBraketSampler {
74 #[must_use]
80 pub const fn new(config: AmazonBraketConfig) -> Self {
81 Self { config }
82 }
83
84 #[must_use]
91 pub fn with_s3(s3_bucket: &str, region: &str) -> Self {
92 Self {
93 config: AmazonBraketConfig {
94 s3_bucket: s3_bucket.to_string(),
95 region: region.to_string(),
96 ..Default::default()
97 },
98 }
99 }
100
101 #[must_use]
103 pub fn with_device(mut self, device: BraketDevice) -> Self {
104 self.config.device = device;
105 self
106 }
107
108 #[must_use]
110 pub const fn with_max_parallel(mut self, max_parallel: usize) -> Self {
111 self.config.max_parallel = max_parallel;
112 self
113 }
114
115 #[must_use]
117 pub const fn with_poll_interval(mut self, interval: u64) -> Self {
118 self.config.poll_interval = interval;
119 self
120 }
121}
122
123impl Sampler for AmazonBraketSampler {
124 fn run_qubo(
125 &self,
126 qubo: &(Array<f64, Ix2>, HashMap<String, usize>),
127 shots: usize,
128 ) -> SamplerResult<Vec<SampleResult>> {
129 let (matrix, var_map) = qubo;
131
132 let n_vars = var_map.len();
134
135 match &self.config.device {
137 BraketDevice::LocalSimulator | BraketDevice::StateVectorSimulator => {
138 if n_vars > 34 {
139 return Err(SamplerError::InvalidParameter(
140 "State vector simulators support up to 34 qubits".to_string(),
141 ));
142 }
143 }
144 BraketDevice::TensorNetworkSimulator => {
145 if n_vars > 50 {
146 return Err(SamplerError::InvalidParameter(
147 "Tensor network simulator supports up to 50 qubits".to_string(),
148 ));
149 }
150 }
151 BraketDevice::IonQDevice => {
152 if n_vars > 29 {
153 return Err(SamplerError::InvalidParameter(
154 "IonQ device supports up to 29 qubits".to_string(),
155 ));
156 }
157 }
158 BraketDevice::RigettiDevice(_) => {
159 if n_vars > 40 {
160 return Err(SamplerError::InvalidParameter(
161 "Rigetti devices support up to 40 qubits".to_string(),
162 ));
163 }
164 }
165 BraketDevice::OQCDevice => {
166 if n_vars > 8 {
167 return Err(SamplerError::InvalidParameter(
168 "OQC device supports up to 8 qubits".to_string(),
169 ));
170 }
171 }
172 BraketDevice::DWaveAdvantage => {
173 if n_vars > 5000 {
174 return Err(SamplerError::InvalidParameter(
175 "D-Wave Advantage supports up to 5000 variables".to_string(),
176 ));
177 }
178 }
179 BraketDevice::DWave2000Q => {
180 if n_vars > 2000 {
181 return Err(SamplerError::InvalidParameter(
182 "D-Wave 2000Q supports up to 2000 variables".to_string(),
183 ));
184 }
185 }
186 }
187
188 let idx_to_var: HashMap<usize, String> = var_map
190 .iter()
191 .map(|(var, &idx)| (idx, var.clone()))
192 .collect();
193
194 let mut qubo_model = QuboModel::new(n_vars);
196
197 for i in 0..n_vars {
199 if matrix[[i, i]] != 0.0 {
200 qubo_model.set_linear(i, matrix[[i, i]])?;
201 }
202
203 for j in (i + 1)..n_vars {
204 if matrix[[i, j]] != 0.0 {
205 qubo_model.set_quadratic(i, j, matrix[[i, j]])?;
206 }
207 }
208 }
209
210 #[cfg(feature = "amazon_braket")]
212 {
213 if self.config.s3_bucket.is_empty() {
215 return Err(SamplerError::ApiError(
216 "Amazon Braket S3 bucket not configured. Call with_s3() to set credentials."
217 .to_string(),
218 ));
219 }
220
221 let linear_terms: serde_json::Value = (0..n_vars)
223 .filter_map(|i| {
224 let v = matrix[[i, i]];
225 if v != 0.0 {
226 Some((i.to_string(), v))
227 } else {
228 None
229 }
230 })
231 .map(|(k, v)| (k, serde_json::Value::from(v)))
232 .collect::<serde_json::Map<_, _>>()
233 .into();
234
235 let mut quadratic_map = serde_json::Map::new();
236 for i in 0..n_vars {
237 for j in (i + 1)..n_vars {
238 let v = matrix[[i, j]];
239 if v != 0.0 {
240 quadratic_map.insert(format!("{i},{j}"), serde_json::json!(v));
241 }
242 }
243 }
244
245 let device_arn = match &self.config.device {
246 BraketDevice::LocalSimulator => {
247 "arn:aws:braket:::device/quantum-simulator/amazon/sv1"
248 }
249 BraketDevice::StateVectorSimulator => {
250 "arn:aws:braket:::device/quantum-simulator/amazon/sv1"
251 }
252 BraketDevice::TensorNetworkSimulator => {
253 "arn:aws:braket:::device/quantum-simulator/amazon/tn1"
254 }
255 BraketDevice::IonQDevice => "arn:aws:braket:us-east-1::device/qpu/ionq/ionQdevice",
256 BraketDevice::RigettiDevice(name) => name.as_str(),
257 BraketDevice::OQCDevice => "arn:aws:braket:eu-west-2::device/qpu/oqc/Lucy",
258 BraketDevice::DWaveAdvantage => {
259 "arn:aws:braket:::device/qpu/d-wave/Advantage_system4"
260 }
261 BraketDevice::DWave2000Q => "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6",
262 };
263
264 let payload = serde_json::json!({
265 "deviceArn": device_arn,
266 "outputS3Bucket": self.config.s3_bucket,
267 "outputS3KeyPrefix": self.config.s3_prefix,
268 "shots": shots,
269 "action": {
270 "actionType": "OPENQASM",
271 "problem": {
272 "type": "QUBO",
273 "linear": linear_terms,
274 "quadratic": quadratic_map
275 }
276 }
277 });
278
279 let endpoint = format!(
280 "https://braket.{}.amazonaws.com/quantum-task",
281 self.config.region
282 );
283
284 let client = reqwest::blocking::Client::builder()
286 .timeout(std::time::Duration::from_secs(30))
287 .build()
288 .map_err(|e| SamplerError::ApiError(format!("Failed to build HTTP client: {e}")))?;
289
290 let response = client
291 .post(&endpoint)
292 .header("Content-Type", "application/json")
293 .json(&payload)
294 .send()
295 .map_err(|e| {
296 SamplerError::ApiError(format!(
297 "Failed to submit Amazon Braket task (endpoint: {endpoint}): {e}. \
298 Check AWS credentials and network connectivity."
299 ))
300 })?;
301
302 if !response.status().is_success() {
303 let status = response.status();
304 let body = response
305 .text()
306 .unwrap_or_else(|_| "<unreadable>".to_string());
307 return Err(SamplerError::ApiError(format!(
308 "Amazon Braket task submission failed (HTTP {status}): {body}"
309 )));
310 }
311
312 let task_response: serde_json::Value = response.json().map_err(|e| {
313 SamplerError::ApiError(format!("Failed to parse Braket response: {e}"))
314 })?;
315
316 let task_arn = task_response["quantumTaskArn"]
317 .as_str()
318 .ok_or_else(|| {
319 SamplerError::ApiError("Missing quantumTaskArn in response".to_string())
320 })?
321 .to_string();
322
323 let status_endpoint = format!(
325 "https://braket.{}.amazonaws.com/quantum-task/{task_arn}",
326 self.config.region
327 );
328 let max_polls = 360u64; let mut poll_count = 0u64;
330 loop {
331 if poll_count >= max_polls {
332 return Err(SamplerError::ApiError(format!(
333 "Amazon Braket task {task_arn} timed out after {} polls",
334 max_polls
335 )));
336 }
337 poll_count += 1;
338 std::thread::sleep(std::time::Duration::from_secs(self.config.poll_interval));
339
340 let status_resp = client.get(&status_endpoint).send().map_err(|e| {
341 SamplerError::ApiError(format!("Failed to poll task status: {e}"))
342 })?;
343
344 let status_json: serde_json::Value = status_resp.json().map_err(|e| {
345 SamplerError::ApiError(format!("Failed to parse status response: {e}"))
346 })?;
347
348 match status_json["status"].as_str() {
349 Some("COMPLETED") => break,
350 Some("FAILED") => {
351 let reason = status_json["failureReason"]
352 .as_str()
353 .unwrap_or("unknown reason");
354 return Err(SamplerError::ApiError(format!(
355 "Amazon Braket task failed: {reason}"
356 )));
357 }
358 Some("CANCELLED") => {
359 return Err(SamplerError::ApiError(
360 "Amazon Braket task was cancelled".to_string(),
361 ));
362 }
363 _ => continue, }
365 }
366
367 let result_s3_uri = task_response["outputS3Directory"]
369 .as_str()
370 .unwrap_or("")
371 .to_string();
372
373 let _ = result_s3_uri; }
378
379 let mut results = Vec::new();
381 let mut rng = thread_rng();
382
383 let unique_solutions = match &self.config.device {
385 BraketDevice::DWaveAdvantage | BraketDevice::DWave2000Q => {
386 shots.min(1000)
388 }
389 BraketDevice::LocalSimulator | BraketDevice::StateVectorSimulator => {
390 shots.min(500)
392 }
393 BraketDevice::TensorNetworkSimulator => shots.min(300),
394 _ => {
395 shots.min(100)
397 }
398 };
399
400 for _ in 0..unique_solutions {
401 let assignments: HashMap<String, bool> = idx_to_var
402 .values()
403 .map(|name| (name.clone(), rng.random::<bool>()))
404 .collect();
405
406 let mut energy = 0.0;
408 for (var_name, &val) in &assignments {
409 let i = var_map[var_name];
410 if val {
411 energy += matrix[[i, i]];
412 for (other_var, &other_val) in &assignments {
413 let j = var_map[other_var];
414 if i < j && other_val {
415 energy += matrix[[i, j]];
416 }
417 }
418 }
419 }
420
421 let occurrences = match &self.config.device {
423 BraketDevice::DWaveAdvantage | BraketDevice::DWave2000Q => {
424 rng.random_range(1..=(shots / unique_solutions + 20))
426 }
427 _ => {
428 rng.random_range(1..=(shots / unique_solutions + 5))
430 }
431 };
432
433 results.push(SampleResult {
434 assignments,
435 energy,
436 occurrences,
437 });
438 }
439
440 results.sort_by(|a, b| {
442 a.energy
443 .partial_cmp(&b.energy)
444 .unwrap_or(std::cmp::Ordering::Equal)
445 });
446
447 results.truncate(shots.min(100));
449
450 Ok(results)
451 }
452
453 fn run_hobo(
454 &self,
455 hobo: &(
456 Array<f64, scirs2_core::ndarray::IxDyn>,
457 HashMap<String, usize>,
458 ),
459 shots: usize,
460 ) -> SamplerResult<Vec<SampleResult>> {
461 use scirs2_core::ndarray::Ix2;
462
463 if hobo.0.ndim() <= 2 {
465 let qubo_matrix = hobo.0.clone().into_dimensionality::<Ix2>().map_err(|e| {
467 SamplerError::InvalidParameter(format!(
468 "Failed to convert HOBO to QUBO dimensionality: {e}"
469 ))
470 })?;
471 let qubo = (qubo_matrix, hobo.1.clone());
472 self.run_qubo(&qubo, shots)
473 } else {
474 Err(SamplerError::InvalidParameter(
476 "Amazon Braket doesn't support HOBO problems directly. Use a quadratization technique first.".to_string()
477 ))
478 }
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn test_amazon_braket_config() {
488 let config = AmazonBraketConfig::default();
489 assert_eq!(config.region, "us-east-1");
490 assert_eq!(config.s3_prefix, "braket-results");
491 assert_eq!(config.max_parallel, 10);
492 assert!(matches!(config.device, BraketDevice::LocalSimulator));
493 }
494
495 #[test]
496 fn test_amazon_braket_sampler_creation() {
497 let sampler = AmazonBraketSampler::with_s3("my-bucket", "us-west-2")
498 .with_device(BraketDevice::IonQDevice)
499 .with_max_parallel(20)
500 .with_poll_interval(10);
501
502 assert_eq!(sampler.config.s3_bucket, "my-bucket");
503 assert_eq!(sampler.config.region, "us-west-2");
504 assert_eq!(sampler.config.max_parallel, 20);
505 assert_eq!(sampler.config.poll_interval, 10);
506 assert!(matches!(sampler.config.device, BraketDevice::IonQDevice));
507 }
508
509 #[test]
510 fn test_braket_device_types() {
511 let devices = [
512 BraketDevice::LocalSimulator,
513 BraketDevice::StateVectorSimulator,
514 BraketDevice::TensorNetworkSimulator,
515 BraketDevice::IonQDevice,
516 BraketDevice::RigettiDevice("Aspen-M-3".to_string()),
517 BraketDevice::OQCDevice,
518 BraketDevice::DWaveAdvantage,
519 BraketDevice::DWave2000Q,
520 ];
521
522 assert_eq!(devices.len(), 8);
523 }
524
525 #[test]
526 fn test_braket_device_limits() {
527 let sv_device = BraketDevice::StateVectorSimulator;
529 let tn_device = BraketDevice::TensorNetworkSimulator;
530 let dwave_device = BraketDevice::DWaveAdvantage;
531
532 assert!(matches!(sv_device, BraketDevice::StateVectorSimulator));
534 assert!(matches!(tn_device, BraketDevice::TensorNetworkSimulator));
535 assert!(matches!(dwave_device, BraketDevice::DWaveAdvantage));
536 }
537}