quantrs2_device/
aws_device.rs

1#[cfg(feature = "aws")]
2use serde_json;
3#[cfg(feature = "aws")]
4use std::collections::HashMap;
5#[cfg(feature = "aws")]
6use std::sync::Arc;
7#[cfg(feature = "aws")]
8use std::time::Duration;
9
10#[cfg(feature = "aws")]
11use async_trait::async_trait;
12#[cfg(feature = "aws")]
13use tokio::sync::RwLock;
14
15#[cfg(feature = "aws")]
16use crate::aws::{AWSBraketClient, AWSCircuitConfig, AWSDevice};
17use crate::DeviceError;
18use crate::DeviceResult;
19#[cfg(feature = "aws")]
20use crate::{CircuitExecutor, CircuitResult, QuantumDevice};
21#[cfg(feature = "aws")]
22use quantrs2_circuit::prelude::Circuit;
23
24/// Configuration for an AWS Braket device
25#[derive(Debug, Clone)]
26pub struct AWSDeviceConfig {
27    /// Number of shots to run for each circuit
28    pub default_shots: usize,
29    /// IR type to use (OPENQASM or BRAKET)
30    pub ir_type: String,
31    /// Device-specific parameters
32    #[cfg(feature = "aws")]
33    pub device_parameters: Option<serde_json::Value>,
34    #[cfg(not(feature = "aws"))]
35    pub device_parameters: Option<()>,
36    /// Timeout for task completion (in seconds)
37    pub timeout_secs: Option<u64>,
38}
39
40impl Default for AWSDeviceConfig {
41    fn default() -> Self {
42        Self {
43            default_shots: 1000,
44            ir_type: "BRAKET".to_string(),
45            device_parameters: None,
46            timeout_secs: None,
47        }
48    }
49}
50
51/// AWS Braket device implementation
52#[cfg(feature = "aws")]
53pub struct AWSBraketDevice {
54    /// AWS Braket client
55    client: Arc<AWSBraketClient>,
56    /// Device ARN
57    device_arn: String,
58    /// Configuration
59    config: AWSDeviceConfig,
60    /// Cached device information
61    device_cache: Arc<RwLock<Option<AWSDevice>>>,
62}
63
64#[cfg(feature = "aws")]
65impl AWSBraketDevice {
66    /// Create a new AWS Braket device instance
67    pub async fn new(
68        client: AWSBraketClient,
69        device_arn: &str,
70        config: Option<AWSDeviceConfig>,
71    ) -> DeviceResult<Self> {
72        let client = Arc::new(client);
73        let device_cache = Arc::new(RwLock::new(None));
74
75        // Get device details to validate
76        let device = client.get_device(device_arn).await?;
77
78        // Create and cache the device
79        let mut cache = device_cache.write().await;
80        *cache = Some(device);
81
82        let config = config.unwrap_or_default();
83
84        Ok(Self {
85            client,
86            device_arn: device_arn.to_string(),
87            config,
88            device_cache: Arc::clone(&device_cache),
89        })
90    }
91
92    /// Get cached device information, fetching if necessary
93    async fn get_device(&self) -> DeviceResult<AWSDevice> {
94        let cache = self.device_cache.read().await;
95
96        if let Some(device) = cache.clone() {
97            return Ok(device);
98        }
99
100        // Cache miss, need to fetch
101        drop(cache);
102        let device = self.client.get_device(&self.device_arn).await?;
103
104        let mut cache = self.device_cache.write().await;
105        *cache = Some(device.clone());
106
107        Ok(device)
108    }
109}
110
111#[cfg(feature = "aws")]
112#[async_trait]
113impl QuantumDevice for AWSBraketDevice {
114    async fn is_available(&self) -> DeviceResult<bool> {
115        let device = self.get_device().await?;
116        Ok(device.status == "ONLINE")
117    }
118
119    async fn qubit_count(&self) -> DeviceResult<usize> {
120        let device = self.get_device().await?;
121        Ok(device.num_qubits)
122    }
123
124    async fn properties(&self) -> DeviceResult<HashMap<String, String>> {
125        let device = self.get_device().await?;
126
127        // Convert complex JSON capabilities to string representation
128        let mut properties = HashMap::new();
129
130        #[cfg(feature = "aws")]
131        {
132            if let serde_json::Value::Object(caps) = &device.device_capabilities {
133                for (key, value) in caps {
134                    properties.insert(key.clone(), value.to_string());
135                }
136            }
137        }
138
139        Ok(properties)
140    }
141
142    async fn is_simulator(&self) -> DeviceResult<bool> {
143        let device = self.get_device().await?;
144        Ok(device.device_type == "SIMULATOR")
145    }
146}
147
148#[cfg(feature = "aws")]
149#[async_trait]
150impl CircuitExecutor for AWSBraketDevice {
151    async fn execute_circuit<const N: usize>(
152        &self,
153        circuit: &Circuit<N>,
154        shots: usize,
155    ) -> DeviceResult<CircuitResult> {
156        // Check if circuit can be executed
157        if !self.can_execute_circuit(circuit).await? {
158            return Err(DeviceError::CircuitConversion(
159                "Circuit cannot be executed on this device".to_string(),
160            ));
161        }
162
163        // Convert circuit to the appropriate IR format
164        let circuit_str = match self.config.ir_type.as_str() {
165            "OPENQASM" => AWSBraketClient::circuit_to_qasm(circuit)?,
166            "BRAKET" => AWSBraketClient::circuit_to_braket_ir(circuit)?,
167            _ => {
168                return Err(DeviceError::CircuitConversion(format!(
169                    "Unsupported IR type: {}",
170                    self.config.ir_type
171                )))
172            }
173        };
174
175        // Create task config
176        let job_name = format!("quantrs_task_{}", chrono::Utc::now().timestamp());
177
178        let s3_bucket = "amazon-braket-examples"; // This would be the client's S3 bucket in reality
179        let s3_key_prefix = format!("quantrs-tasks/{}", job_name);
180
181        let config = AWSCircuitConfig {
182            name: job_name,
183            ir: circuit_str,
184            ir_type: self.config.ir_type.clone(),
185            shots: shots.max(1), // Ensure at least 1 shot
186            s3_bucket: s3_bucket.to_string(),
187            s3_key_prefix,
188            device_parameters: self.config.device_parameters.clone(),
189        };
190
191        // Submit task
192        let task_arn = self.client.submit_circuit(&self.device_arn, config).await?;
193
194        // Wait for completion
195        let result = self
196            .client
197            .wait_for_task(&task_arn, self.config.timeout_secs)
198            .await?;
199
200        // Convert result to CircuitResult
201        let mut counts = HashMap::new();
202        for (bitstring, count) in result.measurements {
203            counts.insert(bitstring, count);
204        }
205
206        let mut metadata = HashMap::new();
207        metadata.insert("task_arn".to_string(), task_arn);
208        metadata.insert("device_arn".to_string(), self.device_arn.clone());
209
210        Ok(CircuitResult {
211            counts,
212            shots,
213            metadata,
214        })
215    }
216
217    async fn execute_circuits<const N: usize>(
218        &self,
219        circuits: Vec<&Circuit<N>>,
220        shots: usize,
221    ) -> DeviceResult<Vec<CircuitResult>> {
222        let mut configs = Vec::with_capacity(circuits.len());
223
224        // Prepare all circuit configs
225        for (idx, circuit) in circuits.iter().enumerate() {
226            // Convert circuit to the appropriate IR format
227            let circuit_str = match self.config.ir_type.as_str() {
228                "OPENQASM" => AWSBraketClient::circuit_to_qasm(circuit)?,
229                "BRAKET" => AWSBraketClient::circuit_to_braket_ir(circuit)?,
230                _ => {
231                    return Err(DeviceError::CircuitConversion(format!(
232                        "Unsupported IR type: {}",
233                        self.config.ir_type
234                    )))
235                }
236            };
237
238            let job_name = format!(
239                "quantrs_batch_{}_task_{}",
240                chrono::Utc::now().timestamp(),
241                idx
242            );
243            let s3_bucket = "amazon-braket-examples"; // This would be the client's S3 bucket in reality
244            let s3_key_prefix = format!("quantrs-tasks/{}", job_name);
245
246            let config = AWSCircuitConfig {
247                name: job_name,
248                ir: circuit_str,
249                ir_type: self.config.ir_type.clone(),
250                shots: shots.max(1), // Ensure at least 1 shot
251                s3_bucket: s3_bucket.to_string(),
252                s3_key_prefix,
253                device_parameters: self.config.device_parameters.clone(),
254            };
255
256            configs.push(config);
257        }
258
259        // Submit all circuits in parallel
260        let task_arns = self
261            .client
262            .submit_circuits_parallel(&self.device_arn, configs)
263            .await?;
264
265        // Wait for all tasks to complete and collect results
266        let mut results = Vec::with_capacity(task_arns.len());
267        for task_arn in task_arns {
268            let result = self
269                .client
270                .wait_for_task(&task_arn, self.config.timeout_secs)
271                .await?;
272
273            let mut counts = HashMap::new();
274            for (bitstring, count) in result.measurements {
275                counts.insert(bitstring, count);
276            }
277
278            let mut metadata = HashMap::new();
279            metadata.insert("task_arn".to_string(), task_arn);
280            metadata.insert("device_arn".to_string(), self.device_arn.clone());
281
282            results.push(CircuitResult {
283                counts,
284                shots,
285                metadata,
286            });
287        }
288
289        Ok(results)
290    }
291
292    async fn can_execute_circuit<const N: usize>(
293        &self,
294        circuit: &Circuit<N>,
295    ) -> DeviceResult<bool> {
296        // Get device qubit count
297        let device_qubits = self.qubit_count().await?;
298
299        // Check if circuit qubit count exceeds device qubit count
300        if N > device_qubits {
301            return Ok(false);
302        }
303
304        // Check if the circuit can be converted to the specified IR format
305        match self.config.ir_type.as_str() {
306            "OPENQASM" => match AWSBraketClient::circuit_to_qasm(circuit) {
307                Ok(_) => Ok(true),
308                Err(_) => Ok(false),
309            },
310            "BRAKET" => match AWSBraketClient::circuit_to_braket_ir(circuit) {
311                Ok(_) => Ok(true),
312                Err(_) => Ok(false),
313            },
314            _ => Ok(false),
315        }
316    }
317
318    async fn estimated_queue_time<const N: usize>(
319        &self,
320        _circuit: &Circuit<N>,
321    ) -> DeviceResult<Duration> {
322        // AWS Braket doesn't provide queue time estimates in the API
323        // Return a conservative estimate based on device type
324        let is_sim = self.is_simulator().await?;
325
326        if is_sim {
327            // Simulators tend to have shorter queue times
328            Ok(Duration::from_secs(30)) // 30 seconds
329        } else {
330            // Hardware devices tend to have longer queue times
331            Ok(Duration::from_secs(600)) // 10 minutes
332        }
333    }
334}
335
336#[cfg(not(feature = "aws"))]
337pub struct AWSBraketDevice;
338
339#[cfg(not(feature = "aws"))]
340impl AWSBraketDevice {
341    pub async fn new(
342        _client: crate::aws::AWSBraketClient,
343        _device_arn: &str,
344        _config: Option<AWSDeviceConfig>,
345    ) -> DeviceResult<Self> {
346        Err(DeviceError::UnsupportedDevice(
347            "AWS Braket support not enabled. Recompile with the 'aws' feature.".to_string(),
348        ))
349    }
350}