use scirs2_core::ndarray::{Array1, Array2, Array4};
use scirs2_core::Complex64;
use super::super::core::SciRS2ProcessTomographer;
use super::super::results::{ExperimentalData, ReconstructionQuality};
use super::utils::calculate_reconstruction_quality;
use crate::DeviceResult;
#[cfg(feature = "scirs2")]
use scirs2_linalg::{inv, matrix_norm, svd};
#[cfg(not(feature = "scirs2"))]
use super::super::fallback::{inv, matrix_norm, svd};
pub fn reconstruct_linear_inversion(
tomographer: &SciRS2ProcessTomographer,
experimental_data: &ExperimentalData,
) -> DeviceResult<(Array4<Complex64>, ReconstructionQuality)> {
let measurement_matrix = tomographer.build_measurement_matrix(experimental_data)?;
let measurement_vector = Array1::from_vec(experimental_data.measurement_results.clone());
let process_vector = solve_linear_system(&measurement_matrix, &measurement_vector)?;
let process_matrix = reshape_to_process_matrix(&process_vector)?;
let log_likelihood = calculate_log_likelihood(&process_matrix, experimental_data, tomographer)?;
let reconstruction_quality =
calculate_reconstruction_quality(&process_matrix, experimental_data, log_likelihood);
Ok((process_matrix, reconstruction_quality))
}
fn solve_linear_system(
measurement_matrix: &Array2<f64>,
measurement_vector: &Array1<f64>,
) -> DeviceResult<Array1<Complex64>> {
#[cfg(feature = "scirs2")]
{
if let Ok((u, s, vt)) = svd(&measurement_matrix.view(), true, None) {
let mut s_pinv = Array1::zeros(s.len());
let tolerance = 1e-12;
for (i, &singular_value) in s.iter().enumerate() {
if singular_value > tolerance {
s_pinv[i] = 1.0 / singular_value;
}
}
let mut solution = Array1::zeros(measurement_matrix.ncols());
for i in 0..measurement_matrix.ncols() {
let mut sum = 0.0;
for j in 0..measurement_matrix.nrows() {
let mut inner_sum = 0.0;
for k in 0..s.len() {
inner_sum += vt[[k, i]] * s_pinv[k] * u[[j, k]];
}
sum += inner_sum * measurement_vector[j];
}
solution[i] = sum;
}
let complex_solution = solution.mapv(|x| Complex64::new(x, 0.0));
Ok(complex_solution)
} else {
simple_least_squares(measurement_matrix, measurement_vector)
}
}
#[cfg(not(feature = "scirs2"))]
{
simple_least_squares(measurement_matrix, measurement_vector)
}
}
fn simple_least_squares(
measurement_matrix: &Array2<f64>,
measurement_vector: &Array1<f64>,
) -> DeviceResult<Array1<Complex64>> {
let n = measurement_matrix.ncols();
let m = measurement_matrix.nrows();
let mut ata = Array2::zeros((n, n));
let mut atb = Array1::zeros(n);
for i in 0..n {
for j in 0..n {
let mut sum = 0.0;
for k in 0..m {
sum += measurement_matrix[[k, i]] * measurement_matrix[[k, j]];
}
ata[[i, j]] = sum;
}
}
for i in 0..n {
let mut sum = 0.0;
for k in 0..m {
sum += measurement_matrix[[k, i]] * measurement_vector[k];
}
atb[i] = sum;
}
#[cfg(feature = "scirs2")]
{
if let Ok(ata_inv) = inv(&ata.view(), None) {
let solution = ata_inv.dot(&atb);
Ok(solution.mapv(|x| Complex64::new(x, 0.0)))
} else {
for i in 0..n {
ata[[i, i]] += 1e-8;
}
if let Ok(ata_inv) = inv(&ata.view(), None) {
let solution = ata_inv.dot(&atb);
Ok(solution.mapv(|x| Complex64::new(x, 0.0)))
} else {
Ok(Array1::<Complex64>::zeros(n))
}
}
}
#[cfg(not(feature = "scirs2"))]
{
Ok(Array1::zeros(n).mapv(|_| Complex64::new(0.0, 0.0)))
}
}
fn reshape_to_process_matrix(
process_vector: &Array1<Complex64>,
) -> DeviceResult<Array4<Complex64>> {
let total_elements = process_vector.len();
let dim = ((total_elements as f64).powf(0.25).round() as usize).max(2);
let mut process_matrix = Array4::zeros((dim, dim, dim, dim));
let mut idx = 0;
for i in 0..dim {
for j in 0..dim {
for k in 0..dim {
for l in 0..dim {
if idx < process_vector.len() {
process_matrix[[i, j, k, l]] = process_vector[idx];
idx += 1;
}
}
}
}
}
Ok(process_matrix)
}
fn calculate_log_likelihood(
process_matrix: &Array4<Complex64>,
experimental_data: &ExperimentalData,
tomographer: &SciRS2ProcessTomographer,
) -> DeviceResult<f64> {
let mut log_likelihood = 0.0;
for (m_idx, (&observed, &uncertainty)) in experimental_data
.measurement_results
.iter()
.zip(experimental_data.measurement_uncertainties.iter())
.enumerate()
{
let input_idx = m_idx / experimental_data.measurement_operators.len();
let meas_idx = m_idx % experimental_data.measurement_operators.len();
if input_idx < experimental_data.input_states.len()
&& meas_idx < experimental_data.measurement_operators.len()
{
let predicted = predict_measurement_outcome(
process_matrix,
&experimental_data.input_states[input_idx],
&experimental_data.measurement_operators[meas_idx],
)?;
let diff = observed - predicted;
let variance = uncertainty * uncertainty;
log_likelihood -=
0.5 * (diff * diff / variance + (2.0 * std::f64::consts::PI * variance).ln());
}
}
Ok(log_likelihood)
}
fn predict_measurement_outcome(
process_matrix: &Array4<Complex64>,
input_state: &Array2<Complex64>,
measurement: &Array2<Complex64>,
) -> DeviceResult<f64> {
let dim = process_matrix.dim().0;
let mut result = Complex64::new(0.0, 0.0);
for i in 0..dim {
for j in 0..dim {
for k in 0..dim {
for l in 0..dim {
result += measurement[[i, j]]
* process_matrix[[i, j, k, l]]
* input_state[[k, l]].conj();
}
}
}
}
Ok(result.re)
}