reasonkit/thinktool/
tot.rs

1//! # Tree-of-Thoughts (ToT) Parallel Exploration
2//!
3//! Implements parallel thought exploration achieving 74% vs 4% on creative tasks
4//! compared to Chain-of-Thought.
5//!
6//! ## Scientific Foundation
7//!
8//! Based on:
9//! - Yao et al. (2023): Tree of Thoughts: Deliberate Problem Solving with Large Language Models
10//! - Long (2023): Large Language Model Guided Tree-of-Thought
11//!
12//! ## Key Concepts
13//!
14//! - **Thought**: Coherent language chunk (sentence to paragraph)
15//! - **Decomposition**: Break problem into thought steps
16//! - **Generation**: Propose multiple candidates per step
17//! - **Evaluation**: Score thoughts for promise
18//! - **Search**: BFS/DFS with pruning
19//!
20//! ## Usage
21//!
22//! ```rust,ignore
23//! use reasonkit::thinktool::tot::{TreeOfThoughts, ToTConfig};
24//!
25//! let tot = TreeOfThoughts::new(ToTConfig {
26//!     branching_factor: 3,
27//!     max_depth: 5,
28//!     search_strategy: SearchStrategy::BreadthFirst,
29//!     ..Default::default()
30//! });
31//!
32//! let result = tot.solve("Creative problem here").await?;
33//! ```
34
35use serde::{Deserialize, Serialize};
36use std::collections::HashMap;
37
38/// A single thought node in the tree
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ThoughtNode {
41    /// Unique node ID
42    pub id: usize,
43    /// Parent node ID (None for root)
44    pub parent: Option<usize>,
45    /// Child node IDs
46    pub children: Vec<usize>,
47    /// The thought content
48    pub thought: String,
49    /// Evaluation score (0.0 - 1.0)
50    pub score: f32,
51    /// Depth in tree (root = 0)
52    pub depth: usize,
53    /// Whether this is a terminal/solution node
54    pub is_terminal: bool,
55    /// State representation after this thought
56    pub state: ThoughtState,
57}
58
59/// State after applying a thought
60#[derive(Debug, Clone, Default, Serialize, Deserialize)]
61pub struct ThoughtState {
62    /// Accumulated reasoning so far
63    pub reasoning_path: Vec<String>,
64    /// Intermediate results
65    pub partial_results: HashMap<String, String>,
66    /// Whether the problem is solved
67    pub is_solved: bool,
68    /// Solution if found
69    pub solution: Option<String>,
70}
71
72/// Result of ToT exploration
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ToTResult {
75    /// Best solution path
76    pub best_path: Vec<ThoughtNode>,
77    /// Best solution
78    pub solution: Option<String>,
79    /// Final score
80    pub score: f32,
81    /// All explored paths (for debugging)
82    pub explored_paths: usize,
83    /// Total nodes generated
84    pub nodes_generated: usize,
85    /// Nodes pruned
86    pub nodes_pruned: usize,
87    /// Statistics
88    pub stats: ToTStats,
89}
90
91#[derive(Debug, Clone, Default, Serialize, Deserialize)]
92pub struct ToTStats {
93    /// Average branching factor observed
94    pub avg_branching_factor: f32,
95    /// Average node score
96    pub avg_node_score: f32,
97    /// Maximum depth reached
98    pub max_depth_reached: usize,
99    /// Number of backtrack operations
100    pub backtrack_count: usize,
101    /// Time spent in generation (ms)
102    pub generation_time_ms: u64,
103    /// Time spent in evaluation (ms)
104    pub evaluation_time_ms: u64,
105}
106
107/// Tree-of-Thoughts configuration
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ToTConfig {
110    /// Number of thoughts to generate per step
111    pub branching_factor: usize,
112    /// Maximum tree depth
113    pub max_depth: usize,
114    /// Search strategy
115    pub search_strategy: SearchStrategy,
116    /// Pruning threshold (nodes below this are dropped)
117    pub prune_threshold: f32,
118    /// Maximum nodes to expand
119    pub max_nodes: usize,
120    /// Beam width for beam search
121    pub beam_width: usize,
122    /// Whether to use value function for evaluation
123    pub use_value_function: bool,
124    /// Temperature for thought generation
125    pub temperature: f32,
126}
127
128impl Default for ToTConfig {
129    fn default() -> Self {
130        Self {
131            branching_factor: 3,
132            max_depth: 5,
133            search_strategy: SearchStrategy::BreadthFirst,
134            prune_threshold: 0.3,
135            max_nodes: 100,
136            beam_width: 5,
137            use_value_function: true,
138            temperature: 0.7,
139        }
140    }
141}
142
143/// Search strategy for exploring the tree
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
145pub enum SearchStrategy {
146    /// Explore level by level
147    BreadthFirst,
148    /// Explore depth-first with backtracking
149    DepthFirst,
150    /// Keep top-k candidates per level
151    BeamSearch,
152    /// Best-first search using scores
153    BestFirst,
154    /// Monte Carlo Tree Search
155    MCTS,
156}
157
158/// The Tree-of-Thoughts engine
159#[derive(Debug)]
160pub struct TreeOfThoughts {
161    pub config: ToTConfig,
162    /// All nodes in the tree
163    nodes: Vec<ThoughtNode>,
164    /// Node ID counter
165    next_id: usize,
166}
167
168impl TreeOfThoughts {
169    pub fn new(config: ToTConfig) -> Self {
170        Self {
171            config,
172            nodes: Vec::new(),
173            next_id: 0,
174        }
175    }
176
177    /// Create a new root node
178    pub fn create_root(&mut self, problem: &str) -> usize {
179        let id = self.next_id;
180        self.next_id += 1;
181
182        let node = ThoughtNode {
183            id,
184            parent: None,
185            children: Vec::new(),
186            thought: problem.to_string(),
187            score: 1.0,
188            depth: 0,
189            is_terminal: false,
190            state: ThoughtState::default(),
191        };
192
193        self.nodes.push(node);
194        id
195    }
196
197    /// Add a child thought to a node
198    pub fn add_child(
199        &mut self,
200        parent_id: usize,
201        thought: String,
202        score: f32,
203        state: ThoughtState,
204    ) -> usize {
205        let id = self.next_id;
206        self.next_id += 1;
207
208        let parent_depth = self.nodes[parent_id].depth;
209
210        let node = ThoughtNode {
211            id,
212            parent: Some(parent_id),
213            children: Vec::new(),
214            thought,
215            score,
216            depth: parent_depth + 1,
217            is_terminal: state.is_solved,
218            state,
219        };
220
221        self.nodes.push(node);
222        self.nodes[parent_id].children.push(id);
223
224        id
225    }
226
227    /// Get a node by ID
228    pub fn get_node(&self, id: usize) -> Option<&ThoughtNode> {
229        self.nodes.get(id)
230    }
231
232    /// Get mutable node by ID
233    pub fn get_node_mut(&mut self, id: usize) -> Option<&mut ThoughtNode> {
234        self.nodes.get_mut(id)
235    }
236
237    /// Get path from root to node
238    pub fn get_path(&self, node_id: usize) -> Vec<&ThoughtNode> {
239        let mut path = Vec::new();
240        let mut current = Some(node_id);
241
242        while let Some(id) = current {
243            if let Some(node) = self.get_node(id) {
244                path.push(node);
245                current = node.parent;
246            } else {
247                break;
248            }
249        }
250
251        path.reverse();
252        path
253    }
254
255    /// Prune nodes below threshold
256    pub fn prune(&mut self) -> usize {
257        let threshold = self.config.prune_threshold;
258        let mut pruned = 0;
259
260        for node in &mut self.nodes {
261            if node.score < threshold && !node.is_terminal {
262                // Mark as terminal (effectively pruned)
263                node.is_terminal = true;
264                pruned += 1;
265            }
266        }
267
268        pruned
269    }
270
271    /// Get frontier nodes (expandable leaves)
272    pub fn get_frontier(&self) -> Vec<usize> {
273        self.nodes
274            .iter()
275            .filter(|n| {
276                !n.is_terminal
277                    && n.children.is_empty()
278                    && n.depth < self.config.max_depth
279                    && n.score >= self.config.prune_threshold
280            })
281            .map(|n| n.id)
282            .collect()
283    }
284
285    /// BFS exploration step
286    pub fn bfs_step(&self) -> Vec<usize> {
287        // In BFS, we process all frontier nodes
288        self.get_frontier()
289    }
290
291    /// DFS exploration step
292    pub fn dfs_step(&self) -> Vec<usize> {
293        let frontier = self.get_frontier();
294
295        // In DFS, we go deep first - pick highest depth node
296        if let Some(best) = frontier.iter().max_by_key(|&&id| self.nodes[id].depth) {
297            vec![*best]
298        } else {
299            vec![]
300        }
301    }
302
303    /// Beam search step
304    pub fn beam_step(&self) -> Vec<usize> {
305        let frontier = self.get_frontier();
306
307        // Keep top-k by score
308        let mut scored: Vec<_> = frontier
309            .iter()
310            .map(|&id| (id, self.nodes[id].score))
311            .collect();
312        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
313
314        scored
315            .into_iter()
316            .take(self.config.beam_width)
317            .map(|(id, _)| id)
318            .collect()
319    }
320
321    /// Best-first step
322    pub fn best_first_step(&self) -> Vec<usize> {
323        let frontier = self.get_frontier();
324
325        // Pick single best node
326        if let Some(&best) = frontier.iter().max_by(|&&a, &&b| {
327            self.nodes[a]
328                .score
329                .partial_cmp(&self.nodes[b].score)
330                .unwrap_or(std::cmp::Ordering::Equal)
331        }) {
332            vec![best]
333        } else {
334            vec![]
335        }
336    }
337
338    /// Get nodes to expand based on search strategy
339    pub fn get_expansion_candidates(&self) -> Vec<usize> {
340        match self.config.search_strategy {
341            SearchStrategy::BreadthFirst => self.bfs_step(),
342            SearchStrategy::DepthFirst => self.dfs_step(),
343            SearchStrategy::BeamSearch => self.beam_step(),
344            SearchStrategy::BestFirst => self.best_first_step(),
345            SearchStrategy::MCTS => self.best_first_step(), // Simplified MCTS
346        }
347    }
348
349    /// Find the best terminal node (solution)
350    pub fn find_best_solution(&self) -> Option<&ThoughtNode> {
351        self.nodes
352            .iter()
353            .filter(|n| n.is_terminal && n.state.is_solved)
354            .max_by(|a, b| {
355                a.score
356                    .partial_cmp(&b.score)
357                    .unwrap_or(std::cmp::Ordering::Equal)
358            })
359    }
360
361    /// Build result from current tree state
362    pub fn build_result(&self) -> ToTResult {
363        let best_node = self.find_best_solution();
364
365        let (best_path, solution, score) = if let Some(node) = best_node {
366            let path = self.get_path(node.id);
367            (
368                path.into_iter().cloned().collect(),
369                node.state.solution.clone(),
370                node.score,
371            )
372        } else {
373            // Return best non-terminal path
374            let best_leaf = self
375                .nodes
376                .iter()
377                .filter(|n| n.children.is_empty())
378                .max_by(|a, b| {
379                    a.score
380                        .partial_cmp(&b.score)
381                        .unwrap_or(std::cmp::Ordering::Equal)
382                });
383
384            if let Some(node) = best_leaf {
385                let path = self.get_path(node.id);
386                (path.into_iter().cloned().collect(), None, node.score)
387            } else {
388                (vec![], None, 0.0)
389            }
390        };
391
392        let nodes_pruned = self
393            .nodes
394            .iter()
395            .filter(|n| n.score < self.config.prune_threshold)
396            .count();
397
398        let max_depth = self.nodes.iter().map(|n| n.depth).max().unwrap_or(0);
399
400        let avg_score = if !self.nodes.is_empty() {
401            self.nodes.iter().map(|n| n.score).sum::<f32>() / self.nodes.len() as f32
402        } else {
403            0.0
404        };
405
406        let avg_branching = if self.nodes.len() > 1 {
407            let non_leaf = self.nodes.iter().filter(|n| !n.children.is_empty()).count();
408            if non_leaf > 0 {
409                self.nodes.iter().map(|n| n.children.len()).sum::<usize>() as f32 / non_leaf as f32
410            } else {
411                0.0
412            }
413        } else {
414            0.0
415        };
416
417        ToTResult {
418            best_path,
419            solution,
420            score,
421            explored_paths: self.nodes.iter().filter(|n| n.children.is_empty()).count(),
422            nodes_generated: self.nodes.len(),
423            nodes_pruned,
424            stats: ToTStats {
425                avg_branching_factor: avg_branching,
426                avg_node_score: avg_score,
427                max_depth_reached: max_depth,
428                backtrack_count: 0,
429                generation_time_ms: 0,
430                evaluation_time_ms: 0,
431            },
432        }
433    }
434
435    /// Reset the tree for a new problem
436    pub fn reset(&mut self) {
437        self.nodes.clear();
438        self.next_id = 0;
439    }
440}
441
442/// Thought generation prompt templates
443pub struct ThoughtPrompts;
444
445impl ThoughtPrompts {
446    /// Generate N diverse thoughts for a math problem
447    pub fn math_thoughts(problem: &str, current_state: &str, n: usize) -> String {
448        format!(
449            r#"You are solving a math problem step by step.
450
451PROBLEM: {problem}
452
453CURRENT STATE:
454{current_state}
455
456Generate exactly {n} different possible next steps to make progress on this problem.
457Each step should be a distinct approach or continuation.
458
459Format each thought as:
460THOUGHT 1: [your first possible step]
461THOUGHT 2: [your second possible step]
462THOUGHT 3: [etc...]
463
464Be creative and explore different angles. Some thoughts might:
465- Apply a formula directly
466- Break down into sub-problems
467- Use a different variable
468- Try a numerical approach
469- Look for patterns"#,
470            problem = problem,
471            current_state = current_state,
472            n = n
473        )
474    }
475
476    /// Evaluate a thought for promise
477    pub fn evaluate_thought(problem: &str, thought: &str, context: &str) -> String {
478        format!(
479            r#"Evaluate how promising this thought is for solving the problem.
480
481PROBLEM: {problem}
482
483CONTEXT/PRIOR STEPS:
484{context}
485
486THOUGHT TO EVALUATE:
487{thought}
488
489Rate on a scale of 0.0 to 1.0:
490- 1.0: Definitely leads to solution
491- 0.7-0.9: Very promising direction
492- 0.4-0.6: Reasonable but uncertain
493- 0.1-0.3: Unlikely to help
494- 0.0: Definitely wrong or counterproductive
495
496Consider:
4971. Is the logic correct?
4982. Does it make progress toward the answer?
4993. Is it a reasonable next step given the context?
5004. Could it lead to the final solution?
501
502Respond with only a JSON object:
503{{"score": 0.0-1.0, "reasoning": "brief explanation"}}"#,
504            problem = problem,
505            context = context,
506            thought = thought
507        )
508    }
509
510    /// Check if a state is terminal (solved)
511    pub fn check_terminal(problem: &str, current_state: &str) -> String {
512        format!(
513            r#"Determine if this problem has been solved.
514
515PROBLEM: {problem}
516
517CURRENT STATE/REASONING:
518{current_state}
519
520Answer with a JSON object:
521{{
522    "is_solved": true/false,
523    "solution": "the answer if solved, null otherwise",
524    "confidence": 0.0-1.0
525}}"#,
526            problem = problem,
527            current_state = current_state
528        )
529    }
530
531    /// Creative problem thoughts
532    pub fn creative_thoughts(problem: &str, current_state: &str, n: usize) -> String {
533        format!(
534            r#"You are exploring creative solutions to a problem.
535
536PROBLEM: {problem}
537
538CURRENT EXPLORATION:
539{current_state}
540
541Generate {n} diverse and creative next thoughts. Think unconventionally.
542
543Format as:
544THOUGHT 1: [first creative direction]
545THOUGHT 2: [second creative direction]
546...
547
548Consider:
549- Analogy to other domains
550- Inverting the problem
551- Combining ideas
552- Extreme cases
553- Different perspectives"#,
554            problem = problem,
555            current_state = current_state,
556            n = n
557        )
558    }
559}
560
561/// Parse thoughts from LLM output
562pub fn parse_thoughts(output: &str, expected: usize) -> Vec<String> {
563    let mut thoughts = Vec::new();
564
565    // Try parsing "THOUGHT N:" format
566    for i in 1..=expected + 5 {
567        let marker = format!("THOUGHT {}:", i);
568        if let Some(pos) = output.to_uppercase().find(&marker.to_uppercase()) {
569            let start = pos + marker.len();
570            let rest = &output[start..];
571
572            // Find end (next THOUGHT marker or end)
573            let end = rest
574                .to_uppercase()
575                .find("THOUGHT ")
576                .unwrap_or(rest.len())
577                .min(rest.len());
578
579            let thought = rest[..end].trim().to_string();
580            if !thought.is_empty() {
581                thoughts.push(thought);
582            }
583        }
584    }
585
586    // Fallback: split by numbered list
587    if thoughts.is_empty() {
588        for line in output.lines() {
589            let trimmed = line.trim();
590            if trimmed.starts_with(|c: char| c.is_ascii_digit()) {
591                // Remove leading number and punctuation
592                let text: String = trimmed
593                    .chars()
594                    .skip_while(|c| c.is_ascii_digit() || *c == '.' || *c == ')' || *c == ':')
595                    .collect();
596                let text = text.trim();
597                if !text.is_empty() {
598                    thoughts.push(text.to_string());
599                }
600            }
601        }
602    }
603
604    thoughts.truncate(expected);
605    thoughts
606}
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611
612    #[test]
613    fn test_tree_creation() {
614        let mut tot = TreeOfThoughts::new(ToTConfig::default());
615        let root = tot.create_root("What is 2 + 2?");
616
617        assert_eq!(root, 0);
618        assert!(tot.get_node(0).is_some());
619        assert_eq!(tot.get_node(0).unwrap().depth, 0);
620    }
621
622    #[test]
623    fn test_add_children() {
624        let mut tot = TreeOfThoughts::new(ToTConfig::default());
625        let root = tot.create_root("Problem");
626
627        let child1 = tot.add_child(root, "Approach 1".into(), 0.8, ThoughtState::default());
628        let child2 = tot.add_child(root, "Approach 2".into(), 0.6, ThoughtState::default());
629
630        assert_eq!(tot.get_node(root).unwrap().children.len(), 2);
631        assert_eq!(tot.get_node(child1).unwrap().depth, 1);
632        assert_eq!(tot.get_node(child2).unwrap().parent, Some(root));
633    }
634
635    #[test]
636    fn test_get_path() {
637        let mut tot = TreeOfThoughts::new(ToTConfig::default());
638        let root = tot.create_root("Problem");
639        let child = tot.add_child(root, "Step 1".into(), 0.8, ThoughtState::default());
640        let grandchild = tot.add_child(child, "Step 2".into(), 0.7, ThoughtState::default());
641
642        let path = tot.get_path(grandchild);
643        assert_eq!(path.len(), 3);
644        assert_eq!(path[0].id, root);
645        assert_eq!(path[2].id, grandchild);
646    }
647
648    #[test]
649    fn test_pruning() {
650        let mut tot = TreeOfThoughts::new(ToTConfig {
651            prune_threshold: 0.5,
652            ..Default::default()
653        });
654        let root = tot.create_root("Problem");
655        tot.add_child(root, "Good".into(), 0.8, ThoughtState::default());
656        tot.add_child(root, "Bad".into(), 0.2, ThoughtState::default());
657
658        let pruned = tot.prune();
659        assert_eq!(pruned, 1);
660    }
661
662    #[test]
663    fn test_parse_thoughts() {
664        let output = r#"
665THOUGHT 1: First approach is to use algebra
666THOUGHT 2: Second approach uses geometry
667THOUGHT 3: Third uses numerical methods
668"#;
669
670        let thoughts = parse_thoughts(output, 3);
671        assert_eq!(thoughts.len(), 3);
672        assert!(thoughts[0].contains("algebra"));
673        assert!(thoughts[1].contains("geometry"));
674    }
675
676    #[test]
677    fn test_beam_search() {
678        let mut tot = TreeOfThoughts::new(ToTConfig {
679            beam_width: 2,
680            search_strategy: SearchStrategy::BeamSearch,
681            ..Default::default()
682        });
683        let root = tot.create_root("Problem");
684        tot.add_child(root, "Low".into(), 0.3, ThoughtState::default());
685        tot.add_child(root, "High".into(), 0.9, ThoughtState::default());
686        tot.add_child(root, "Medium".into(), 0.6, ThoughtState::default());
687
688        let candidates = tot.beam_step();
689        assert_eq!(candidates.len(), 2);
690
691        // Beam should select highest scored nodes
692        let scores: Vec<f32> = candidates
693            .iter()
694            .map(|&id| tot.get_node(id).unwrap().score)
695            .collect();
696        assert!(scores[0] >= 0.6);
697    }
698}