use crate::error::TransformError;
#[derive(Debug, Clone)]
pub struct PHATEParams {
pub n_components: usize,
pub k: usize,
pub n_landmark: usize,
pub t: usize,
pub gamma: f64,
}
impl Default for PHATEParams {
fn default() -> Self {
Self {
n_components: 2,
k: 5,
n_landmark: 2000,
t: 1,
gamma: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct PHATEResult {
pub embedding: Vec<Vec<f64>>,
pub potential: Vec<Vec<f64>>,
pub diffusion_time: usize,
}
pub struct PHATE {
pub params: PHATEParams,
}
impl PHATE {
pub fn new(n_components: usize) -> Self {
let mut params = PHATEParams::default();
params.n_components = n_components;
Self { params }
}
pub fn with_k(mut self, k: usize) -> Self {
self.params.k = k;
self
}
pub fn with_t(mut self, t: usize) -> Self {
self.params.t = t;
self
}
pub fn with_gamma(mut self, gamma: f64) -> Self {
self.params.gamma = gamma;
self
}
pub fn fit_transform(&self, data: &[Vec<f64>]) -> Result<PHATEResult, TransformError> {
let n = data.len();
if n == 0 {
return Err(TransformError::InvalidInput(
"PHATE requires at least one sample".into(),
));
}
if n == 1 {
return Ok(PHATEResult {
embedding: vec![vec![0.0; self.params.n_components]],
potential: vec![vec![0.0]],
diffusion_time: self.params.t,
});
}
let k = self.params.k.min(n - 1).max(1);
let kernel = self.compute_markov_kernel(data, k);
let diffused = self.diffuse_kernel(&kernel, n, self.params.t);
let potential = self.compute_potential_distances(&diffused, n);
let embedding = self.classical_mds(&potential, n)?;
Ok(PHATEResult {
embedding,
potential,
diffusion_time: self.params.t,
})
}
fn compute_markov_kernel(&self, data: &[Vec<f64>], k: usize) -> Vec<Vec<f64>> {
let n = data.len();
let sq_dists: Vec<Vec<f64>> = (0..n)
.map(|i| {
(0..n)
.map(|j| {
data[i]
.iter()
.zip(data[j].iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
})
.collect()
})
.collect();
let bandwidths: Vec<f64> = (0..n)
.map(|i| {
let mut sorted_dists: Vec<f64> = (0..n)
.filter(|&j| j != i)
.map(|j| sq_dists[i][j].sqrt())
.collect();
sorted_dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
sorted_dists
.get(k.saturating_sub(1))
.copied()
.unwrap_or(1.0)
.max(1e-10)
})
.collect();
let mut kernel: Vec<Vec<f64>> = (0..n)
.map(|i| {
(0..n)
.map(|j| {
let denom = bandwidths[i] * bandwidths[j];
(-sq_dists[i][j] / denom.max(1e-20)).exp()
})
.collect()
})
.collect();
for row in &mut kernel {
let s: f64 = row.iter().sum::<f64>().max(1e-15);
for v in row.iter_mut() {
*v /= s;
}
}
kernel
}
fn diffuse_kernel(&self, p: &[Vec<f64>], n: usize, t: usize) -> Vec<Vec<f64>> {
if t <= 1 {
return p.to_vec();
}
let mut result = p.to_vec();
for _ in 1..t {
let new_result = mat_mul(&result, p, n);
result = new_result;
}
result
}
fn compute_potential_distances(&self, diff: &[Vec<f64>], n: usize) -> Vec<Vec<f64>> {
let eps = 1e-7;
let u: Vec<Vec<f64>> = diff
.iter()
.map(|row| row.iter().map(|&p| -(p + eps).ln()).collect())
.collect();
(0..n)
.map(|i| {
(0..n)
.map(|j| {
u[i].iter()
.zip(u[j].iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt()
})
.collect()
})
.collect()
}
fn classical_mds(
&self,
distances: &[Vec<f64>],
n: usize,
) -> Result<Vec<Vec<f64>>, TransformError> {
let k = self.params.n_components.min(n - 1).max(1);
let d2: Vec<Vec<f64>> = distances
.iter()
.map(|row| row.iter().map(|x| x * x).collect())
.collect();
let row_means: Vec<f64> = d2
.iter()
.map(|row| row.iter().sum::<f64>() / n as f64)
.collect();
let grand_mean: f64 = row_means.iter().sum::<f64>() / n as f64;
let mut b: Vec<Vec<f64>> = (0..n)
.map(|i| {
(0..n)
.map(|j| {
-0.5 * (d2[i][j] - row_means[i] - row_means[j] + grand_mean)
})
.collect()
})
.collect();
let mut embedding: Vec<Vec<f64>> = vec![vec![0.0; k]; n];
for comp in 0..k {
let mut v: Vec<f64> = (0..n).map(|i| (i as f64 + 1.0) / n as f64).collect();
normalize_inplace(&mut v);
let mut eigenval = 0.0f64;
for _iter in 0..300 {
let bv = mat_vec_mul(&b, &v, n);
let new_ev: f64 = v.iter().zip(bv.iter()).map(|(vi, bvi)| vi * bvi).sum();
let new_norm = bv
.iter()
.map(|x| x * x)
.sum::<f64>()
.sqrt()
.max(1e-15);
let new_v: Vec<f64> = bv.iter().map(|x| x / new_norm).collect();
let delta = (new_ev - eigenval).abs();
eigenval = new_ev;
v = new_v;
if delta < 1e-12 {
break;
}
}
let scale = eigenval.abs().sqrt();
for i in 0..n {
embedding[i][comp] = scale * v[i];
}
for i in 0..n {
for j in 0..n {
b[i][j] -= eigenval * v[i] * v[j];
}
}
}
Ok(embedding)
}
}
fn mat_mul(a: &[Vec<f64>], b: &[Vec<f64>], n: usize) -> Vec<Vec<f64>> {
let mut c = vec![vec![0.0f64; n]; n];
for i in 0..n {
for k in 0..n {
if a[i][k].abs() < 1e-15 {
continue; }
for j in 0..n {
c[i][j] += a[i][k] * b[k][j];
}
}
}
c
}
fn mat_vec_mul(a: &[Vec<f64>], x: &[f64], n: usize) -> Vec<f64> {
(0..n)
.map(|i| a[i].iter().zip(x.iter()).map(|(aij, xj)| aij * xj).sum())
.collect()
}
fn normalize_inplace(v: &mut Vec<f64>) {
let norm = v.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-15);
for x in v.iter_mut() {
*x /= norm;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_data(n: usize) -> Vec<Vec<f64>> {
use std::f64::consts::TAU;
(0..n)
.map(|i| {
let t = TAU * i as f64 / n as f64;
vec![t.cos(), t.sin(), (i as f64) * 0.05]
})
.collect()
}
#[test]
fn test_phate_output_shape() {
let data = make_data(10);
let result = PHATE::new(2)
.with_k(3)
.fit_transform(&data)
.expect("PHATE fit_transform");
assert_eq!(result.embedding.len(), 10, "wrong number of samples");
assert!(
result.embedding.iter().all(|row| row.len() == 2),
"every row should have 2 dimensions"
);
}
#[test]
fn test_phate_potential_shape() {
let data = make_data(8);
let result = PHATE::new(2).with_k(2).fit_transform(&data).expect("PHATE");
assert_eq!(result.potential.len(), 8);
assert!(result.potential.iter().all(|r| r.len() == 8));
}
#[test]
fn test_phate_embedding_finite() {
let data = make_data(12);
let result = PHATE::new(2).with_k(3).fit_transform(&data).expect("PHATE");
for row in &result.embedding {
for &v in row {
assert!(v.is_finite(), "embedding contains non-finite value: {v}");
}
}
}
#[test]
fn test_phate_potential_nonneg() {
let data = make_data(8);
let result = PHATE::new(2).with_k(2).fit_transform(&data).expect("PHATE");
for row in &result.potential {
for &v in row {
assert!(v >= 0.0, "negative potential distance: {v}");
}
}
}
#[test]
fn test_phate_3d_output() {
let data = make_data(10);
let result = PHATE::new(3).with_k(3).fit_transform(&data).expect("PHATE 3d");
assert_eq!(result.embedding.len(), 10);
assert!(result.embedding.iter().all(|r| r.len() == 3));
}
#[test]
fn test_phate_diffusion_time() {
let data = make_data(10);
let result = PHATE::new(2).with_k(3).with_t(3).fit_transform(&data).expect("PHATE t=3");
assert_eq!(result.diffusion_time, 3);
}
#[test]
fn test_phate_empty_input() {
let result = PHATE::new(2).fit_transform(&[]);
assert!(result.is_err(), "empty input should return Err");
}
#[test]
fn test_phate_single_sample() {
let data = vec![vec![1.0, 2.0]];
let result = PHATE::new(2).fit_transform(&data).expect("single sample");
assert_eq!(result.embedding.len(), 1);
assert_eq!(result.embedding[0].len(), 2);
}
}