#[derive(Debug, Clone)]
pub struct FactoredBlock {
pub probs: Vec<f64>,
pub mask: u64,
}
#[derive(Debug, Clone)]
pub enum Probabilities {
Dense(Vec<f64>),
Factored {
blocks: Vec<FactoredBlock>,
total_qubits: usize,
},
}
impl Probabilities {
pub fn len(&self) -> usize {
match self {
Probabilities::Dense(v) => v.len(),
Probabilities::Factored { total_qubits, .. } => 1 << total_qubits,
}
}
pub fn is_empty(&self) -> bool {
false
}
pub fn get(&self, index: usize) -> f64 {
match self {
Probabilities::Dense(v) => v[index],
Probabilities::Factored { blocks, .. } => {
let mut p = 1.0;
for block in blocks {
let local = extract_block_bits(index, block.mask);
p *= block.probs[local];
}
p
}
}
}
pub fn iter(&self) -> ProbabilitiesIter<'_> {
match self {
Probabilities::Dense(v) => ProbabilitiesIter {
inner: ProbabilitiesIterInner::Dense(v.iter().copied()),
},
Probabilities::Factored {
blocks,
total_qubits,
} => ProbabilitiesIter {
inner: ProbabilitiesIterInner::Factored {
blocks,
next: 0,
len: 1usize << total_qubits,
},
},
}
}
pub fn to_vec(&self) -> Vec<f64> {
match self {
Probabilities::Dense(v) => v.clone(),
Probabilities::Factored {
blocks,
total_qubits,
} => {
let n = 1usize << total_qubits;
let mut result = vec![0.0f64; n];
#[cfg(feature = "parallel")]
{
const MIN_PAR_STATES: usize = 1 << 14;
if n >= MIN_PAR_STATES {
use rayon::prelude::*;
crate::backend::init_thread_pool();
result.par_iter_mut().enumerate().for_each(|(i, slot)| {
let mut p = 1.0;
for block in blocks {
let local = extract_block_bits(i, block.mask);
p *= block.probs[local];
}
*slot = p;
});
return result;
}
}
for (i, slot) in result.iter_mut().enumerate() {
let mut p = 1.0;
for block in blocks {
let local = extract_block_bits(i, block.mask);
p *= block.probs[local];
}
*slot = p;
}
result
}
}
}
}
impl std::ops::Index<usize> for Probabilities {
type Output = f64;
fn index(&self, index: usize) -> &f64 {
match self {
Probabilities::Dense(v) => &v[index],
Probabilities::Factored { .. } => {
panic!("cannot index Factored probabilities; use .get(i) or .to_vec()")
}
}
}
}
pub struct ProbabilitiesIter<'a> {
inner: ProbabilitiesIterInner<'a>,
}
enum ProbabilitiesIterInner<'a> {
Dense(std::iter::Copied<std::slice::Iter<'a, f64>>),
Factored {
blocks: &'a [FactoredBlock],
next: usize,
len: usize,
},
}
impl Iterator for ProbabilitiesIter<'_> {
type Item = f64;
fn next(&mut self) -> Option<Self::Item> {
match &mut self.inner {
ProbabilitiesIterInner::Dense(iter) => iter.next(),
ProbabilitiesIterInner::Factored { blocks, next, len } => {
if *next >= *len {
return None;
}
let index = *next;
*next += 1;
let mut p = 1.0;
for block in *blocks {
let local = extract_block_bits(index, block.mask);
p *= block.probs[local];
}
Some(p)
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match &self.inner {
ProbabilitiesIterInner::Dense(iter) => iter.size_hint(),
ProbabilitiesIterInner::Factored { next, len, .. } => {
let remaining = len.saturating_sub(*next);
(remaining, Some(remaining))
}
}
}
}
impl ExactSizeIterator for ProbabilitiesIter<'_> {}
#[inline]
fn extract_block_bits(global_index: usize, mask: u64) -> usize {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("bmi2") {
return unsafe { core::arch::x86_64::_pext_u64(global_index as u64, mask) as usize };
}
}
let mut result = 0usize;
let mut bit = 0;
let mut m = mask;
while m != 0 {
let pos = m.trailing_zeros() as usize;
if global_index & (1 << pos) != 0 {
result |= 1 << bit;
}
bit += 1;
m &= m.wrapping_sub(1);
}
result
}