use crate::error::FdarError;
use crate::matrix::FdMatrix;
use super::phase::{spm_monitor, spm_phase1, SpmChart, SpmConfig};
#[derive(Debug, Clone, PartialEq)]
pub struct IterativePhase1Config {
pub spm: SpmConfig,
pub max_iterations: usize,
pub remove_t2_outliers: bool,
pub remove_spe_outliers: bool,
pub max_removal_fraction: f64,
}
impl Default for IterativePhase1Config {
fn default() -> Self {
Self {
spm: SpmConfig::default(),
max_iterations: 10,
remove_t2_outliers: true,
remove_spe_outliers: true,
max_removal_fraction: 0.3,
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct IterativePhase1Result {
pub chart: SpmChart,
pub n_iterations: usize,
pub removed_indices: Vec<usize>,
pub n_remaining: usize,
pub removal_history: Vec<Vec<usize>>,
pub removal_rates: Vec<f64>,
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn spm_phase1_iterative(
data: &FdMatrix,
argvals: &[f64],
config: &IterativePhase1Config,
) -> Result<IterativePhase1Result, FdarError> {
if config.spm.alpha <= 0.0 || config.spm.alpha >= 1.0 {
return Err(FdarError::InvalidParameter {
parameter: "alpha",
message: format!("alpha must be in (0, 1), got {}", config.spm.alpha),
});
}
if config.max_iterations < 1 {
return Err(FdarError::InvalidParameter {
parameter: "max_iterations",
message: format!(
"max_iterations must be at least 1, got {}",
config.max_iterations
),
});
}
if config.max_removal_fraction <= 0.0 || config.max_removal_fraction > 1.0 {
return Err(FdarError::InvalidParameter {
parameter: "max_removal_fraction",
message: format!(
"max_removal_fraction must be in (0, 1], got {}",
config.max_removal_fraction
),
});
}
let n_original = data.nrows();
let mut remaining_indices: Vec<usize> = (0..n_original).collect();
let mut all_removed: Vec<usize> = vec![];
let mut removal_history: Vec<Vec<usize>> = vec![];
let mut removal_rates: Vec<f64> = vec![];
let mut chart = None;
for _ in 0..config.max_iterations {
let current_data = crate::cv::subset_rows(data, &remaining_indices);
let current_chart = spm_phase1(¤t_data, argvals, &config.spm)?;
let monitor = spm_monitor(¤t_chart, ¤t_data, argvals)?;
let n_current = remaining_indices.len();
let mut flagged_local: Vec<usize> = Vec::new();
for i in 0..n_current {
let is_flagged = (config.remove_t2_outliers && monitor.t2_alarm[i])
|| (config.remove_spe_outliers && monitor.spe_alarm[i]);
if is_flagged {
flagged_local.push(i);
}
}
if flagged_local.is_empty() {
chart = Some(current_chart);
break;
}
let total_removed = all_removed.len() + flagged_local.len();
if total_removed as f64 / n_original as f64 > config.max_removal_fraction {
chart = Some(current_chart);
break;
}
let n_after = n_current - flagged_local.len();
if n_after < 4 {
chart = Some(current_chart);
break;
}
let flagged_original: Vec<usize> = flagged_local
.iter()
.map(|&i| remaining_indices[i])
.collect();
let flagged_set: std::collections::HashSet<usize> = flagged_local.iter().copied().collect();
remaining_indices = remaining_indices
.iter()
.enumerate()
.filter(|(local_i, _)| !flagged_set.contains(local_i))
.map(|(_, &orig_i)| orig_i)
.collect();
let removal_rate = flagged_original.len() as f64 / n_current as f64;
removal_rates.push(removal_rate);
all_removed.extend_from_slice(&flagged_original);
removal_history.push(flagged_original);
}
let final_chart = match chart {
Some(c) => c,
None => {
let final_data = crate::cv::subset_rows(data, &remaining_indices);
spm_phase1(&final_data, argvals, &config.spm)?
}
};
Ok(IterativePhase1Result {
chart: final_chart,
n_iterations: removal_history.len(),
removed_indices: all_removed,
n_remaining: remaining_indices.len(),
removal_history,
removal_rates,
})
}