use std::collections::BTreeMap;
use crate::error::{InferustError, Result};
use crate::glm::{Logistic, LogisticResult, Poisson, PoissonResult};
use crate::regression::{Ols, OlsResult, Wls};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Formula {
pub response: String,
pub predictors: Vec<String>,
}
impl Formula {
pub fn parse(input: &str) -> Result<Self> {
let (lhs, rhs) = input
.split_once('~')
.ok_or_else(|| InferustError::InvalidInput("formula must contain `~`".into()))?;
let response = lhs.trim();
if response.is_empty() {
return Err(InferustError::InvalidInput(
"formula response cannot be empty".into(),
));
}
let predictors = rhs
.split('+')
.map(str::trim)
.filter(|term| !term.is_empty())
.map(ToString::to_string)
.collect::<Vec<_>>();
if predictors.is_empty() {
return Err(InferustError::InvalidInput(
"formula must contain at least one predictor".into(),
));
}
if predictors.iter().any(|term| term == "1") {
return Err(InferustError::InvalidInput(
"explicit intercept terms are not supported yet; intercept is added by model builders"
.into(),
));
}
Ok(Self {
response: response.to_string(),
predictors,
})
}
}
#[derive(Debug, Clone)]
pub struct DesignMatrices {
pub x: Vec<Vec<f64>>,
pub y: Vec<f64>,
pub predictor_names: Vec<String>,
}
#[derive(Debug, Clone, Default)]
pub struct DataFrame {
columns: BTreeMap<String, Vec<f64>>,
nrows: Option<usize>,
}
impl DataFrame {
pub fn new() -> Self {
Self::default()
}
pub fn with_column(mut self, name: impl Into<String>, values: Vec<f64>) -> Result<Self> {
self.add_column(name, values)?;
Ok(self)
}
pub fn add_column(&mut self, name: impl Into<String>, values: Vec<f64>) -> Result<()> {
let name = name.into();
if name.trim().is_empty() {
return Err(InferustError::InvalidInput(
"column name cannot be empty".into(),
));
}
if values.is_empty() {
return Err(InferustError::InsufficientData { needed: 1, got: 0 });
}
if let Some(nrows) = self.nrows {
if values.len() != nrows {
return Err(InferustError::DimensionMismatch {
x_rows: values.len(),
y_len: nrows,
});
}
} else {
self.nrows = Some(values.len());
}
self.columns.insert(name, values);
Ok(())
}
pub fn nrows(&self) -> usize {
self.nrows.unwrap_or(0)
}
pub fn column(&self, name: &str) -> Result<&[f64]> {
self.columns
.get(name)
.map(Vec::as_slice)
.ok_or_else(|| InferustError::InvalidInput(format!("unknown column `{name}`")))
}
pub fn design_matrices(&self, formula: &str) -> Result<DesignMatrices> {
let formula = Formula::parse(formula)?;
let y = self.column(&formula.response)?.to_vec();
let predictor_columns = formula
.predictors
.iter()
.map(|name| self.column(name))
.collect::<Result<Vec<_>>>()?;
let nrows = self.nrows();
let mut x = Vec::with_capacity(nrows);
for row_idx in 0..nrows {
let row = predictor_columns
.iter()
.map(|column| column[row_idx])
.collect::<Vec<_>>();
x.push(row);
}
Ok(DesignMatrices {
x,
y,
predictor_names: formula.predictors,
})
}
pub fn design_matrices_with_categorical(
&self,
formula: &str,
categorical: &[&str],
) -> Result<DesignMatrices> {
let formula = Formula::parse(formula)?;
let y = self.column(&formula.response)?.to_vec();
let nrows = self.nrows();
let mut x = vec![Vec::new(); nrows];
let mut predictor_names = Vec::new();
for name in &formula.predictors {
let column = self.column(name)?;
if categorical.iter().any(|candidate| candidate == name) {
let mut levels = column.to_vec();
levels.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
levels.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON);
for level in levels.iter().skip(1) {
predictor_names.push(format!("{name}[T.{level}]"));
for row_idx in 0..nrows {
x[row_idx].push(f64::from((column[row_idx] - *level).abs() < f64::EPSILON));
}
}
} else {
predictor_names.push(name.clone());
for row_idx in 0..nrows {
x[row_idx].push(column[row_idx]);
}
}
}
Ok(DesignMatrices {
x,
y,
predictor_names,
})
}
pub fn ols_with_categorical(&self, formula: &str, categorical: &[&str]) -> Result<OlsResult> {
let design = self.design_matrices_with_categorical(formula, categorical)?;
Ols::new()
.with_feature_names(design.predictor_names)
.fit(&design.x, &design.y)
}
pub fn ols(&self, formula: &str) -> Result<OlsResult> {
let design = self.design_matrices(formula)?;
Ols::new()
.with_feature_names(design.predictor_names)
.fit(&design.x, &design.y)
}
pub fn wls(&self, formula: &str, weights: &str) -> Result<OlsResult> {
let design = self.design_matrices(formula)?;
let weights = self.column(weights)?;
Wls::new()
.with_feature_names(design.predictor_names)
.fit(&design.x, &design.y, weights)
}
pub fn logistic(&self, formula: &str) -> Result<LogisticResult> {
let design = self.design_matrices(formula)?;
Logistic::new()
.with_feature_names(design.predictor_names)
.fit(&design.x, &design.y)
}
pub fn poisson(&self, formula: &str) -> Result<PoissonResult> {
let design = self.design_matrices(formula)?;
Poisson::new()
.with_feature_names(design.predictor_names)
.fit(&design.x, &design.y)
}
}
#[cfg(test)]
mod tests {
use super::{DataFrame, Formula};
fn assert_close(actual: f64, expected: f64, tolerance: f64) {
assert!(
(actual - expected).abs() <= tolerance,
"actual {actual} differed from expected {expected} by more than {tolerance}"
);
}
fn frame() -> DataFrame {
DataFrame::new()
.with_column("x1", vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap()
.with_column("x2", vec![2.0, 1.0, 4.0, 3.0, 5.0, 7.0])
.unwrap()
.with_column("y", vec![5.1, 5.9, 10.2, 10.8, 14.9, 19.1])
.unwrap()
.with_column("weights", vec![1.0, 0.8, 1.2, 1.5, 0.9, 1.1])
.unwrap()
}
#[test]
fn parses_basic_formula() {
let formula = Formula::parse("y ~ x1 + x2").unwrap();
assert_eq!(formula.response, "y");
assert_eq!(formula.predictors, vec!["x1", "x2"]);
}
#[test]
fn builds_design_matrices_from_named_columns() {
let design = frame().design_matrices("y ~ x1 + x2").unwrap();
assert_eq!(design.predictor_names, vec!["x1", "x2"]);
assert_eq!(design.y[0], 5.1);
assert_eq!(design.x[0], vec![1.0, 2.0]);
assert_eq!(design.x[5], vec![6.0, 7.0]);
}
#[test]
fn categorical_formula_expands_treatment_dummies() {
let frame = DataFrame::new()
.with_column("group", vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0])
.unwrap()
.with_column("x", vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0])
.unwrap()
.with_column("y", vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap();
let design = frame
.design_matrices_with_categorical("y ~ group + x", &["group"])
.unwrap();
assert_eq!(
design.predictor_names,
vec!["group[T.2]", "group[T.3]", "x"]
);
assert_eq!(design.x[2], vec![1.0, 0.0, 0.0]);
}
#[test]
fn formula_ols_matches_matrix_ols_reference() {
let result = frame().ols("y ~ x1 + x2").unwrap();
assert_close(result.coefficients[0], 1.1666007905138316, 1e-10);
assert_close(result.coefficients[1], 1.656126482213441, 1e-10);
assert_close(result.coefficients[2], 1.100988142292489, 1e-10);
assert_eq!(result.feature_names, vec!["const", "x1", "x2"]);
}
#[test]
fn formula_wls_matches_matrix_wls_reference() {
let result = frame().wls("y ~ x1 + x2", "weights").unwrap();
assert_close(result.coefficients[0], 1.0910621653414276, 1e-10);
assert_close(result.coefficients[1], 1.6265313140792843, 1e-10);
assert_close(result.coefficients[2], 1.139502728692733, 1e-10);
assert_eq!(result.feature_names, vec!["const", "x1", "x2"]);
}
#[test]
fn formula_poisson_fits_named_columns() {
let frame = DataFrame::new()
.with_column(
"x1",
vec![0.2, 0.8, 1.2, 1.9, 2.4, 2.9, 3.4, 3.9, 4.5, 5.0, 5.5, 6.0],
)
.unwrap()
.with_column(
"x2",
vec![1.0, 1.4, 1.1, 1.7, 2.2, 2.0, 2.8, 3.1, 3.5, 3.8, 4.0, 4.4],
)
.unwrap()
.with_column(
"y",
vec![
1.0, 2.0, 1.0, 3.0, 4.0, 3.0, 6.0, 7.0, 8.0, 11.0, 12.0, 15.0,
],
)
.unwrap();
let result = frame.poisson("y ~ x1 + x2").unwrap();
assert_close(result.coefficients[0], -0.2951503394477173, 1e-8);
assert_eq!(result.feature_names, vec!["const", "x1", "x2"]);
}
#[test]
fn formula_logistic_fits_named_columns() {
let frame = DataFrame::new()
.with_column(
"x1",
vec![0.2, 1.1, 1.8, 2.4, 3.0, 3.7, 4.1, 4.8, 5.2, 5.9, 2.2, 4.6],
)
.unwrap()
.with_column(
"x2",
vec![1.0, 0.9, 1.5, 1.9, 2.5, 2.9, 3.4, 3.8, 4.2, 4.8, 3.6, 1.2],
)
.unwrap()
.with_column(
"y",
vec![0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0],
)
.unwrap();
let result = frame.logistic("y ~ x1 + x2").unwrap();
assert_close(result.coefficients[0], -1.7689272112231273, 1e-8);
assert_eq!(result.feature_names, vec!["const", "x1", "x2"]);
}
}