use crate::HarmonyError;
use ndarray::{Array1, ArrayView2};
#[derive(Debug, Clone)]
pub struct Phi {
pub b: usize,
pub n: usize,
pub n_cov: usize,
pub offset: Vec<usize>,
pub row_of_cell: Vec<u32>,
pub level_names: Vec<Vec<String>>,
}
impl Phi {
pub fn from_codes(labels: ArrayView2<'_, u32>) -> Result<Self, HarmonyError> {
let (n, n_cov) = labels.dim();
if n == 0 || n_cov == 0 {
return Err(HarmonyError::ShapeMismatch(format!(
"labels is {n}x{n_cov}; must be non-empty in both dims"
)));
}
let mut level_names: Vec<Vec<String>> = Vec::with_capacity(n_cov);
let mut offset = Vec::with_capacity(n_cov + 1);
offset.push(0);
for c in 0..n_cov {
let col = labels.column(c);
let max_code = col.iter().copied().max().unwrap_or(0) as usize;
let n_levels = max_code + 1;
level_names.push((0..n_levels).map(|l| format!("c{c}_l{l}")).collect());
offset.push(offset[c] + n_levels);
}
let b = offset[n_cov];
let mut row_of_cell = vec![0u32; n_cov * n];
for c in 0..n_cov {
let base = offset[c] as u32;
for i in 0..n {
row_of_cell[c * n + i] = base + labels[[i, c]];
}
}
Ok(Self {
b,
n,
n_cov,
offset,
row_of_cell,
level_names,
})
}
pub fn n_b(&self) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.b);
for c in 0..self.n_cov {
for i in 0..self.n {
let row = self.row_of_cell[c * self.n + i] as usize;
out[row] += 1.0;
}
}
out
}
pub fn pr_b(&self) -> Array1<f64> {
self.n_b() / (self.n as f64)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn single_covariate_three_levels() {
let labels = array![[0u32], [1], [2], [0], [1]];
let phi = Phi::from_codes(labels.view()).unwrap();
assert_eq!(phi.b, 3);
assert_eq!(phi.n, 5);
assert_eq!(phi.n_cov, 1);
assert_eq!(phi.offset, vec![0, 3]);
assert_eq!(phi.row_of_cell, vec![0, 1, 2, 0, 1]);
let counts = phi.n_b();
assert_eq!(counts.to_vec(), vec![2.0, 2.0, 1.0]);
}
#[test]
fn two_covariates_offsets_stack() {
let labels = array![[0u32, 0], [0, 1], [1, 0], [1, 1]];
let phi = Phi::from_codes(labels.view()).unwrap();
assert_eq!(phi.b, 4);
assert_eq!(phi.n, 4);
assert_eq!(phi.n_cov, 2);
assert_eq!(phi.offset, vec![0, 2, 4]);
assert_eq!(phi.row_of_cell, vec![0, 0, 1, 1, 2, 3, 2, 3]);
}
#[test]
fn empty_labels_rejected() {
let labels = ndarray::Array2::<u32>::zeros((0, 1));
assert!(Phi::from_codes(labels.view()).is_err());
}
}