quantrs2_tytan/sampler/hardware/
amazon_braket.rs1use scirs2_core::ndarray::{Array, Ix2};
7use scirs2_core::random::{thread_rng, Rng};
8use std::collections::HashMap;
9
10use quantrs2_anneal::QuboModel;
11
12use super::super::{SampleResult, Sampler, SamplerError, SamplerResult};
13
14#[derive(Debug, Clone)]
16pub enum BraketDevice {
17 LocalSimulator,
19 StateVectorSimulator,
21 TensorNetworkSimulator,
23 IonQDevice,
25 RigettiDevice(String),
27 OQCDevice,
29 DWaveAdvantage,
31 DWave2000Q,
33}
34
35#[derive(Debug, Clone)]
37pub struct AmazonBraketConfig {
38 pub region: String,
40 pub s3_bucket: String,
42 pub s3_prefix: String,
44 pub device: BraketDevice,
46 pub max_parallel: usize,
48 pub poll_interval: u64,
50}
51
52impl Default for AmazonBraketConfig {
53 fn default() -> Self {
54 Self {
55 region: "us-east-1".to_string(),
56 s3_bucket: String::new(),
57 s3_prefix: "braket-results".to_string(),
58 device: BraketDevice::LocalSimulator,
59 max_parallel: 10,
60 poll_interval: 5,
61 }
62 }
63}
64
65pub struct AmazonBraketSampler {
70 config: AmazonBraketConfig,
71}
72
73impl AmazonBraketSampler {
74 #[must_use]
80 pub const fn new(config: AmazonBraketConfig) -> Self {
81 Self { config }
82 }
83
84 #[must_use]
91 pub fn with_s3(s3_bucket: &str, region: &str) -> Self {
92 Self {
93 config: AmazonBraketConfig {
94 s3_bucket: s3_bucket.to_string(),
95 region: region.to_string(),
96 ..Default::default()
97 },
98 }
99 }
100
101 #[must_use]
103 pub fn with_device(mut self, device: BraketDevice) -> Self {
104 self.config.device = device;
105 self
106 }
107
108 #[must_use]
110 pub const fn with_max_parallel(mut self, max_parallel: usize) -> Self {
111 self.config.max_parallel = max_parallel;
112 self
113 }
114
115 #[must_use]
117 pub const fn with_poll_interval(mut self, interval: u64) -> Self {
118 self.config.poll_interval = interval;
119 self
120 }
121}
122
123impl Sampler for AmazonBraketSampler {
124 fn run_qubo(
125 &self,
126 qubo: &(Array<f64, Ix2>, HashMap<String, usize>),
127 shots: usize,
128 ) -> SamplerResult<Vec<SampleResult>> {
129 let (matrix, var_map) = qubo;
131
132 let n_vars = var_map.len();
134
135 match &self.config.device {
137 BraketDevice::LocalSimulator | BraketDevice::StateVectorSimulator => {
138 if n_vars > 34 {
139 return Err(SamplerError::InvalidParameter(
140 "State vector simulators support up to 34 qubits".to_string(),
141 ));
142 }
143 }
144 BraketDevice::TensorNetworkSimulator => {
145 if n_vars > 50 {
146 return Err(SamplerError::InvalidParameter(
147 "Tensor network simulator supports up to 50 qubits".to_string(),
148 ));
149 }
150 }
151 BraketDevice::IonQDevice => {
152 if n_vars > 29 {
153 return Err(SamplerError::InvalidParameter(
154 "IonQ device supports up to 29 qubits".to_string(),
155 ));
156 }
157 }
158 BraketDevice::RigettiDevice(_) => {
159 if n_vars > 40 {
160 return Err(SamplerError::InvalidParameter(
161 "Rigetti devices support up to 40 qubits".to_string(),
162 ));
163 }
164 }
165 BraketDevice::OQCDevice => {
166 if n_vars > 8 {
167 return Err(SamplerError::InvalidParameter(
168 "OQC device supports up to 8 qubits".to_string(),
169 ));
170 }
171 }
172 BraketDevice::DWaveAdvantage => {
173 if n_vars > 5000 {
174 return Err(SamplerError::InvalidParameter(
175 "D-Wave Advantage supports up to 5000 variables".to_string(),
176 ));
177 }
178 }
179 BraketDevice::DWave2000Q => {
180 if n_vars > 2000 {
181 return Err(SamplerError::InvalidParameter(
182 "D-Wave 2000Q supports up to 2000 variables".to_string(),
183 ));
184 }
185 }
186 }
187
188 let idx_to_var: HashMap<usize, String> = var_map
190 .iter()
191 .map(|(var, &idx)| (idx, var.clone()))
192 .collect();
193
194 let mut qubo_model = QuboModel::new(n_vars);
196
197 for i in 0..n_vars {
199 if matrix[[i, i]] != 0.0 {
200 qubo_model.set_linear(i, matrix[[i, i]])?;
201 }
202
203 for j in (i + 1)..n_vars {
204 if matrix[[i, j]] != 0.0 {
205 qubo_model.set_quadratic(i, j, matrix[[i, j]])?;
206 }
207 }
208 }
209
210 #[cfg(feature = "amazon_braket")]
212 {
213 let _braket_result = "placeholder";
221 }
222
223 let mut results = Vec::new();
225 let mut rng = thread_rng();
226
227 let unique_solutions = match &self.config.device {
229 BraketDevice::DWaveAdvantage | BraketDevice::DWave2000Q => {
230 shots.min(1000)
232 }
233 BraketDevice::LocalSimulator | BraketDevice::StateVectorSimulator => {
234 shots.min(500)
236 }
237 BraketDevice::TensorNetworkSimulator => shots.min(300),
238 _ => {
239 shots.min(100)
241 }
242 };
243
244 for _ in 0..unique_solutions {
245 let assignments: HashMap<String, bool> = idx_to_var
246 .values()
247 .map(|name| (name.clone(), rng.gen::<bool>()))
248 .collect();
249
250 let mut energy = 0.0;
252 for (var_name, &val) in &assignments {
253 let i = var_map[var_name];
254 if val {
255 energy += matrix[[i, i]];
256 for (other_var, &other_val) in &assignments {
257 let j = var_map[other_var];
258 if i < j && other_val {
259 energy += matrix[[i, j]];
260 }
261 }
262 }
263 }
264
265 let occurrences = match &self.config.device {
267 BraketDevice::DWaveAdvantage | BraketDevice::DWave2000Q => {
268 rng.gen_range(1..=(shots / unique_solutions + 20))
270 }
271 _ => {
272 rng.gen_range(1..=(shots / unique_solutions + 5))
274 }
275 };
276
277 results.push(SampleResult {
278 assignments,
279 energy,
280 occurrences,
281 });
282 }
283
284 results.sort_by(|a, b| {
286 a.energy
287 .partial_cmp(&b.energy)
288 .unwrap_or(std::cmp::Ordering::Equal)
289 });
290
291 results.truncate(shots.min(100));
293
294 Ok(results)
295 }
296
297 fn run_hobo(
298 &self,
299 hobo: &(
300 Array<f64, scirs2_core::ndarray::IxDyn>,
301 HashMap<String, usize>,
302 ),
303 shots: usize,
304 ) -> SamplerResult<Vec<SampleResult>> {
305 use scirs2_core::ndarray::Ix2;
306
307 if hobo.0.ndim() <= 2 {
309 let qubo_matrix = hobo.0.clone().into_dimensionality::<Ix2>().map_err(|e| {
311 SamplerError::InvalidParameter(format!(
312 "Failed to convert HOBO to QUBO dimensionality: {e}"
313 ))
314 })?;
315 let qubo = (qubo_matrix, hobo.1.clone());
316 self.run_qubo(&qubo, shots)
317 } else {
318 Err(SamplerError::InvalidParameter(
320 "Amazon Braket doesn't support HOBO problems directly. Use a quadratization technique first.".to_string()
321 ))
322 }
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_amazon_braket_config() {
332 let config = AmazonBraketConfig::default();
333 assert_eq!(config.region, "us-east-1");
334 assert_eq!(config.s3_prefix, "braket-results");
335 assert_eq!(config.max_parallel, 10);
336 assert!(matches!(config.device, BraketDevice::LocalSimulator));
337 }
338
339 #[test]
340 fn test_amazon_braket_sampler_creation() {
341 let sampler = AmazonBraketSampler::with_s3("my-bucket", "us-west-2")
342 .with_device(BraketDevice::IonQDevice)
343 .with_max_parallel(20)
344 .with_poll_interval(10);
345
346 assert_eq!(sampler.config.s3_bucket, "my-bucket");
347 assert_eq!(sampler.config.region, "us-west-2");
348 assert_eq!(sampler.config.max_parallel, 20);
349 assert_eq!(sampler.config.poll_interval, 10);
350 assert!(matches!(sampler.config.device, BraketDevice::IonQDevice));
351 }
352
353 #[test]
354 fn test_braket_device_types() {
355 let devices = [
356 BraketDevice::LocalSimulator,
357 BraketDevice::StateVectorSimulator,
358 BraketDevice::TensorNetworkSimulator,
359 BraketDevice::IonQDevice,
360 BraketDevice::RigettiDevice("Aspen-M-3".to_string()),
361 BraketDevice::OQCDevice,
362 BraketDevice::DWaveAdvantage,
363 BraketDevice::DWave2000Q,
364 ];
365
366 assert_eq!(devices.len(), 8);
367 }
368
369 #[test]
370 fn test_braket_device_limits() {
371 let sv_device = BraketDevice::StateVectorSimulator;
373 let tn_device = BraketDevice::TensorNetworkSimulator;
374 let dwave_device = BraketDevice::DWaveAdvantage;
375
376 assert!(matches!(sv_device, BraketDevice::StateVectorSimulator));
378 assert!(matches!(tn_device, BraketDevice::TensorNetworkSimulator));
379 assert!(matches!(dwave_device, BraketDevice::DWaveAdvantage));
380 }
381}