use crate::complex::Complex64;
use crate::error::{Error, Result};
#[derive(Clone, PartialEq, Debug)]
pub struct DensityMatrix {
data: Vec<Complex64>,
n: usize,
}
impl DensityMatrix {
#[must_use]
pub fn zero(n: usize) -> Self {
assert!(
n < 32,
"n={n} would require a {}-element matrix",
1usize << (2 * n)
);
let dim = 1usize << n;
let mut data = vec![Complex64::ZERO; dim * dim];
data[0] = Complex64::ONE; Self { data, n }
}
pub fn from_statevector(amps: &[Complex64]) -> Result<Self> {
let dim = amps.len();
if !dim.is_power_of_two() {
return Err(Error::DimensionMismatch {
len: dim,
expected: dim.next_power_of_two(),
});
}
let n = dim.trailing_zeros() as usize;
let mut data = vec![Complex64::ZERO; dim * dim];
for i in 0..dim {
for j in 0..dim {
data[i * dim + j] = amps[i] * amps[j].conj();
}
}
Ok(Self { data, n })
}
#[inline]
#[must_use]
pub fn num_qubits(&self) -> usize {
self.n
}
#[inline]
#[must_use]
pub fn dim(&self) -> usize {
1usize << self.n
}
#[inline]
#[must_use]
pub fn data(&self) -> &[Complex64] {
&self.data
}
#[inline]
pub(crate) fn data_mut(&mut self) -> &mut [Complex64] {
&mut self.data
}
#[inline]
#[must_use]
pub fn get(&self, i: usize, j: usize) -> Complex64 {
self.data[i * self.dim() + j]
}
#[must_use]
pub fn trace(&self) -> f64 {
let dim = self.dim();
(0..dim).map(|i| self.data[i * dim + i].re).sum()
}
#[must_use]
pub fn purity(&self) -> f64 {
let dim = self.dim();
let mut s = 0.0;
for i in 0..dim {
for j in 0..dim {
s += self.data[i * dim + j].norm_sqr();
}
}
s
}
#[must_use]
pub fn probability(&self, basis: usize) -> f64 {
let dim = self.dim();
assert!(
basis < dim,
"basis state {basis} out of range for {}-qubit register",
self.n
);
self.data[basis * dim + basis].re
}
pub(crate) fn apply_unitary(&mut self, u: &[Complex64], targets: &[usize]) {
let m = targets.len();
let gate_dim = 1usize << m;
debug_assert_eq!(u.len(), gate_dim * gate_dim, "gate matrix size mismatch");
debug_assert!(
targets.iter().all(|&t| t < self.n),
"target qubit out of range"
);
debug_assert!(
(0..targets.len()).all(|i| (i + 1..targets.len()).all(|j| targets[i] != targets[j])),
"duplicate target qubit"
);
let full_dim = self.dim();
let mut out = vec![Complex64::ZERO; full_dim * full_dim];
left_multiply(&self.data, &mut out, u, targets, self.n);
right_multiply_adjoint(&out, &mut self.data, u, targets, self.n);
}
pub(crate) fn apply_channel(&mut self, kraus: &[KrausOp], target: usize) {
debug_assert!(
target < self.n,
"channel target qubit {target} out of range"
);
debug_assert!(!kraus.is_empty(), "Kraus operator list must be non-empty");
let full_dim = self.dim();
let mut acc = vec![Complex64::ZERO; full_dim * full_dim];
for k in kraus {
let u = &k.0;
let mut tmp = vec![Complex64::ZERO; full_dim * full_dim];
left_multiply(&self.data, &mut tmp, u, &[target], self.n);
right_multiply_adjoint_add(&tmp, &mut acc, u, &[target], self.n);
}
self.data = acc;
}
pub(crate) fn renormalize(&mut self) {
let t = self.trace();
if t > 0.0 {
let inv = 1.0 / t;
for x in &mut self.data {
*x *= inv;
}
}
}
#[must_use]
pub fn expectation_pauli(&self, qubit: usize, observable: char) -> f64 {
assert!(qubit < self.n, "qubit {qubit} out of range");
let dim = self.dim();
let bit = 1usize << qubit;
let mut result = Complex64::ZERO;
for i in 0..dim {
for j in 0..dim {
let i_other = i & !bit;
let j_other = j & !bit;
if i_other != j_other {
continue;
}
let iq = (i & bit) >> qubit; let jq = (j & bit) >> qubit;
let p = pauli_element(observable, iq, jq);
if p.re != 0.0 || p.im != 0.0 {
result += p * self.data[j * dim + i]; }
}
}
result.re
}
#[must_use]
pub fn partial_trace_1q(&self, qubit: usize) -> [Complex64; 4] {
assert!(qubit < self.n, "qubit {qubit} out of range");
let dim = self.dim();
let bit = 1usize << qubit;
let mut reduced = [Complex64::ZERO; 4];
for i in 0..dim {
for j in 0..dim {
if (i & !bit) != (j & !bit) {
continue;
}
let iq = (i & bit) >> qubit;
let jq = (j & bit) >> qubit;
reduced[iq * 2 + jq] += self.data[i * dim + j];
}
}
reduced
}
}
fn pauli_element(op: char, i: usize, j: usize) -> Complex64 {
let c = |re, im| Complex64::new(re, im);
match op {
'I' => {
if i == j {
Complex64::ONE
} else {
Complex64::ZERO
}
}
'X' => {
if i == j {
Complex64::ZERO
} else {
Complex64::ONE
}
}
'Y' => match (i, j) {
(0, 1) => c(0.0, -1.0),
(1, 0) => c(0.0, 1.0),
_ => Complex64::ZERO,
},
'Z' => match (i, j) {
(0, 0) => Complex64::ONE,
(1, 1) => c(-1.0, 0.0),
_ => Complex64::ZERO,
},
other => panic!("unknown Pauli '{other}'; expected I, X, Y, or Z"),
}
}
fn left_multiply(
rho: &[Complex64],
out: &mut [Complex64],
u: &[Complex64],
targets: &[usize],
n: usize,
) {
let full_dim = 1usize << n;
let m = targets.len();
let gate_dim = 1usize << m;
for col in 0..full_dim {
let row_groups = full_dim / gate_dim;
for g in 0..row_groups {
let rows = index_group(targets, n, g);
let inputs: Vec<Complex64> = rows.iter().map(|&r| rho[r * full_dim + col]).collect();
for (r, &row) in rows.iter().enumerate() {
let mut acc = Complex64::ZERO;
for c in 0..gate_dim {
acc += u[r * gate_dim + c] * inputs[c];
}
out[row * full_dim + col] = acc;
}
}
}
}
fn right_multiply_adjoint(
tmp: &[Complex64],
rho_out: &mut [Complex64],
u: &[Complex64],
targets: &[usize],
n: usize,
) {
let full_dim = 1usize << n;
let m = targets.len();
let gate_dim = 1usize << m;
for row in 0..full_dim {
let col_groups = full_dim / gate_dim;
for g in 0..col_groups {
let cols = index_group(targets, n, g);
let inputs: Vec<Complex64> = cols.iter().map(|&c| tmp[row * full_dim + c]).collect();
for (c_idx, &col) in cols.iter().enumerate() {
let mut acc = Complex64::ZERO;
for k in 0..gate_dim {
acc += inputs[k] * u[c_idx * gate_dim + k].conj();
}
rho_out[row * full_dim + col] = acc;
}
}
}
}
fn right_multiply_adjoint_add(
tmp: &[Complex64],
acc: &mut [Complex64],
u: &[Complex64],
targets: &[usize],
n: usize,
) {
let full_dim = 1usize << n;
let m = targets.len();
let gate_dim = 1usize << m;
for row in 0..full_dim {
let col_groups = full_dim / gate_dim;
for g in 0..col_groups {
let cols = index_group(targets, n, g);
let inputs: Vec<Complex64> = cols.iter().map(|&c| tmp[row * full_dim + c]).collect();
for (c_idx, &col) in cols.iter().enumerate() {
let mut val = Complex64::ZERO;
for k in 0..gate_dim {
val += inputs[k] * u[c_idx * gate_dim + k].conj();
}
acc[row * full_dim + col] += val;
}
}
}
}
fn index_group(targets: &[usize], _n: usize, g: usize) -> Vec<usize> {
let m = targets.len();
let gate_dim = 1usize << m;
let mut sorted_targets = targets.to_vec();
sorted_targets.sort_unstable();
let mut base = g;
for &t in &sorted_targets {
let low_mask = (1usize << t) - 1;
base = ((base >> t) << (t + 1)) | (base & low_mask);
}
(0..gate_dim)
.map(|k| {
let mut idx = base;
for (bit_pos, &t) in targets.iter().enumerate() {
if (k >> bit_pos) & 1 != 0 {
idx |= 1usize << t;
}
}
idx
})
.collect()
}
#[derive(Clone, Debug)]
pub struct KrausOp(pub [Complex64; 4]);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zero_state_has_trace_one() {
for n in 1..=4 {
let rho = DensityMatrix::zero(n);
assert!((rho.trace() - 1.0).abs() < 1e-12, "n={n}");
}
}
#[test]
fn from_statevector_matches_zero_state() {
use crate::complex::Complex64;
let amps = vec![Complex64::ONE, Complex64::ZERO];
let rho = DensityMatrix::from_statevector(&s).unwrap();
assert!((rho.get(0, 0) - Complex64::ONE).norm() < 1e-12);
assert!(rho.get(0, 1).norm() < 1e-12);
assert!(rho.get(1, 0).norm() < 1e-12);
assert!(rho.get(1, 1).norm() < 1e-12);
}
#[test]
fn purity_of_pure_state_is_one() {
let rho = DensityMatrix::zero(2);
assert!((rho.purity() - 1.0).abs() < 1e-12);
}
#[test]
fn apply_unitary_x_gate_flips_zero_state() {
let mut rho = DensityMatrix::zero(1);
let x = [
Complex64::ZERO,
Complex64::ONE,
Complex64::ONE,
Complex64::ZERO,
];
rho.apply_unitary(&x, &[0]);
assert!(rho.get(0, 0).norm() < 1e-12);
assert!((rho.get(1, 1) - Complex64::ONE).norm() < 1e-12);
}
#[test]
fn partial_trace_of_product_state() {
use crate::gate::Gate1;
use crate::kernel::apply_1q;
let mut amps = vec![
Complex64::ONE,
Complex64::ZERO,
Complex64::ZERO,
Complex64::ZERO,
];
apply_1q(&mut amps, 0, &Gate1::h());
let rho = DensityMatrix::from_statevector(&s).unwrap();
let reduced = rho.partial_trace_1q(0);
assert!((reduced[0].re - 0.5).abs() < 1e-12, "r00={}", reduced[0]);
assert!((reduced[3].re - 0.5).abs() < 1e-12, "r11={}", reduced[3]);
assert!((reduced[1].re - 0.5).abs() < 1e-12, "r01={}", reduced[1]);
}
#[test]
fn index_group_2q_covers_all() {
let mut seen = [false; 8];
for g in 0..2 {
for &idx in &index_group(&[0, 1], 3, g) {
assert!(!seen[idx], "index {idx} seen twice");
seen[idx] = true;
}
}
assert!(seen.iter().all(|&b| b), "not all indices covered");
}
}