use crate::complex::Complex;
const PHI: f32 = 1.618033988;
fn splitmix64_f32(seed: u64) -> f32 {
let mut z = seed.wrapping_add(0x9e3779b97f4a7c15);
z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
z = z ^ (z >> 31);
(z as f32) / (u64::MAX as f32)
}
#[derive(Clone)]
pub struct DensityMatrixN {
pub dim: usize,
pub entries: Vec<Complex>,
pub(crate) scratch_a: Vec<Complex>,
#[allow(dead_code)]
pub(crate) scratch_b: Vec<Complex>,
}
impl DensityMatrixN {
fn with_scratch(dim: usize, entries: Vec<Complex>) -> Self {
let n2 = dim * dim;
Self {
dim,
entries,
scratch_a: vec![Complex::ZERO; n2],
scratch_b: vec![Complex::ZERO; n2],
}
}
pub fn maximally_mixed(dim: usize) -> Self {
let mut entries = vec![Complex::ZERO; dim * dim];
let p = 1.0 / dim as f32;
for k in 0..dim {
entries[k * dim + k] = Complex::new(p, 0.0);
}
Self::with_scratch(dim, entries)
}
pub fn pure_state(k: usize, dim: usize) -> Self {
let mut entries = vec![Complex::ZERO; dim * dim];
entries[k * dim + k] = Complex::ONE;
Self::with_scratch(dim, entries)
}
pub fn equal_superposition(dim: usize) -> Self {
let amp = 1.0 / (dim as f32).sqrt();
let val = amp * amp;
let mut entries = vec![Complex::ZERO; dim * dim];
for i in 0..dim {
for j in 0..dim {
entries[i * dim + j] = Complex::new(val, 0.0);
}
}
Self::with_scratch(dim, entries)
}
pub fn populations(&self) -> Vec<f32> {
(0..self.dim).map(|k| self.entries[k * self.dim + k].re).collect()
}
pub fn coherence_magnitude(&self) -> f32 {
let mut sum = 0.0f32;
for i in 0..self.dim {
for j in (i + 1)..self.dim {
sum += self.entries[i * self.dim + j].norm();
}
}
sum
}
pub fn purity(&self) -> f32 {
let mut sum = 0.0f32;
for e in &self.entries {
sum += e.norm_sq();
}
sum
}
pub fn von_neumann_entropy(&self) -> f32 {
let d = self.dim;
let mut work = self.entries.clone(); let mut eigenvalues = vec![0.0f32; d];
dreamwell_math::eigen::eigenvalues_hermitian(&mut work, &mut eigenvalues, d, 50, 1e-8);
dreamwell_math::eigen::von_neumann_entropy(&eigenvalues)
}
pub fn free_energy(&self, energies: &[f32]) -> f32 {
let pops = self.populations();
let expected_h: f32 = pops.iter().zip(energies.iter()).map(|(p, e)| p * e).sum();
let temperature = 1.0 / (1.0 + PHI * self.coherence_magnitude());
let entropy = self.von_neumann_entropy();
expected_h - temperature * entropy
}
pub fn dephase(&mut self, epsilon: f32) {
let retain = (1.0 - epsilon).max(0.0);
for i in 0..self.dim {
for j in 0..self.dim {
if i != j {
self.entries[i * self.dim + j] = self.entries[i * self.dim + j].scale(retain);
}
}
}
}
pub fn couple_dephase(&mut self, other: &DensityMatrixN, strength: f32) {
let other_coh = other.coherence_magnitude();
let retain = (1.0 - strength * (1.0 - other_coh.min(1.0))).max(0.0);
for i in 0..self.dim {
for j in 0..self.dim {
if i != j {
self.entries[i * self.dim + j] = self.entries[i * self.dim + j].scale(retain);
}
}
}
}
pub fn evolve(&mut self, unitary: &[Complex]) {
let d = self.dim;
dreamwell_math::linalg::cgemm(unitary, &self.entries, &mut self.scratch_a, d, d, d);
dreamwell_math::linalg::cgemm_conj_transpose_b(&self.scratch_a, unitary, &mut self.entries, d, d, d);
}
pub fn hamiltonian_unitary(h: &[f32], dim: usize, dt: f32) -> Vec<Complex> {
let n2 = dim * dim;
let mut u = vec![Complex::ZERO; n2];
let mut sa = vec![Complex::ZERO; n2];
let mut sb = vec![Complex::ZERO; n2];
dreamwell_math::matrix_exp::expm_skew_hermitian(h, dt, &mut u, &mut sa, &mut sb, dim);
u
}
pub fn born_sample(&self, seed: u64) -> usize {
let r = splitmix64_f32(seed);
let mut cumulative = 0.0f32;
for k in 0..self.dim {
cumulative += self.entries[k * self.dim + k].re.max(0.0);
if r < cumulative {
return k;
}
}
self.dim - 1
}
pub fn trace(&self) -> f32 {
(0..self.dim).map(|k| self.entries[k * self.dim + k].re).sum()
}
}
fn gpu_linalg_lock() -> Option<std::sync::MutexGuard<'static, Option<dreamwell_math::gpu_linalg::GpuLinalgContext>>> {
use std::sync::{Mutex, OnceLock};
static GPU_CTX: OnceLock<Mutex<Option<dreamwell_math::gpu_linalg::GpuLinalgContext>>> = OnceLock::new();
let mtx = GPU_CTX.get_or_init(|| {
Mutex::new(dreamwell_math::gpu_linalg::GpuLinalgContext::new(256))
});
mtx.lock().ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pure_state_valid() {
let rho = DensityMatrixN::pure_state(0, 5);
assert!((rho.trace() - 1.0).abs() < 1e-6);
assert!((rho.purity() - 1.0).abs() < 1e-6);
assert_eq!(rho.populations()[0], 1.0);
}
#[test]
fn equal_superposition_valid() {
let rho = DensityMatrixN::equal_superposition(5);
assert!((rho.trace() - 1.0).abs() < 1e-6);
let pops = rho.populations();
for &p in &pops {
assert!((p - 0.2).abs() < 1e-6);
}
}
#[test]
fn dephasing_reduces_coherence() {
let mut rho = DensityMatrixN::equal_superposition(5);
let coh_before = rho.coherence_magnitude();
rho.dephase(0.5);
let coh_after = rho.coherence_magnitude();
assert!(coh_after < coh_before, "Dephasing should reduce coherence");
}
#[test]
fn maximally_mixed_min_purity() {
let rho = DensityMatrixN::maximally_mixed(5);
assert!((rho.purity() - 0.2).abs() < 1e-6);
}
#[test]
fn free_energy_computable() {
let rho = DensityMatrixN::equal_superposition(5);
let energies = vec![0.2, 0.3, 0.1, 0.15, 0.25];
let f = rho.free_energy(&energies);
assert!(f.is_finite());
}
#[test]
fn born_sample_deterministic() {
let rho = DensityMatrixN::pure_state(2, 5);
assert_eq!(rho.born_sample(42), 2);
assert_eq!(rho.born_sample(42), 2);
}
#[test]
fn evolve_preserves_trace() {
let mut rho = DensityMatrixN::equal_superposition(4);
let h = vec![0.1; 16];
let u = DensityMatrixN::hamiltonian_unitary(&h, 4, 0.01);
rho.evolve(&u);
assert!((rho.trace() - 1.0).abs() < 0.1, "trace={}", rho.trace());
}
#[test]
fn pure_state_entropy_is_zero() {
let rho = DensityMatrixN::pure_state(0, 5);
let s = rho.von_neumann_entropy();
assert!(s.abs() < 0.1, "pure state entropy should be ~0, got {s}");
}
#[test]
fn maximally_mixed_entropy_is_maximal() {
let rho = DensityMatrixN::maximally_mixed(5);
let s = rho.von_neumann_entropy();
let expected = (5.0f32).ln(); assert!(
(s - expected).abs() < 0.2,
"maximally mixed entropy should be ~{expected}, got {s}"
);
}
#[test]
fn evolve_at_dim16_preserves_trace() {
let dim = 16;
let mut rho = DensityMatrixN::equal_superposition(dim);
let mut h = vec![0.0f32; dim * dim];
for i in 0..dim {
h[i * dim + i] = (i as f32) * 0.1;
}
h[1] = 0.2;
h[dim] = 0.2; let u = DensityMatrixN::hamiltonian_unitary(&h, dim, 0.1);
rho.evolve(&u);
assert!(
(rho.trace() - 1.0).abs() < 0.15,
"dim=16 trace should be ~1.0, got {}",
rho.trace()
);
}
}