coreason_runtime_rust/
policy_tree.rs1use petgraph::graph::{DiGraph, NodeIndex};
5use rayon::prelude::*;
6
7#[derive(Debug, Clone)]
9pub struct PolicyNode {
10 pub name: String,
11 pub expected_free_energy: f64,
12}
13
14pub struct PolicySearchTree {
16 pub graph: DiGraph<PolicyNode, ()>,
17}
18
19impl PolicySearchTree {
20 pub fn new() -> Self {
22 Self {
23 graph: DiGraph::new(),
24 }
25 }
26
27 pub fn add_node(&mut self, name: &str, efe: f64) -> NodeIndex {
29 self.graph.add_node(PolicyNode {
30 name: name.to_string(),
31 expected_free_energy: efe,
32 })
33 }
34
35 pub fn add_transition(&mut self, from: NodeIndex, to: NodeIndex) {
37 self.graph.add_edge(from, to, ());
38 }
39
40 pub fn evaluate_tree_parallel(&self) -> f64 {
43 let raw_nodes: Vec<&PolicyNode> = self.graph.node_weights().collect();
44 raw_nodes
45 .par_iter()
46 .map(|node| node.expected_free_energy)
47 .sum()
48 }
49}