#[cfg(not(feature = "std"))]
use alloc::{collections::BTreeMap, vec::Vec};
#[cfg(feature = "std")]
use std::collections::BTreeMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Domain {
Binary,
Spin,
Discrete(i64),
}
#[derive(Debug, Clone, PartialEq)]
pub struct XqmxModel {
pub domain: Domain,
pub size: usize,
pub(crate) linear: BTreeMap<usize, i64>,
pub(crate) quadratic: BTreeMap<(usize, usize), i64>,
pub rows: usize,
pub cols: usize,
}
impl XqmxModel {
pub fn new(domain: Domain, size: usize) -> Self {
Self {
domain,
size,
linear: BTreeMap::new(),
quadratic: BTreeMap::new(),
rows: 0,
cols: 0,
}
}
pub fn get_linear(&self, i: usize) -> i64 {
self.linear.get(&i).copied().unwrap_or(0)
}
pub fn set_linear(&mut self, i: usize, val: i64) {
if val == 0 {
let _ = self.linear.remove(&i);
} else {
let _ = self.linear.insert(i, val);
}
}
pub fn add_linear(&mut self, i: usize, delta: i64) {
let v = self.linear.entry(i).or_insert(0);
*v += delta;
if *v == 0 {
let _ = self.linear.remove(&i);
}
}
pub fn get_quad(&self, i: usize, j: usize) -> i64 {
let key = if i <= j { (i, j) } else { (j, i) };
self.quadratic.get(&key).copied().unwrap_or(0)
}
pub fn set_quad(&mut self, i: usize, j: usize, val: i64) {
let key = if i <= j { (i, j) } else { (j, i) };
if val == 0 {
let _ = self.quadratic.remove(&key);
} else {
let _ = self.quadratic.insert(key, val);
}
}
pub fn add_quad(&mut self, i: usize, j: usize, delta: i64) {
let key = if i <= j { (i, j) } else { (j, i) };
let v = self.quadratic.entry(key).or_insert(0);
*v += delta;
if *v == 0 {
let _ = self.quadratic.remove(&key);
}
}
pub fn linear_len(&self) -> usize {
self.linear.len()
}
pub fn quadratic_len(&self) -> usize {
self.quadratic.len()
}
pub fn iter_linear(&self) -> impl Iterator<Item = (usize, i64)> + '_ {
self.linear.iter().map(|(&i, &v)| (i, v))
}
pub fn iter_quadratic(&self) -> impl Iterator<Item = (usize, usize, i64)> + '_ {
self.quadratic.iter().map(|(&(i, j), &v)| (i, j, v))
}
pub fn energy(&self, sample: &[i64]) -> Result<i64, crate::Error> {
if sample.len() != self.size {
return Err(crate::Error::SizeMismatch {
model_size: self.size,
sample_len: sample.len(),
});
}
let mut h: i64 = 0;
for (&i, &coeff) in &self.linear {
let xi = sample.get(i).copied().ok_or(crate::Error::SizeMismatch {
model_size: self.size,
sample_len: i.saturating_add(1),
})?;
h = h.wrapping_add(coeff.wrapping_mul(xi));
}
for (&(i, j), &coeff) in &self.quadratic {
let xi = sample.get(i).copied().ok_or(crate::Error::SizeMismatch {
model_size: self.size,
sample_len: i.saturating_add(1),
})?;
let xj = sample.get(j).copied().ok_or(crate::Error::SizeMismatch {
model_size: self.size,
sample_len: j.saturating_add(1),
})?;
h = h.wrapping_add(coeff.wrapping_mul(xi).wrapping_mul(xj));
}
Ok(h)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct XqmxSample {
pub domain: Domain,
pub values: Vec<i64>,
pub rows: usize,
pub cols: usize,
}
impl XqmxSample {
pub fn new(domain: Domain, values: Vec<i64>) -> Self {
Self {
domain,
values,
rows: 0,
cols: 0,
}
}
}