1use quantrs2_circuit::prelude::Circuit;
2use std::collections::HashMap;
3#[cfg(feature = "aws")]
4use std::sync::Arc;
5#[cfg(feature = "aws")]
6use std::thread::sleep;
7#[cfg(feature = "aws")]
8use std::time::Duration;
9
10#[cfg(feature = "aws")]
11use reqwest::{header, Client};
12#[cfg(feature = "aws")]
13use serde::{Deserialize, Serialize};
14#[cfg(feature = "aws")]
15use serde_json;
16use thiserror::Error;
17
18use crate::DeviceError;
19use crate::DeviceResult;
20
21#[cfg(feature = "aws")]
22const AWS_BRAKET_API_URL: &str = "https://braket.{region}.amazonaws.com";
23#[cfg(feature = "aws")]
24const DEFAULT_TIMEOUT_SECS: u64 = 120;
25#[cfg(feature = "aws")]
26const DEFAULT_REGION: &str = "us-east-1";
27
28#[derive(Debug, Clone)]
30#[cfg_attr(feature = "aws", derive(serde::Deserialize))]
31pub struct AWSDevice {
32 pub device_arn: String,
34 pub name: String,
36 pub device_type: String,
38 pub provider_name: String,
40 pub status: String,
42 pub num_qubits: usize,
44 #[cfg(feature = "aws")]
46 pub device_capabilities: serde_json::Value,
47 #[cfg(not(feature = "aws"))]
49 pub device_capabilities: (),
50}
51
52#[derive(Debug, Clone)]
54#[cfg_attr(feature = "aws", derive(Serialize))]
55pub struct AWSCircuitConfig {
56 pub name: String,
58 pub ir: String,
60 pub ir_type: String,
62 pub shots: usize,
64 pub s3_bucket: String,
66 pub s3_key_prefix: String,
68 #[cfg(feature = "aws")]
70 pub device_parameters: Option<serde_json::Value>,
71 #[cfg(not(feature = "aws"))]
73 pub device_parameters: Option<()>,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
78#[cfg_attr(feature = "aws", derive(Deserialize))]
79pub enum AWSTaskStatus {
80 #[cfg_attr(feature = "aws", serde(rename = "CREATED"))]
81 Created,
82 #[cfg_attr(feature = "aws", serde(rename = "QUEUED"))]
83 Queued,
84 #[cfg_attr(feature = "aws", serde(rename = "RUNNING"))]
85 Running,
86 #[cfg_attr(feature = "aws", serde(rename = "COMPLETED"))]
87 Completed,
88 #[cfg_attr(feature = "aws", serde(rename = "FAILED"))]
89 Failed,
90 #[cfg_attr(feature = "aws", serde(rename = "CANCELLING"))]
91 Cancelling,
92 #[cfg_attr(feature = "aws", serde(rename = "CANCELLED"))]
93 Cancelled,
94}
95
96#[cfg(feature = "aws")]
98#[derive(Debug, Deserialize)]
99pub struct AWSTaskResponse {
100 pub quantum_task_arn: String,
102 pub status: AWSTaskStatus,
104 pub creation_time: String,
106 pub device_arn: String,
108 pub output_s3_bucket: String,
110 pub output_s3_key_prefix: String,
112 pub shots: usize,
114}
115
116#[cfg(not(feature = "aws"))]
117#[derive(Debug)]
118pub struct AWSTaskResponse {
119 pub quantum_task_arn: String,
121 pub status: AWSTaskStatus,
123}
124
125#[cfg(feature = "aws")]
127#[derive(Debug, Deserialize)]
128pub struct AWSTaskResult {
129 pub measurements: HashMap<String, usize>,
131 pub measurement_probabilities: HashMap<String, f64>,
133 pub shots: usize,
135 pub task_metadata: HashMap<String, serde_json::Value>,
137 pub additional_metadata: HashMap<String, serde_json::Value>,
139}
140
141#[cfg(not(feature = "aws"))]
142#[derive(Debug)]
143pub struct AWSTaskResult {
144 pub measurements: HashMap<String, usize>,
146 pub measurement_probabilities: HashMap<String, f64>,
148 pub shots: usize,
150}
151
152#[derive(Error, Debug)]
154pub enum AWSBraketError {
155 #[error("Authentication error: {0}")]
156 Authentication(String),
157
158 #[error("API error: {0}")]
159 API(String),
160
161 #[error("Device not available: {0}")]
162 DeviceUnavailable(String),
163
164 #[error("Circuit conversion error: {0}")]
165 CircuitConversion(String),
166
167 #[error("Task submission error: {0}")]
168 TaskSubmission(String),
169
170 #[error("Timeout waiting for task completion")]
171 Timeout,
172
173 #[error("S3 error: {0}")]
174 S3Error(String),
175}
176
177#[cfg(feature = "aws")]
179#[derive(Clone)]
180pub struct AWSBraketClient {
181 client: Client,
183 api_url: String,
185 region: String,
187 access_key: String,
189 secret_key: String,
191 s3_bucket: String,
193 s3_key_prefix: String,
195}
196
197#[cfg(not(feature = "aws"))]
198#[derive(Clone)]
199pub struct AWSBraketClient;
200
201#[cfg(feature = "aws")]
202impl AWSBraketClient {
203 pub fn new(
205 access_key: &str,
206 secret_key: &str,
207 region: Option<&str>,
208 s3_bucket: &str,
209 s3_key_prefix: Option<&str>,
210 ) -> DeviceResult<Self> {
211 let mut headers = header::HeaderMap::new();
212 headers.insert(
213 header::CONTENT_TYPE,
214 header::HeaderValue::from_static("application/json"),
215 );
216
217 let client = Client::builder()
218 .default_headers(headers)
219 .timeout(Duration::from_secs(30))
220 .build()
221 .map_err(|e| DeviceError::Connection(e.to_string()))?;
222
223 let region = region.unwrap_or(DEFAULT_REGION).to_string();
224 let api_url = AWS_BRAKET_API_URL.replace("{region}", ®ion);
225 let s3_key_prefix = s3_key_prefix.unwrap_or("quantrs").to_string();
226
227 Ok(Self {
228 client,
229 api_url,
230 region,
231 access_key: access_key.to_string(),
232 secret_key: secret_key.to_string(),
233 s3_bucket: s3_bucket.to_string(),
234 s3_key_prefix,
235 })
236 }
237
238 fn generate_aws_v4_signature(
240 &self,
241 request_method: &str,
242 path: &str,
243 body: &str,
244 ) -> reqwest::header::HeaderMap {
245 use crate::aws_auth::{AwsRegion, AwsSignatureV4};
246 use chrono::Utc;
247
248 let mut headers = reqwest::header::HeaderMap::new();
249
250 headers.insert(
252 reqwest::header::CONTENT_TYPE,
253 reqwest::header::HeaderValue::from_static("application/json"),
254 );
255
256 let host = format!("braket.{}.amazonaws.com", self.region);
257 headers.insert(
258 reqwest::header::HOST,
259 reqwest::header::HeaderValue::from_str(&host)
260 .expect("AWS region contains invalid header characters"),
261 );
262
263 let now = Utc::now();
264 headers.insert(
265 reqwest::header::HeaderName::from_static("x-amz-date"),
266 reqwest::header::HeaderValue::from_str(&now.format("%Y%m%dT%H%M%SZ").to_string())
267 .expect("Date format produces valid header value"),
268 );
269
270 let region = AwsRegion {
272 name: self.region.clone(),
273 service: "braket".to_string(),
274 };
275
276 AwsSignatureV4::sign_request(
278 request_method,
279 path,
280 "", &mut headers,
282 body.as_bytes(),
283 &self.access_key,
284 &self.secret_key,
285 ®ion,
286 &now,
287 );
288
289 headers
290 }
291
292 pub async fn list_devices(&self) -> DeviceResult<Vec<AWSDevice>> {
294 let path = "/devices";
295 let url = format!("{}{}", self.api_url, path);
296 let body = "{}";
297
298 let headers = self.generate_aws_v4_signature("GET", path, body);
299
300 let mut request = self.client.get(&url);
301 for (key, value) in headers.iter() {
302 request = request.header(key, value);
303 }
304
305 let response = request
306 .send()
307 .await
308 .map_err(|e| DeviceError::Connection(e.to_string()))?;
309
310 if !response.status().is_success() {
311 let error_msg = response
312 .text()
313 .await
314 .unwrap_or_else(|_| "Unknown error".to_string());
315 return Err(DeviceError::APIError(error_msg));
316 }
317
318 let devices: Vec<AWSDevice> = response
319 .json()
320 .await
321 .map_err(|e| DeviceError::Deserialization(e.to_string()))?;
322
323 Ok(devices)
324 }
325
326 pub async fn get_device(&self, device_arn: &str) -> DeviceResult<AWSDevice> {
328 let path = format!("/device/{}", device_arn);
329 let url = format!("{}{}", self.api_url, path);
330 let body = "{}";
331
332 let headers = self.generate_aws_v4_signature("GET", &path, body);
333
334 let mut request = self.client.get(&url);
335 for (key, value) in headers.iter() {
336 request = request.header(key, value);
337 }
338
339 let response = request
340 .send()
341 .await
342 .map_err(|e| DeviceError::Connection(e.to_string()))?;
343
344 if !response.status().is_success() {
345 let error_msg = response
346 .text()
347 .await
348 .unwrap_or_else(|_| "Unknown error".to_string());
349 return Err(DeviceError::APIError(error_msg));
350 }
351
352 let device: AWSDevice = response
353 .json()
354 .await
355 .map_err(|e| DeviceError::Deserialization(e.to_string()))?;
356
357 Ok(device)
358 }
359
360 pub async fn submit_circuit(
362 &self,
363 device_arn: &str,
364 config: AWSCircuitConfig,
365 ) -> DeviceResult<String> {
366 let path = "/quantum-task";
367 let url = format!("{}{}", self.api_url, path);
368
369 use serde_json::json;
370
371 let payload = json!({
372 "action": config.ir,
373 "deviceArn": device_arn,
374 "shots": config.shots,
375 "outputS3Bucket": config.s3_bucket,
376 "outputS3KeyPrefix": config.s3_key_prefix,
377 "deviceParameters": config.device_parameters,
378 "name": config.name,
379 "irType": config.ir_type
380 });
381
382 let body = payload.to_string();
383 let headers = self.generate_aws_v4_signature("POST", path, &body);
384
385 let mut request = self.client.post(&url);
386 for (key, value) in headers.iter() {
387 request = request.header(key, value);
388 }
389
390 let response = request
391 .json(&payload)
392 .send()
393 .await
394 .map_err(|e| DeviceError::Connection(e.to_string()))?;
395
396 if !response.status().is_success() {
397 let error_msg = response
398 .text()
399 .await
400 .unwrap_or_else(|_| "Unknown error".to_string());
401 return Err(DeviceError::JobSubmission(error_msg));
402 }
403
404 let task_response: AWSTaskResponse = response
405 .json()
406 .await
407 .map_err(|e| DeviceError::Deserialization(e.to_string()))?;
408
409 Ok(task_response.quantum_task_arn)
410 }
411
412 pub async fn get_task_status(&self, task_arn: &str) -> DeviceResult<AWSTaskStatus> {
414 let path = format!("/quantum-task/{}", task_arn);
415 let url = format!("{}{}", self.api_url, path);
416 let body = "{}";
417
418 let headers = self.generate_aws_v4_signature("GET", &path, body);
419
420 let mut request = self.client.get(&url);
421 for (key, value) in headers.iter() {
422 request = request.header(key, value);
423 }
424
425 let response = request
426 .send()
427 .await
428 .map_err(|e| DeviceError::Connection(e.to_string()))?;
429
430 if !response.status().is_success() {
431 let error_msg = response
432 .text()
433 .await
434 .unwrap_or_else(|_| "Unknown error".to_string());
435 return Err(DeviceError::APIError(error_msg));
436 }
437
438 let task: AWSTaskResponse = response
439 .json()
440 .await
441 .map_err(|e| DeviceError::Deserialization(e.to_string()))?;
442
443 Ok(task.status)
444 }
445
446 pub async fn get_task_result(&self, task_arn: &str) -> DeviceResult<AWSTaskResult> {
448 let status = self.get_task_status(task_arn).await?;
453
454 if status != AWSTaskStatus::Completed {
455 return Err(DeviceError::JobExecution(format!(
456 "Task {} is not completed, current status: {:?}",
457 task_arn, status
458 )));
459 }
460
461 let dummy_result = AWSTaskResult {
466 measurements: HashMap::new(),
467 measurement_probabilities: HashMap::new(),
468 shots: 0,
469 task_metadata: HashMap::new(),
470 additional_metadata: HashMap::new(),
471 };
472
473 Ok(dummy_result)
474 }
475
476 pub async fn wait_for_task(
478 &self,
479 task_arn: &str,
480 timeout_secs: Option<u64>,
481 ) -> DeviceResult<AWSTaskResult> {
482 let timeout = timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
483 let mut elapsed = 0;
484 let interval = 5; while elapsed < timeout {
487 let status = self.get_task_status(task_arn).await?;
488
489 match status {
490 AWSTaskStatus::Completed => {
491 return self.get_task_result(task_arn).await;
492 }
493 AWSTaskStatus::Failed => {
494 return Err(DeviceError::JobExecution(format!(
495 "Task {} failed",
496 task_arn
497 )));
498 }
499 AWSTaskStatus::Cancelled => {
500 return Err(DeviceError::JobExecution(format!(
501 "Task {} was cancelled",
502 task_arn
503 )));
504 }
505 _ => {
506 sleep(Duration::from_secs(interval));
508 elapsed += interval;
509 }
510 }
511 }
512
513 Err(DeviceError::Timeout(format!(
514 "Timed out waiting for task {} to complete",
515 task_arn
516 )))
517 }
518
519 pub async fn submit_circuits_parallel(
521 &self,
522 device_arn: &str,
523 configs: Vec<AWSCircuitConfig>,
524 ) -> DeviceResult<Vec<String>> {
525 use tokio::task;
526
527 let client = Arc::new(self.clone());
528
529 let mut handles = vec![];
530
531 for config in configs {
532 let client_clone = client.clone();
533 let device_arn = device_arn.to_string();
534
535 let handle =
536 task::spawn(async move { client_clone.submit_circuit(&device_arn, config).await });
537
538 handles.push(handle);
539 }
540
541 let mut task_arns = vec![];
542
543 for handle in handles {
544 match handle.await {
545 Ok(result) => match result {
546 Ok(task_arn) => task_arns.push(task_arn),
547 Err(e) => return Err(e),
548 },
549 Err(e) => {
550 return Err(DeviceError::JobSubmission(format!(
551 "Failed to join task: {}",
552 e
553 )));
554 }
555 }
556 }
557
558 Ok(task_arns)
559 }
560
561 pub fn circuit_to_braket_ir<const N: usize>(circuit: &Circuit<N>) -> DeviceResult<String> {
563 use crate::aws_conversion;
564 aws_conversion::circuit_to_braket_ir(circuit)
565 }
566
567 pub fn circuit_to_qasm<const N: usize>(circuit: &Circuit<N>) -> DeviceResult<String> {
569 use crate::aws_conversion;
570 aws_conversion::circuit_to_qasm(circuit)
571 }
572}
573
574#[cfg(not(feature = "aws"))]
575impl AWSBraketClient {
576 pub fn new(
577 _access_key: &str,
578 _secret_key: &str,
579 _region: Option<&str>,
580 _s3_bucket: &str,
581 _s3_key_prefix: Option<&str>,
582 ) -> DeviceResult<Self> {
583 Err(DeviceError::UnsupportedDevice(
584 "AWS Braket support not enabled. Recompile with the 'aws' feature.".to_string(),
585 ))
586 }
587
588 pub async fn list_devices(&self) -> DeviceResult<Vec<AWSDevice>> {
589 Err(DeviceError::UnsupportedDevice(
590 "AWS Braket support not enabled".to_string(),
591 ))
592 }
593
594 pub async fn get_device(&self, _device_arn: &str) -> DeviceResult<AWSDevice> {
595 Err(DeviceError::UnsupportedDevice(
596 "AWS Braket support not enabled".to_string(),
597 ))
598 }
599
600 pub async fn submit_circuit(
601 &self,
602 _device_arn: &str,
603 _config: AWSCircuitConfig,
604 ) -> DeviceResult<String> {
605 Err(DeviceError::UnsupportedDevice(
606 "AWS Braket support not enabled".to_string(),
607 ))
608 }
609
610 pub async fn get_task_status(&self, _task_arn: &str) -> DeviceResult<AWSTaskStatus> {
611 Err(DeviceError::UnsupportedDevice(
612 "AWS Braket support not enabled".to_string(),
613 ))
614 }
615
616 pub async fn get_task_result(&self, _task_arn: &str) -> DeviceResult<AWSTaskResult> {
617 Err(DeviceError::UnsupportedDevice(
618 "AWS Braket support not enabled".to_string(),
619 ))
620 }
621
622 pub async fn wait_for_task(
623 &self,
624 _task_arn: &str,
625 _timeout_secs: Option<u64>,
626 ) -> DeviceResult<AWSTaskResult> {
627 Err(DeviceError::UnsupportedDevice(
628 "AWS Braket support not enabled".to_string(),
629 ))
630 }
631
632 pub async fn submit_circuits_parallel(
633 &self,
634 _device_arn: &str,
635 _configs: Vec<AWSCircuitConfig>,
636 ) -> DeviceResult<Vec<String>> {
637 Err(DeviceError::UnsupportedDevice(
638 "AWS Braket support not enabled".to_string(),
639 ))
640 }
641
642 pub fn circuit_to_braket_ir<const N: usize>(_circuit: &Circuit<N>) -> DeviceResult<String> {
643 Err(DeviceError::UnsupportedDevice(
644 "AWS Braket support not enabled".to_string(),
645 ))
646 }
647
648 pub fn circuit_to_qasm<const N: usize>(_circuit: &Circuit<N>) -> DeviceResult<String> {
649 Err(DeviceError::UnsupportedDevice(
650 "AWS Braket support not enabled".to_string(),
651 ))
652 }
653}