quantrs2_device/
aws.rs

1use quantrs2_circuit::prelude::Circuit;
2use std::collections::HashMap;
3#[cfg(feature = "aws")]
4use std::sync::Arc;
5#[cfg(feature = "aws")]
6use std::thread::sleep;
7#[cfg(feature = "aws")]
8use std::time::Duration;
9
10#[cfg(feature = "aws")]
11use reqwest::{header, Client};
12#[cfg(feature = "aws")]
13use serde::{Deserialize, Serialize};
14#[cfg(feature = "aws")]
15use serde_json;
16use thiserror::Error;
17
18use crate::DeviceError;
19use crate::DeviceResult;
20
21#[cfg(feature = "aws")]
22const AWS_BRAKET_API_URL: &str = "https://braket.{region}.amazonaws.com";
23#[cfg(feature = "aws")]
24const DEFAULT_TIMEOUT_SECS: u64 = 120;
25#[cfg(feature = "aws")]
26const DEFAULT_REGION: &str = "us-east-1";
27
28/// Represents the available devices on AWS Braket
29#[derive(Debug, Clone)]
30#[cfg_attr(feature = "aws", derive(serde::Deserialize))]
31pub struct AWSDevice {
32    /// Device ARN
33    pub device_arn: String,
34    /// Name of the device
35    pub name: String,
36    /// Type of device (QPU or SIMULATOR)
37    pub device_type: String,
38    /// Device provider
39    pub provider_name: String,
40    /// Status of the device
41    pub status: String,
42    /// Number of qubits on the device
43    pub num_qubits: usize,
44    /// Device capabilities
45    #[cfg(feature = "aws")]
46    pub device_capabilities: serde_json::Value,
47    /// Device properties
48    #[cfg(not(feature = "aws"))]
49    pub device_capabilities: (),
50}
51
52/// Configuration for a quantum circuit to be submitted to AWS Braket
53#[derive(Debug, Clone)]
54#[cfg_attr(feature = "aws", derive(Serialize))]
55pub struct AWSCircuitConfig {
56    /// Name of the job/task
57    pub name: String,
58    /// AWS Braket IR (ABIR) representation of the circuit
59    pub ir: String,
60    /// Type of IR (e.g., "OPENQASM", "BRAKET")
61    pub ir_type: String,
62    /// Number of shots to run
63    pub shots: usize,
64    /// AWS S3 bucket for results
65    pub s3_bucket: String,
66    /// AWS S3 key prefix for results
67    pub s3_key_prefix: String,
68    /// Device-specific parameters
69    #[cfg(feature = "aws")]
70    pub device_parameters: Option<serde_json::Value>,
71    /// Device-specific parameters
72    #[cfg(not(feature = "aws"))]
73    pub device_parameters: Option<()>,
74}
75
76/// Status of a task in AWS Braket
77#[derive(Debug, Clone, PartialEq, Eq)]
78#[cfg_attr(feature = "aws", derive(Deserialize))]
79pub enum AWSTaskStatus {
80    #[cfg_attr(feature = "aws", serde(rename = "CREATED"))]
81    Created,
82    #[cfg_attr(feature = "aws", serde(rename = "QUEUED"))]
83    Queued,
84    #[cfg_attr(feature = "aws", serde(rename = "RUNNING"))]
85    Running,
86    #[cfg_attr(feature = "aws", serde(rename = "COMPLETED"))]
87    Completed,
88    #[cfg_attr(feature = "aws", serde(rename = "FAILED"))]
89    Failed,
90    #[cfg_attr(feature = "aws", serde(rename = "CANCELLING"))]
91    Cancelling,
92    #[cfg_attr(feature = "aws", serde(rename = "CANCELLED"))]
93    Cancelled,
94}
95
96/// Response from submitting a task to AWS Braket
97#[cfg(feature = "aws")]
98#[derive(Debug, Deserialize)]
99pub struct AWSTaskResponse {
100    /// Task ARN
101    pub quantum_task_arn: String,
102    /// Status of the task
103    pub status: AWSTaskStatus,
104    /// Creation time
105    pub creation_time: String,
106    /// Device ARN
107    pub device_arn: String,
108    /// S3 bucket for results
109    pub output_s3_bucket: String,
110    /// S3 key prefix for results
111    pub output_s3_key_prefix: String,
112    /// Shots
113    pub shots: usize,
114}
115
116#[cfg(not(feature = "aws"))]
117#[derive(Debug)]
118pub struct AWSTaskResponse {
119    /// Task ARN
120    pub quantum_task_arn: String,
121    /// Status of the task
122    pub status: AWSTaskStatus,
123}
124
125/// Results from a completed task
126#[cfg(feature = "aws")]
127#[derive(Debug, Deserialize)]
128pub struct AWSTaskResult {
129    /// Measurement counts
130    pub measurements: HashMap<String, usize>,
131    /// Measurement probabilities
132    pub measurement_probabilities: HashMap<String, f64>,
133    /// Number of shots
134    pub shots: usize,
135    /// Task metadata
136    pub task_metadata: HashMap<String, serde_json::Value>,
137    /// Additional results
138    pub additional_metadata: HashMap<String, serde_json::Value>,
139}
140
141#[cfg(not(feature = "aws"))]
142#[derive(Debug)]
143pub struct AWSTaskResult {
144    /// Measurement counts
145    pub measurements: HashMap<String, usize>,
146    /// Measurement probabilities
147    pub measurement_probabilities: HashMap<String, f64>,
148    /// Number of shots
149    pub shots: usize,
150}
151
152/// Errors specific to AWS Braket
153#[derive(Error, Debug)]
154pub enum AWSBraketError {
155    #[error("Authentication error: {0}")]
156    Authentication(String),
157
158    #[error("API error: {0}")]
159    API(String),
160
161    #[error("Device not available: {0}")]
162    DeviceUnavailable(String),
163
164    #[error("Circuit conversion error: {0}")]
165    CircuitConversion(String),
166
167    #[error("Task submission error: {0}")]
168    TaskSubmission(String),
169
170    #[error("Timeout waiting for task completion")]
171    Timeout,
172
173    #[error("S3 error: {0}")]
174    S3Error(String),
175}
176
177/// Client for interacting with AWS Braket
178#[cfg(feature = "aws")]
179#[derive(Clone)]
180pub struct AWSBraketClient {
181    /// HTTP client for making API requests
182    client: Client,
183    /// Base URL for the AWS Braket API
184    api_url: String,
185    /// AWS region
186    region: String,
187    /// AWS access key
188    access_key: String,
189    /// AWS secret key
190    secret_key: String,
191    /// AWS S3 bucket for results
192    s3_bucket: String,
193    /// AWS S3 key prefix for results
194    s3_key_prefix: String,
195}
196
197#[cfg(not(feature = "aws"))]
198#[derive(Clone)]
199pub struct AWSBraketClient;
200
201#[cfg(feature = "aws")]
202impl AWSBraketClient {
203    /// Create a new AWS Braket client with the given credentials
204    pub fn new(
205        access_key: &str,
206        secret_key: &str,
207        region: Option<&str>,
208        s3_bucket: &str,
209        s3_key_prefix: Option<&str>,
210    ) -> DeviceResult<Self> {
211        let mut headers = header::HeaderMap::new();
212        headers.insert(
213            header::CONTENT_TYPE,
214            header::HeaderValue::from_static("application/json"),
215        );
216
217        let client = Client::builder()
218            .default_headers(headers)
219            .timeout(Duration::from_secs(30))
220            .build()
221            .map_err(|e| DeviceError::Connection(e.to_string()))?;
222
223        let region = region.unwrap_or(DEFAULT_REGION).to_string();
224        let api_url = AWS_BRAKET_API_URL.replace("{region}", &region);
225        let s3_key_prefix = s3_key_prefix.unwrap_or("quantrs").to_string();
226
227        Ok(Self {
228            client,
229            api_url,
230            region,
231            access_key: access_key.to_string(),
232            secret_key: secret_key.to_string(),
233            s3_bucket: s3_bucket.to_string(),
234            s3_key_prefix,
235        })
236    }
237
238    /// Generate AWS signature for API requests
239    fn generate_aws_v4_signature(
240        &self,
241        request_method: &str,
242        path: &str,
243        body: &str,
244    ) -> reqwest::header::HeaderMap {
245        use crate::aws_auth::{AwsRegion, AwsSignatureV4};
246        use chrono::Utc;
247
248        let mut headers = reqwest::header::HeaderMap::new();
249
250        // Add required headers
251        headers.insert(
252            reqwest::header::CONTENT_TYPE,
253            reqwest::header::HeaderValue::from_static("application/json"),
254        );
255
256        let host = format!("braket.{}.amazonaws.com", self.region);
257        headers.insert(
258            reqwest::header::HOST,
259            reqwest::header::HeaderValue::from_str(&host)
260                .expect("AWS region contains invalid header characters"),
261        );
262
263        let now = Utc::now();
264        headers.insert(
265            reqwest::header::HeaderName::from_static("x-amz-date"),
266            reqwest::header::HeaderValue::from_str(&now.format("%Y%m%dT%H%M%SZ").to_string())
267                .expect("Date format produces valid header value"),
268        );
269
270        // Create region information
271        let region = AwsRegion {
272            name: self.region.clone(),
273            service: "braket".to_string(),
274        };
275
276        // Sign the request
277        AwsSignatureV4::sign_request(
278            request_method,
279            path,
280            "", // No query string
281            &mut headers,
282            body.as_bytes(),
283            &self.access_key,
284            &self.secret_key,
285            &region,
286            &now,
287        );
288
289        headers
290    }
291
292    /// List all available devices
293    pub async fn list_devices(&self) -> DeviceResult<Vec<AWSDevice>> {
294        let path = "/devices";
295        let url = format!("{}{}", self.api_url, path);
296        let body = "{}";
297
298        let headers = self.generate_aws_v4_signature("GET", path, body);
299
300        let mut request = self.client.get(&url);
301        for (key, value) in headers.iter() {
302            request = request.header(key, value);
303        }
304
305        let response = request
306            .send()
307            .await
308            .map_err(|e| DeviceError::Connection(e.to_string()))?;
309
310        if !response.status().is_success() {
311            let error_msg = response
312                .text()
313                .await
314                .unwrap_or_else(|_| "Unknown error".to_string());
315            return Err(DeviceError::APIError(error_msg));
316        }
317
318        let devices: Vec<AWSDevice> = response
319            .json()
320            .await
321            .map_err(|e| DeviceError::Deserialization(e.to_string()))?;
322
323        Ok(devices)
324    }
325
326    /// Get details about a specific device
327    pub async fn get_device(&self, device_arn: &str) -> DeviceResult<AWSDevice> {
328        let path = format!("/device/{}", device_arn);
329        let url = format!("{}{}", self.api_url, path);
330        let body = "{}";
331
332        let headers = self.generate_aws_v4_signature("GET", &path, body);
333
334        let mut request = self.client.get(&url);
335        for (key, value) in headers.iter() {
336            request = request.header(key, value);
337        }
338
339        let response = request
340            .send()
341            .await
342            .map_err(|e| DeviceError::Connection(e.to_string()))?;
343
344        if !response.status().is_success() {
345            let error_msg = response
346                .text()
347                .await
348                .unwrap_or_else(|_| "Unknown error".to_string());
349            return Err(DeviceError::APIError(error_msg));
350        }
351
352        let device: AWSDevice = response
353            .json()
354            .await
355            .map_err(|e| DeviceError::Deserialization(e.to_string()))?;
356
357        Ok(device)
358    }
359
360    /// Submit a circuit to be executed on an AWS Braket device
361    pub async fn submit_circuit(
362        &self,
363        device_arn: &str,
364        config: AWSCircuitConfig,
365    ) -> DeviceResult<String> {
366        let path = "/quantum-task";
367        let url = format!("{}{}", self.api_url, path);
368
369        use serde_json::json;
370
371        let payload = json!({
372            "action": config.ir,
373            "deviceArn": device_arn,
374            "shots": config.shots,
375            "outputS3Bucket": config.s3_bucket,
376            "outputS3KeyPrefix": config.s3_key_prefix,
377            "deviceParameters": config.device_parameters,
378            "name": config.name,
379            "irType": config.ir_type
380        });
381
382        let body = payload.to_string();
383        let headers = self.generate_aws_v4_signature("POST", path, &body);
384
385        let mut request = self.client.post(&url);
386        for (key, value) in headers.iter() {
387            request = request.header(key, value);
388        }
389
390        let response = request
391            .json(&payload)
392            .send()
393            .await
394            .map_err(|e| DeviceError::Connection(e.to_string()))?;
395
396        if !response.status().is_success() {
397            let error_msg = response
398                .text()
399                .await
400                .unwrap_or_else(|_| "Unknown error".to_string());
401            return Err(DeviceError::JobSubmission(error_msg));
402        }
403
404        let task_response: AWSTaskResponse = response
405            .json()
406            .await
407            .map_err(|e| DeviceError::Deserialization(e.to_string()))?;
408
409        Ok(task_response.quantum_task_arn)
410    }
411
412    /// Get the status of a task
413    pub async fn get_task_status(&self, task_arn: &str) -> DeviceResult<AWSTaskStatus> {
414        let path = format!("/quantum-task/{}", task_arn);
415        let url = format!("{}{}", self.api_url, path);
416        let body = "{}";
417
418        let headers = self.generate_aws_v4_signature("GET", &path, body);
419
420        let mut request = self.client.get(&url);
421        for (key, value) in headers.iter() {
422            request = request.header(key, value);
423        }
424
425        let response = request
426            .send()
427            .await
428            .map_err(|e| DeviceError::Connection(e.to_string()))?;
429
430        if !response.status().is_success() {
431            let error_msg = response
432                .text()
433                .await
434                .unwrap_or_else(|_| "Unknown error".to_string());
435            return Err(DeviceError::APIError(error_msg));
436        }
437
438        let task: AWSTaskResponse = response
439            .json()
440            .await
441            .map_err(|e| DeviceError::Deserialization(e.to_string()))?;
442
443        Ok(task.status)
444    }
445
446    /// Get the results of a completed task
447    pub async fn get_task_result(&self, task_arn: &str) -> DeviceResult<AWSTaskResult> {
448        // For AWS Braket, we need to:
449        // 1. Get the task status to confirm it's completed
450        // 2. Fetch the results from S3
451
452        let status = self.get_task_status(task_arn).await?;
453
454        if status != AWSTaskStatus::Completed {
455            return Err(DeviceError::JobExecution(format!(
456                "Task {} is not completed, current status: {:?}",
457                task_arn, status
458            )));
459        }
460
461        // In a real implementation, this would fetch the result from S3
462        // For now, this is a placeholder
463        // The actual S3 fetching would use the aws-sdk-s3 crate
464
465        let dummy_result = AWSTaskResult {
466            measurements: HashMap::new(),
467            measurement_probabilities: HashMap::new(),
468            shots: 0,
469            task_metadata: HashMap::new(),
470            additional_metadata: HashMap::new(),
471        };
472
473        Ok(dummy_result)
474    }
475
476    /// Wait for a task to complete with timeout
477    pub async fn wait_for_task(
478        &self,
479        task_arn: &str,
480        timeout_secs: Option<u64>,
481    ) -> DeviceResult<AWSTaskResult> {
482        let timeout = timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
483        let mut elapsed = 0;
484        let interval = 5; // Check status every 5 seconds
485
486        while elapsed < timeout {
487            let status = self.get_task_status(task_arn).await?;
488
489            match status {
490                AWSTaskStatus::Completed => {
491                    return self.get_task_result(task_arn).await;
492                }
493                AWSTaskStatus::Failed => {
494                    return Err(DeviceError::JobExecution(format!(
495                        "Task {} failed",
496                        task_arn
497                    )));
498                }
499                AWSTaskStatus::Cancelled => {
500                    return Err(DeviceError::JobExecution(format!(
501                        "Task {} was cancelled",
502                        task_arn
503                    )));
504                }
505                _ => {
506                    // Still in progress, wait and check again
507                    sleep(Duration::from_secs(interval));
508                    elapsed += interval;
509                }
510            }
511        }
512
513        Err(DeviceError::Timeout(format!(
514            "Timed out waiting for task {} to complete",
515            task_arn
516        )))
517    }
518
519    /// Submit multiple circuits in parallel
520    pub async fn submit_circuits_parallel(
521        &self,
522        device_arn: &str,
523        configs: Vec<AWSCircuitConfig>,
524    ) -> DeviceResult<Vec<String>> {
525        use tokio::task;
526
527        let client = Arc::new(self.clone());
528
529        let mut handles = vec![];
530
531        for config in configs {
532            let client_clone = client.clone();
533            let device_arn = device_arn.to_string();
534
535            let handle =
536                task::spawn(async move { client_clone.submit_circuit(&device_arn, config).await });
537
538            handles.push(handle);
539        }
540
541        let mut task_arns = vec![];
542
543        for handle in handles {
544            match handle.await {
545                Ok(result) => match result {
546                    Ok(task_arn) => task_arns.push(task_arn),
547                    Err(e) => return Err(e),
548                },
549                Err(e) => {
550                    return Err(DeviceError::JobSubmission(format!(
551                        "Failed to join task: {}",
552                        e
553                    )));
554                }
555            }
556        }
557
558        Ok(task_arns)
559    }
560
561    /// Convert a Quantrs circuit to Braket IR JSON
562    pub fn circuit_to_braket_ir<const N: usize>(circuit: &Circuit<N>) -> DeviceResult<String> {
563        use crate::aws_conversion;
564        aws_conversion::circuit_to_braket_ir(circuit)
565    }
566
567    /// Convert a Quantrs circuit to OpenQASM
568    pub fn circuit_to_qasm<const N: usize>(circuit: &Circuit<N>) -> DeviceResult<String> {
569        use crate::aws_conversion;
570        aws_conversion::circuit_to_qasm(circuit)
571    }
572}
573
574#[cfg(not(feature = "aws"))]
575impl AWSBraketClient {
576    pub fn new(
577        _access_key: &str,
578        _secret_key: &str,
579        _region: Option<&str>,
580        _s3_bucket: &str,
581        _s3_key_prefix: Option<&str>,
582    ) -> DeviceResult<Self> {
583        Err(DeviceError::UnsupportedDevice(
584            "AWS Braket support not enabled. Recompile with the 'aws' feature.".to_string(),
585        ))
586    }
587
588    pub async fn list_devices(&self) -> DeviceResult<Vec<AWSDevice>> {
589        Err(DeviceError::UnsupportedDevice(
590            "AWS Braket support not enabled".to_string(),
591        ))
592    }
593
594    pub async fn get_device(&self, _device_arn: &str) -> DeviceResult<AWSDevice> {
595        Err(DeviceError::UnsupportedDevice(
596            "AWS Braket support not enabled".to_string(),
597        ))
598    }
599
600    pub async fn submit_circuit(
601        &self,
602        _device_arn: &str,
603        _config: AWSCircuitConfig,
604    ) -> DeviceResult<String> {
605        Err(DeviceError::UnsupportedDevice(
606            "AWS Braket support not enabled".to_string(),
607        ))
608    }
609
610    pub async fn get_task_status(&self, _task_arn: &str) -> DeviceResult<AWSTaskStatus> {
611        Err(DeviceError::UnsupportedDevice(
612            "AWS Braket support not enabled".to_string(),
613        ))
614    }
615
616    pub async fn get_task_result(&self, _task_arn: &str) -> DeviceResult<AWSTaskResult> {
617        Err(DeviceError::UnsupportedDevice(
618            "AWS Braket support not enabled".to_string(),
619        ))
620    }
621
622    pub async fn wait_for_task(
623        &self,
624        _task_arn: &str,
625        _timeout_secs: Option<u64>,
626    ) -> DeviceResult<AWSTaskResult> {
627        Err(DeviceError::UnsupportedDevice(
628            "AWS Braket support not enabled".to_string(),
629        ))
630    }
631
632    pub async fn submit_circuits_parallel(
633        &self,
634        _device_arn: &str,
635        _configs: Vec<AWSCircuitConfig>,
636    ) -> DeviceResult<Vec<String>> {
637        Err(DeviceError::UnsupportedDevice(
638            "AWS Braket support not enabled".to_string(),
639        ))
640    }
641
642    pub fn circuit_to_braket_ir<const N: usize>(_circuit: &Circuit<N>) -> DeviceResult<String> {
643        Err(DeviceError::UnsupportedDevice(
644            "AWS Braket support not enabled".to_string(),
645        ))
646    }
647
648    pub fn circuit_to_qasm<const N: usize>(_circuit: &Circuit<N>) -> DeviceResult<String> {
649        Err(DeviceError::UnsupportedDevice(
650            "AWS Braket support not enabled".to_string(),
651        ))
652    }
653}