Skip to main content

prism_q/sim/
probability.rs

1/// A single block in a factored probability distribution.
2///
3/// Each block represents the marginal probabilities for one independent
4/// subsystem. The `mask` indicates which global qubit positions belong
5/// to this block, and `probs` holds the 2^k marginal distribution.
6#[derive(Debug, Clone)]
7pub struct FactoredBlock {
8    /// Marginal probability vector for this block (length 2^k).
9    pub probs: Vec<f64>,
10    /// Bitmask of global qubit positions belonging to this block.
11    pub mask: u64,
12}
13
14/// Probability distribution over computational basis states.
15///
16/// For monolithic simulations this wraps a dense `Vec<f64>` of length 2^n.
17/// For decomposed simulations with independent subsystems, this stores
18/// per-block marginal distributions that are multiplied on demand,
19/// avoiding the O(2^N) Kronecker product unless explicitly requested.
20#[derive(Debug, Clone)]
21pub enum Probabilities {
22    /// Full probability vector of length 2^n.
23    Dense(Vec<f64>),
24    /// Lazy Kronecker product of independent block distributions.
25    Factored {
26        /// Per-block marginal probability vectors and bitmasks.
27        blocks: Vec<FactoredBlock>,
28        /// Total qubit count across all blocks.
29        total_qubits: usize,
30    },
31}
32
33impl Probabilities {
34    /// Number of basis states (2^n).
35    pub fn len(&self) -> usize {
36        match self {
37            Probabilities::Dense(v) => v.len(),
38            Probabilities::Factored { total_qubits, .. } => 1 << total_qubits,
39        }
40    }
41
42    /// Always false, a probability distribution has at least one state.
43    pub fn is_empty(&self) -> bool {
44        false
45    }
46
47    /// Probability of a single computational basis state. O(1) for dense,
48    /// O(K) for factored where K is the number of independent blocks.
49    ///
50    /// # Panics
51    /// Panics if `index >= self.len()`.
52    pub fn get(&self, index: usize) -> f64 {
53        match self {
54            Probabilities::Dense(v) => v[index],
55            Probabilities::Factored { blocks, .. } => {
56                let mut p = 1.0;
57                for block in blocks {
58                    let local = extract_block_bits(index, block.mask);
59                    p *= block.probs[local];
60                }
61                p
62            }
63        }
64    }
65
66    /// Iterate over all basis-state probabilities in order.
67    ///
68    /// For `Dense` this is a direct slice iteration. For `Factored` each
69    /// probability is computed on the fly in O(K) per element.
70    pub fn iter(&self) -> ProbabilitiesIter<'_> {
71        match self {
72            Probabilities::Dense(v) => ProbabilitiesIter {
73                inner: ProbabilitiesIterInner::Dense(v.iter().copied()),
74            },
75            Probabilities::Factored {
76                blocks,
77                total_qubits,
78            } => ProbabilitiesIter {
79                inner: ProbabilitiesIterInner::Factored {
80                    blocks,
81                    next: 0,
82                    len: 1usize << total_qubits,
83                },
84            },
85        }
86    }
87
88    /// Materialize the full probability vector. O(1) clone for dense,
89    /// O(K x 2^N) for factored. Prefer [`Probabilities::get`] for spot-checking.
90    pub fn to_vec(&self) -> Vec<f64> {
91        match self {
92            Probabilities::Dense(v) => v.clone(),
93            Probabilities::Factored {
94                blocks,
95                total_qubits,
96            } => {
97                let n = 1usize << total_qubits;
98                let mut result = vec![0.0f64; n];
99                #[cfg(feature = "parallel")]
100                {
101                    const MIN_PAR_STATES: usize = 1 << 14;
102                    if n >= MIN_PAR_STATES {
103                        use rayon::prelude::*;
104                        crate::backend::init_thread_pool();
105                        result.par_iter_mut().enumerate().for_each(|(i, slot)| {
106                            let mut p = 1.0;
107                            for block in blocks {
108                                let local = extract_block_bits(i, block.mask);
109                                p *= block.probs[local];
110                            }
111                            *slot = p;
112                        });
113                        return result;
114                    }
115                }
116                for (i, slot) in result.iter_mut().enumerate() {
117                    let mut p = 1.0;
118                    for block in blocks {
119                        let local = extract_block_bits(i, block.mask);
120                        p *= block.probs[local];
121                    }
122                    *slot = p;
123                }
124                result
125            }
126        }
127    }
128}
129
130impl std::ops::Index<usize> for Probabilities {
131    type Output = f64;
132
133    /// Index into a dense probability vector.
134    ///
135    /// Only works for `Dense`. Panics on `Factored` because `Index` must
136    /// return `&f64` and factored values are computed, not stored.
137    /// Use [`Probabilities::get`] or [`Probabilities::iter`] instead.
138    fn index(&self, index: usize) -> &f64 {
139        match self {
140            Probabilities::Dense(v) => &v[index],
141            Probabilities::Factored { .. } => {
142                panic!("cannot index Factored probabilities; use .get(i) or .to_vec()")
143            }
144        }
145    }
146}
147
148/// Concrete iterator for [`Probabilities::iter`].
149pub struct ProbabilitiesIter<'a> {
150    inner: ProbabilitiesIterInner<'a>,
151}
152
153enum ProbabilitiesIterInner<'a> {
154    Dense(std::iter::Copied<std::slice::Iter<'a, f64>>),
155    Factored {
156        blocks: &'a [FactoredBlock],
157        next: usize,
158        len: usize,
159    },
160}
161
162impl Iterator for ProbabilitiesIter<'_> {
163    type Item = f64;
164
165    fn next(&mut self) -> Option<Self::Item> {
166        match &mut self.inner {
167            ProbabilitiesIterInner::Dense(iter) => iter.next(),
168            ProbabilitiesIterInner::Factored { blocks, next, len } => {
169                if *next >= *len {
170                    return None;
171                }
172                let index = *next;
173                *next += 1;
174                let mut p = 1.0;
175                for block in *blocks {
176                    let local = extract_block_bits(index, block.mask);
177                    p *= block.probs[local];
178                }
179                Some(p)
180            }
181        }
182    }
183
184    fn size_hint(&self) -> (usize, Option<usize>) {
185        match &self.inner {
186            ProbabilitiesIterInner::Dense(iter) => iter.size_hint(),
187            ProbabilitiesIterInner::Factored { next, len, .. } => {
188                let remaining = len.saturating_sub(*next);
189                (remaining, Some(remaining))
190            }
191        }
192    }
193}
194
195impl ExactSizeIterator for ProbabilitiesIter<'_> {}
196
197/// Extract the bits of `global_index` at positions set in `mask`,
198/// packing them into contiguous low bits.
199#[inline]
200fn extract_block_bits(global_index: usize, mask: u64) -> usize {
201    #[cfg(target_arch = "x86_64")]
202    {
203        if is_x86_feature_detected!("bmi2") {
204            // SAFETY: BMI2 availability is checked immediately before this call.
205            return unsafe { core::arch::x86_64::_pext_u64(global_index as u64, mask) as usize };
206        }
207    }
208    let mut result = 0usize;
209    let mut bit = 0;
210    let mut m = mask;
211    while m != 0 {
212        let pos = m.trailing_zeros() as usize;
213        if global_index & (1 << pos) != 0 {
214            result |= 1 << bit;
215        }
216        bit += 1;
217        m &= m.wrapping_sub(1);
218    }
219    result
220}