use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use tracing::{debug, info};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChainOfThoughtConfig {
pub max_steps: usize,
pub min_step_confidence: f32,
pub enable_self_consistency: bool,
pub num_consistency_paths: usize,
pub enable_verification: bool,
pub explain_steps: bool,
}
impl Default for ChainOfThoughtConfig {
fn default() -> Self {
Self {
max_steps: 10,
min_step_confidence: 0.7,
enable_self_consistency: true,
num_consistency_paths: 3,
enable_verification: true,
explain_steps: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThoughtStep {
pub step_number: usize,
pub description: String,
pub reasoning: String,
pub input: String,
pub output: String,
pub confidence: f32,
pub evidence: Vec<String>,
pub assumptions: Vec<String>,
pub step_type: StepType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StepType {
Decomposition,
Retrieval,
Inference,
Calculation,
Synthesis,
Verification,
Conclusion,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChainOfThought {
pub id: String,
pub query: String,
pub steps: Vec<ThoughtStep>,
pub answer: String,
pub confidence: f32,
pub reasoning_time: Duration,
pub alternative_paths: Vec<AlternativePath>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AlternativePath {
pub path_id: usize,
pub steps: Vec<ThoughtStep>,
pub answer: String,
pub confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SelfConsistencyResult {
pub consensus_answer: String,
pub consensus_confidence: f32,
pub answer_distribution: HashMap<String, usize>,
pub all_paths: Vec<AlternativePath>,
}
pub struct ChainOfThoughtEngine {
config: ChainOfThoughtConfig,
}
impl ChainOfThoughtEngine {
pub fn new(config: ChainOfThoughtConfig) -> Self {
info!("Initialized Chain-of-Thought reasoning engine");
Self { config }
}
pub async fn reason(&self, query: &str, context: &str) -> Result<ChainOfThought> {
let start_time = std::time::Instant::now();
info!("Starting Chain-of-Thought reasoning for query: {}", query);
let mut steps = Vec::new();
let decomposition_step = self.decompose_problem(query, context)?;
steps.push(decomposition_step);
let retrieval_step = self.retrieve_information(query, context, &steps)?;
steps.push(retrieval_step);
let mut current_step = 3;
while current_step <= self.config.max_steps {
if let Some(inference_step) = self.apply_inference(&steps, context)? {
steps.push(inference_step);
if steps
.last()
.expect("collection validated to be non-empty")
.step_type
== StepType::Conclusion
{
break;
}
} else {
break;
}
current_step += 1;
}
let answer = self.extract_answer(&steps)?;
let confidence = self.calculate_overall_confidence(&steps);
let alternative_paths = if self.config.enable_self_consistency {
self.generate_alternative_paths(query, context).await?
} else {
Vec::new()
};
let reasoning_time = start_time.elapsed();
debug!(
"Chain-of-Thought reasoning completed in {:?} with {} steps",
reasoning_time,
steps.len()
);
Ok(ChainOfThought {
id: uuid::Uuid::new_v4().to_string(),
query: query.to_string(),
steps,
answer,
confidence,
reasoning_time,
alternative_paths,
})
}
fn decompose_problem(&self, query: &str, _context: &str) -> Result<ThoughtStep> {
debug!("Decomposing problem: {}", query);
let components = self.identify_components(query);
let description = "Breaking down the query into manageable components".to_string();
let reasoning = format!(
"Identified {} key components: {}",
components.len(),
components.join(", ")
);
Ok(ThoughtStep {
step_number: 1,
description,
reasoning,
input: query.to_string(),
output: components.join("; "),
confidence: 0.9,
evidence: vec![],
assumptions: vec!["Query is well-formed".to_string()],
step_type: StepType::Decomposition,
})
}
fn retrieve_information(
&self,
_query: &str,
context: &str,
_previous_steps: &[ThoughtStep],
) -> Result<ThoughtStep> {
debug!("Retrieving relevant information from context");
let relevant_facts = self.extract_relevant_facts(context);
Ok(ThoughtStep {
step_number: 2,
description: "Gathering relevant information from knowledge base".to_string(),
reasoning: format!("Found {} relevant facts from context", relevant_facts.len()),
input: context.to_string(),
output: relevant_facts.join("; "),
confidence: 0.85,
evidence: relevant_facts.clone(),
assumptions: vec![],
step_type: StepType::Retrieval,
})
}
fn apply_inference(
&self,
previous_steps: &[ThoughtStep],
_context: &str,
) -> Result<Option<ThoughtStep>> {
if previous_steps.is_empty() {
return Ok(None);
}
let step_number = previous_steps.len() + 1;
if previous_steps.len() >= 2 {
let conclusion = self.synthesize_conclusion(previous_steps)?;
return Ok(Some(conclusion));
}
let last_output = &previous_steps
.last()
.expect("collection validated to be non-empty")
.output;
let inference = format!("Based on previous analysis: {}", last_output);
Ok(Some(ThoughtStep {
step_number,
description: "Applying logical inference".to_string(),
reasoning: "Connecting retrieved facts to query components".to_string(),
input: last_output.clone(),
output: inference.clone(),
confidence: 0.8,
evidence: vec![],
assumptions: vec![],
step_type: StepType::Inference,
}))
}
fn synthesize_conclusion(&self, steps: &[ThoughtStep]) -> Result<ThoughtStep> {
let step_number = steps.len() + 1;
let all_outputs: Vec<String> = steps.iter().map(|s| s.output.clone()).collect();
let conclusion = format!(
"Based on the reasoning chain, the answer is derived from: {}",
all_outputs.join("; ")
);
Ok(ThoughtStep {
step_number,
description: "Synthesizing final conclusion".to_string(),
reasoning: "Combining all reasoning steps into final answer".to_string(),
input: all_outputs.join("\n"),
output: conclusion.clone(),
confidence: 0.85,
evidence: all_outputs,
assumptions: vec![],
step_type: StepType::Conclusion,
})
}
fn extract_answer(&self, steps: &[ThoughtStep]) -> Result<String> {
if let Some(conclusion_step) = steps.iter().find(|s| s.step_type == StepType::Conclusion) {
Ok(conclusion_step.output.clone())
} else if let Some(last_step) = steps.last() {
Ok(last_step.output.clone())
} else {
Ok("Unable to determine answer".to_string())
}
}
fn calculate_overall_confidence(&self, steps: &[ThoughtStep]) -> f32 {
if steps.is_empty() {
return 0.0;
}
let total_weight: f32 = (1..=steps.len()).map(|i| i as f32).sum();
let weighted_sum: f32 = steps
.iter()
.enumerate()
.map(|(i, step)| (i + 1) as f32 * step.confidence)
.sum();
weighted_sum / total_weight
}
async fn generate_alternative_paths(
&self,
query: &str,
context: &str,
) -> Result<Vec<AlternativePath>> {
let mut paths = Vec::new();
for path_id in 0..self.config.num_consistency_paths {
let steps = self.generate_alternative_steps(query, context, path_id)?;
let answer = self.extract_answer(&steps)?;
let confidence = self.calculate_overall_confidence(&steps);
paths.push(AlternativePath {
path_id,
steps,
answer,
confidence,
});
}
Ok(paths)
}
fn generate_alternative_steps(
&self,
query: &str,
context: &str,
path_id: usize,
) -> Result<Vec<ThoughtStep>> {
let mut steps = Vec::new();
let decomposition = self.decompose_problem(query, context)?;
let mut varied_decomposition = decomposition.clone();
varied_decomposition.reasoning = format!(
"Alternative approach {}: {}",
path_id, decomposition.reasoning
);
steps.push(varied_decomposition);
let retrieval = self.retrieve_information(query, context, &steps)?;
steps.push(retrieval);
if let Some(inference) = self.apply_inference(&steps, context)? {
steps.push(inference);
}
let conclusion = self.synthesize_conclusion(&steps)?;
steps.push(conclusion);
Ok(steps)
}
pub fn check_self_consistency(&self, paths: &[AlternativePath]) -> SelfConsistencyResult {
let mut answer_distribution: HashMap<String, usize> = HashMap::new();
for path in paths {
*answer_distribution.entry(path.answer.clone()).or_insert(0) += 1;
}
let (consensus_answer, count) = answer_distribution
.iter()
.max_by_key(|(_, &count)| count)
.map(|(answer, &count)| (answer.clone(), count))
.unwrap_or_else(|| ("Unknown".to_string(), 0));
let consensus_confidence = count as f32 / paths.len() as f32;
SelfConsistencyResult {
consensus_answer,
consensus_confidence,
answer_distribution,
all_paths: paths.to_vec(),
}
}
fn identify_components(&self, query: &str) -> Vec<String> {
query
.split_whitespace()
.filter(|word| word.len() > 3)
.take(5)
.map(|s| s.to_string())
.collect()
}
fn extract_relevant_facts(&self, context: &str) -> Vec<String> {
context
.split(['.', '!', '?'])
.filter(|s| !s.trim().is_empty())
.take(3)
.map(|s| s.trim().to_string())
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_chain_of_thought_reasoning() {
let config = ChainOfThoughtConfig::default();
let engine = ChainOfThoughtEngine::new(config);
let query = "What movies were directed by Christopher Nolan in 2020?";
let context = "Christopher Nolan directed Tenet in 2020. Tenet is a science fiction film.";
let result = engine.reason(query, context).await.expect("should succeed");
assert!(!result.steps.is_empty());
assert!(!result.answer.is_empty());
assert!(result.confidence > 0.0);
}
#[test]
fn test_decompose_problem() {
let config = ChainOfThoughtConfig::default();
let engine = ChainOfThoughtEngine::new(config);
let query = "Find all proteins related to cancer";
let step = engine.decompose_problem(query, "").expect("should succeed");
assert_eq!(step.step_number, 1);
assert_eq!(step.step_type, StepType::Decomposition);
assert!(step.confidence > 0.0);
}
#[test]
fn test_confidence_calculation() {
let config = ChainOfThoughtConfig::default();
let engine = ChainOfThoughtEngine::new(config);
let steps = vec![
ThoughtStep {
step_number: 1,
description: "Step 1".to_string(),
reasoning: "Reasoning 1".to_string(),
input: "Input 1".to_string(),
output: "Output 1".to_string(),
confidence: 0.9,
evidence: vec![],
assumptions: vec![],
step_type: StepType::Decomposition,
},
ThoughtStep {
step_number: 2,
description: "Step 2".to_string(),
reasoning: "Reasoning 2".to_string(),
input: "Input 2".to_string(),
output: "Output 2".to_string(),
confidence: 0.8,
evidence: vec![],
assumptions: vec![],
step_type: StepType::Conclusion,
},
];
let confidence = engine.calculate_overall_confidence(&steps);
assert!(confidence > 0.8 && confidence < 0.9);
}
}