Skip to main content

dsfb_semiconductor/
preprocessing.rs

1use crate::config::PipelineConfig;
2use crate::dataset::secom::SecomDataset;
3use crate::error::{DsfbSemiconductorError, Result};
4use chrono::NaiveDateTime;
5use serde::Serialize;
6
7#[derive(Debug, Clone, Serialize)]
8pub struct DatasetSummary {
9    pub run_count: usize,
10    pub feature_count: usize,
11    pub pass_count: usize,
12    pub fail_count: usize,
13    pub dataset_missing_fraction: f64,
14    pub healthy_pass_runs_requested: usize,
15    pub healthy_pass_runs_found: usize,
16}
17
18#[derive(Debug, Clone, Serialize)]
19pub struct PreparedDataset {
20    pub feature_names: Vec<String>,
21    pub labels: Vec<i8>,
22    pub timestamps: Vec<NaiveDateTime>,
23    pub raw_values: Vec<Vec<Option<f64>>>,
24    pub healthy_pass_indices: Vec<usize>,
25    pub per_feature_missing_fraction: Vec<f64>,
26    pub summary: DatasetSummary,
27}
28
29pub fn prepare_secom(dataset: &SecomDataset, config: &PipelineConfig) -> Result<PreparedDataset> {
30    let mut runs = dataset.runs.clone();
31    runs.sort_by_key(|run| (run.timestamp, run.index));
32
33    let run_count = runs.len();
34    let feature_count = dataset.feature_names.len();
35    let pass_count = runs.iter().filter(|run| run.label == -1).count();
36    let fail_count = runs.iter().filter(|run| run.label == 1).count();
37
38    let healthy_pass_indices = runs
39        .iter()
40        .enumerate()
41        .filter_map(|(index, run)| (run.label == -1).then_some(index))
42        .take(config.healthy_pass_runs)
43        .collect::<Vec<_>>();
44    let healthy_pass_runs_found = healthy_pass_indices.len();
45
46    if healthy_pass_runs_found < config.minimum_healthy_observations {
47        return Err(DsfbSemiconductorError::DatasetFormat(format!(
48            "SECOM does not provide enough passing runs for a healthy window: found {}, need at least {}",
49            healthy_pass_runs_found,
50            config.minimum_healthy_observations
51        )));
52    }
53
54    let raw_values = runs
55        .iter()
56        .map(|run| run.features.clone())
57        .collect::<Vec<_>>();
58
59    let timestamps = runs.iter().map(|run| run.timestamp).collect::<Vec<_>>();
60    let labels = runs.iter().map(|run| run.label).collect::<Vec<_>>();
61
62    let total_cells = (run_count * feature_count) as f64;
63    let mut missing_cells = 0usize;
64    let mut per_feature_missing = vec![0usize; feature_count];
65    for row in &raw_values {
66        for (feature_index, value) in row.iter().enumerate() {
67            if value.is_none() {
68                missing_cells += 1;
69                per_feature_missing[feature_index] += 1;
70            }
71        }
72    }
73
74    let per_feature_missing_fraction = per_feature_missing
75        .into_iter()
76        .map(|missing| missing as f64 / run_count as f64)
77        .collect::<Vec<_>>();
78
79    Ok(PreparedDataset {
80        feature_names: dataset.feature_names.clone(),
81        labels,
82        timestamps,
83        raw_values,
84        healthy_pass_indices,
85        per_feature_missing_fraction,
86        summary: DatasetSummary {
87            run_count,
88            feature_count,
89            pass_count,
90            fail_count,
91            dataset_missing_fraction: if total_cells > 0.0 {
92                missing_cells as f64 / total_cells
93            } else {
94                0.0
95            },
96            healthy_pass_runs_requested: config.healthy_pass_runs,
97            healthy_pass_runs_found,
98        },
99    })
100}