dda-rs 0.2.0

Pure Rust Delay Differential Analysis engine
Documentation
use crate::error::{DDAError, Result};
use rayon::prelude::*;

use super::model::ModelSpec;
use super::{NormalizationMode, PureRustOptions, PARALLEL_BATCH_MIN_LEN};

#[derive(Debug, Clone)]
pub(crate) struct PreparedWindow {
    pub(crate) shifted: Vec<Vec<f64>>,
    pub(crate) deriv: Vec<Vec<f64>>,
    pub(crate) max_delay: usize,
}

impl PreparedWindow {
    pub(crate) fn from_raw(
        raw_window: &[Vec<f64>],
        model: &ModelSpec,
        options: &PureRustOptions,
    ) -> Result<Self> {
        let rows = raw_window.len();
        let cols = raw_window
            .first()
            .map(|row| row.len())
            .ok_or_else(|| DDAError::InvalidParameter("Raw DDA window is empty".to_string()))?;
        let mut data = raw_window.to_vec();
        apply_nan_runs(&mut data, options.nr_exclude);
        let derivative = deriv_all_2d(&data, model.dm, options.derivative_step)?;
        let (shifted, deriv) = normalize_window(
            &data,
            &derivative,
            rows,
            cols,
            model.dm,
            model.max_delay,
            options.normalization_mode,
        )?;
        Ok(Self {
            shifted,
            deriv,
            max_delay: model.max_delay,
        })
    }
}

fn apply_nan_runs(data: &mut [Vec<f64>], nr_exclude: usize) {
    if nr_exclude == 0 || data.is_empty() {
        return;
    }
    let rows = data.len();
    let cols = data[0].len();
    for col in 0..cols {
        let mut runs = Vec::new();
        let mut current_start = None;
        let mut current_len = 1usize;
        for row in 1..rows {
            if data[row - 1][col] == data[row][col] {
                if current_start.is_none() {
                    current_start = Some(row - 1);
                }
                current_len += 1;
                if row == rows - 1 && current_len >= nr_exclude {
                    runs.push((current_start.unwrap_or(row - 1), row + 1));
                }
            } else if current_len >= nr_exclude {
                runs.push((current_start.unwrap_or(row - 1), row));
                current_start = None;
                current_len = 1;
            } else {
                current_start = None;
                current_len = 1;
            }
        }
        for (start, end) in runs {
            for row in start..end {
                data[row][col] = f64::NAN;
            }
        }
    }
}

fn deriv_all_2d(data: &[Vec<f64>], dm: usize, step: usize) -> Result<Vec<Vec<f64>>> {
    if data.is_empty() {
        return Err(DDAError::InvalidParameter(
            "Cannot derive an empty DDA window".to_string(),
        ));
    }
    let rows = data.len();
    let cols = data[0].len();
    if rows <= 2 * dm {
        return Err(DDAError::InvalidParameter(format!(
            "Need more than 2*dm={} rows for derivative computation, got {}",
            2 * dm,
            rows
        )));
    }
    let step = step.max(1);
    let stencil_count = dm / step;
    if stencil_count == 0 {
        return Err(DDAError::InvalidParameter(format!(
            "Invalid derivative_step={} for dm={}",
            step, dm
        )));
    }

    let effective_rows = rows - 2 * dm;
    let mut derivative = vec![vec![f64::NAN; effective_rows]; cols];
    let fill_column = |(col, deriv_column): (usize, &mut Vec<f64>)| {
        for center in dm..(rows - dm) {
            let mut valid = !data[center][col].is_nan();
            let mut value = 0.0;
            for stencil in 1..=stencil_count {
                let offset = stencil * step;
                let plus = data[center + offset][col];
                let minus = data[center - offset][col];
                if plus.is_nan() || minus.is_nan() {
                    valid = false;
                }
                if valid {
                    value += (plus - minus) / (stencil as f64);
                }
            }
            deriv_column[center - dm] = if valid {
                value / (stencil_count as f64)
            } else {
                f64::NAN
            };
        }
    };
    if cols >= PARALLEL_BATCH_MIN_LEN {
        derivative.par_iter_mut().enumerate().for_each(fill_column);
    } else {
        derivative.iter_mut().enumerate().for_each(fill_column);
    }
    Ok(derivative)
}

fn normalize_window(
    raw: &[Vec<f64>],
    derivative: &[Vec<f64>],
    rows: usize,
    cols: usize,
    dm: usize,
    max_delay: usize,
    mode: NormalizationMode,
) -> Result<(Vec<Vec<f64>>, Vec<Vec<f64>>)> {
    let shifted_rows = rows
        .checked_sub(2 * dm)
        .ok_or_else(|| DDAError::InvalidParameter("Invalid shifted row count".to_string()))?;
    let window_length = shifted_rows.checked_sub(max_delay).ok_or_else(|| {
        DDAError::InvalidParameter("Window length became negative after max(TAU) trim".to_string())
    })?;
    let mut shifted = vec![vec![f64::NAN; cols]; shifted_rows];
    let mut trimmed_deriv = vec![vec![f64::NAN; window_length]; cols];

    for col in 0..cols {
        for row in 0..shifted_rows {
            shifted[row][col] = raw[row + dm][col];
        }
        match mode {
            NormalizationMode::Raw => {
                for row in 0..window_length {
                    trimmed_deriv[col][row] = derivative[col][row + max_delay];
                }
            }
            NormalizationMode::MinMax => {
                let mut min_value = f64::INFINITY;
                let mut max_value = f64::NEG_INFINITY;
                for row in 0..shifted_rows {
                    let value = shifted[row][col];
                    if !value.is_nan() {
                        min_value = min_value.min(value);
                        max_value = max_value.max(value);
                    }
                }
                let scale = max_value - min_value;
                if !scale.is_finite() || scale == 0.0 {
                    continue;
                }
                for row in 0..shifted_rows {
                    shifted[row][col] = (shifted[row][col] - min_value) / scale;
                }
                for row in 0..window_length {
                    trimmed_deriv[col][row] = derivative[col][row + max_delay] / scale;
                }
            }
            NormalizationMode::ZScore => {
                let valid_values = shifted
                    .iter()
                    .map(|row| row[col])
                    .filter(|value| !value.is_nan())
                    .collect::<Vec<_>>();
                if valid_values.len() < 2 {
                    continue;
                }
                let mean = valid_values.iter().sum::<f64>() / (valid_values.len() as f64);
                let variance = valid_values
                    .iter()
                    .map(|value| (value - mean).powi(2))
                    .sum::<f64>()
                    / ((valid_values.len() - 1) as f64);
                let std = variance.sqrt();
                if !std.is_finite() || std == 0.0 {
                    continue;
                }
                for row in 0..shifted_rows {
                    shifted[row][col] = (shifted[row][col] - mean) / std;
                }
                for row in 0..window_length {
                    trimmed_deriv[col][row] = derivative[col][row + max_delay] / std;
                }
            }
        }
    }

    Ok((shifted, trimmed_deriv))
}