use rand::{
distributions::{Distribution, Uniform},
Rng,
};
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::f64;
use std::sync::{Arc, RwLock, Weak};
use {Strategy, Term, TRS};
pub struct Trace<'a> {
trs: &'a TRS,
root: TraceNode,
unobserved: BinaryHeap<TraceNode>,
p_observe: f64,
noise_level: f64,
max_term_size: Option<usize>,
strategy: Strategy,
}
impl<'a> Trace<'a> {
pub fn new(
trs: &'a TRS,
term: &Term,
p_observe: f64,
noise_level: f64,
max_term_size: Option<usize>,
strategy: Strategy,
) -> Trace<'a> {
let root = TraceNode::new(term.clone(), TraceState::Unobserved, 0.0, 0, None, vec![]);
let mut unobserved = BinaryHeap::new();
unobserved.push(root.clone());
Trace {
trs,
root,
unobserved,
p_observe,
noise_level,
max_term_size,
strategy,
}
}
fn new_node(
&mut self,
term: Term,
parent: Option<&TraceNode>,
depth: usize,
state: TraceState,
log_p: f64,
) -> TraceNode {
let node = TraceNode::new(term, state, log_p, depth, parent, vec![]);
self.unobserved.push(node.clone());
node
}
pub fn root(&self) -> &TraceNode {
&self.root
}
pub fn depth(&self) -> usize {
let mut deepest = 0;
self.root.walk(|n| {
deepest = deepest.max(n.depth());
});
deepest
}
pub fn size(&self) -> usize {
let mut count = 0;
self.root.walk(|_| count += 1);
count
}
pub fn mass(&self) -> f64 {
let mut masses = self.root.accumulate(|n| {
if [TraceState::Unobserved].contains(&n.state()) && n.is_leaf() {
Some(n.log_p())
} else {
None
}
});
masses.push(f64::NEG_INFINITY);
1.0 - logsumexp(masses.as_slice()).exp()
}
pub fn outcomes(&mut self, max_steps: usize) -> Vec<TraceNode> {
self.rewrite(max_steps);
self.root.leaves(&[TraceState::Normal])
}
pub fn sample<R: Rng>(&self, rng: &mut R) -> TraceNode {
let leaves = self.root.leaves(&[TraceState::Normal]);
let ws = leaves.iter().map(|x| x.log_p().exp()).collect::<Vec<f64>>();
weighted_sample(rng, &leaves, &ws).clone()
}
pub fn rewrite(&mut self, max_steps: usize) {
if max_steps > self.size() {
let n_steps = max_steps - self.size();
self.nth(n_steps);
}
}
pub fn rewrites_to(&mut self, max_steps: usize, term: &Term) -> f64 {
self.rewrite(max_steps);
let lps = self.root.accumulate(|n| {
let n_r = &n.0.read().expect("poisoned TraceNode");
let ln_p = n_r.log_p;
let score = Term::shared_score(term, &n_r.term);
let ln_adj_score = score.powf(self.noise_level).ln();
Some(ln_p + ln_adj_score)
});
if lps.is_empty() {
f64::NEG_INFINITY
} else {
logsumexp(&lps)
}
}
}
impl<'a> Iterator for Trace<'a> {
type Item = TraceNode;
fn next(&mut self) -> Option<TraceNode> {
if let Some(node) = self.unobserved.pop() {
{
let mut node_w = node.0.write().expect("poisoned TraceNode");
match self.max_term_size {
Some(max_term_size) if node_w.term.size() > max_term_size => {
node_w.state = TraceState::TooBig;
}
_ => match self.trs.rewrite(&node_w.term, self.strategy) {
None => node_w.state = TraceState::Normal,
Some(ref rewrites) if rewrites.is_empty() => {
node_w.state = TraceState::Normal
}
Some(rewrites) => {
let term_selection_p = -(rewrites.len() as f64).ln();
node_w.log_p += self.p_observe.ln();
node_w.state = TraceState::Rewritten;
for term in rewrites {
let new_p = node_w.log_p + term_selection_p;
let new_node = self.new_node(
term,
Some(&node),
node_w.depth + 1,
TraceState::Unobserved,
new_p,
);
node_w.children.push(new_node);
}
}
},
}
}
Some(node)
} else {
None
}
}
}
#[derive(Debug, Clone)]
struct TraceNodeStore {
term: Term,
state: TraceState,
log_p: f64,
depth: usize,
parent: Option<Weak<RwLock<TraceNodeStore>>>,
children: Vec<TraceNode>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TraceState {
Unobserved,
TooBig,
Normal,
Rewritten,
}
#[derive(Debug, Clone)]
pub struct TraceNode(Arc<RwLock<TraceNodeStore>>);
impl TraceNode {
fn new(
term: Term,
state: TraceState,
log_p: f64,
depth: usize,
parent: Option<&TraceNode>,
children: Vec<TraceNode>,
) -> TraceNode {
let parent = parent.map(|n| Arc::downgrade(&n.0));
TraceNode(Arc::new(RwLock::new(TraceNodeStore {
term,
state,
log_p,
depth,
parent,
children,
})))
}
pub fn state(&self) -> TraceState {
self.0.read().expect("poisoned TraceNode").state
}
pub fn term(&self) -> Term {
self.0.read().expect("poisoned TraceNode").term.clone()
}
pub fn log_p(&self) -> f64 {
self.0.read().expect("poisoned TraceNode").log_p
}
pub fn depth(&self) -> usize {
self.0.read().expect("poisoned TraceNode").depth
}
pub fn parent(&self) -> Option<TraceNode> {
self.0
.read()
.expect("poisoned TraceNode")
.parent
.as_ref()
.and_then(Weak::upgrade)
.map(TraceNode)
}
pub fn children(&self) -> Vec<TraceNode> {
self.0.read().expect("poisoned TraceNode").children.clone()
}
pub fn is_leaf(&self) -> bool {
self.0
.read()
.expect("poisoned TraceNode")
.children
.is_empty()
}
fn accumulate<T, F>(&self, condition: F) -> Vec<T>
where
F: Fn(&TraceNode) -> Option<T>,
{
let mut values = Vec::new();
self.walk(|n| {
if let Some(v) = condition(n) {
values.push(v)
}
});
values
}
fn walk<F>(&self, mut callback: F)
where
F: FnMut(&TraceNode),
{
self.walk_helper(&mut callback)
}
fn walk_helper<F>(&self, callback: &mut F)
where
F: FnMut(&TraceNode),
{
callback(self);
for child in &self.0.read().expect("poisoned TraceNode").children {
child.walk_helper(callback)
}
}
pub fn iter(&self) -> TraceNodeIter {
TraceNodeIter::new(self.clone())
}
pub fn progeny(&self, states: &[TraceState]) -> Vec<TraceNode> {
self.accumulate(|n| {
if states.contains(&n.state()) {
Some(n.clone())
} else {
None
}
})
}
pub fn leaves(&self, states: &[TraceState]) -> Vec<TraceNode> {
self.accumulate(|n| {
if states.contains(&n.state()) && n.is_leaf() {
Some(n.clone())
} else {
None
}
})
}
pub fn leaf_terms(&self, states: &[TraceState]) -> Vec<Term> {
self.accumulate(|n| {
if states.contains(&n.state()) && n.is_leaf() {
Some(n.term())
} else {
None
}
})
}
}
impl PartialEq for TraceNode {
fn eq(&self, other: &TraceNode) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
impl Eq for TraceNode {}
impl PartialOrd for TraceNode {
fn partial_cmp(&self, other: &TraceNode) -> Option<Ordering> {
self.log_p().partial_cmp(&other.log_p())
}
}
impl Ord for TraceNode {
fn cmp(&self, other: &TraceNode) -> Ordering {
if let Some(x) = self.partial_cmp(&other) {
x
} else {
Ordering::Equal
}
}
}
impl<'a> IntoIterator for &'a TraceNode {
type Item = TraceNode;
type IntoIter = TraceNodeIter;
fn into_iter(self) -> TraceNodeIter {
self.iter()
}
}
impl IntoIterator for TraceNode {
type Item = TraceNode;
type IntoIter = TraceNodeIter;
fn into_iter(self) -> TraceNodeIter {
TraceNodeIter::new(self)
}
}
pub struct TraceNodeIter {
queue: Vec<TraceNode>,
}
impl TraceNodeIter {
fn new(root: TraceNode) -> TraceNodeIter {
TraceNodeIter { queue: vec![root] }
}
}
impl Iterator for TraceNodeIter {
type Item = TraceNode;
fn next(&mut self) -> Option<TraceNode> {
let node = self.queue.pop()?;
self.queue.append(&mut node.children());
Some(node)
}
}
fn logsumexp(lps: &[f64]) -> f64 {
let largest = lps.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if largest == f64::NEG_INFINITY {
f64::NEG_INFINITY
} else {
let x = lps.iter().map(|lp| (lp - largest).exp()).sum::<f64>().ln();
largest + x
}
}
fn weighted_sample<'a, T, R: Rng>(rng: &mut R, xs: &'a [T], ws: &[f64]) -> &'a T {
assert_eq!(xs.len(), ws.len(), "weighted sample given invalid inputs");
let total = ws.iter().fold(0f64, |acc, x| acc + x);
let threshold: f64 = Uniform::new(0f64, total).sample(rng);
let mut cum = 0f64;
for (wp, x) in ws.iter().zip(xs) {
cum += *wp;
if threshold <= cum {
return x;
}
}
unreachable!()
}