#![allow(dead_code)]
#[cfg(feature = "dwave")]
use quantrs2_symengine_pure::Expression as SymEngineExpression;
#[cfg(feature = "dwave")]
use scirs2_core::ndarray::Array;
use std::fmt::Write;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum SymbolError {
#[error("Format string must have same number of placeholders as dimensions")]
FormatMismatch,
#[error("Format string must separate placeholders, got {0}")]
InvalidPlaceholders(String),
#[error("Currently only up to 5 dimensions are supported")]
TooManyDimensions,
#[error("Failed to create symbol: {0}")]
SymbolCreationError(String),
}
pub type SymbolResult<T> = Result<T, SymbolError>;
#[cfg(feature = "dwave")]
pub fn symbols<T: AsRef<str>>(name: T) -> SymEngineExpression {
SymEngineExpression::symbol(name.as_ref())
}
#[cfg(not(feature = "dwave"))]
#[doc(hidden)]
pub fn symbols<T: AsRef<str>>(_name: T) {
panic!("The dwave feature is required to use symbolic functionality")
}
#[cfg(feature = "dwave")]
pub fn symbols_list<T>(
shape: T,
format_txt: &str,
) -> SymbolResult<Array<SymEngineExpression, scirs2_core::ndarray::IxDyn>>
where
T: Into<Vec<usize>>,
{
let shape = shape.into();
let dim = shape.len();
if dim != format_txt.matches("{}").count() {
return Err(SymbolError::FormatMismatch);
}
if format_txt.contains("}{") {
return Err(SymbolError::InvalidPlaceholders(format_txt.to_string()));
}
if dim > 5 {
return Err(SymbolError::TooManyDimensions);
}
let shape_dim = scirs2_core::ndarray::IxDyn(&shape);
let mut array = Array::from_elem(shape_dim, SymEngineExpression::int(0));
let mut indices = vec![0; dim];
fill_symbol_array(&mut array, &mut indices, 0, dim, format_txt)?;
Ok(array)
}
#[cfg(feature = "dwave")]
fn fill_symbol_array(
array: &mut Array<SymEngineExpression, scirs2_core::ndarray::IxDyn>,
indices: &mut Vec<usize>,
level: usize,
max_level: usize,
format_txt: &str,
) -> SymbolResult<()> {
if level == max_level {
let format_args: Vec<String> = indices.iter().map(ToString::to_string).collect();
let mut symbol_name = format_txt.to_string();
for arg in &format_args {
symbol_name = symbol_name.replacen("{}", arg, 1);
}
let sym = SymEngineExpression::symbol(&symbol_name);
let idx = scirs2_core::ndarray::IxDyn(indices);
array[idx] = sym;
Ok(())
} else {
let dim = array.shape()[level];
for i in 0..dim {
indices[level] = i;
fill_symbol_array(array, indices, level + 1, max_level, format_txt)?;
}
Ok(())
}
}
#[cfg(feature = "dwave")]
pub fn symbols_define<T>(shape: T, format_txt: &str) -> SymbolResult<String>
where
T: Into<Vec<usize>>,
{
let shape = shape.into();
let dim = shape.len();
if dim != format_txt.matches("{}").count() {
return Err(SymbolError::FormatMismatch);
}
if format_txt.contains("}{") {
return Err(SymbolError::InvalidPlaceholders(format_txt.to_string()));
}
if dim > 5 {
return Err(SymbolError::TooManyDimensions);
}
let mut commands = String::new();
let mut indices = vec![0; dim];
generate_symbol_commands(&mut commands, &mut indices, 0, &shape, format_txt)?;
Ok(commands)
}
#[allow(dead_code)]
fn generate_symbol_commands(
commands: &mut String,
indices: &mut Vec<usize>,
level: usize,
shape: &[usize],
format_txt: &str,
) -> SymbolResult<()> {
if level == shape.len() {
let format_args: Vec<String> = indices.iter().map(ToString::to_string).collect();
let mut symbol_name = format_txt.to_string();
for arg in &format_args {
symbol_name = symbol_name.replacen("{}", arg, 1);
}
writeln!(commands, "{symbol_name} = symbols(\"{symbol_name}\");")
.expect("Writing to string should not fail");
Ok(())
} else {
for i in 0..shape[level] {
indices[level] = i;
generate_symbol_commands(commands, indices, level + 1, shape, format_txt)?;
}
Ok(())
}
}
#[cfg(feature = "dwave")]
pub fn symbols_nbit(
start: u64,
stop: u64,
format_txt: &str,
num: usize,
) -> SymbolResult<SymEngineExpression> {
if format_txt.matches("{}").count() != 1 {
return Err(SymbolError::FormatMismatch);
}
let mut result = SymEngineExpression::int(0);
let range = (stop - start) as f64;
for n in 0..num {
let bit_name = format_txt.replacen("{}", &n.to_string(), 1);
let bit = SymEngineExpression::symbol(&bit_name);
let weight = range * 2.0_f64.powi((num as i32) - 1 - (n as i32)) / 2.0_f64.powi(num as i32);
result = result + (SymEngineExpression::from(weight) * bit);
}
if start > 0 {
result = result + SymEngineExpression::int(start as i64);
}
Ok(result)
}
#[cfg(feature = "dwave")]
pub type Symbol = SymEngineExpression;
#[cfg(feature = "dwave")]
pub type Expression = SymEngineExpression;
#[cfg(not(feature = "dwave"))]
#[derive(Debug, Clone)]
pub struct Symbol {
name: String,
}
#[cfg(not(feature = "dwave"))]
impl Symbol {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
}
}
}
#[cfg(not(feature = "dwave"))]
#[derive(Debug, Clone)]
pub struct Expression {
value: String,
}
#[cfg(not(feature = "dwave"))]
impl Expression {
pub fn new(value: &str) -> Self {
Self {
value: value.to_string(),
}
}
}