use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use petgraph::graph::{DiGraph, NodeIndex};
use petgraph::visit::Dfs;
use serde::{Deserialize, Serialize};
use serde_json::json;
use thiserror::Error;
use super::{ThinkToolContext, ThinkToolModule, ThinkToolModuleConfig, ThinkToolOutput};
use crate::error::{Error, Result};
#[derive(Error, Debug, Clone)]
pub enum GraphOfThoughtError {
#[error("Graph construction failed: {reason}")]
GraphConstructionFailed { reason: String },
#[error("Invalid graph operation: {operation}")]
InvalidOperation { operation: String },
#[error("Maximum graph depth exceeded: {depth} > {max_depth}")]
MaxDepthExceeded { depth: usize, max_depth: usize },
#[error("No valid thoughts generated after {attempts} attempts")]
NoValidThoughts { attempts: usize },
#[error("Aggregation failed: {reason}")]
AggregationFailed { reason: String },
#[error("Refinement exceeded maximum rounds: {rounds} > {max_rounds}")]
RefinementExceeded { rounds: usize, max_rounds: usize },
#[error("Backtrack target not found: node {node_id}")]
BacktrackTargetNotFound { node_id: u64 },
#[error("Final confidence too low: {confidence:.2} < {threshold:.2}")]
LowConfidence { confidence: f64, threshold: f64 },
#[error("Query validation failed: {reason}")]
QueryValidationFailed { reason: String },
}
impl From<GraphOfThoughtError> for Error {
fn from(err: GraphOfThoughtError) -> Self {
Error::ThinkToolExecutionError(err.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GoTConfig {
pub max_branches: usize,
pub enable_aggregation: bool,
pub enable_refinement: bool,
pub max_refinement_rounds: usize,
pub backtrack_threshold: f64,
pub min_confidence: f64,
pub max_depth: usize,
pub auto_prune: bool,
pub keep_top_n: usize,
pub prevent_cycles: bool,
pub max_nodes: usize,
pub min_query_length: usize,
pub max_query_length: usize,
pub verbose: bool,
}
impl Default for GoTConfig {
fn default() -> Self {
Self {
max_branches: 4,
enable_aggregation: true,
enable_refinement: true,
max_refinement_rounds: 3,
backtrack_threshold: 0.3,
min_confidence: 0.75,
max_depth: 10,
auto_prune: true,
keep_top_n: 5,
prevent_cycles: true,
max_nodes: 100,
min_query_length: 10,
max_query_length: 5000,
verbose: false,
}
}
}
impl GoTConfig {
pub fn quick() -> Self {
Self {
max_branches: 2,
enable_aggregation: false,
enable_refinement: false,
max_refinement_rounds: 1,
backtrack_threshold: 0.4,
min_confidence: 0.65,
max_depth: 5,
auto_prune: true,
keep_top_n: 3,
prevent_cycles: true,
max_nodes: 30,
min_query_length: 5,
max_query_length: 2000,
verbose: false,
}
}
pub fn deep() -> Self {
Self {
max_branches: 6,
enable_aggregation: true,
enable_refinement: true,
max_refinement_rounds: 5,
backtrack_threshold: 0.25,
min_confidence: 0.80,
max_depth: 15,
auto_prune: true,
keep_top_n: 8,
prevent_cycles: true,
max_nodes: 200,
min_query_length: 10,
max_query_length: 10000,
verbose: true,
}
}
pub fn paranoid() -> Self {
Self {
max_branches: 8,
enable_aggregation: true,
enable_refinement: true,
max_refinement_rounds: 7,
backtrack_threshold: 0.35,
min_confidence: 0.90,
max_depth: 20,
auto_prune: true,
keep_top_n: 10,
prevent_cycles: true,
max_nodes: 500,
min_query_length: 10,
max_query_length: 10000,
verbose: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThoughtNode {
pub id: u64,
pub state: HashMap<String, serde_json::Value>,
pub score: Option<f64>,
pub valid: Option<bool>,
pub reasoning_step: String,
pub children: Vec<u64>,
pub parents: Vec<u64>,
pub depth: usize,
pub tags: Vec<String>,
pub metadata: ThoughtMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ThoughtMetadata {
pub created_at: u64,
pub created_by: Option<String>,
pub refinement_count: usize,
pub aggregated_from: Vec<u64>,
pub processing_time_ms: Option<u64>,
}
impl ThoughtNode {
pub fn new(id: u64, reasoning_step: impl Into<String>) -> Self {
Self {
id,
state: HashMap::new(),
score: None,
valid: None,
reasoning_step: reasoning_step.into(),
children: Vec::new(),
parents: Vec::new(),
depth: 0,
tags: Vec::new(),
metadata: ThoughtMetadata {
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0),
..Default::default()
},
}
}
pub fn with_state(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.state.insert(key.into(), value);
self
}
pub fn with_score(mut self, score: f64) -> Self {
self.score = Some(score.clamp(0.0, 1.0));
self
}
pub fn with_valid(mut self, valid: bool) -> Self {
self.valid = Some(valid);
self
}
pub fn with_depth(mut self, depth: usize) -> Self {
self.depth = depth;
self
}
pub fn with_parent(mut self, parent_id: u64) -> Self {
if !self.parents.contains(&parent_id) {
self.parents.push(parent_id);
}
self
}
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn add_child(&mut self, child_id: u64) {
if !self.children.contains(&child_id) {
self.children.push(child_id);
}
}
pub fn is_scored(&self) -> bool {
self.score.is_some()
}
pub fn is_validated(&self) -> bool {
self.valid.is_some()
}
pub fn is_viable(&self, threshold: f64) -> bool {
self.valid.unwrap_or(true) && self.score.unwrap_or(0.5) >= threshold
}
pub fn effective_score(&self) -> f64 {
self.score.unwrap_or(0.5)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GraphOp {
Generate {
k: usize,
},
Aggregate {
source_ids: Vec<u64>,
},
Refine {
max_iterations: usize,
},
Score,
KeepBestN {
n: usize,
},
Backtrack {
to_id: u64,
},
}
impl GraphOp {
pub fn description(&self) -> String {
match self {
Self::Generate { k } => format!("Generate {} new thought branches", k),
Self::Aggregate { source_ids } => {
format!("Aggregate {} thoughts into one", source_ids.len())
}
Self::Refine { max_iterations } => {
format!("Refine with up to {} iterations", max_iterations)
}
Self::Score => "Score all unscored thoughts".to_string(),
Self::KeepBestN { n } => format!("Keep best {} thoughts", n),
Self::Backtrack { to_id } => format!("Backtrack to node {}", to_id),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutedOperation {
pub operation: GraphOp,
pub success: bool,
pub affected_nodes: Vec<u64>,
pub duration_ms: u64,
pub notes: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GoTResult {
pub query: String,
pub final_thoughts: Vec<ThoughtNode>,
pub reasoning_graph: ReasoningGraphData,
pub operations_executed: Vec<ExecutedOperation>,
pub total_nodes_explored: usize,
pub pruned_branches: usize,
pub confidence: f64,
pub metadata: GoTMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningGraphData {
pub nodes: Vec<ThoughtNode>,
pub edges: Vec<(u64, u64)>,
pub root_id: u64,
pub leaf_ids: Vec<u64>,
pub max_depth: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GoTMetadata {
pub version: String,
pub duration_ms: u64,
pub config: GoTConfig,
pub generate_count: usize,
pub aggregate_count: usize,
pub refine_count: usize,
pub backtrack_count: usize,
}
struct ReasoningGraph {
graph: DiGraph<ThoughtNode, ()>,
id_to_index: HashMap<u64, NodeIndex>,
next_id: u64,
root: Option<NodeIndex>,
frontier: Vec<NodeIndex>,
}
impl ReasoningGraph {
fn new() -> Self {
Self {
graph: DiGraph::new(),
id_to_index: HashMap::new(),
next_id: 1,
root: None,
frontier: Vec::new(),
}
}
fn add_root(&mut self, reasoning_step: impl Into<String>) -> u64 {
let id = self.next_id;
self.next_id += 1;
let node = ThoughtNode::new(id, reasoning_step).with_depth(0);
let index = self.graph.add_node(node);
self.id_to_index.insert(id, index);
self.root = Some(index);
self.frontier = vec![index];
id
}
fn add_child(&mut self, parent_id: u64, reasoning_step: impl Into<String>) -> Result<u64> {
let parent_index = *self.id_to_index.get(&parent_id).ok_or_else(|| {
GraphOfThoughtError::InvalidOperation {
operation: format!("Parent node {} not found", parent_id),
}
})?;
let parent_depth = self.graph[parent_index].depth;
let id = self.next_id;
self.next_id += 1;
let node = ThoughtNode::new(id, reasoning_step)
.with_depth(parent_depth + 1)
.with_parent(parent_id);
let index = self.graph.add_node(node);
self.id_to_index.insert(id, index);
self.graph.add_edge(parent_index, index, ());
self.graph[parent_index].add_child(id);
self.frontier.retain(|&i| i != parent_index);
self.frontier.push(index);
Ok(id)
}
fn aggregate(&mut self, source_ids: &[u64], reasoning_step: impl Into<String>) -> Result<u64> {
if source_ids.is_empty() {
return Err(GraphOfThoughtError::AggregationFailed {
reason: "No source nodes to aggregate".to_string(),
}
.into());
}
let max_depth = source_ids
.iter()
.filter_map(|id| self.id_to_index.get(id))
.map(|idx| self.graph[*idx].depth)
.max()
.unwrap_or(0);
let id = self.next_id;
self.next_id += 1;
let mut node = ThoughtNode::new(id, reasoning_step).with_depth(max_depth + 1);
node.metadata.aggregated_from = source_ids.to_vec();
node.metadata.created_by = Some("Aggregate".to_string());
node.parents = source_ids.to_vec();
let index = self.graph.add_node(node);
self.id_to_index.insert(id, index);
for source_id in source_ids {
if let Some(&source_index) = self.id_to_index.get(source_id) {
self.graph.add_edge(source_index, index, ());
self.graph[source_index].add_child(id);
}
}
for source_id in source_ids {
if let Some(&source_index) = self.id_to_index.get(source_id) {
self.frontier.retain(|&i| i != source_index);
}
}
self.frontier.push(index);
Ok(id)
}
fn get(&self, id: u64) -> Option<&ThoughtNode> {
self.id_to_index.get(&id).map(|idx| &self.graph[*idx])
}
fn get_mut(&mut self, id: u64) -> Option<&mut ThoughtNode> {
self.id_to_index
.get(&id)
.copied()
.map(|idx| &mut self.graph[idx])
}
fn all_nodes(&self) -> Vec<&ThoughtNode> {
self.graph.node_weights().collect()
}
fn frontier_ids(&self) -> Vec<u64> {
self.frontier
.iter()
.map(|idx| self.graph[*idx].id)
.collect()
}
fn max_depth(&self) -> usize {
self.graph
.node_weights()
.map(|n| n.depth)
.max()
.unwrap_or(0)
}
fn node_count(&self) -> usize {
self.graph.node_count()
}
fn prune(&mut self, threshold: f64, min_keep: usize) -> usize {
let mut scored_nodes: Vec<(NodeIndex, f64)> = self
.frontier
.iter()
.filter_map(|&idx| {
let node = &self.graph[idx];
node.score.map(|s| (idx, s))
})
.collect();
scored_nodes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let keep_indices: std::collections::HashSet<NodeIndex> = scored_nodes
.iter()
.enumerate()
.filter(|(i, (_, score))| *i < min_keep || *score >= threshold)
.map(|(_, (idx, _))| *idx)
.collect();
let to_remove: Vec<NodeIndex> = self
.frontier
.iter()
.filter(|idx| !keep_indices.contains(idx))
.copied()
.collect();
let removed_count = to_remove.len();
for idx in to_remove {
self.frontier.retain(|&i| i != idx);
let id = self.graph[idx].id;
self.id_to_index.remove(&id);
self.graph[idx].valid = Some(false);
self.graph[idx].tags.push("pruned".to_string());
}
removed_count
}
#[allow(dead_code)]
fn has_cycle(&self) -> bool {
petgraph::algo::is_cyclic_directed(&self.graph)
}
fn to_data(&self) -> ReasoningGraphData {
let nodes: Vec<ThoughtNode> = self.graph.node_weights().cloned().collect();
let edges: Vec<(u64, u64)> = self
.graph
.edge_indices()
.filter_map(|e| {
let (a, b) = self.graph.edge_endpoints(e)?;
Some((self.graph[a].id, self.graph[b].id))
})
.collect();
let root_id = self.root.map(|idx| self.graph[idx].id).unwrap_or(0);
let leaf_ids = self.frontier_ids();
let max_depth = self.max_depth();
ReasoningGraphData {
nodes,
edges,
root_id,
leaf_ids,
max_depth,
}
}
#[allow(dead_code)]
fn dfs_collect(&self) -> Vec<u64> {
let mut result = Vec::new();
if let Some(root) = self.root {
let mut dfs = Dfs::new(&self.graph, root);
while let Some(idx) = dfs.next(&self.graph) {
result.push(self.graph[idx].id);
}
}
result
}
}
pub trait AsyncThinkToolModule: ThinkToolModule {
fn execute_async<'a>(
&'a self,
context: &'a ThinkToolContext,
) -> Pin<Box<dyn Future<Output = Result<ThinkToolOutput>> + Send + 'a>>;
}
pub struct GraphOfThought {
module_config: ThinkToolModuleConfig,
config: GoTConfig,
}
impl Default for GraphOfThought {
fn default() -> Self {
Self::new()
}
}
impl GraphOfThought {
pub fn new() -> Self {
Self::with_config(GoTConfig::default())
}
pub fn with_config(config: GoTConfig) -> Self {
Self {
module_config: ThinkToolModuleConfig {
name: "GraphOfThought".to_string(),
version: "1.0.0".to_string(),
description:
"DAG-based reasoning with branching, aggregation, and refinement (Besta et al. 2023)"
.to_string(),
confidence_weight: 0.25,
},
config,
}
}
pub fn got_config(&self) -> &GoTConfig {
&self.config
}
fn validate_query(&self, query: &str) -> Result<()> {
let length = query.len();
if length < self.config.min_query_length {
return Err(GraphOfThoughtError::QueryValidationFailed {
reason: format!(
"Query too short: {} characters, minimum required is {}",
length, self.config.min_query_length
),
}
.into());
}
if length > self.config.max_query_length {
return Err(GraphOfThoughtError::QueryValidationFailed {
reason: format!(
"Query too long: {} characters, maximum allowed is {}",
length, self.config.max_query_length
),
}
.into());
}
Ok(())
}
fn run_reasoning(&self, query: &str) -> Result<GoTResult> {
let start = std::time::Instant::now();
let mut graph = ReasoningGraph::new();
let mut operations_executed: Vec<ExecutedOperation> = Vec::new();
let mut pruned_count = 0;
let mut generate_count = 0;
let mut aggregate_count = 0;
let mut refine_count = 0;
let mut backtrack_count = 0;
let root_id = graph.add_root(format!("Initial analysis of: {}", query));
let op_start = std::time::Instant::now();
let mut generated_ids =
self.generate_thoughts(&mut graph, root_id, self.config.max_branches)?;
operations_executed.push(ExecutedOperation {
operation: GraphOp::Generate {
k: self.config.max_branches,
},
success: true,
affected_nodes: generated_ids.clone(),
duration_ms: op_start.elapsed().as_millis() as u64,
notes: format!("Generated {} initial branches", generated_ids.len()),
});
generate_count += 1;
let op_start = std::time::Instant::now();
let scored_ids = self.score_thoughts(&mut graph)?;
operations_executed.push(ExecutedOperation {
operation: GraphOp::Score,
success: true,
affected_nodes: scored_ids.clone(),
duration_ms: op_start.elapsed().as_millis() as u64,
notes: format!("Scored {} thoughts", scored_ids.len()),
});
let mut iteration = 0;
let max_iterations = self.config.max_depth.min(5);
while iteration < max_iterations && graph.node_count() < self.config.max_nodes {
iteration += 1;
let frontier_ids = graph.frontier_ids();
let low_scorers: Vec<u64> = frontier_ids
.iter()
.filter(|&&id| {
graph
.get(id)
.map(|n| n.effective_score() < self.config.backtrack_threshold)
.unwrap_or(false)
})
.copied()
.collect();
if !low_scorers.is_empty() && backtrack_count < 3 {
if let Some(best_ancestor) = self.find_best_backtrack_target(&graph, &low_scorers) {
let op_start = std::time::Instant::now();
self.backtrack(&mut graph, best_ancestor)?;
operations_executed.push(ExecutedOperation {
operation: GraphOp::Backtrack {
to_id: best_ancestor,
},
success: true,
affected_nodes: vec![best_ancestor],
duration_ms: op_start.elapsed().as_millis() as u64,
notes: format!("Backtracked to node {} due to low scores", best_ancestor),
});
backtrack_count += 1;
continue;
}
}
if self.config.auto_prune {
let op_start = std::time::Instant::now();
let pruned = graph.prune(self.config.backtrack_threshold, self.config.keep_top_n);
if pruned > 0 {
operations_executed.push(ExecutedOperation {
operation: GraphOp::KeepBestN {
n: self.config.keep_top_n,
},
success: true,
affected_nodes: vec![],
duration_ms: op_start.elapsed().as_millis() as u64,
notes: format!("Pruned {} low-scoring branches", pruned),
});
pruned_count += pruned;
}
}
let top_frontier: Vec<u64> = graph
.frontier_ids()
.into_iter()
.filter(|&id| {
graph
.get(id)
.map(|n| n.is_viable(self.config.backtrack_threshold))
.unwrap_or(false)
})
.take(2)
.collect();
for parent_id in top_frontier {
if graph.node_count() >= self.config.max_nodes {
break;
}
let op_start = std::time::Instant::now();
let new_ids = self.generate_thoughts(&mut graph, parent_id, 2)?;
generated_ids.extend(new_ids.clone());
operations_executed.push(ExecutedOperation {
operation: GraphOp::Generate { k: 2 },
success: true,
affected_nodes: new_ids,
duration_ms: op_start.elapsed().as_millis() as u64,
notes: format!("Generated branches from node {}", parent_id),
});
generate_count += 1;
}
let op_start = std::time::Instant::now();
let newly_scored = self.score_thoughts(&mut graph)?;
if !newly_scored.is_empty() {
operations_executed.push(ExecutedOperation {
operation: GraphOp::Score,
success: true,
affected_nodes: newly_scored.clone(),
duration_ms: op_start.elapsed().as_millis() as u64,
notes: format!("Scored {} new thoughts", newly_scored.len()),
});
}
if self.config.enable_aggregation && iteration >= 2 {
let good_frontier: Vec<u64> = graph
.frontier_ids()
.into_iter()
.filter(|&id| {
graph
.get(id)
.map(|n| n.effective_score() >= self.config.min_confidence * 0.8)
.unwrap_or(false)
})
.take(3)
.collect();
if good_frontier.len() >= 2 {
let op_start = std::time::Instant::now();
if let Ok(agg_id) = self.aggregate_thoughts(&mut graph, &good_frontier) {
operations_executed.push(ExecutedOperation {
operation: GraphOp::Aggregate {
source_ids: good_frontier.clone(),
},
success: true,
affected_nodes: vec![agg_id],
duration_ms: op_start.elapsed().as_millis() as u64,
notes: format!("Aggregated {} thoughts", good_frontier.len()),
});
aggregate_count += 1;
let avg_source_score: f64 = good_frontier
.iter()
.filter_map(|&id| graph.get(id).and_then(|n| n.score))
.sum::<f64>()
/ good_frontier.len() as f64;
if let Some(node) = graph.get_mut(agg_id) {
node.score = Some((avg_source_score * 1.1).min(1.0));
node.valid = Some(true);
}
}
}
}
if self.config.enable_refinement && refine_count < self.config.max_refinement_rounds {
let top_thought = graph
.frontier_ids()
.into_iter()
.filter_map(|id| graph.get(id).map(|n| (id, n.effective_score())))
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(id, _)| id);
if let Some(thought_id) = top_thought {
let op_start = std::time::Instant::now();
if self.refine_thought(&mut graph, thought_id).is_ok() {
operations_executed.push(ExecutedOperation {
operation: GraphOp::Refine { max_iterations: 1 },
success: true,
affected_nodes: vec![thought_id],
duration_ms: op_start.elapsed().as_millis() as u64,
notes: format!("Refined thought {}", thought_id),
});
refine_count += 1;
}
}
}
}
let final_thoughts = self.select_final_thoughts(&graph);
let confidence = self.calculate_confidence(&final_thoughts, &graph);
let duration_ms = start.elapsed().as_millis() as u64;
Ok(GoTResult {
query: query.to_string(),
final_thoughts,
reasoning_graph: graph.to_data(),
operations_executed,
total_nodes_explored: graph.node_count(),
pruned_branches: pruned_count,
confidence,
metadata: GoTMetadata {
version: self.module_config.version.clone(),
duration_ms,
config: self.config.clone(),
generate_count,
aggregate_count,
refine_count,
backtrack_count,
},
})
}
fn generate_thoughts(
&self,
graph: &mut ReasoningGraph,
parent_id: u64,
k: usize,
) -> Result<Vec<u64>> {
let parent = graph
.get(parent_id)
.ok_or_else(|| GraphOfThoughtError::InvalidOperation {
operation: format!("Parent node {} not found", parent_id),
})?;
let parent_content = parent.reasoning_step.clone();
let parent_depth = parent.depth;
if parent_depth >= self.config.max_depth {
return Err(GraphOfThoughtError::MaxDepthExceeded {
depth: parent_depth + 1,
max_depth: self.config.max_depth,
}
.into());
}
let mut generated_ids = Vec::new();
let branch_types = [
("analytical", "Breaking down the problem into components"),
("creative", "Exploring unconventional approaches"),
("critical", "Examining potential flaws and counterarguments"),
("practical", "Considering implementation and feasibility"),
("comparative", "Drawing parallels with similar problems"),
("systematic", "Following a structured methodology"),
("intuitive", "Applying domain expertise and patterns"),
("contrarian", "Challenging assumptions and conventions"),
];
for (i, (branch_type, description)) in branch_types.iter().take(k).enumerate() {
let reasoning = format!(
"Branch {}: {} approach\n\nBuilding on: {}\n\n{}",
i + 1,
branch_type,
parent_content,
description
);
let child_id = graph.add_child(parent_id, reasoning)?;
if let Some(node) = graph.get_mut(child_id) {
node.metadata.created_by = Some("Generate".to_string());
node.tags.push(branch_type.to_string());
}
generated_ids.push(child_id);
}
Ok(generated_ids)
}
fn score_thoughts(&self, graph: &mut ReasoningGraph) -> Result<Vec<u64>> {
let unscored: Vec<u64> = graph
.all_nodes()
.iter()
.filter(|n| n.score.is_none())
.map(|n| n.id)
.collect();
for id in &unscored {
if let Some(node) = graph.get_mut(*id) {
let depth_factor = 1.0 - (node.depth as f64 * 0.05).min(0.3);
let content_factor = (node.reasoning_step.len() as f64 / 500.0).min(1.0);
let tag_bonus = if node.tags.is_empty() { 0.0 } else { 0.1 };
let base_score = 0.5 + (depth_factor * 0.2) + (content_factor * 0.2) + tag_bonus;
let has_analysis = node.reasoning_step.to_lowercase().contains("because")
|| node.reasoning_step.to_lowercase().contains("therefore")
|| node.reasoning_step.to_lowercase().contains("consider");
let final_score = if has_analysis {
(base_score * 1.1).min(0.95)
} else {
base_score * 0.9
};
node.score = Some(final_score.clamp(0.3, 0.95));
node.valid = Some(final_score >= self.config.backtrack_threshold);
}
}
Ok(unscored)
}
fn aggregate_thoughts(&self, graph: &mut ReasoningGraph, source_ids: &[u64]) -> Result<u64> {
if source_ids.len() < 2 {
return Err(GraphOfThoughtError::AggregationFailed {
reason: "Need at least 2 thoughts to aggregate".to_string(),
}
.into());
}
let mut synthesis =
String::from("Synthesized conclusion from multiple reasoning paths:\n\n");
for (i, &id) in source_ids.iter().enumerate() {
if let Some(node) = graph.get(id) {
synthesis.push_str(&format!(
"Path {} (score: {:.2}): {}\n\n",
i + 1,
node.effective_score(),
node.reasoning_step.chars().take(200).collect::<String>()
));
}
}
synthesis.push_str("Key insights from aggregation:\n");
synthesis.push_str("- Multiple perspectives converge on common themes\n");
synthesis.push_str("- Complementary approaches strengthen the analysis\n");
synthesis.push_str("- Synthesis captures the best elements from each path\n");
graph.aggregate(source_ids, synthesis)
}
fn refine_thought(&self, graph: &mut ReasoningGraph, thought_id: u64) -> Result<()> {
let node =
graph
.get_mut(thought_id)
.ok_or_else(|| GraphOfThoughtError::InvalidOperation {
operation: format!("Thought {} not found for refinement", thought_id),
})?;
if node.metadata.refinement_count >= self.config.max_refinement_rounds {
return Err(GraphOfThoughtError::RefinementExceeded {
rounds: node.metadata.refinement_count,
max_rounds: self.config.max_refinement_rounds,
}
.into());
}
let enhanced = format!(
"{}\n\n[Refinement {}] Additional considerations:\n\
- Strengthening the logical connections\n\
- Addressing potential counterarguments\n\
- Clarifying assumptions and constraints",
node.reasoning_step,
node.metadata.refinement_count + 1
);
node.reasoning_step = enhanced;
node.metadata.refinement_count += 1;
if let Some(score) = node.score {
node.score = Some((score * 1.05).min(0.98));
}
Ok(())
}
fn find_best_backtrack_target(
&self,
graph: &ReasoningGraph,
low_scorers: &[u64],
) -> Option<u64> {
for &low_id in low_scorers {
if let Some(node) = graph.get(low_id) {
for &parent_id in &node.parents {
if let Some(parent) = graph.get(parent_id) {
if parent.effective_score() >= self.config.min_confidence * 0.7 {
return Some(parent_id);
}
}
}
}
}
None
}
fn backtrack(&self, graph: &mut ReasoningGraph, target_id: u64) -> Result<()> {
if graph.get(target_id).is_none() {
return Err(GraphOfThoughtError::BacktrackTargetNotFound { node_id: target_id }.into());
}
if let Some(&idx) = graph.id_to_index.get(&target_id) {
graph.frontier = vec![idx];
}
Ok(())
}
fn select_final_thoughts(&self, graph: &ReasoningGraph) -> Vec<ThoughtNode> {
let mut candidates: Vec<ThoughtNode> = graph
.frontier_ids()
.iter()
.filter_map(|&id| graph.get(id).cloned())
.filter(|n| n.is_viable(self.config.backtrack_threshold))
.collect();
candidates.sort_by(|a, b| {
b.effective_score()
.partial_cmp(&a.effective_score())
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(self.config.keep_top_n);
candidates
}
fn calculate_confidence(&self, final_thoughts: &[ThoughtNode], _graph: &ReasoningGraph) -> f64 {
if final_thoughts.is_empty() {
return 0.0;
}
let avg_score: f64 = final_thoughts
.iter()
.map(|n| n.effective_score())
.sum::<f64>()
/ final_thoughts.len() as f64;
let unique_tags: std::collections::HashSet<&str> = final_thoughts
.iter()
.flat_map(|n| n.tags.iter().map(|s| s.as_str()))
.collect();
let diversity_factor = (unique_tags.len() as f64 / 8.0).min(1.0);
(avg_score * 0.7 + diversity_factor * 0.3).clamp(0.0, 1.0)
}
}
impl ThinkToolModule for GraphOfThought {
fn config(&self) -> &ThinkToolModuleConfig {
&self.module_config
}
fn execute(&self, context: &ThinkToolContext) -> Result<ThinkToolOutput> {
self.validate_query(&context.query)?;
let result = self.run_reasoning(&context.query)?;
let output = json!({
"final_thoughts": result.final_thoughts.iter().map(|t| json!({
"id": t.id,
"reasoning_step": t.reasoning_step.chars().take(500).collect::<String>(),
"score": t.score,
"valid": t.valid,
"depth": t.depth,
"tags": t.tags
})).collect::<Vec<_>>(),
"graph_summary": {
"total_nodes": result.reasoning_graph.nodes.len(),
"edges": result.reasoning_graph.edges.len(),
"max_depth": result.reasoning_graph.max_depth,
"leaf_count": result.reasoning_graph.leaf_ids.len()
},
"operations": result.operations_executed.iter().map(|op| json!({
"type": op.operation.description(),
"success": op.success,
"affected_nodes": op.affected_nodes.len(),
"duration_ms": op.duration_ms,
"notes": op.notes
})).collect::<Vec<_>>(),
"statistics": {
"total_nodes_explored": result.total_nodes_explored,
"pruned_branches": result.pruned_branches,
"generate_operations": result.metadata.generate_count,
"aggregate_operations": result.metadata.aggregate_count,
"refine_operations": result.metadata.refine_count,
"backtrack_operations": result.metadata.backtrack_count
},
"confidence": result.confidence,
"metadata": {
"version": result.metadata.version,
"duration_ms": result.metadata.duration_ms
}
});
Ok(ThinkToolOutput {
module: self.module_config.name.clone(),
confidence: result.confidence,
output,
})
}
}
impl AsyncThinkToolModule for GraphOfThought {
fn execute_async<'a>(
&'a self,
context: &'a ThinkToolContext,
) -> Pin<Box<dyn Future<Output = Result<ThinkToolOutput>> + Send + 'a>> {
Box::pin(async move {
self.execute(context)
})
}
}
#[derive(Default)]
pub struct GoTBuilder {
config: GoTConfig,
}
impl GoTBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn max_branches(mut self, k: usize) -> Self {
self.config.max_branches = k;
self
}
pub fn enable_aggregation(mut self, enabled: bool) -> Self {
self.config.enable_aggregation = enabled;
self
}
pub fn enable_refinement(mut self, enabled: bool) -> Self {
self.config.enable_refinement = enabled;
self
}
pub fn max_refinement_rounds(mut self, rounds: usize) -> Self {
self.config.max_refinement_rounds = rounds;
self
}
pub fn backtrack_threshold(mut self, threshold: f64) -> Self {
self.config.backtrack_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn min_confidence(mut self, confidence: f64) -> Self {
self.config.min_confidence = confidence.clamp(0.0, 1.0);
self
}
pub fn max_depth(mut self, depth: usize) -> Self {
self.config.max_depth = depth;
self
}
pub fn auto_prune(mut self, enabled: bool) -> Self {
self.config.auto_prune = enabled;
self
}
pub fn keep_top_n(mut self, n: usize) -> Self {
self.config.keep_top_n = n;
self
}
pub fn max_nodes(mut self, max: usize) -> Self {
self.config.max_nodes = max;
self
}
pub fn verbose(mut self, enabled: bool) -> Self {
self.config.verbose = enabled;
self
}
pub fn build(self) -> GraphOfThought {
GraphOfThought::with_config(self.config)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_got_creation() {
let got = GraphOfThought::new();
assert_eq!(got.config().name, "GraphOfThought");
assert_eq!(got.config().version, "1.0.0");
}
#[test]
fn test_got_config_default() {
let config = GoTConfig::default();
assert_eq!(config.max_branches, 4);
assert!(config.enable_aggregation);
assert!(config.enable_refinement);
assert_eq!(config.max_refinement_rounds, 3);
assert!((config.backtrack_threshold - 0.3).abs() < 0.01);
assert!((config.min_confidence - 0.75).abs() < 0.01);
}
#[test]
fn test_got_config_presets() {
let quick = GoTConfig::quick();
assert_eq!(quick.max_branches, 2);
assert!(!quick.enable_aggregation);
let deep = GoTConfig::deep();
assert_eq!(deep.max_branches, 6);
assert!(deep.enable_aggregation);
let paranoid = GoTConfig::paranoid();
assert_eq!(paranoid.max_branches, 8);
assert!((paranoid.min_confidence - 0.90).abs() < 0.01);
}
#[test]
fn test_thought_node_creation() {
let node = ThoughtNode::new(1, "Test reasoning step")
.with_score(0.85)
.with_valid(true)
.with_depth(2)
.with_tag("analytical");
assert_eq!(node.id, 1);
assert_eq!(node.reasoning_step, "Test reasoning step");
assert_eq!(node.score, Some(0.85));
assert_eq!(node.valid, Some(true));
assert_eq!(node.depth, 2);
assert!(node.tags.contains(&"analytical".to_string()));
}
#[test]
fn test_thought_node_viability() {
let node = ThoughtNode::new(1, "Test").with_score(0.8).with_valid(true);
assert!(node.is_viable(0.5));
assert!(node.is_viable(0.8));
assert!(!node.is_viable(0.81));
let invalid_node = ThoughtNode::new(2, "Test")
.with_score(0.9)
.with_valid(false);
assert!(!invalid_node.is_viable(0.5));
}
#[test]
fn test_graph_op_descriptions() {
let gen_op = GraphOp::Generate { k: 4 };
assert!(gen_op.description().contains("4"));
let agg_op = GraphOp::Aggregate {
source_ids: vec![1, 2, 3],
};
assert!(agg_op.description().contains("3"));
let refine_op = GraphOp::Refine { max_iterations: 5 };
assert!(refine_op.description().contains("5"));
}
#[test]
fn test_reasoning_graph_basic() {
let mut graph = ReasoningGraph::new();
let root_id = graph.add_root("Root thought");
assert_eq!(root_id, 1);
assert_eq!(graph.node_count(), 1);
let child_id = graph.add_child(root_id, "Child thought").unwrap();
assert_eq!(child_id, 2);
assert_eq!(graph.node_count(), 2);
let root = graph.get(root_id).unwrap();
assert_eq!(root.children, vec![child_id]);
let child = graph.get(child_id).unwrap();
assert_eq!(child.parents, vec![root_id]);
assert_eq!(child.depth, 1);
}
#[test]
fn test_reasoning_graph_aggregation() {
let mut graph = ReasoningGraph::new();
let root_id = graph.add_root("Root");
let child1 = graph.add_child(root_id, "Child 1").unwrap();
let child2 = graph.add_child(root_id, "Child 2").unwrap();
let agg_id = graph.aggregate(&[child1, child2], "Aggregated").unwrap();
let agg_node = graph.get(agg_id).unwrap();
assert_eq!(agg_node.parents.len(), 2);
assert!(agg_node.metadata.aggregated_from.contains(&child1));
assert!(agg_node.metadata.aggregated_from.contains(&child2));
}
#[test]
fn test_got_builder() {
let got = GoTBuilder::new()
.max_branches(6)
.enable_aggregation(false)
.min_confidence(0.85)
.build();
assert_eq!(got.got_config().max_branches, 6);
assert!(!got.got_config().enable_aggregation);
assert!((got.got_config().min_confidence - 0.85).abs() < 0.01);
}
#[test]
fn test_query_validation() {
let got = GraphOfThought::new();
let short_result = got.validate_query("short");
assert!(short_result.is_err());
let valid_result =
got.validate_query("What is the optimal strategy for solving complex problems?");
assert!(valid_result.is_ok());
}
#[test]
fn test_got_execution() {
let got = GraphOfThought::with_config(GoTConfig {
max_nodes: 20,
max_depth: 3,
max_branches: 2,
..GoTConfig::quick()
});
let context = ThinkToolContext::new("What are the key considerations for system design?");
let result = got.execute(&context).unwrap();
assert_eq!(result.module, "GraphOfThought");
assert!(result.confidence > 0.0);
assert!(result.output.get("final_thoughts").is_some());
assert!(result.output.get("graph_summary").is_some());
assert!(result.output.get("operations").is_some());
assert!(result.output.get("statistics").is_some());
}
#[test]
fn test_reasoning_graph_no_cycles() {
let mut graph = ReasoningGraph::new();
let root_id = graph.add_root("Root");
let child1 = graph.add_child(root_id, "Child 1").unwrap();
let _child2 = graph.add_child(child1, "Child 2").unwrap();
assert!(!graph.has_cycle());
}
#[test]
fn test_reasoning_graph_pruning() {
let mut graph = ReasoningGraph::new();
let root_id = graph.add_root("Root");
for i in 0..5 {
let child_id = graph.add_child(root_id, format!("Child {}", i)).unwrap();
if let Some(node) = graph.get_mut(child_id) {
node.score = Some(0.2 + (i as f64 * 0.15));
}
}
let initial_frontier = graph.frontier_ids().len();
let pruned = graph.prune(0.5, 2);
assert!(pruned > 0);
assert!(graph.frontier_ids().len() < initial_frontier);
}
#[test]
fn test_got_result_structure() {
let got = GraphOfThought::with_config(GoTConfig::quick());
let context =
ThinkToolContext::new("Analyze the trade-offs between performance and maintainability");
let result = got.run_reasoning(&context.query).unwrap();
assert!(!result.final_thoughts.is_empty());
assert!(!result.reasoning_graph.nodes.is_empty());
assert!(result.total_nodes_explored > 0);
assert!(result.confidence > 0.0);
}
#[tokio::test]
async fn test_async_execution() {
let got = GraphOfThought::with_config(GoTConfig::quick());
let context =
ThinkToolContext::new("What factors should be considered in technology selection?");
let result = got.execute_async(&context).await.unwrap();
assert_eq!(result.module, "GraphOfThought");
assert!(result.confidence > 0.0);
}
#[test]
fn test_dfs_traversal() {
let mut graph = ReasoningGraph::new();
let root_id = graph.add_root("Root");
let child1 = graph.add_child(root_id, "Child 1").unwrap();
let child2 = graph.add_child(root_id, "Child 2").unwrap();
let grandchild = graph.add_child(child1, "Grandchild").unwrap();
let traversal = graph.dfs_collect();
assert_eq!(traversal.len(), 4);
assert!(traversal.contains(&root_id));
assert!(traversal.contains(&child1));
assert!(traversal.contains(&child2));
assert!(traversal.contains(&grandchild));
}
}