1use serde::{Deserialize, Serialize};
36use std::collections::HashMap;
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ThoughtNode {
41 pub id: usize,
43 pub parent: Option<usize>,
45 pub children: Vec<usize>,
47 pub thought: String,
49 pub score: f32,
51 pub depth: usize,
53 pub is_terminal: bool,
55 pub state: ThoughtState,
57}
58
59#[derive(Debug, Clone, Default, Serialize, Deserialize)]
61pub struct ThoughtState {
62 pub reasoning_path: Vec<String>,
64 pub partial_results: HashMap<String, String>,
66 pub is_solved: bool,
68 pub solution: Option<String>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ToTResult {
75 pub best_path: Vec<ThoughtNode>,
77 pub solution: Option<String>,
79 pub score: f32,
81 pub explored_paths: usize,
83 pub nodes_generated: usize,
85 pub nodes_pruned: usize,
87 pub stats: ToTStats,
89}
90
91#[derive(Debug, Clone, Default, Serialize, Deserialize)]
92pub struct ToTStats {
93 pub avg_branching_factor: f32,
95 pub avg_node_score: f32,
97 pub max_depth_reached: usize,
99 pub backtrack_count: usize,
101 pub generation_time_ms: u64,
103 pub evaluation_time_ms: u64,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ToTConfig {
110 pub branching_factor: usize,
112 pub max_depth: usize,
114 pub search_strategy: SearchStrategy,
116 pub prune_threshold: f32,
118 pub max_nodes: usize,
120 pub beam_width: usize,
122 pub use_value_function: bool,
124 pub temperature: f32,
126}
127
128impl Default for ToTConfig {
129 fn default() -> Self {
130 Self {
131 branching_factor: 3,
132 max_depth: 5,
133 search_strategy: SearchStrategy::BreadthFirst,
134 prune_threshold: 0.3,
135 max_nodes: 100,
136 beam_width: 5,
137 use_value_function: true,
138 temperature: 0.7,
139 }
140 }
141}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
145pub enum SearchStrategy {
146 BreadthFirst,
148 DepthFirst,
150 BeamSearch,
152 BestFirst,
154 MCTS,
156}
157
158#[derive(Debug)]
160pub struct TreeOfThoughts {
161 pub config: ToTConfig,
162 nodes: Vec<ThoughtNode>,
164 next_id: usize,
166}
167
168impl TreeOfThoughts {
169 pub fn new(config: ToTConfig) -> Self {
170 Self {
171 config,
172 nodes: Vec::new(),
173 next_id: 0,
174 }
175 }
176
177 pub fn create_root(&mut self, problem: &str) -> usize {
179 let id = self.next_id;
180 self.next_id += 1;
181
182 let node = ThoughtNode {
183 id,
184 parent: None,
185 children: Vec::new(),
186 thought: problem.to_string(),
187 score: 1.0,
188 depth: 0,
189 is_terminal: false,
190 state: ThoughtState::default(),
191 };
192
193 self.nodes.push(node);
194 id
195 }
196
197 pub fn add_child(
199 &mut self,
200 parent_id: usize,
201 thought: String,
202 score: f32,
203 state: ThoughtState,
204 ) -> usize {
205 let id = self.next_id;
206 self.next_id += 1;
207
208 let parent_depth = self.nodes[parent_id].depth;
209
210 let node = ThoughtNode {
211 id,
212 parent: Some(parent_id),
213 children: Vec::new(),
214 thought,
215 score,
216 depth: parent_depth + 1,
217 is_terminal: state.is_solved,
218 state,
219 };
220
221 self.nodes.push(node);
222 self.nodes[parent_id].children.push(id);
223
224 id
225 }
226
227 pub fn get_node(&self, id: usize) -> Option<&ThoughtNode> {
229 self.nodes.get(id)
230 }
231
232 pub fn get_node_mut(&mut self, id: usize) -> Option<&mut ThoughtNode> {
234 self.nodes.get_mut(id)
235 }
236
237 pub fn get_path(&self, node_id: usize) -> Vec<&ThoughtNode> {
239 let mut path = Vec::new();
240 let mut current = Some(node_id);
241
242 while let Some(id) = current {
243 if let Some(node) = self.get_node(id) {
244 path.push(node);
245 current = node.parent;
246 } else {
247 break;
248 }
249 }
250
251 path.reverse();
252 path
253 }
254
255 pub fn prune(&mut self) -> usize {
257 let threshold = self.config.prune_threshold;
258 let mut pruned = 0;
259
260 for node in &mut self.nodes {
261 if node.score < threshold && !node.is_terminal {
262 node.is_terminal = true;
264 pruned += 1;
265 }
266 }
267
268 pruned
269 }
270
271 pub fn get_frontier(&self) -> Vec<usize> {
273 self.nodes
274 .iter()
275 .filter(|n| {
276 !n.is_terminal
277 && n.children.is_empty()
278 && n.depth < self.config.max_depth
279 && n.score >= self.config.prune_threshold
280 })
281 .map(|n| n.id)
282 .collect()
283 }
284
285 pub fn bfs_step(&self) -> Vec<usize> {
287 self.get_frontier()
289 }
290
291 pub fn dfs_step(&self) -> Vec<usize> {
293 let frontier = self.get_frontier();
294
295 if let Some(best) = frontier.iter().max_by_key(|&&id| self.nodes[id].depth) {
297 vec![*best]
298 } else {
299 vec![]
300 }
301 }
302
303 pub fn beam_step(&self) -> Vec<usize> {
305 let frontier = self.get_frontier();
306
307 let mut scored: Vec<_> = frontier
309 .iter()
310 .map(|&id| (id, self.nodes[id].score))
311 .collect();
312 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
313
314 scored
315 .into_iter()
316 .take(self.config.beam_width)
317 .map(|(id, _)| id)
318 .collect()
319 }
320
321 pub fn best_first_step(&self) -> Vec<usize> {
323 let frontier = self.get_frontier();
324
325 if let Some(&best) = frontier.iter().max_by(|&&a, &&b| {
327 self.nodes[a]
328 .score
329 .partial_cmp(&self.nodes[b].score)
330 .unwrap_or(std::cmp::Ordering::Equal)
331 }) {
332 vec![best]
333 } else {
334 vec![]
335 }
336 }
337
338 pub fn get_expansion_candidates(&self) -> Vec<usize> {
340 match self.config.search_strategy {
341 SearchStrategy::BreadthFirst => self.bfs_step(),
342 SearchStrategy::DepthFirst => self.dfs_step(),
343 SearchStrategy::BeamSearch => self.beam_step(),
344 SearchStrategy::BestFirst => self.best_first_step(),
345 SearchStrategy::MCTS => self.best_first_step(), }
347 }
348
349 pub fn find_best_solution(&self) -> Option<&ThoughtNode> {
351 self.nodes
352 .iter()
353 .filter(|n| n.is_terminal && n.state.is_solved)
354 .max_by(|a, b| {
355 a.score
356 .partial_cmp(&b.score)
357 .unwrap_or(std::cmp::Ordering::Equal)
358 })
359 }
360
361 pub fn build_result(&self) -> ToTResult {
363 let best_node = self.find_best_solution();
364
365 let (best_path, solution, score) = if let Some(node) = best_node {
366 let path = self.get_path(node.id);
367 (
368 path.into_iter().cloned().collect(),
369 node.state.solution.clone(),
370 node.score,
371 )
372 } else {
373 let best_leaf = self
375 .nodes
376 .iter()
377 .filter(|n| n.children.is_empty())
378 .max_by(|a, b| {
379 a.score
380 .partial_cmp(&b.score)
381 .unwrap_or(std::cmp::Ordering::Equal)
382 });
383
384 if let Some(node) = best_leaf {
385 let path = self.get_path(node.id);
386 (path.into_iter().cloned().collect(), None, node.score)
387 } else {
388 (vec![], None, 0.0)
389 }
390 };
391
392 let nodes_pruned = self
393 .nodes
394 .iter()
395 .filter(|n| n.score < self.config.prune_threshold)
396 .count();
397
398 let max_depth = self.nodes.iter().map(|n| n.depth).max().unwrap_or(0);
399
400 let avg_score = if !self.nodes.is_empty() {
401 self.nodes.iter().map(|n| n.score).sum::<f32>() / self.nodes.len() as f32
402 } else {
403 0.0
404 };
405
406 let avg_branching = if self.nodes.len() > 1 {
407 let non_leaf = self.nodes.iter().filter(|n| !n.children.is_empty()).count();
408 if non_leaf > 0 {
409 self.nodes.iter().map(|n| n.children.len()).sum::<usize>() as f32 / non_leaf as f32
410 } else {
411 0.0
412 }
413 } else {
414 0.0
415 };
416
417 ToTResult {
418 best_path,
419 solution,
420 score,
421 explored_paths: self.nodes.iter().filter(|n| n.children.is_empty()).count(),
422 nodes_generated: self.nodes.len(),
423 nodes_pruned,
424 stats: ToTStats {
425 avg_branching_factor: avg_branching,
426 avg_node_score: avg_score,
427 max_depth_reached: max_depth,
428 backtrack_count: 0,
429 generation_time_ms: 0,
430 evaluation_time_ms: 0,
431 },
432 }
433 }
434
435 pub fn reset(&mut self) {
437 self.nodes.clear();
438 self.next_id = 0;
439 }
440}
441
442pub struct ThoughtPrompts;
444
445impl ThoughtPrompts {
446 pub fn math_thoughts(problem: &str, current_state: &str, n: usize) -> String {
448 format!(
449 r#"You are solving a math problem step by step.
450
451PROBLEM: {problem}
452
453CURRENT STATE:
454{current_state}
455
456Generate exactly {n} different possible next steps to make progress on this problem.
457Each step should be a distinct approach or continuation.
458
459Format each thought as:
460THOUGHT 1: [your first possible step]
461THOUGHT 2: [your second possible step]
462THOUGHT 3: [etc...]
463
464Be creative and explore different angles. Some thoughts might:
465- Apply a formula directly
466- Break down into sub-problems
467- Use a different variable
468- Try a numerical approach
469- Look for patterns"#,
470 problem = problem,
471 current_state = current_state,
472 n = n
473 )
474 }
475
476 pub fn evaluate_thought(problem: &str, thought: &str, context: &str) -> String {
478 format!(
479 r#"Evaluate how promising this thought is for solving the problem.
480
481PROBLEM: {problem}
482
483CONTEXT/PRIOR STEPS:
484{context}
485
486THOUGHT TO EVALUATE:
487{thought}
488
489Rate on a scale of 0.0 to 1.0:
490- 1.0: Definitely leads to solution
491- 0.7-0.9: Very promising direction
492- 0.4-0.6: Reasonable but uncertain
493- 0.1-0.3: Unlikely to help
494- 0.0: Definitely wrong or counterproductive
495
496Consider:
4971. Is the logic correct?
4982. Does it make progress toward the answer?
4993. Is it a reasonable next step given the context?
5004. Could it lead to the final solution?
501
502Respond with only a JSON object:
503{{"score": 0.0-1.0, "reasoning": "brief explanation"}}"#,
504 problem = problem,
505 context = context,
506 thought = thought
507 )
508 }
509
510 pub fn check_terminal(problem: &str, current_state: &str) -> String {
512 format!(
513 r#"Determine if this problem has been solved.
514
515PROBLEM: {problem}
516
517CURRENT STATE/REASONING:
518{current_state}
519
520Answer with a JSON object:
521{{
522 "is_solved": true/false,
523 "solution": "the answer if solved, null otherwise",
524 "confidence": 0.0-1.0
525}}"#,
526 problem = problem,
527 current_state = current_state
528 )
529 }
530
531 pub fn creative_thoughts(problem: &str, current_state: &str, n: usize) -> String {
533 format!(
534 r#"You are exploring creative solutions to a problem.
535
536PROBLEM: {problem}
537
538CURRENT EXPLORATION:
539{current_state}
540
541Generate {n} diverse and creative next thoughts. Think unconventionally.
542
543Format as:
544THOUGHT 1: [first creative direction]
545THOUGHT 2: [second creative direction]
546...
547
548Consider:
549- Analogy to other domains
550- Inverting the problem
551- Combining ideas
552- Extreme cases
553- Different perspectives"#,
554 problem = problem,
555 current_state = current_state,
556 n = n
557 )
558 }
559}
560
561pub fn parse_thoughts(output: &str, expected: usize) -> Vec<String> {
563 let mut thoughts = Vec::new();
564
565 for i in 1..=expected + 5 {
567 let marker = format!("THOUGHT {}:", i);
568 if let Some(pos) = output.to_uppercase().find(&marker.to_uppercase()) {
569 let start = pos + marker.len();
570 let rest = &output[start..];
571
572 let end = rest
574 .to_uppercase()
575 .find("THOUGHT ")
576 .unwrap_or(rest.len())
577 .min(rest.len());
578
579 let thought = rest[..end].trim().to_string();
580 if !thought.is_empty() {
581 thoughts.push(thought);
582 }
583 }
584 }
585
586 if thoughts.is_empty() {
588 for line in output.lines() {
589 let trimmed = line.trim();
590 if trimmed.starts_with(|c: char| c.is_ascii_digit()) {
591 let text: String = trimmed
593 .chars()
594 .skip_while(|c| c.is_ascii_digit() || *c == '.' || *c == ')' || *c == ':')
595 .collect();
596 let text = text.trim();
597 if !text.is_empty() {
598 thoughts.push(text.to_string());
599 }
600 }
601 }
602 }
603
604 thoughts.truncate(expected);
605 thoughts
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611
612 #[test]
613 fn test_tree_creation() {
614 let mut tot = TreeOfThoughts::new(ToTConfig::default());
615 let root = tot.create_root("What is 2 + 2?");
616
617 assert_eq!(root, 0);
618 assert!(tot.get_node(0).is_some());
619 assert_eq!(tot.get_node(0).unwrap().depth, 0);
620 }
621
622 #[test]
623 fn test_add_children() {
624 let mut tot = TreeOfThoughts::new(ToTConfig::default());
625 let root = tot.create_root("Problem");
626
627 let child1 = tot.add_child(root, "Approach 1".into(), 0.8, ThoughtState::default());
628 let child2 = tot.add_child(root, "Approach 2".into(), 0.6, ThoughtState::default());
629
630 assert_eq!(tot.get_node(root).unwrap().children.len(), 2);
631 assert_eq!(tot.get_node(child1).unwrap().depth, 1);
632 assert_eq!(tot.get_node(child2).unwrap().parent, Some(root));
633 }
634
635 #[test]
636 fn test_get_path() {
637 let mut tot = TreeOfThoughts::new(ToTConfig::default());
638 let root = tot.create_root("Problem");
639 let child = tot.add_child(root, "Step 1".into(), 0.8, ThoughtState::default());
640 let grandchild = tot.add_child(child, "Step 2".into(), 0.7, ThoughtState::default());
641
642 let path = tot.get_path(grandchild);
643 assert_eq!(path.len(), 3);
644 assert_eq!(path[0].id, root);
645 assert_eq!(path[2].id, grandchild);
646 }
647
648 #[test]
649 fn test_pruning() {
650 let mut tot = TreeOfThoughts::new(ToTConfig {
651 prune_threshold: 0.5,
652 ..Default::default()
653 });
654 let root = tot.create_root("Problem");
655 tot.add_child(root, "Good".into(), 0.8, ThoughtState::default());
656 tot.add_child(root, "Bad".into(), 0.2, ThoughtState::default());
657
658 let pruned = tot.prune();
659 assert_eq!(pruned, 1);
660 }
661
662 #[test]
663 fn test_parse_thoughts() {
664 let output = r#"
665THOUGHT 1: First approach is to use algebra
666THOUGHT 2: Second approach uses geometry
667THOUGHT 3: Third uses numerical methods
668"#;
669
670 let thoughts = parse_thoughts(output, 3);
671 assert_eq!(thoughts.len(), 3);
672 assert!(thoughts[0].contains("algebra"));
673 assert!(thoughts[1].contains("geometry"));
674 }
675
676 #[test]
677 fn test_beam_search() {
678 let mut tot = TreeOfThoughts::new(ToTConfig {
679 beam_width: 2,
680 search_strategy: SearchStrategy::BeamSearch,
681 ..Default::default()
682 });
683 let root = tot.create_root("Problem");
684 tot.add_child(root, "Low".into(), 0.3, ThoughtState::default());
685 tot.add_child(root, "High".into(), 0.9, ThoughtState::default());
686 tot.add_child(root, "Medium".into(), 0.6, ThoughtState::default());
687
688 let candidates = tot.beam_step();
689 assert_eq!(candidates.len(), 2);
690
691 let scores: Vec<f32> = candidates
693 .iter()
694 .map(|&id| tot.get_node(id).unwrap().score)
695 .collect();
696 assert!(scores[0] >= 0.6);
697 }
698}