use std::sync::Arc;
use nalgebra::{DMatrix, DVector};
use rand::Rng;
use rand_distr::{Distribution, WeightedIndex};
use rayon::prelude::*;
use serde_derive::{Deserialize, Serialize};
use crate::dataset::{Dataset, MaskedSample};
use crate::ppca_model::{self, InferredMasked, PPCAModel};
use crate::Prior;
fn robust_log_softmax(data: DVector<f64>) -> DVector<f64> {
let max = data.max();
let log_norm = data.iter().map(|&xi| (xi - max).exp()).sum::<f64>().ln();
data.map(|xi| xi - max - log_norm)
}
fn robust_log_softnorm(data: DVector<f64>) -> f64 {
let max = data.max();
let log_norm = data.iter().map(|&xi| (xi - max).exp()).sum::<f64>().ln();
max + log_norm
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PPCAMixInner {
output_size: usize,
models: Vec<PPCAModel>,
log_weights: DVector<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PPCAMix(Arc<PPCAMixInner>);
impl PPCAMix {
pub fn new(models: Vec<PPCAModel>, log_weights: DVector<f64>) -> PPCAMix {
assert!(models.len() > 0);
assert_eq!(models.len(), log_weights.len());
let output_sizes = models
.iter()
.map(PPCAModel::output_size)
.collect::<Vec<_>>();
let mut unique_sizes = output_sizes.clone();
unique_sizes.dedup();
assert_eq!(
unique_sizes.len(),
1,
"Model output sizes are not the same: {output_sizes:?}"
);
PPCAMix(Arc::new(PPCAMixInner {
output_size: unique_sizes[0],
models,
log_weights: robust_log_softmax(log_weights),
}))
}
pub fn init(n_models: usize, state_size: usize, dataset: &Dataset) -> PPCAMix {
PPCAMix::new(
(0..n_models)
.map(|_| PPCAModel::init(state_size, dataset))
.collect(),
vec![0.0; n_models].into(),
)
}
pub fn output_size(&self) -> usize {
self.0.output_size
}
pub fn state_sizes(&self) -> Vec<usize> {
self.0.models.iter().map(PPCAModel::state_size).collect()
}
pub fn n_parameters(&self) -> usize {
self.0
.models
.iter()
.map(PPCAModel::n_parameters)
.sum::<usize>()
+ self.0.models.len()
- 1
}
pub fn models(&self) -> &[PPCAModel] {
&self.0.models
}
pub fn log_weights(&self) -> &DVector<f64> {
&self.0.log_weights
}
pub fn weights(&self) -> DVector<f64> {
self.0.log_weights.map(f64::exp)
}
pub fn sample(&self, dataset_size: usize, mask_probability: f64) -> Dataset {
let index = WeightedIndex::new(self.0.log_weights.iter().copied().map(f64::exp))
.expect("can create WeighedIndex from distribution");
(0..dataset_size)
.into_par_iter()
.map(|_| {
let model_idx = index.sample(&mut rand::thread_rng());
self.0.models[model_idx].sample_one(mask_probability)
})
.collect()
}
pub(crate) fn llks_one(&self, sample: &MaskedSample) -> DVector<f64> {
self.0
.models
.iter()
.map(|model| model.llk_one(sample))
.collect::<Vec<_>>()
.into()
}
pub fn llk_one(&self, sample: &MaskedSample) -> f64 {
robust_log_softnorm(self.llks_one(sample) + &self.0.log_weights)
}
pub fn llks(&self, dataset: &Dataset) -> DVector<f64> {
dataset
.data
.par_iter()
.map(|sample| self.llk_one(sample))
.collect::<Vec<_>>()
.into()
}
pub fn llk(&self, dataset: &Dataset) -> f64 {
if dataset.is_empty() {
return 0.0;
}
dataset
.data
.par_iter()
.zip(&dataset.weights)
.map(|(sample, &weight)| weight * self.llk_one(sample))
.sum::<f64>()
}
pub fn infer_cluster(&self, dataset: &Dataset) -> DMatrix<f64> {
let rows: Vec<_> = dataset
.data
.par_iter()
.map(|sample| {
robust_log_softmax(self.llks_one(sample) + &self.0.log_weights).transpose()
})
.collect();
DMatrix::from_rows(&*rows)
}
pub fn uninferred(&self) -> InferredMaskedMix {
InferredMaskedMix {
log_posterior: self.log_weights().clone(),
inferred: self
.models()
.iter()
.map(|model| model.uninferred())
.collect::<Vec<_>>(),
}
}
pub fn infer_one(&self, sample: &MaskedSample) -> InferredMaskedMix {
InferredMaskedMix {
log_posterior: robust_log_softmax(self.llks_one(sample) + &self.0.log_weights),
inferred: self
.0
.models
.iter()
.map(|model| model.infer_one(sample))
.collect::<Vec<_>>(),
}
}
pub fn inferred_one(
&self,
log_posterior: DVector<f64>,
inferred: Vec<InferredMasked>,
) -> InferredMaskedMix {
InferredMaskedMix {
log_posterior,
inferred,
}
}
pub fn infer(&self, dataset: &Dataset) -> Vec<InferredMaskedMix> {
dataset
.data
.par_iter()
.map(|sample| self.infer_one(sample))
.collect()
}
pub fn smooth_one(&self, sample: &MaskedSample) -> MaskedSample {
MaskedSample::unmasked(self.infer_one(sample).smoothed(self))
}
pub fn smooth(&self, dataset: &Dataset) -> Dataset {
dataset
.data
.par_iter()
.map(|sample| self.smooth_one(sample))
.collect()
}
pub fn extrapolate_one(&self, sample: &MaskedSample) -> MaskedSample {
MaskedSample::unmasked(self.infer_one(sample).extrapolated(self, sample))
}
pub fn extrapolate(&self, dataset: &Dataset) -> Dataset {
dataset
.data
.par_iter()
.map(|sample| self.extrapolate_one(sample))
.collect()
}
#[must_use]
pub fn iterate(&self, dataset: &Dataset) -> PPCAMix {
self.iterate_with_prior(dataset, &Prior::default())
}
#[must_use]
pub fn iterate_with_prior(&self, dataset: &Dataset, prior: &Prior) -> PPCAMix {
let llks = self
.0
.models
.iter()
.map(|model| model.llks(dataset))
.collect::<Vec<_>>();
let log_posteriors = (0..dataset.len())
.into_par_iter()
.map(|idx| {
let llk: DVector<f64> = llks.iter().map(|llk| llk[idx]).collect::<Vec<_>>().into();
robust_log_softmax(llk + &self.0.log_weights)
})
.collect::<Vec<_>>();
let (iterated_models, log_weights): (Vec<_>, Vec<f64>) = self
.0
.models
.iter()
.enumerate()
.map(|(i, model)| {
let log_posteriors: Vec<_> = log_posteriors
.par_iter()
.zip(&dataset.weights)
.filter(|&(_, &wi)| wi > 0.0)
.map(|(lp, &wi)| wi.ln() + lp[i])
.collect();
let max_posterior: f64 = log_posteriors
.par_iter()
.filter_map(|&xi| ordered_float::NotNan::new(xi).ok())
.max()
.expect("dataset not empty")
.into();
let unnorm_posteriors: Vec<_> = log_posteriors
.par_iter()
.map(|&p| f64::exp(p - max_posterior))
.collect();
let logsum_posteriors =
unnorm_posteriors.iter().copied().sum::<f64>().ln() + max_posterior;
let dataset = dataset.with_weights(unnorm_posteriors);
(model.iterate_with_prior(&dataset, prior), logsum_posteriors)
})
.unzip();
PPCAMix(Arc::new(PPCAMixInner {
output_size: self.0.output_size,
models: iterated_models,
log_weights: robust_log_softmax(log_weights.into()),
}))
}
pub fn to_canonical(&self) -> PPCAMix {
PPCAMix(Arc::new(PPCAMixInner {
output_size: self.0.output_size,
models: self.0.models.iter().map(PPCAModel::to_canonical).collect(),
log_weights: self.0.log_weights.clone(),
}))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferredMaskedMix {
log_posterior: DVector<f64>,
inferred: Vec<InferredMasked>,
}
impl InferredMaskedMix {
pub fn log_posterior(&self) -> &DVector<f64> {
&self.log_posterior
}
pub fn posterior(&self) -> DVector<f64> {
self.log_posterior.map(f64::exp)
}
pub fn sub_states(&self) -> &[InferredMasked] {
&self.inferred
}
pub fn state(&self) -> DVector<f64> {
self.log_posterior
.iter()
.zip(&self.inferred)
.map(|(&pi, inferred)| pi * inferred.state())
.sum()
}
pub fn covariance(&self) -> DMatrix<f64> {
let mean = self.state();
self.inferred
.iter()
.zip(&self.posterior())
.map(|(inferred, &weight)| {
weight
* (inferred.covariance()
+ (inferred.state() - &mean) * (inferred.state() - &mean).transpose())
})
.sum::<DMatrix<f64>>()
}
pub fn smoothed(&self, mix: &PPCAMix) -> DVector<f64> {
self.inferred
.iter()
.zip(&self.posterior())
.zip(&mix.0.models)
.map(|((inferred, &weight), ppca)| weight * inferred.smoothed(ppca))
.sum::<DVector<f64>>()
}
pub fn extrapolated(&self, mix: &PPCAMix, sample: &MaskedSample) -> DVector<f64> {
self.inferred
.iter()
.zip(&self.posterior())
.zip(&mix.0.models)
.map(|((inferred, &weight), ppca)| weight * inferred.extrapolated(ppca, sample))
.sum::<DVector<f64>>()
}
pub fn smoothed_covariance(&self, mix: &PPCAMix) -> DMatrix<f64> {
let mean = self.smoothed(mix);
self.inferred
.iter()
.zip(&self.posterior())
.zip(&mix.0.models)
.map(|((inferred, &weight), ppca)| {
weight
* (inferred.smoothed_covariance(ppca)
+ (inferred.smoothed(ppca) - &mean)
* (inferred.smoothed(ppca) - &mean).transpose())
})
.sum::<DMatrix<f64>>()
}
pub fn smoothed_covariance_diagonal(&self, mix: &PPCAMix) -> DVector<f64> {
let mean = self.smoothed(mix);
self.inferred
.iter()
.zip(&self.posterior())
.zip(&mix.0.models)
.map(|((inferred, &weight), ppca)| {
weight
* (inferred.smoothed_covariance_diagonal(ppca)
+ (inferred.smoothed(ppca) - &mean).map(|v| v.powi(2)))
})
.sum()
}
pub fn extrapolated_covariance(&self, mix: &PPCAMix, sample: &MaskedSample) -> DMatrix<f64> {
let mean = self.extrapolated(mix, sample).clone();
self.inferred
.iter()
.zip(&self.posterior())
.zip(&mix.0.models)
.map(|((inferred, &weight), ppca)| {
weight
* (inferred.smoothed_covariance(ppca)
+ (inferred.extrapolated(ppca, sample) - &mean)
* (inferred.extrapolated(ppca, sample) - &mean).transpose())
})
.sum::<DMatrix<f64>>()
}
pub fn extrapolated_covariance_diagonal(
&self,
mix: &PPCAMix,
sample: &MaskedSample,
) -> DVector<f64> {
let mean = self.extrapolated(mix, sample);
self.inferred
.iter()
.zip(&self.posterior())
.zip(&mix.0.models)
.map(|((inferred, &weight), ppca)| {
weight
* (inferred.extrapolated_covariance_diagonal(ppca, sample)
+ (inferred.extrapolated(ppca, sample) - &mean).map(|v| v.powi(2)))
})
.sum()
}
pub fn posterior_sampler(&self) -> PosteriorSamplerMix {
let index = WeightedIndex::new(self.posterior().iter().copied())
.expect("failed to create WeightedIndex for posterior");
let posteriors = self
.inferred
.iter()
.map(InferredMasked::posterior_sampler)
.collect::<Vec<_>>();
PosteriorSamplerMix { index, posteriors }
}
}
pub struct PosteriorSamplerMix {
index: WeightedIndex<f64>,
posteriors: Vec<ppca_model::PosteriorSampler>,
}
impl Distribution<DVector<f64>> for PosteriorSamplerMix {
fn sample<R>(&self, rng: &mut R) -> DVector<f64>
where
R: Rng + ?Sized,
{
let posterior = self.index.sample(rng);
self.posteriors[posterior].sample(rng)
}
}