use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThoughtNode {
pub id: usize,
pub parent: Option<usize>,
pub children: Vec<usize>,
pub thought: String,
pub score: f32,
pub depth: usize,
pub is_terminal: bool,
pub state: ThoughtState,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ThoughtState {
pub reasoning_path: Vec<String>,
pub partial_results: HashMap<String, String>,
pub is_solved: bool,
pub solution: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToTResult {
pub best_path: Vec<ThoughtNode>,
pub solution: Option<String>,
pub score: f32,
pub explored_paths: usize,
pub nodes_generated: usize,
pub nodes_pruned: usize,
pub stats: ToTStats,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ToTStats {
pub avg_branching_factor: f32,
pub avg_node_score: f32,
pub max_depth_reached: usize,
pub backtrack_count: usize,
pub generation_time_ms: u64,
pub evaluation_time_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToTConfig {
pub branching_factor: usize,
pub max_depth: usize,
pub search_strategy: SearchStrategy,
pub prune_threshold: f32,
pub max_nodes: usize,
pub beam_width: usize,
pub use_value_function: bool,
pub temperature: f32,
}
impl Default for ToTConfig {
fn default() -> Self {
Self {
branching_factor: 3,
max_depth: 5,
search_strategy: SearchStrategy::BreadthFirst,
prune_threshold: 0.3,
max_nodes: 100,
beam_width: 5,
use_value_function: true,
temperature: 0.7,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SearchStrategy {
BreadthFirst,
DepthFirst,
BeamSearch,
BestFirst,
MCTS,
}
#[derive(Debug)]
pub struct TreeOfThoughts {
pub config: ToTConfig,
nodes: Vec<ThoughtNode>,
next_id: usize,
}
impl TreeOfThoughts {
pub fn new(config: ToTConfig) -> Self {
Self {
config,
nodes: Vec::new(),
next_id: 0,
}
}
pub fn create_root(&mut self, problem: &str) -> usize {
let id = self.next_id;
self.next_id += 1;
let node = ThoughtNode {
id,
parent: None,
children: Vec::new(),
thought: problem.to_string(),
score: 1.0,
depth: 0,
is_terminal: false,
state: ThoughtState::default(),
};
self.nodes.push(node);
id
}
pub fn add_child(
&mut self,
parent_id: usize,
thought: String,
score: f32,
state: ThoughtState,
) -> usize {
let id = self.next_id;
self.next_id += 1;
let parent_depth = self.nodes[parent_id].depth;
let node = ThoughtNode {
id,
parent: Some(parent_id),
children: Vec::new(),
thought,
score,
depth: parent_depth + 1,
is_terminal: state.is_solved,
state,
};
self.nodes.push(node);
self.nodes[parent_id].children.push(id);
id
}
pub fn get_node(&self, id: usize) -> Option<&ThoughtNode> {
self.nodes.get(id)
}
pub fn get_node_mut(&mut self, id: usize) -> Option<&mut ThoughtNode> {
self.nodes.get_mut(id)
}
pub fn get_path(&self, node_id: usize) -> Vec<&ThoughtNode> {
let mut path = Vec::new();
let mut current = Some(node_id);
while let Some(id) = current {
if let Some(node) = self.get_node(id) {
path.push(node);
current = node.parent;
} else {
break;
}
}
path.reverse();
path
}
pub fn prune(&mut self) -> usize {
let threshold = self.config.prune_threshold;
let mut pruned = 0;
for node in &mut self.nodes {
if node.score < threshold && !node.is_terminal {
node.is_terminal = true;
pruned += 1;
}
}
pruned
}
pub fn get_frontier(&self) -> Vec<usize> {
self.nodes
.iter()
.filter(|n| {
!n.is_terminal
&& n.children.is_empty()
&& n.depth < self.config.max_depth
&& n.score >= self.config.prune_threshold
})
.map(|n| n.id)
.collect()
}
pub fn bfs_step(&self) -> Vec<usize> {
self.get_frontier()
}
pub fn dfs_step(&self) -> Vec<usize> {
let frontier = self.get_frontier();
if let Some(best) = frontier.iter().max_by_key(|&&id| self.nodes[id].depth) {
vec![*best]
} else {
vec![]
}
}
pub fn beam_step(&self) -> Vec<usize> {
let frontier = self.get_frontier();
let mut scored: Vec<_> = frontier
.iter()
.map(|&id| (id, self.nodes[id].score))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored
.into_iter()
.take(self.config.beam_width)
.map(|(id, _)| id)
.collect()
}
pub fn best_first_step(&self) -> Vec<usize> {
let frontier = self.get_frontier();
if let Some(&best) = frontier.iter().max_by(|&&a, &&b| {
self.nodes[a]
.score
.partial_cmp(&self.nodes[b].score)
.unwrap_or(std::cmp::Ordering::Equal)
}) {
vec![best]
} else {
vec![]
}
}
pub fn get_expansion_candidates(&self) -> Vec<usize> {
match self.config.search_strategy {
SearchStrategy::BreadthFirst => self.bfs_step(),
SearchStrategy::DepthFirst => self.dfs_step(),
SearchStrategy::BeamSearch => self.beam_step(),
SearchStrategy::BestFirst => self.best_first_step(),
SearchStrategy::MCTS => self.best_first_step(), }
}
pub fn find_best_solution(&self) -> Option<&ThoughtNode> {
self.nodes
.iter()
.filter(|n| n.is_terminal && n.state.is_solved)
.max_by(|a, b| {
a.score
.partial_cmp(&b.score)
.unwrap_or(std::cmp::Ordering::Equal)
})
}
pub fn build_result(&self) -> ToTResult {
let best_node = self.find_best_solution();
let (best_path, solution, score) = if let Some(node) = best_node {
let path = self.get_path(node.id);
(
path.into_iter().cloned().collect(),
node.state.solution.clone(),
node.score,
)
} else {
let best_leaf = self
.nodes
.iter()
.filter(|n| n.children.is_empty())
.max_by(|a, b| {
a.score
.partial_cmp(&b.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
if let Some(node) = best_leaf {
let path = self.get_path(node.id);
(path.into_iter().cloned().collect(), None, node.score)
} else {
(vec![], None, 0.0)
}
};
let nodes_pruned = self
.nodes
.iter()
.filter(|n| n.score < self.config.prune_threshold)
.count();
let max_depth = self.nodes.iter().map(|n| n.depth).max().unwrap_or(0);
let avg_score = if !self.nodes.is_empty() {
self.nodes.iter().map(|n| n.score).sum::<f32>() / self.nodes.len() as f32
} else {
0.0
};
let avg_branching = if self.nodes.len() > 1 {
let non_leaf = self.nodes.iter().filter(|n| !n.children.is_empty()).count();
if non_leaf > 0 {
self.nodes.iter().map(|n| n.children.len()).sum::<usize>() as f32 / non_leaf as f32
} else {
0.0
}
} else {
0.0
};
ToTResult {
best_path,
solution,
score,
explored_paths: self.nodes.iter().filter(|n| n.children.is_empty()).count(),
nodes_generated: self.nodes.len(),
nodes_pruned,
stats: ToTStats {
avg_branching_factor: avg_branching,
avg_node_score: avg_score,
max_depth_reached: max_depth,
backtrack_count: 0,
generation_time_ms: 0,
evaluation_time_ms: 0,
},
}
}
pub fn reset(&mut self) {
self.nodes.clear();
self.next_id = 0;
}
}
pub struct ThoughtPrompts;
impl ThoughtPrompts {
pub fn math_thoughts(problem: &str, current_state: &str, n: usize) -> String {
format!(
r#"You are solving a math problem step by step.
PROBLEM: {problem}
CURRENT STATE:
{current_state}
Generate exactly {n} different possible next steps to make progress on this problem.
Each step should be a distinct approach or continuation.
Format each thought as:
THOUGHT 1: [your first possible step]
THOUGHT 2: [your second possible step]
THOUGHT 3: [etc...]
Be creative and explore different angles. Some thoughts might:
- Apply a formula directly
- Break down into sub-problems
- Use a different variable
- Try a numerical approach
- Look for patterns"#,
problem = problem,
current_state = current_state,
n = n
)
}
pub fn evaluate_thought(problem: &str, thought: &str, context: &str) -> String {
format!(
r#"Evaluate how promising this thought is for solving the problem.
PROBLEM: {problem}
CONTEXT/PRIOR STEPS:
{context}
THOUGHT TO EVALUATE:
{thought}
Rate on a scale of 0.0 to 1.0:
- 1.0: Definitely leads to solution
- 0.7-0.9: Very promising direction
- 0.4-0.6: Reasonable but uncertain
- 0.1-0.3: Unlikely to help
- 0.0: Definitely wrong or counterproductive
Consider:
1. Is the logic correct?
2. Does it make progress toward the answer?
3. Is it a reasonable next step given the context?
4. Could it lead to the final solution?
Respond with only a JSON object:
{{"score": 0.0-1.0, "reasoning": "brief explanation"}}"#,
problem = problem,
context = context,
thought = thought
)
}
pub fn check_terminal(problem: &str, current_state: &str) -> String {
format!(
r#"Determine if this problem has been solved.
PROBLEM: {problem}
CURRENT STATE/REASONING:
{current_state}
Answer with a JSON object:
{{
"is_solved": true/false,
"solution": "the answer if solved, null otherwise",
"confidence": 0.0-1.0
}}"#,
problem = problem,
current_state = current_state
)
}
pub fn creative_thoughts(problem: &str, current_state: &str, n: usize) -> String {
format!(
r#"You are exploring creative solutions to a problem.
PROBLEM: {problem}
CURRENT EXPLORATION:
{current_state}
Generate {n} diverse and creative next thoughts. Think unconventionally.
Format as:
THOUGHT 1: [first creative direction]
THOUGHT 2: [second creative direction]
...
Consider:
- Analogy to other domains
- Inverting the problem
- Combining ideas
- Extreme cases
- Different perspectives"#,
problem = problem,
current_state = current_state,
n = n
)
}
}
pub fn parse_thoughts(output: &str, expected: usize) -> Vec<String> {
let mut thoughts = Vec::new();
for i in 1..=expected + 5 {
let marker = format!("THOUGHT {}:", i);
if let Some(pos) = output.to_uppercase().find(&marker.to_uppercase()) {
let start = pos + marker.len();
let rest = &output[start..];
let end = rest
.to_uppercase()
.find("THOUGHT ")
.unwrap_or(rest.len())
.min(rest.len());
let thought = rest[..end].trim().to_string();
if !thought.is_empty() {
thoughts.push(thought);
}
}
}
if thoughts.is_empty() {
for line in output.lines() {
let trimmed = line.trim();
if trimmed.starts_with(|c: char| c.is_ascii_digit()) {
let text: String = trimmed
.chars()
.skip_while(|c| c.is_ascii_digit() || *c == '.' || *c == ')' || *c == ':')
.collect();
let text = text.trim();
if !text.is_empty() {
thoughts.push(text.to_string());
}
}
}
}
thoughts.truncate(expected);
thoughts
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tree_creation() {
let mut tot = TreeOfThoughts::new(ToTConfig::default());
let root = tot.create_root("What is 2 + 2?");
assert_eq!(root, 0);
assert!(tot.get_node(0).is_some());
assert_eq!(tot.get_node(0).unwrap().depth, 0);
}
#[test]
fn test_add_children() {
let mut tot = TreeOfThoughts::new(ToTConfig::default());
let root = tot.create_root("Problem");
let child1 = tot.add_child(root, "Approach 1".into(), 0.8, ThoughtState::default());
let child2 = tot.add_child(root, "Approach 2".into(), 0.6, ThoughtState::default());
assert_eq!(tot.get_node(root).unwrap().children.len(), 2);
assert_eq!(tot.get_node(child1).unwrap().depth, 1);
assert_eq!(tot.get_node(child2).unwrap().parent, Some(root));
}
#[test]
fn test_get_path() {
let mut tot = TreeOfThoughts::new(ToTConfig::default());
let root = tot.create_root("Problem");
let child = tot.add_child(root, "Step 1".into(), 0.8, ThoughtState::default());
let grandchild = tot.add_child(child, "Step 2".into(), 0.7, ThoughtState::default());
let path = tot.get_path(grandchild);
assert_eq!(path.len(), 3);
assert_eq!(path[0].id, root);
assert_eq!(path[2].id, grandchild);
}
#[test]
fn test_pruning() {
let mut tot = TreeOfThoughts::new(ToTConfig {
prune_threshold: 0.5,
..Default::default()
});
let root = tot.create_root("Problem");
tot.add_child(root, "Good".into(), 0.8, ThoughtState::default());
tot.add_child(root, "Bad".into(), 0.2, ThoughtState::default());
let pruned = tot.prune();
assert_eq!(pruned, 1);
}
#[test]
fn test_parse_thoughts() {
let output = r#"
THOUGHT 1: First approach is to use algebra
THOUGHT 2: Second approach uses geometry
THOUGHT 3: Third uses numerical methods
"#;
let thoughts = parse_thoughts(output, 3);
assert_eq!(thoughts.len(), 3);
assert!(thoughts[0].contains("algebra"));
assert!(thoughts[1].contains("geometry"));
}
#[test]
fn test_beam_search() {
let mut tot = TreeOfThoughts::new(ToTConfig {
beam_width: 2,
search_strategy: SearchStrategy::BeamSearch,
..Default::default()
});
let root = tot.create_root("Problem");
tot.add_child(root, "Low".into(), 0.3, ThoughtState::default());
tot.add_child(root, "High".into(), 0.9, ThoughtState::default());
tot.add_child(root, "Medium".into(), 0.6, ThoughtState::default());
let candidates = tot.beam_step();
assert_eq!(candidates.len(), 2);
let scores: Vec<f32> = candidates
.iter()
.map(|&id| tot.get_node(id).unwrap().score)
.collect();
assert!(scores[0] >= 0.6);
}
}