use petgraph::graph::{DiGraph, NodeIndex};
use rayon::prelude::*;
#[derive(Debug, Clone)]
pub struct PolicyNode {
pub name: String,
pub expected_free_energy: f64,
}
pub struct PolicySearchTree {
pub graph: DiGraph<PolicyNode, ()>,
}
impl PolicySearchTree {
pub fn new() -> Self {
Self {
graph: DiGraph::new(),
}
}
pub fn add_node(&mut self, name: &str, efe: f64) -> NodeIndex {
self.graph.add_node(PolicyNode {
name: name.to_string(),
expected_free_energy: efe,
})
}
pub fn add_transition(&mut self, from: NodeIndex, to: NodeIndex) {
self.graph.add_edge(from, to, ());
}
pub fn evaluate_tree_parallel(&self) -> f64 {
let raw_nodes: Vec<&PolicyNode> = self.graph.node_weights().collect();
raw_nodes
.par_iter()
.map(|node| node.expected_free_energy)
.sum()
}
}