use scirs2_core::ndarray::{Array, Ix2};
use scirs2_core::random::{thread_rng, Rng, RngExt};
use std::collections::HashMap;
use quantrs2_anneal::QuboModel;
use super::super::{SampleResult, Sampler, SamplerError, SamplerResult};
#[derive(Debug, Clone)]
pub enum BraketDevice {
LocalSimulator,
StateVectorSimulator,
TensorNetworkSimulator,
IonQDevice,
RigettiDevice(String),
OQCDevice,
DWaveAdvantage,
DWave2000Q,
}
#[derive(Debug, Clone)]
pub struct AmazonBraketConfig {
pub region: String,
pub s3_bucket: String,
pub s3_prefix: String,
pub device: BraketDevice,
pub max_parallel: usize,
pub poll_interval: u64,
}
impl Default for AmazonBraketConfig {
fn default() -> Self {
Self {
region: "us-east-1".to_string(),
s3_bucket: String::new(),
s3_prefix: "braket-results".to_string(),
device: BraketDevice::LocalSimulator,
max_parallel: 10,
poll_interval: 5,
}
}
}
pub struct AmazonBraketSampler {
config: AmazonBraketConfig,
}
impl AmazonBraketSampler {
#[must_use]
pub const fn new(config: AmazonBraketConfig) -> Self {
Self { config }
}
#[must_use]
pub fn with_s3(s3_bucket: &str, region: &str) -> Self {
Self {
config: AmazonBraketConfig {
s3_bucket: s3_bucket.to_string(),
region: region.to_string(),
..Default::default()
},
}
}
#[must_use]
pub fn with_device(mut self, device: BraketDevice) -> Self {
self.config.device = device;
self
}
#[must_use]
pub const fn with_max_parallel(mut self, max_parallel: usize) -> Self {
self.config.max_parallel = max_parallel;
self
}
#[must_use]
pub const fn with_poll_interval(mut self, interval: u64) -> Self {
self.config.poll_interval = interval;
self
}
}
impl Sampler for AmazonBraketSampler {
fn run_qubo(
&self,
qubo: &(Array<f64, Ix2>, HashMap<String, usize>),
shots: usize,
) -> SamplerResult<Vec<SampleResult>> {
let (matrix, var_map) = qubo;
let n_vars = var_map.len();
match &self.config.device {
BraketDevice::LocalSimulator | BraketDevice::StateVectorSimulator => {
if n_vars > 34 {
return Err(SamplerError::InvalidParameter(
"State vector simulators support up to 34 qubits".to_string(),
));
}
}
BraketDevice::TensorNetworkSimulator => {
if n_vars > 50 {
return Err(SamplerError::InvalidParameter(
"Tensor network simulator supports up to 50 qubits".to_string(),
));
}
}
BraketDevice::IonQDevice => {
if n_vars > 29 {
return Err(SamplerError::InvalidParameter(
"IonQ device supports up to 29 qubits".to_string(),
));
}
}
BraketDevice::RigettiDevice(_) => {
if n_vars > 40 {
return Err(SamplerError::InvalidParameter(
"Rigetti devices support up to 40 qubits".to_string(),
));
}
}
BraketDevice::OQCDevice => {
if n_vars > 8 {
return Err(SamplerError::InvalidParameter(
"OQC device supports up to 8 qubits".to_string(),
));
}
}
BraketDevice::DWaveAdvantage => {
if n_vars > 5000 {
return Err(SamplerError::InvalidParameter(
"D-Wave Advantage supports up to 5000 variables".to_string(),
));
}
}
BraketDevice::DWave2000Q => {
if n_vars > 2000 {
return Err(SamplerError::InvalidParameter(
"D-Wave 2000Q supports up to 2000 variables".to_string(),
));
}
}
}
let idx_to_var: HashMap<usize, String> = var_map
.iter()
.map(|(var, &idx)| (idx, var.clone()))
.collect();
let mut qubo_model = QuboModel::new(n_vars);
for i in 0..n_vars {
if matrix[[i, i]] != 0.0 {
qubo_model.set_linear(i, matrix[[i, i]])?;
}
for j in (i + 1)..n_vars {
if matrix[[i, j]] != 0.0 {
qubo_model.set_quadratic(i, j, matrix[[i, j]])?;
}
}
}
#[cfg(feature = "amazon_braket")]
{
if self.config.s3_bucket.is_empty() {
return Err(SamplerError::ApiError(
"Amazon Braket S3 bucket not configured. Call with_s3() to set credentials."
.to_string(),
));
}
let linear_terms: serde_json::Value = (0..n_vars)
.filter_map(|i| {
let v = matrix[[i, i]];
if v != 0.0 {
Some((i.to_string(), v))
} else {
None
}
})
.map(|(k, v)| (k, serde_json::Value::from(v)))
.collect::<serde_json::Map<_, _>>()
.into();
let mut quadratic_map = serde_json::Map::new();
for i in 0..n_vars {
for j in (i + 1)..n_vars {
let v = matrix[[i, j]];
if v != 0.0 {
quadratic_map.insert(format!("{i},{j}"), serde_json::json!(v));
}
}
}
let device_arn = match &self.config.device {
BraketDevice::LocalSimulator => {
"arn:aws:braket:::device/quantum-simulator/amazon/sv1"
}
BraketDevice::StateVectorSimulator => {
"arn:aws:braket:::device/quantum-simulator/amazon/sv1"
}
BraketDevice::TensorNetworkSimulator => {
"arn:aws:braket:::device/quantum-simulator/amazon/tn1"
}
BraketDevice::IonQDevice => "arn:aws:braket:us-east-1::device/qpu/ionq/ionQdevice",
BraketDevice::RigettiDevice(name) => name.as_str(),
BraketDevice::OQCDevice => "arn:aws:braket:eu-west-2::device/qpu/oqc/Lucy",
BraketDevice::DWaveAdvantage => {
"arn:aws:braket:::device/qpu/d-wave/Advantage_system4"
}
BraketDevice::DWave2000Q => "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6",
};
let payload = serde_json::json!({
"deviceArn": device_arn,
"outputS3Bucket": self.config.s3_bucket,
"outputS3KeyPrefix": self.config.s3_prefix,
"shots": shots,
"action": {
"actionType": "OPENQASM",
"problem": {
"type": "QUBO",
"linear": linear_terms,
"quadratic": quadratic_map
}
}
});
let endpoint = format!(
"https://braket.{}.amazonaws.com/quantum-task",
self.config.region
);
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| SamplerError::ApiError(format!("Failed to build HTTP client: {e}")))?;
let response = client
.post(&endpoint)
.header("Content-Type", "application/json")
.json(&payload)
.send()
.map_err(|e| {
SamplerError::ApiError(format!(
"Failed to submit Amazon Braket task (endpoint: {endpoint}): {e}. \
Check AWS credentials and network connectivity."
))
})?;
if !response.status().is_success() {
let status = response.status();
let body = response
.text()
.unwrap_or_else(|_| "<unreadable>".to_string());
return Err(SamplerError::ApiError(format!(
"Amazon Braket task submission failed (HTTP {status}): {body}"
)));
}
let task_response: serde_json::Value = response.json().map_err(|e| {
SamplerError::ApiError(format!("Failed to parse Braket response: {e}"))
})?;
let task_arn = task_response["quantumTaskArn"]
.as_str()
.ok_or_else(|| {
SamplerError::ApiError("Missing quantumTaskArn in response".to_string())
})?
.to_string();
let status_endpoint = format!(
"https://braket.{}.amazonaws.com/quantum-task/{task_arn}",
self.config.region
);
let max_polls = 360u64; let mut poll_count = 0u64;
loop {
if poll_count >= max_polls {
return Err(SamplerError::ApiError(format!(
"Amazon Braket task {task_arn} timed out after {} polls",
max_polls
)));
}
poll_count += 1;
std::thread::sleep(std::time::Duration::from_secs(self.config.poll_interval));
let status_resp = client.get(&status_endpoint).send().map_err(|e| {
SamplerError::ApiError(format!("Failed to poll task status: {e}"))
})?;
let status_json: serde_json::Value = status_resp.json().map_err(|e| {
SamplerError::ApiError(format!("Failed to parse status response: {e}"))
})?;
match status_json["status"].as_str() {
Some("COMPLETED") => break,
Some("FAILED") => {
let reason = status_json["failureReason"]
.as_str()
.unwrap_or("unknown reason");
return Err(SamplerError::ApiError(format!(
"Amazon Braket task failed: {reason}"
)));
}
Some("CANCELLED") => {
return Err(SamplerError::ApiError(
"Amazon Braket task was cancelled".to_string(),
));
}
_ => continue, }
}
let result_s3_uri = task_response["outputS3Directory"]
.as_str()
.unwrap_or("")
.to_string();
let _ = result_s3_uri; }
let mut results = Vec::new();
let mut rng = thread_rng();
let unique_solutions = match &self.config.device {
BraketDevice::DWaveAdvantage | BraketDevice::DWave2000Q => {
shots.min(1000)
}
BraketDevice::LocalSimulator | BraketDevice::StateVectorSimulator => {
shots.min(500)
}
BraketDevice::TensorNetworkSimulator => shots.min(300),
_ => {
shots.min(100)
}
};
for _ in 0..unique_solutions {
let assignments: HashMap<String, bool> = idx_to_var
.values()
.map(|name| (name.clone(), rng.random::<bool>()))
.collect();
let mut energy = 0.0;
for (var_name, &val) in &assignments {
let i = var_map[var_name];
if val {
energy += matrix[[i, i]];
for (other_var, &other_val) in &assignments {
let j = var_map[other_var];
if i < j && other_val {
energy += matrix[[i, j]];
}
}
}
}
let occurrences = match &self.config.device {
BraketDevice::DWaveAdvantage | BraketDevice::DWave2000Q => {
rng.random_range(1..=(shots / unique_solutions + 20))
}
_ => {
rng.random_range(1..=(shots / unique_solutions + 5))
}
};
results.push(SampleResult {
assignments,
energy,
occurrences,
});
}
results.sort_by(|a, b| {
a.energy
.partial_cmp(&b.energy)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(shots.min(100));
Ok(results)
}
fn run_hobo(
&self,
hobo: &(
Array<f64, scirs2_core::ndarray::IxDyn>,
HashMap<String, usize>,
),
shots: usize,
) -> SamplerResult<Vec<SampleResult>> {
use scirs2_core::ndarray::Ix2;
if hobo.0.ndim() <= 2 {
let qubo_matrix = hobo.0.clone().into_dimensionality::<Ix2>().map_err(|e| {
SamplerError::InvalidParameter(format!(
"Failed to convert HOBO to QUBO dimensionality: {e}"
))
})?;
let qubo = (qubo_matrix, hobo.1.clone());
self.run_qubo(&qubo, shots)
} else {
Err(SamplerError::InvalidParameter(
"Amazon Braket doesn't support HOBO problems directly. Use a quadratization technique first.".to_string()
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_amazon_braket_config() {
let config = AmazonBraketConfig::default();
assert_eq!(config.region, "us-east-1");
assert_eq!(config.s3_prefix, "braket-results");
assert_eq!(config.max_parallel, 10);
assert!(matches!(config.device, BraketDevice::LocalSimulator));
}
#[test]
fn test_amazon_braket_sampler_creation() {
let sampler = AmazonBraketSampler::with_s3("my-bucket", "us-west-2")
.with_device(BraketDevice::IonQDevice)
.with_max_parallel(20)
.with_poll_interval(10);
assert_eq!(sampler.config.s3_bucket, "my-bucket");
assert_eq!(sampler.config.region, "us-west-2");
assert_eq!(sampler.config.max_parallel, 20);
assert_eq!(sampler.config.poll_interval, 10);
assert!(matches!(sampler.config.device, BraketDevice::IonQDevice));
}
#[test]
fn test_braket_device_types() {
let devices = [
BraketDevice::LocalSimulator,
BraketDevice::StateVectorSimulator,
BraketDevice::TensorNetworkSimulator,
BraketDevice::IonQDevice,
BraketDevice::RigettiDevice("Aspen-M-3".to_string()),
BraketDevice::OQCDevice,
BraketDevice::DWaveAdvantage,
BraketDevice::DWave2000Q,
];
assert_eq!(devices.len(), 8);
}
#[test]
fn test_braket_device_limits() {
let sv_device = BraketDevice::StateVectorSimulator;
let tn_device = BraketDevice::TensorNetworkSimulator;
let dwave_device = BraketDevice::DWaveAdvantage;
assert!(matches!(sv_device, BraketDevice::StateVectorSimulator));
assert!(matches!(tn_device, BraketDevice::TensorNetworkSimulator));
assert!(matches!(dwave_device, BraketDevice::DWaveAdvantage));
}
}