use std::cmp::Eq;
use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;
use itertools::Itertools;
use ordered_float::NotNan;
use crate::stats::LogProb;
pub type JointProbUniverse<Event> = HashMap<Event, LogProb>;
pub trait Likelihood<Payload = ()> {
type Event;
type Data;
fn compute(&self, event: &Self::Event, data: &Self::Data, payload: &mut Payload) -> LogProb;
}
pub trait Prior {
type Event;
fn compute(&self, event: &Self::Event) -> LogProb;
}
pub trait Posterior {
type Event;
type BaseEvent;
type Data;
fn compute<F: FnMut(&Self::BaseEvent, &Self::Data) -> LogProb>(
&self,
event: &Self::Event,
data: &Self::Data,
joint_prob: &mut F,
) -> LogProb;
}
#[derive(
Default,
Getters,
MutGetters,
Copy,
Clone,
Eq,
PartialEq,
Ord,
PartialOrd,
Hash,
Debug,
Serialize,
Deserialize,
)]
pub struct Model<L, Pr, Po, Payload = ()>
where
L: Likelihood<Payload>,
Pr: Prior,
Po: Posterior,
Payload: Default,
{
#[get = "pub"]
#[get_mut = "pub"]
likelihood: L,
#[get = "pub"]
#[get_mut = "pub"]
prior: Pr,
#[get = "pub"]
#[get_mut = "pub"]
posterior: Po,
payload: PhantomData<Payload>,
}
impl<Event, PosteriorEvent, Data, L, Pr, Po, Payload> Model<L, Pr, Po, Payload>
where
Payload: Default,
Event: Hash + Eq + Clone,
PosteriorEvent: Hash + Eq + Clone,
L: Likelihood<Payload, Event = Event, Data = Data>,
Pr: Prior<Event = Event>,
Po: Posterior<BaseEvent = Event, Event = PosteriorEvent, Data = Data>,
{
pub fn new(likelihood: L, prior: Pr, posterior: Po) -> Self {
Model {
likelihood,
prior,
posterior,
payload: PhantomData,
}
}
fn joint_prob(&self, event: &Event, data: &Data, payload: &mut Payload) -> LogProb {
self.prior.compute(event) + self.likelihood.compute(event, data, payload)
}
pub fn compute<U: IntoIterator<Item = PosteriorEvent>>(
&self,
universe: U,
data: &Data,
) -> ModelInstance<Event, PosteriorEvent> {
let mut joint_probs = HashMap::new();
let mut payload = Payload::default();
let (posterior_probs, marginal) = {
let mut joint_prob = |event: &Event, data: &Data| {
let p = self.joint_prob(event, data, &mut payload);
joint_probs.insert(event.clone(), p);
p
};
let posterior_probs: HashMap<PosteriorEvent, LogProb> = universe
.into_iter()
.map(|event| {
let p = self.posterior.compute(&event, data, &mut joint_prob);
(event, p)
})
.collect();
let marginal = LogProb::ln_sum_exp(&posterior_probs.values().cloned().collect_vec());
(posterior_probs, marginal)
};
ModelInstance {
joint_probs,
posterior_probs,
marginal,
}
}
pub fn compute_from_marginal<M>(
&self,
marginal: &M,
data: &Data,
) -> ModelInstance<Event, PosteriorEvent>
where
M: Marginal<Data = Data, Event = PosteriorEvent, BaseEvent = Event>,
{
let mut joint_probs = HashMap::new();
let mut posterior_probs = HashMap::new();
let mut payload = Payload::default();
let marginal = {
let mut joint_prob = |event: &Event, data: &Data| {
let p = self.joint_prob(event, data, &mut payload);
joint_probs.insert(event.clone(), p);
p
};
let mut joint_prob_posterior = |event: &PosteriorEvent, data: &Data| {
let p = self.posterior.compute(event, data, &mut joint_prob);
posterior_probs.insert(event.clone(), p);
p
};
marginal.compute(data, &mut joint_prob_posterior)
};
ModelInstance {
joint_probs,
posterior_probs,
marginal,
}
}
}
pub trait Marginal {
type Event;
type BaseEvent;
type Data;
fn compute<F: FnMut(&Self::Event, &Self::Data) -> LogProb>(
&self,
data: &Self::Data,
joint_prob: &mut F,
) -> LogProb;
}
#[derive(Default, Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct ModelInstance<Event, PosteriorEvent>
where
Event: Hash + Eq,
PosteriorEvent: Hash + Eq,
{
joint_probs: HashMap<Event, LogProb>,
posterior_probs: HashMap<PosteriorEvent, LogProb>,
marginal: LogProb,
}
impl<Event, PosteriorEvent> ModelInstance<Event, PosteriorEvent>
where
Event: Hash + Eq,
PosteriorEvent: Hash + Eq,
{
pub fn posterior(&self, event: &PosteriorEvent) -> Option<LogProb> {
self.posterior_probs.get(event).map(|p| p - self.marginal)
}
pub fn marginal(&self) -> LogProb {
self.marginal
}
pub fn maximum_posterior(&self) -> Option<&Event> {
self.joint_probs
.iter()
.max_by_key(|(_, prob)| NotNan::new(***prob).unwrap())
.map(|(event, _)| event)
}
pub fn event_posteriors(&self) -> impl Iterator<Item = (&Event, LogProb)> {
self.joint_probs
.iter()
.map(|(event, prob)| (event, prob - self.marginal))
.sorted_by_key(|(_, prob)| -NotNan::new(**prob).unwrap())
}
}
impl<PosteriorEvent> ModelInstance<NotNan<f64>, PosteriorEvent>
where
PosteriorEvent: Hash + Eq,
{
pub fn expected_value(&self) -> NotNan<f64> {
self.joint_probs
.iter()
.map(|(event, prob)| *event * NotNan::new(**prob).unwrap())
.fold(NotNan::default(), |s, e| s + e)
}
}
mod tests {}