use crate::BoxError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DensityRegion {
params: Vec<f32>,
dim: usize,
}
impl DensityRegion {
pub fn new(params: Vec<f32>, dim: usize) -> Result<Self, BoxError> {
if params.len() != 2 * dim {
return Err(BoxError::DimensionMismatch {
expected: 2 * dim,
actual: params.len(),
});
}
for (i, &p) in params.iter().enumerate() {
if !p.is_finite() {
return Err(BoxError::InvalidBounds {
dim: i / 2,
min: p as f64,
max: p as f64,
});
}
}
let norm_sq = squared_norm(¶ms);
if norm_sq < 1e-12 {
return Err(BoxError::InvalidBounds {
dim: 0,
min: 0.0,
max: 0.0,
});
}
Ok(Self { params, dim })
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
pub fn params(&self) -> &[f32] {
&self.params
}
#[must_use]
pub fn squared_norm(&self) -> f32 {
squared_norm(&self.params)
}
#[must_use]
pub fn trace(&self) -> f32 {
1.0
}
#[must_use]
pub fn log_volume(&self) -> f32 {
-(self.dim as f32).ln()
}
}
fn squared_norm(params: &[f32]) -> f32 {
params.iter().map(|x| x * x).sum()
}
fn complex_inner_product(a: &[f32], b: &[f32]) -> (f32, f32) {
debug_assert_eq!(a.len(), b.len());
let mut re = 0.0f32;
let mut im = 0.0f32;
for i in (0..a.len()).step_by(2) {
let a_re = a[i];
let a_im = a[i + 1];
let b_re = b[i];
let b_im = b[i + 1];
re += a_re * b_re + a_im * b_im;
im += a_re * b_im - a_im * b_re;
}
(re, im)
}
pub fn fidelity(a: &DensityRegion, b: &DensityRegion) -> Result<f32, BoxError> {
if a.dim != b.dim {
return Err(BoxError::DimensionMismatch {
expected: a.dim,
actual: b.dim,
});
}
let (re, im) = complex_inner_product(&a.params, &b.params);
let inner_sq = re * re + im * im;
let norm_a = a.squared_norm();
let norm_b = b.squared_norm();
Ok(inner_sq / (norm_a * norm_b))
}
pub fn subsumption_loss(child: &DensityRegion, parent: &DensityRegion) -> Result<f32, BoxError> {
let f = fidelity(child, parent)?;
let deficit = (1.0 - f).max(0.0);
Ok(deficit * deficit)
}
pub fn disjointness_loss(a: &DensityRegion, b: &DensityRegion) -> Result<f32, BoxError> {
let f = fidelity(a, b)?;
Ok(f * f)
}
pub fn bures_distance_sq(a: &DensityRegion, b: &DensityRegion) -> Result<f32, BoxError> {
let f = fidelity(a, b)?;
Ok(2.0 * (1.0 - f.sqrt()))
}
pub fn trace_distance(a: &DensityRegion, b: &DensityRegion) -> Result<f32, BoxError> {
let f = fidelity(a, b)?;
Ok((1.0 - f).max(0.0).sqrt())
}
#[cfg(test)]
mod tests {
use super::*;
fn make_real(v: &[f32]) -> DensityRegion {
let mut params = Vec::with_capacity(v.len() * 2);
for &x in v {
params.push(x);
params.push(0.0);
}
DensityRegion::new(params, v.len()).unwrap()
}
fn make_complex(pairs: &[(f32, f32)]) -> DensityRegion {
let mut params = Vec::with_capacity(pairs.len() * 2);
for &(re, im) in pairs {
params.push(re);
params.push(im);
}
DensityRegion::new(params, pairs.len()).unwrap()
}
#[test]
fn fidelity_identical() {
let a = make_real(&[1.0, 0.0, 0.0]);
let f = fidelity(&a, &a).unwrap();
assert!((f - 1.0).abs() < 1e-6, "fidelity(a, a) = {f}, expected 1.0");
}
#[test]
fn fidelity_orthogonal() {
let a = make_real(&[1.0, 0.0, 0.0]);
let b = make_real(&[0.0, 1.0, 0.0]);
let f = fidelity(&a, &b).unwrap();
assert!(f.abs() < 1e-6, "fidelity(orthogonal) = {f}, expected 0.0");
}
#[test]
fn fidelity_scale_invariant() {
let a = make_real(&[1.0, 0.0]);
let b = make_real(&[3.0, 0.0]);
let f = fidelity(&a, &b).unwrap();
assert!(
(f - 1.0).abs() < 1e-6,
"fidelity should be scale-invariant: {f}"
);
}
#[test]
fn fidelity_complex() {
let a = make_complex(&[(1.0, 0.0), (0.0, 1.0)]);
let b = make_complex(&[(1.0, 0.0), (0.0, -1.0)]);
let f = fidelity(&a, &b).unwrap();
assert!(f.abs() < 1e-6, "fidelity((1,i),(1,-i)) = {f}, expected 0.0");
}
#[test]
fn fidelity_partial_overlap() {
let a = make_real(&[1.0, 0.0]);
let b = make_real(&[1.0, 1.0]);
let f = fidelity(&a, &b).unwrap();
assert!(
(f - 0.5).abs() < 1e-6,
"fidelity partial overlap = {f}, expected 0.5"
);
}
#[test]
fn subsumption_loss_identical_is_zero() {
let a = make_real(&[1.0, 2.0, 3.0]);
let loss = subsumption_loss(&a, &a).unwrap();
assert!(loss.abs() < 1e-6, "subsumption_loss(a, a) = {loss}");
}
#[test]
fn subsumption_loss_orthogonal_is_one() {
let a = make_real(&[1.0, 0.0]);
let b = make_real(&[0.0, 1.0]);
let loss = subsumption_loss(&a, &b).unwrap();
assert!(
(loss - 1.0).abs() < 1e-6,
"subsumption_loss(orthogonal) = {loss}, expected 1.0"
);
}
#[test]
fn disjointness_loss_orthogonal_is_zero() {
let a = make_real(&[1.0, 0.0, 0.0]);
let b = make_real(&[0.0, 1.0, 0.0]);
let loss = disjointness_loss(&a, &b).unwrap();
assert!(loss.abs() < 1e-6, "disjointness(orthogonal) = {loss}");
}
#[test]
fn disjointness_loss_identical_is_one() {
let a = make_real(&[1.0, 0.0]);
let loss = disjointness_loss(&a, &a).unwrap();
assert!(
(loss - 1.0).abs() < 1e-6,
"disjointness(a, a) = {loss}, expected 1.0"
);
}
#[test]
fn bures_distance_identical_is_zero() {
let a = make_real(&[1.0, 2.0]);
let d = bures_distance_sq(&a, &a).unwrap();
assert!(d.abs() < 1e-6, "bures_distance_sq(a, a) = {d}");
}
#[test]
fn bures_distance_orthogonal_is_two() {
let a = make_real(&[1.0, 0.0]);
let b = make_real(&[0.0, 1.0]);
let d = bures_distance_sq(&a, &b).unwrap();
assert!(
(d - 2.0).abs() < 1e-6,
"bures_distance_sq(orthogonal) = {d}, expected 2.0"
);
}
#[test]
fn trace_distance_bounds() {
let a = make_real(&[1.0, 0.0]);
let b = make_real(&[0.0, 1.0]);
let t = trace_distance(&a, &b).unwrap();
assert!(
(t - 1.0).abs() < 1e-6,
"trace_distance(orthogonal) = {t}, expected 1.0"
);
let t_self = trace_distance(&a, &a).unwrap();
assert!(t_self.abs() < 1e-6, "trace_distance(a, a) = {t_self}");
}
#[test]
fn rejects_zero_vector() {
let result = DensityRegion::new(vec![0.0; 6], 3);
assert!(result.is_err());
}
#[test]
fn rejects_dimension_mismatch() {
let result = DensityRegion::new(vec![1.0, 0.0, 0.0], 3);
assert!(result.is_err());
}
#[test]
fn rejects_non_finite() {
let result = DensityRegion::new(vec![f32::NAN, 0.0, 1.0, 0.0], 2);
assert!(result.is_err());
let result = DensityRegion::new(vec![f32::INFINITY, 0.0, 1.0, 0.0], 2);
assert!(result.is_err());
}
#[test]
fn trace_is_one() {
let a = make_real(&[3.0, 4.0, 5.0]);
assert_eq!(a.trace(), 1.0);
}
#[test]
fn dimension_mismatch_error() {
let a = make_real(&[1.0, 0.0]);
let b = make_real(&[1.0, 0.0, 0.0]);
assert!(fidelity(&a, &b).is_err());
assert!(subsumption_loss(&a, &b).is_err());
assert!(disjointness_loss(&a, &b).is_err());
assert!(bures_distance_sq(&a, &b).is_err());
assert!(trace_distance(&a, &b).is_err());
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
fn arb_density(dim: usize) -> impl Strategy<Value = DensityRegion> {
prop::collection::vec(-10.0f32..10.0, 2 * dim)
.prop_filter_map("non-zero vector", move |params| {
DensityRegion::new(params, dim).ok()
})
}
fn arb_density_pair(dim: usize) -> impl Strategy<Value = (DensityRegion, DensityRegion)> {
(arb_density(dim), arb_density(dim))
}
proptest! {
#[test]
fn prop_fidelity_in_unit_interval(
(a, b) in arb_density_pair(4)
) {
let f = fidelity(&a, &b).unwrap();
prop_assert!(f >= -1e-6, "fidelity should be >= 0, got {f}");
prop_assert!(f <= 1.0 + 1e-6, "fidelity should be <= 1, got {f}");
}
#[test]
fn prop_self_fidelity_is_one(
a in arb_density(4)
) {
let f = fidelity(&a, &a).unwrap();
prop_assert!((f - 1.0).abs() < 1e-5, "fidelity(a, a) = {f}, expected 1.0");
}
#[test]
fn prop_fidelity_symmetric(
(a, b) in arb_density_pair(4)
) {
let f_ab = fidelity(&a, &b).unwrap();
let f_ba = fidelity(&b, &a).unwrap();
prop_assert!(
(f_ab - f_ba).abs() < 1e-5,
"fidelity should be symmetric: {f_ab} != {f_ba}"
);
}
#[test]
fn prop_subsumption_loss_nonneg(
(a, b) in arb_density_pair(4)
) {
let loss = subsumption_loss(&a, &b).unwrap();
prop_assert!(loss >= -1e-6, "subsumption_loss should be >= 0, got {loss}");
}
#[test]
fn prop_disjointness_loss_nonneg(
(a, b) in arb_density_pair(4)
) {
let loss = disjointness_loss(&a, &b).unwrap();
prop_assert!(loss >= -1e-6, "disjointness_loss should be >= 0, got {loss}");
}
#[test]
fn prop_bures_distance_nonneg(
(a, b) in arb_density_pair(4)
) {
let d = bures_distance_sq(&a, &b).unwrap();
prop_assert!(d >= -1e-5, "bures_distance_sq should be >= 0, got {d}");
}
#[test]
fn prop_trace_distance_in_unit_interval(
(a, b) in arb_density_pair(4)
) {
let t = trace_distance(&a, &b).unwrap();
prop_assert!(t >= -1e-6, "trace_distance should be >= 0, got {t}");
prop_assert!(t <= 1.0 + 1e-6, "trace_distance should be <= 1, got {t}");
}
#[test]
fn prop_bures_distance_symmetric(
(a, b) in arb_density_pair(4)
) {
let d_ab = bures_distance_sq(&a, &b).unwrap();
let d_ba = bures_distance_sq(&b, &a).unwrap();
prop_assert!(
(d_ab - d_ba).abs() < 1e-5,
"bures_distance should be symmetric: {d_ab} != {d_ba}"
);
}
}
}