prism_q/sim/
probability.rs1#[derive(Debug, Clone)]
7pub struct FactoredBlock {
8 pub probs: Vec<f64>,
10 pub mask: u64,
12}
13
14#[derive(Debug, Clone)]
21pub enum Probabilities {
22 Dense(Vec<f64>),
24 Factored {
26 blocks: Vec<FactoredBlock>,
28 total_qubits: usize,
30 },
31}
32
33impl Probabilities {
34 pub fn len(&self) -> usize {
36 match self {
37 Probabilities::Dense(v) => v.len(),
38 Probabilities::Factored { total_qubits, .. } => 1 << total_qubits,
39 }
40 }
41
42 pub fn is_empty(&self) -> bool {
44 false
45 }
46
47 pub fn get(&self, index: usize) -> f64 {
53 match self {
54 Probabilities::Dense(v) => v[index],
55 Probabilities::Factored { blocks, .. } => {
56 let mut p = 1.0;
57 for block in blocks {
58 let local = extract_block_bits(index, block.mask);
59 p *= block.probs[local];
60 }
61 p
62 }
63 }
64 }
65
66 pub fn iter(&self) -> ProbabilitiesIter<'_> {
71 match self {
72 Probabilities::Dense(v) => ProbabilitiesIter {
73 inner: ProbabilitiesIterInner::Dense(v.iter().copied()),
74 },
75 Probabilities::Factored {
76 blocks,
77 total_qubits,
78 } => ProbabilitiesIter {
79 inner: ProbabilitiesIterInner::Factored {
80 blocks,
81 next: 0,
82 len: 1usize << total_qubits,
83 },
84 },
85 }
86 }
87
88 pub fn to_vec(&self) -> Vec<f64> {
91 match self {
92 Probabilities::Dense(v) => v.clone(),
93 Probabilities::Factored {
94 blocks,
95 total_qubits,
96 } => {
97 let n = 1usize << total_qubits;
98 let mut result = vec![0.0f64; n];
99 #[cfg(feature = "parallel")]
100 {
101 const MIN_PAR_STATES: usize = 1 << 14;
102 if n >= MIN_PAR_STATES {
103 use rayon::prelude::*;
104 crate::backend::init_thread_pool();
105 result.par_iter_mut().enumerate().for_each(|(i, slot)| {
106 let mut p = 1.0;
107 for block in blocks {
108 let local = extract_block_bits(i, block.mask);
109 p *= block.probs[local];
110 }
111 *slot = p;
112 });
113 return result;
114 }
115 }
116 for (i, slot) in result.iter_mut().enumerate() {
117 let mut p = 1.0;
118 for block in blocks {
119 let local = extract_block_bits(i, block.mask);
120 p *= block.probs[local];
121 }
122 *slot = p;
123 }
124 result
125 }
126 }
127 }
128}
129
130impl std::ops::Index<usize> for Probabilities {
131 type Output = f64;
132
133 fn index(&self, index: usize) -> &f64 {
139 match self {
140 Probabilities::Dense(v) => &v[index],
141 Probabilities::Factored { .. } => {
142 panic!("cannot index Factored probabilities; use .get(i) or .to_vec()")
143 }
144 }
145 }
146}
147
148pub struct ProbabilitiesIter<'a> {
150 inner: ProbabilitiesIterInner<'a>,
151}
152
153enum ProbabilitiesIterInner<'a> {
154 Dense(std::iter::Copied<std::slice::Iter<'a, f64>>),
155 Factored {
156 blocks: &'a [FactoredBlock],
157 next: usize,
158 len: usize,
159 },
160}
161
162impl Iterator for ProbabilitiesIter<'_> {
163 type Item = f64;
164
165 fn next(&mut self) -> Option<Self::Item> {
166 match &mut self.inner {
167 ProbabilitiesIterInner::Dense(iter) => iter.next(),
168 ProbabilitiesIterInner::Factored { blocks, next, len } => {
169 if *next >= *len {
170 return None;
171 }
172 let index = *next;
173 *next += 1;
174 let mut p = 1.0;
175 for block in *blocks {
176 let local = extract_block_bits(index, block.mask);
177 p *= block.probs[local];
178 }
179 Some(p)
180 }
181 }
182 }
183
184 fn size_hint(&self) -> (usize, Option<usize>) {
185 match &self.inner {
186 ProbabilitiesIterInner::Dense(iter) => iter.size_hint(),
187 ProbabilitiesIterInner::Factored { next, len, .. } => {
188 let remaining = len.saturating_sub(*next);
189 (remaining, Some(remaining))
190 }
191 }
192 }
193}
194
195impl ExactSizeIterator for ProbabilitiesIter<'_> {}
196
197#[inline]
200fn extract_block_bits(global_index: usize, mask: u64) -> usize {
201 #[cfg(target_arch = "x86_64")]
202 {
203 if is_x86_feature_detected!("bmi2") {
204 return unsafe { core::arch::x86_64::_pext_u64(global_index as u64, mask) as usize };
206 }
207 }
208 let mut result = 0usize;
209 let mut bit = 0;
210 let mut m = mask;
211 while m != 0 {
212 let pos = m.trailing_zeros() as usize;
213 if global_index & (1 << pos) != 0 {
214 result |= 1 << bit;
215 }
216 bit += 1;
217 m &= m.wrapping_sub(1);
218 }
219 result
220}