gsdesign 0.1.0

Group sequential design
Documentation
/// Error type reported by numerical integration routines.
#[derive(Debug, Clone, PartialEq)]
pub enum IntegrationError {
    /// The Simpson stencil requires ``r`` to be at least two.
    InvalidStencil {
        /// Provided number of odd points defining the Simpson stencil.
        r: usize,
    },
    /// The interval bounds must satisfy ``a < b``.
    NonIncreasingBounds {
        /// Lower truncation limit.
        a: f64,
        /// Upper truncation limit.
        b: f64,
    },
    /// Fisher information must be strictly positive.
    NonPositiveInformation {
        /// Fisher information supplied for the analysis.
        info: f64,
    },
    /// Previous Fisher information must be strictly positive.
    NonPositivePreviousInformation {
        /// Fisher information at the previous analysis.
        info_prev: f64,
    },
    /// Current Fisher information must exceed the previous analysis value.
    NonIncreasingInformation {
        /// Fisher information from the prior analysis.
        info_prev: f64,
        /// Fisher information proposed for the current analysis.
        info: f64,
    },
    /// Stored grid metadata is inconsistent.
    ShapeMismatch {
        /// Number of Simpson grid points stored.
        points: usize,
        /// Number of Simpson weights stored.
        weights: usize,
        /// Optional number of density values stored.
        values: Option<usize>,
    },
}

impl std::fmt::Display for IntegrationError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::InvalidStencil { r } => {
                write!(f, "r must be at least 2 for Simpson integration (got {r})")
            }
            Self::NonIncreasingBounds { a, b } => {
                write!(f, "Lower bound a must satisfy a < b (got a={a}, b={b})")
            }
            Self::NonPositiveInformation { info } => {
                write!(f, "Information must be positive (got {info})")
            }
            Self::NonPositivePreviousInformation { info_prev } => {
                write!(f, "Previous information must be positive (got {info_prev})")
            }
            Self::NonIncreasingInformation { info_prev, info } => {
                write!(
                    f,
                    "Current information ({info}) must exceed previous information ({info_prev})"
                )
            }
            Self::ShapeMismatch {
                points,
                weights,
                values,
            } => {
                if let Some(values) = values {
                    write!(
                        f,
                        "Grid metadata mismatch: points={points}, weights={weights}, values={values}"
                    )
                } else {
                    write!(
                        f,
                        "Grid metadata mismatch: points={points}, weights={weights}"
                    )
                }
            }
        }
    }
}

impl std::error::Error for IntegrationError {}

/// Grid produced by Simpson's rule on the canonical normal scale.
#[derive(Debug, Clone, PartialEq)]
pub struct IntegrationGrid {
    /// Locations of the grid points, including odd and even Simpson nodes.
    pub points: Vec<f64>,
    /// Corresponding Simpson weights scaled by ``1/6``.
    pub weights: Vec<f64>,
}

impl IntegrationGrid {
    fn validate(&self) -> Result<(), IntegrationError> {
        if self.points.len() != self.weights.len() {
            return Err(IntegrationError::ShapeMismatch {
                points: self.points.len(),
                weights: self.weights.len(),
                values: None,
            });
        }
        Ok(())
    }
}

/// Density grid used for recursive convolution across interim analyses.
#[derive(Debug, Clone, PartialEq)]
pub struct DensityGrid {
    /// Simpson grid points.
    pub points: Vec<f64>,
    /// Simpson weights associated with each point.
    pub weights: Vec<f64>,
    /// Weighted density values used for numerical integration.
    pub values: Vec<f64>,
}

impl DensityGrid {
    fn validate(&self) -> Result<(), IntegrationError> {
        if !(self.points.len() == self.weights.len() && self.points.len() == self.values.len()) {
            return Err(IntegrationError::ShapeMismatch {
                points: self.points.len(),
                weights: self.weights.len(),
                values: Some(self.values.len()),
            });
        }
        Ok(())
    }
}

const SQRT_2PI: f64 = 2.506_628_274_631_000_453_559_f64;
fn normal_pdf(x: f64) -> f64 {
    (-0.5 * x * x).exp() / SQRT_2PI
}

/// Construct Simpson's rule grid points for canonical normal integration.
///
/// # Errors
///
/// Returns [`IntegrationError::InvalidStencil`] when ``r < 2`` and
/// [`IntegrationError::NonIncreasingBounds`] when the truncation interval does
/// not satisfy ``a < b``.
///
/// # Examples
///
/// ```
/// use gsdesign::{gridpts, IntegrationError};
///
/// # fn main() -> Result<(), IntegrationError> {
/// let grid = gridpts(5, 0.5, -2.0, 2.0)?;
/// assert_eq!(grid.points.len(), grid.weights.len());
/// assert!(grid.weights.iter().all(|w| *w >= 0.0));
/// # Ok(())
/// # }
/// ```
pub fn gridpts(r: usize, mu: f64, a: f64, b: f64) -> Result<IntegrationGrid, IntegrationError> {
    if r < 2 {
        return Err(IntegrationError::InvalidStencil { r });
    }
    if !(a < b) || a.is_nan() || b.is_nan() {
        return Err(IntegrationError::NonIncreasingBounds { a, b });
    }

    let base_count = 6 * r - 1;
    let mut x = vec![0.0f64; base_count];
    let r_f64 = r as f64;

    for i in 0..(r - 1) {
        let idx = i as f64 + 1.0;
        let tmp = 3.0 + 4.0 * (r_f64 / idx).ln();
        x[i] = mu - tmp;
        x[6 * r - 2 - i] = mu + tmp;
    }

    for i in (r - 1)..(5 * r) {
        let idx = i as f64 - (r as f64 - 1.0);
        x[i] = mu - 3.0 + 3.0 * (idx / (2.0 * r_f64));
    }

    let mut trimmed = x;

    let min_x = trimmed.iter().copied().fold(f64::INFINITY, f64::min);
    if min_x < a {
        trimmed.retain(|value| *value > a);
        trimmed.insert(0, a);
    }

    let max_x = trimmed.iter().copied().fold(f64::NEG_INFINITY, f64::max);
    if max_x > b {
        trimmed.retain(|value| *value < b);
        trimmed.push(b);
    }

    if trimmed.is_empty() {
        let fallback = if a.is_finite() {
            a
        } else if b.is_finite() {
            b
        } else {
            mu
        };
        trimmed.push(fallback);
    }

    let m = trimmed.len();
    if m == 1 {
        return Ok(IntegrationGrid {
            points: vec![trimmed[0]],
            weights: vec![1.0],
        });
    }

    let mut points = vec![0.0f64; 2 * m - 1];
    let mut weights = vec![0.0f64; 2 * m - 1];

    for (i, value) in trimmed.iter().enumerate() {
        points[2 * i] = *value;
        if i + 1 < m {
            points[2 * i + 1] = 0.5 * (trimmed[i] + trimmed[i + 1]);
        }
    }

    weights[0] = trimmed[1] - trimmed[0];
    for i in 1..(m - 1) {
        weights[2 * i] = trimmed[i + 1] - trimmed[i - 1];
    }
    weights[2 * (m - 1)] = trimmed[m - 1] - trimmed[m - 2];
    for i in 0..(m - 1) {
        weights[2 * i + 1] = 4.0 * (trimmed[i + 1] - trimmed[i]);
    }

    for weight in &mut weights {
        *weight /= 6.0;
    }

    let grid = IntegrationGrid { points, weights };
    grid.validate()?;
    Ok(grid)
}

/// Initialize the density grid for the first group sequential analysis.
///
/// # Errors
///
/// Returns [`IntegrationError::NonPositiveInformation`] when ``info`` is not
/// strictly positive or propagates any [`IntegrationError`] from [`gridpts`].
///
/// # Examples
///
/// ```
/// use gsdesign::{h1, IntegrationError};
///
/// # fn main() -> Result<(), IntegrationError> {
/// let grid = h1(5, 0.5, 2.0, -2.0, 2.0)?;
/// assert_eq!(grid.points.len(), grid.values.len());
/// assert!(grid.values.iter().all(|v| *v >= 0.0));
/// # Ok(())
/// # }
/// ```
pub fn h1(
    r: usize,
    theta: f64,
    info: f64,
    a: f64,
    b: f64,
) -> Result<DensityGrid, IntegrationError> {
    if !(info > 0.0) || info.is_nan() {
        return Err(IntegrationError::NonPositiveInformation { info });
    }

    let mu = theta * info.sqrt();
    let IntegrationGrid { points, weights } = gridpts(r, mu, a, b)?;

    let values = points
        .iter()
        .zip(weights.iter())
        .map(|(z, w)| w * normal_pdf(z - mu))
        .collect();

    let grid = DensityGrid {
        points,
        weights,
        values,
    };
    grid.validate()?;
    Ok(grid)
}

/// Update the density grid for a subsequent group sequential analysis.
///
/// The update performs a convolution step using the canonical drift parameters
/// and information levels at two successive analyses.
///
/// # Errors
///
/// Returns [`IntegrationError::NonPositivePreviousInformation`] when the prior
/// information is not strictly positive, [`IntegrationError::NonIncreasingInformation`]
/// when the current information fails to exceed the previous value, or propagates
/// any [`IntegrationError`] produced by [`gridpts`].
///
/// # Examples
///
/// ```
/// use gsdesign::{h1, hupdate, IntegrationError};
///
/// # fn main() -> Result<(), IntegrationError> {
/// let first = h1(5, 0.3, 1.5, -2.0, 2.0)?;
/// let second = hupdate(5, 0.5, 2.5, -2.0, 2.0, 0.3, 1.5, &first)?;
/// assert_eq!(second.points.len(), second.values.len());
/// # Ok(())
/// # }
/// ```
pub fn hupdate(
    r: usize,
    theta: f64,
    info: f64,
    a: f64,
    b: f64,
    theta_prev: f64,
    info_prev: f64,
    prev: &DensityGrid,
) -> Result<DensityGrid, IntegrationError> {
    if !(info_prev > 0.0) || info_prev.is_nan() {
        return Err(IntegrationError::NonPositivePreviousInformation { info_prev });
    }
    if !(info > info_prev) || info.is_nan() {
        return Err(IntegrationError::NonIncreasingInformation { info_prev, info });
    }
    prev.validate()?;

    let rt_info = info.sqrt();
    let rt_info_prev = info_prev.sqrt();
    let delta = info - info_prev;
    let rt_delta = delta.sqrt();

    let IntegrationGrid { points, weights } = gridpts(r, theta * rt_info, a, b)?;

    let mu = theta * info - theta_prev * info_prev;
    let scale = rt_info / rt_delta;

    let mut t = Vec::with_capacity(prev.points.len());
    for &z in &prev.points {
        t.push((z * rt_info_prev + mu) / rt_delta);
    }

    let mut values = Vec::with_capacity(points.len());
    for &z in &points {
        let base = z * scale;
        let mut acc = 0.0;
        for (idx, &prev_value) in prev.values.iter().enumerate() {
            let x = base - t[idx];
            acc += prev_value * normal_pdf(x);
        }
        values.push(acc);
    }

    for (value, weight) in values.iter_mut().zip(weights.iter()) {
        *value *= weight * scale;
    }

    let updated = DensityGrid {
        points,
        weights,
        values,
    };
    updated.validate()?;
    Ok(updated)
}