use crate::types::Complex;
use rand::Rng;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct DeviceCalibration {
pub qubit_t1: Vec<f64>,
pub qubit_t2: Vec<f64>,
pub readout_error: Vec<(f64, f64)>,
pub gate_errors: HashMap<String, f64>,
pub gate_times: HashMap<String, f64>,
pub coupling_map: Vec<(u32, u32)>,
}
#[derive(Debug, Clone, Copy)]
pub struct ThermalRelaxation {
pub t1: f64,
pub t2: f64,
pub gate_time: f64,
}
#[derive(Debug, Clone)]
pub struct EnhancedNoiseModel {
pub depolarizing_rate: f64,
pub two_qubit_depolarizing_rate: f64,
pub amplitude_damping_gamma: Option<f64>,
pub phase_damping_lambda: Option<f64>,
pub readout_error: Option<(f64, f64)>,
pub thermal_relaxation: Option<ThermalRelaxation>,
pub crosstalk_zz: Option<f64>,
}
impl Default for EnhancedNoiseModel {
fn default() -> Self {
Self {
depolarizing_rate: 0.0,
two_qubit_depolarizing_rate: 0.0,
amplitude_damping_gamma: None,
phase_damping_lambda: None,
readout_error: None,
thermal_relaxation: None,
crosstalk_zz: None,
}
}
}
impl EnhancedNoiseModel {
pub fn from_calibration(cal: &DeviceCalibration, gate_name: &str, qubit: u32) -> Self {
let idx = qubit as usize;
let depolarizing_rate = cal
.gate_errors
.get(gate_name)
.copied()
.unwrap_or(0.0);
let gate_time = cal
.gate_times
.get(gate_name)
.copied()
.unwrap_or(0.0);
let t1 = cal.qubit_t1.get(idx).copied().unwrap_or(f64::INFINITY);
let t2 = cal.qubit_t2.get(idx).copied().unwrap_or(f64::INFINITY);
let amplitude_damping_gamma = if t1.is_finite() && t1 > 0.0 && gate_time > 0.0 {
Some(1.0 - (-gate_time / t1).exp())
} else {
None
};
let phase_damping_lambda = if t2.is_finite() && t2 > 0.0 && gate_time > 0.0 {
let inv_t_phi = (1.0 / t2) - (1.0 / (2.0 * t1));
if inv_t_phi > 0.0 {
Some(1.0 - (-gate_time * inv_t_phi).exp())
} else {
None
}
} else {
None
};
let readout_error = cal.readout_error.get(idx).copied();
let thermal_relaxation =
if t1.is_finite() && t2.is_finite() && t1 > 0.0 && t2 > 0.0 && gate_time > 0.0 {
Some(ThermalRelaxation {
t1,
t2,
gate_time,
})
} else {
None
};
Self {
depolarizing_rate,
two_qubit_depolarizing_rate: 0.0,
amplitude_damping_gamma,
phase_damping_lambda,
readout_error,
thermal_relaxation,
crosstalk_zz: None,
}
}
}
const IDENTITY: [[Complex; 2]; 2] = [
[Complex::ONE, Complex::ZERO],
[Complex::ZERO, Complex::ONE],
];
pub fn depolarizing_kraus(p: f64) -> Vec<[[Complex; 2]; 2]> {
let s0 = (1.0 - p).max(0.0).sqrt();
let sp = (p / 3.0).max(0.0).sqrt();
let c = |v: f64| Complex::new(v, 0.0);
let k0 = [
[c(s0), Complex::ZERO],
[Complex::ZERO, c(s0)],
];
let k1 = [
[Complex::ZERO, c(sp)],
[c(sp), Complex::ZERO],
];
let k2 = [
[Complex::ZERO, Complex::new(0.0, -sp)],
[Complex::new(0.0, sp), Complex::ZERO],
];
let k3 = [
[c(sp), Complex::ZERO],
[Complex::ZERO, c(-sp)],
];
vec![k0, k1, k2, k3]
}
pub fn amplitude_damping_kraus(gamma: f64) -> Vec<[[Complex; 2]; 2]> {
let sg = gamma.max(0.0).min(1.0).sqrt();
let s1g = (1.0 - gamma).max(0.0).sqrt();
let c = |v: f64| Complex::new(v, 0.0);
let k0 = [
[Complex::ONE, Complex::ZERO],
[Complex::ZERO, c(s1g)],
];
let k1 = [
[Complex::ZERO, c(sg)],
[Complex::ZERO, Complex::ZERO],
];
vec![k0, k1]
}
pub fn phase_damping_kraus(lambda: f64) -> Vec<[[Complex; 2]; 2]> {
let sl = lambda.max(0.0).min(1.0).sqrt();
let s1l = (1.0 - lambda).max(0.0).sqrt();
let c = |v: f64| Complex::new(v, 0.0);
let k0 = [
[Complex::ONE, Complex::ZERO],
[Complex::ZERO, c(s1l)],
];
let k1 = [
[Complex::ZERO, Complex::ZERO],
[Complex::ZERO, c(sl)],
];
vec![k0, k1]
}
pub fn thermal_relaxation_kraus(t1: f64, t2: f64, gate_time: f64) -> Vec<[[Complex; 2]; 2]> {
if gate_time <= 0.0 || t1 <= 0.0 {
return vec![IDENTITY];
}
let gamma = 1.0 - (-gate_time / t1).exp();
let t2_eff = t2.min(2.0 * t1);
let inv_t_phi = if t2_eff > 0.0 {
(1.0 / t2_eff) - (1.0 / (2.0 * t1))
} else {
0.0
};
let lambda = if inv_t_phi > 0.0 {
1.0 - (-gate_time * inv_t_phi).exp()
} else {
0.0
};
let ad_ops = amplitude_damping_kraus(gamma);
let pd_ops = phase_damping_kraus(lambda);
let mut combined = Vec::with_capacity(ad_ops.len() * pd_ops.len());
for ad in &ad_ops {
for pd in &pd_ops {
combined.push(mat_mul_2x2(ad, pd));
}
}
combined
}
pub fn apply_readout_error(outcome: bool, p01: f64, p10: f64, rng: &mut impl Rng) -> bool {
let r: f64 = rng.gen();
if outcome {
if r < p10 {
false
} else {
true
}
} else {
if r < p01 {
true
} else {
false
}
}
}
#[derive(Debug, Clone)]
pub struct ReadoutCorrector {
readout_errors: Vec<(f64, f64)>,
num_qubits: usize,
}
impl ReadoutCorrector {
pub fn new(readout_errors: &[(f64, f64)]) -> Self {
Self {
readout_errors: readout_errors.to_vec(),
num_qubits: readout_errors.len(),
}
}
pub fn correct_counts(
&self,
counts: &HashMap<Vec<bool>, usize>,
) -> HashMap<Vec<bool>, f64> {
if self.num_qubits == 0 {
return counts
.iter()
.map(|(k, &v)| (k.clone(), v as f64))
.collect();
}
if self.num_qubits <= 12 {
self.correct_full_matrix(counts)
} else {
self.correct_tensor_product(counts)
}
}
fn correct_full_matrix(
&self,
counts: &HashMap<Vec<bool>, usize>,
) -> HashMap<Vec<bool>, f64> {
let n = self.num_qubits;
let dim = 1usize << n;
let confusion = self.build_confusion_matrix(dim, n);
let mut raw_vec = vec![0.0f64; dim];
for (bits, &count) in counts {
let idx = bits_to_index(bits, n);
raw_vec[idx] = count as f64;
}
let corrected_vec = solve_linear_system(&confusion, &raw_vec, dim);
let mut result = HashMap::new();
for i in 0..dim {
let val = corrected_vec[i].max(0.0);
if val > 1e-10 {
let bits = index_to_bits(i, n);
result.insert(bits, val);
}
}
result
}
fn correct_tensor_product(
&self,
counts: &HashMap<Vec<bool>, usize>,
) -> HashMap<Vec<bool>, f64> {
let n = self.num_qubits;
let inv_matrices: Vec<[[f64; 2]; 2]> = self
.readout_errors
.iter()
.map(|&(p01, p10)| invert_2x2_confusion(p01, p10))
.collect();
let mut corrected: HashMap<Vec<bool>, f64> = counts
.iter()
.map(|(k, &v)| (k.clone(), v as f64))
.collect();
for q in 0..n {
let inv = &inv_matrices[q];
let mut new_corrected: HashMap<Vec<bool>, f64> = HashMap::new();
let keys: Vec<Vec<bool>> = corrected.keys().cloned().collect();
let mut processed: std::collections::HashSet<Vec<bool>> = std::collections::HashSet::new();
for bits in &keys {
if processed.contains(bits) {
continue;
}
let mut partner = bits.clone();
partner[q] = !partner[q];
processed.insert(bits.clone());
processed.insert(partner.clone());
let val_this = corrected.get(bits).copied().unwrap_or(0.0);
let val_partner = corrected.get(&partner).copied().unwrap_or(0.0);
let (val_0, val_1, bits_0, bits_1) = if !bits[q] {
(val_this, val_partner, bits.clone(), partner.clone())
} else {
(val_partner, val_this, partner.clone(), bits.clone())
};
let new_0 = inv[0][0] * val_0 + inv[0][1] * val_1;
let new_1 = inv[1][0] * val_0 + inv[1][1] * val_1;
if new_0.abs() > 1e-10 {
new_corrected.insert(bits_0, new_0.max(0.0));
}
if new_1.abs() > 1e-10 {
new_corrected.insert(bits_1, new_1.max(0.0));
}
}
corrected = new_corrected;
}
corrected
}
fn build_confusion_matrix(&self, dim: usize, n: usize) -> Vec<Vec<f64>> {
let mut confusion = vec![vec![0.0f64; dim]; dim];
for true_state in 0..dim {
for measured_state in 0..dim {
let mut prob = 1.0;
for q in 0..n {
let true_bit = (true_state >> q) & 1;
let meas_bit = (measured_state >> q) & 1;
let (p01, p10) = self.readout_errors[q];
prob *= match (true_bit, meas_bit) {
(0, 0) => 1.0 - p01,
(0, 1) => p01,
(1, 0) => p10,
(1, 1) => 1.0 - p10,
_ => unreachable!(),
};
}
confusion[measured_state][true_state] = prob;
}
}
confusion
}
}
fn mat_mul_2x2(
a: &[[Complex; 2]; 2],
b: &[[Complex; 2]; 2],
) -> [[Complex; 2]; 2] {
[
[
a[0][0] * b[0][0] + a[0][1] * b[1][0],
a[0][0] * b[0][1] + a[0][1] * b[1][1],
],
[
a[1][0] * b[0][0] + a[1][1] * b[1][0],
a[1][0] * b[0][1] + a[1][1] * b[1][1],
],
]
}
#[cfg(test)]
fn dagger_2x2(m: &[[Complex; 2]; 2]) -> [[Complex; 2]; 2] {
[
[m[0][0].conj(), m[1][0].conj()],
[m[0][1].conj(), m[1][1].conj()],
]
}
fn bits_to_index(bits: &[bool], n: usize) -> usize {
let mut idx = 0usize;
for q in 0..n.min(bits.len()) {
if bits[q] {
idx |= 1 << q;
}
}
idx
}
fn index_to_bits(idx: usize, n: usize) -> Vec<bool> {
(0..n).map(|q| (idx >> q) & 1 == 1).collect()
}
fn invert_2x2_confusion(p01: f64, p10: f64) -> [[f64; 2]; 2] {
let a = 1.0 - p01;
let b = p10;
let c = p01;
let d = 1.0 - p10;
let det = a * d - b * c;
if det.abs() < 1e-15 {
return [[1.0, 0.0], [0.0, 1.0]];
}
let inv_det = 1.0 / det;
[
[d * inv_det, -b * inv_det],
[-c * inv_det, a * inv_det],
]
}
fn solve_linear_system(a: &[Vec<f64>], b: &[f64], dim: usize) -> Vec<f64> {
let mut aug: Vec<Vec<f64>> = Vec::with_capacity(dim);
for i in 0..dim {
let mut row = Vec::with_capacity(dim + 1);
row.extend_from_slice(&a[i]);
row.push(b[i]);
aug.push(row);
}
for col in 0..dim {
let mut max_row = col;
let mut max_val = aug[col][col].abs();
for row in (col + 1)..dim {
let val = aug[row][col].abs();
if val > max_val {
max_val = val;
max_row = row;
}
}
if max_row != col {
aug.swap(col, max_row);
}
let pivot = aug[col][col];
if pivot.abs() < 1e-15 {
continue; }
for row in (col + 1)..dim {
let factor = aug[row][col] / pivot;
for j in col..=dim {
let val = aug[col][j];
aug[row][j] -= factor * val;
}
}
}
let mut x = vec![0.0f64; dim];
for col in (0..dim).rev() {
let pivot = aug[col][col];
if pivot.abs() < 1e-15 {
x[col] = 0.0;
continue;
}
let mut sum = aug[col][dim];
for j in (col + 1)..dim {
sum -= aug[col][j] * x[j];
}
x[col] = sum / pivot;
}
x
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
fn assert_trace_preserving(ops: &[[[Complex; 2]; 2]], tol: f64) {
let mut sum = [[Complex::ZERO; 2]; 2];
for k in ops {
let kdag = dagger_2x2(k);
let prod = mat_mul_2x2(&kdag, k);
for r in 0..2 {
for c in 0..2 {
sum[r][c] = sum[r][c] + prod[r][c];
}
}
}
assert!(
(sum[0][0].re - 1.0).abs() < tol,
"sum[0][0] = {:?}, expected 1.0",
sum[0][0]
);
assert!(
sum[0][0].im.abs() < tol,
"sum[0][0].im = {}, expected 0.0",
sum[0][0].im
);
assert!(
sum[0][1].re.abs() < tol && sum[0][1].im.abs() < tol,
"sum[0][1] = {:?}, expected 0.0",
sum[0][1]
);
assert!(
sum[1][0].re.abs() < tol && sum[1][0].im.abs() < tol,
"sum[1][0] = {:?}, expected 0.0",
sum[1][0]
);
assert!(
(sum[1][1].re - 1.0).abs() < tol,
"sum[1][1] = {:?}, expected 1.0",
sum[1][1]
);
assert!(
sum[1][1].im.abs() < tol,
"sum[1][1].im = {}, expected 0.0",
sum[1][1].im
);
}
#[test]
fn depolarizing_kraus_trace_preserving() {
for &p in &[0.0, 0.01, 0.1, 0.5, 1.0] {
let ops = depolarizing_kraus(p);
assert_trace_preserving(&ops, 1e-12);
}
}
#[test]
fn depolarizing_p0_is_identity() {
let ops = depolarizing_kraus(0.0);
assert_eq!(ops.len(), 4);
let k0 = &ops[0];
assert!((k0[0][0].re - 1.0).abs() < 1e-14);
assert!((k0[1][1].re - 1.0).abs() < 1e-14);
assert!(k0[0][1].norm_sq() < 1e-28);
assert!(k0[1][0].norm_sq() < 1e-28);
for k in &ops[1..] {
for r in 0..2 {
for c in 0..2 {
assert!(
k[r][c].norm_sq() < 1e-28,
"Non-zero element in zero Kraus op: {:?}",
k[r][c]
);
}
}
}
}
#[test]
fn amplitude_damping_kraus_trace_preserving() {
for &gamma in &[0.0, 0.01, 0.1, 0.5, 0.99, 1.0] {
let ops = amplitude_damping_kraus(gamma);
assert_trace_preserving(&ops, 1e-12);
}
}
#[test]
fn amplitude_damping_gamma1_decays_one_to_zero() {
let ops = amplitude_damping_kraus(1.0);
assert_eq!(ops.len(), 2);
assert!((ops[0][0][0].re - 1.0).abs() < 1e-14);
assert!(ops[0][1][1].norm_sq() < 1e-28);
assert!((ops[1][0][1].re - 1.0).abs() < 1e-14);
assert!(ops[1][1][0].norm_sq() < 1e-28);
assert!(ops[1][1][1].norm_sq() < 1e-28);
let state_one = [Complex::ZERO, Complex::ONE];
let k1_on_one = [
ops[1][0][0] * state_one[0] + ops[1][0][1] * state_one[1],
ops[1][1][0] * state_one[0] + ops[1][1][1] * state_one[1],
];
assert!((k1_on_one[0].re - 1.0).abs() < 1e-14, "Expected |0> component = 1.0");
assert!(k1_on_one[1].norm_sq() < 1e-28, "Expected |1> component = 0.0");
}
#[test]
fn phase_damping_kraus_trace_preserving() {
for &lambda in &[0.0, 0.01, 0.1, 0.5, 1.0] {
let ops = phase_damping_kraus(lambda);
assert_trace_preserving(&ops, 1e-12);
}
}
#[test]
fn phase_damping_lambda0_is_identity() {
let ops = phase_damping_kraus(0.0);
assert_eq!(ops.len(), 2);
assert!((ops[0][0][0].re - 1.0).abs() < 1e-14);
assert!((ops[0][1][1].re - 1.0).abs() < 1e-14);
for r in 0..2 {
for c in 0..2 {
assert!(ops[1][r][c].norm_sq() < 1e-28);
}
}
}
#[test]
fn thermal_relaxation_kraus_trace_preserving() {
let test_cases = [
(50.0, 30.0, 0.05), (50.0, 50.0, 0.05), (50.0, 100.0, 0.05), (100.0, 80.0, 1.0), (50.0, 30.0, 0.001), ];
for &(t1, t2, gt) in &test_cases {
let ops = thermal_relaxation_kraus(t1, t2, gt);
assert_trace_preserving(&ops, 1e-10);
}
}
#[test]
fn thermal_relaxation_zero_gate_time_is_identity() {
let ops = thermal_relaxation_kraus(50.0, 30.0, 0.0);
assert_eq!(ops.len(), 1);
assert!((ops[0][0][0].re - 1.0).abs() < 1e-14);
assert!((ops[0][1][1].re - 1.0).abs() < 1e-14);
}
#[test]
fn readout_error_no_flip_when_rates_zero() {
let mut rng = StdRng::seed_from_u64(42);
for _ in 0..1000 {
assert!(!apply_readout_error(false, 0.0, 0.0, &mut rng));
assert!(apply_readout_error(true, 0.0, 0.0, &mut rng));
}
}
#[test]
fn readout_error_always_flips_when_rates_one() {
let mut rng = StdRng::seed_from_u64(42);
for _ in 0..1000 {
assert!(apply_readout_error(false, 1.0, 0.0, &mut rng));
assert!(!apply_readout_error(true, 0.0, 1.0, &mut rng));
}
}
#[test]
fn readout_error_statistical_rates() {
let mut rng = StdRng::seed_from_u64(12345);
let p01 = 0.1;
let p10 = 0.2;
let trials = 100_000;
let mut flips_01 = 0usize;
let mut flips_10 = 0usize;
for _ in 0..trials {
if apply_readout_error(false, p01, p10, &mut rng) {
flips_01 += 1;
}
if !apply_readout_error(true, p01, p10, &mut rng) {
flips_10 += 1;
}
}
let measured_p01 = flips_01 as f64 / trials as f64;
let measured_p10 = flips_10 as f64 / trials as f64;
assert!(
(measured_p01 - p01).abs() < 0.01,
"p01: expected ~{}, got {}",
p01,
measured_p01
);
assert!(
(measured_p10 - p10).abs() < 0.01,
"p10: expected ~{}, got {}",
p10,
measured_p10
);
}
#[test]
fn readout_corrector_identity_when_no_errors() {
let corrector = ReadoutCorrector::new(&[(0.0, 0.0), (0.0, 0.0)]);
let mut counts = HashMap::new();
counts.insert(vec![false, false], 500);
counts.insert(vec![true, true], 500);
let corrected = corrector.correct_counts(&counts);
assert!(
(corrected.get(&vec![false, false]).copied().unwrap_or(0.0) - 500.0).abs() < 1e-6,
"Expected 500.0 for |00>"
);
assert!(
(corrected.get(&vec![true, true]).copied().unwrap_or(0.0) - 500.0).abs() < 1e-6,
"Expected 500.0 for |11>"
);
}
#[test]
fn readout_corrector_corrects_known_bias() {
let corrector = ReadoutCorrector::new(&[(0.10, 0.05)]);
let mut counts = HashMap::new();
counts.insert(vec![false], 645);
counts.insert(vec![true], 355);
let corrected = corrector.correct_counts(&counts);
let c0 = corrected.get(&vec![false]).copied().unwrap_or(0.0);
let c1 = corrected.get(&vec![true]).copied().unwrap_or(0.0);
assert!(
(c0 - 700.0).abs() < 1.0,
"Expected ~700, got {}",
c0
);
assert!(
(c1 - 300.0).abs() < 1.0,
"Expected ~300, got {}",
c1
);
}
#[test]
fn readout_corrector_two_qubit_correction() {
let corrector = ReadoutCorrector::new(&[(0.05, 0.03), (0.05, 0.03)]);
let mut counts = HashMap::new();
counts.insert(vec![false, false], 903);
counts.insert(vec![true, false], 47);
counts.insert(vec![false, true], 48);
counts.insert(vec![true, true], 2);
let corrected = corrector.correct_counts(&counts);
let c00 = corrected.get(&vec![false, false]).copied().unwrap_or(0.0);
assert!(
(c00 - 1000.0).abs() < 10.0,
"Expected ~1000, got {}",
c00
);
}
#[test]
fn from_calibration_produces_valid_model() {
let mut gate_errors = HashMap::new();
gate_errors.insert("sx_0".to_string(), 0.001);
gate_errors.insert("cx_0_1".to_string(), 0.01);
let mut gate_times = HashMap::new();
gate_times.insert("sx_0".to_string(), 0.035); gate_times.insert("cx_0_1".to_string(), 0.3);
let cal = DeviceCalibration {
qubit_t1: vec![50.0, 60.0],
qubit_t2: vec![30.0, 40.0],
readout_error: vec![(0.02, 0.03), (0.01, 0.02)],
gate_errors,
gate_times,
coupling_map: vec![(0, 1)],
};
let model = EnhancedNoiseModel::from_calibration(&cal, "sx_0", 0);
assert!((model.depolarizing_rate - 0.001).abs() < 1e-10);
assert!(model.amplitude_damping_gamma.is_some());
let gamma = model.amplitude_damping_gamma.unwrap();
let expected_gamma = 1.0 - (-0.035 / 50.0_f64).exp();
assert!(
(gamma - expected_gamma).abs() < 1e-10,
"gamma: expected {}, got {}",
expected_gamma,
gamma
);
assert!(model.phase_damping_lambda.is_some());
assert_eq!(model.readout_error, Some((0.02, 0.03)));
assert!(model.thermal_relaxation.is_some());
let tr = model.thermal_relaxation.unwrap();
assert!((tr.t1 - 50.0).abs() < 1e-10);
assert!((tr.t2 - 30.0).abs() < 1e-10);
assert!((tr.gate_time - 0.035).abs() < 1e-10);
}
#[test]
fn from_calibration_missing_gate_defaults_to_zero() {
let cal = DeviceCalibration {
qubit_t1: vec![50.0],
qubit_t2: vec![30.0],
readout_error: vec![(0.02, 0.03)],
gate_errors: HashMap::new(),
gate_times: HashMap::new(),
coupling_map: vec![],
};
let model = EnhancedNoiseModel::from_calibration(&cal, "nonexistent", 0);
assert!((model.depolarizing_rate).abs() < 1e-10);
assert!(model.amplitude_damping_gamma.is_none());
assert!(model.phase_damping_lambda.is_none());
assert_eq!(model.readout_error, Some((0.02, 0.03)));
}
#[test]
fn from_calibration_qubit_out_of_range() {
let cal = DeviceCalibration {
qubit_t1: vec![50.0],
qubit_t2: vec![30.0],
readout_error: vec![(0.02, 0.03)],
gate_errors: HashMap::new(),
gate_times: HashMap::new(),
coupling_map: vec![],
};
let model = EnhancedNoiseModel::from_calibration(&cal, "sx_5", 5);
assert!(model.amplitude_damping_gamma.is_none());
assert!(model.readout_error.is_none());
}
#[test]
fn bits_to_index_roundtrip() {
for n in 1..=6 {
for idx in 0..(1usize << n) {
let bits = index_to_bits(idx, n);
assert_eq!(bits.len(), n);
let recovered = bits_to_index(&bits, n);
assert_eq!(recovered, idx, "Roundtrip failed for n={}, idx={}", n, idx);
}
}
}
#[test]
fn mat_mul_identity() {
let id = IDENTITY;
let result = mat_mul_2x2(&id, &id);
for r in 0..2 {
for c in 0..2 {
let expected = if r == c { 1.0 } else { 0.0 };
assert!(
(result[r][c].re - expected).abs() < 1e-14,
"result[{}][{}] = {:?}",
r,
c,
result[r][c]
);
assert!(result[r][c].im.abs() < 1e-14);
}
}
}
#[test]
fn invert_2x2_confusion_roundtrip() {
let p01 = 0.1;
let p10 = 0.05;
let inv = invert_2x2_confusion(p01, p10);
let a = 1.0 - p01;
let b = p10;
let c = p01;
let d = 1.0 - p10;
let prod_00 = a * inv[0][0] + b * inv[1][0];
let prod_01 = a * inv[0][1] + b * inv[1][1];
let prod_10 = c * inv[0][0] + d * inv[1][0];
let prod_11 = c * inv[0][1] + d * inv[1][1];
assert!((prod_00 - 1.0).abs() < 1e-10);
assert!(prod_01.abs() < 1e-10);
assert!(prod_10.abs() < 1e-10);
assert!((prod_11 - 1.0).abs() < 1e-10);
}
#[test]
fn solve_linear_system_simple() {
let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
let b = vec![5.0, 10.0];
let x = solve_linear_system(&a, &b, 2);
assert!((x[0] - 1.0).abs() < 1e-10, "x[0] = {}", x[0]);
assert!((x[1] - 3.0).abs() < 1e-10, "x[1] = {}", x[1]);
}
}