reasonkit/thinktool/
consistency.rs

1//! Self-Consistency Module
2//!
3//! Implements scientifically-proven self-consistency voting mechanism
4//! based on Wang et al. (2023) "Self-Consistency Improves Chain of Thought Reasoning"
5//!
6//! Key findings from research:
7//! - GSM8K: +17.9% accuracy improvement
8//! - SVAMP: +11.0% accuracy improvement
9//! - AQuA: +12.2% accuracy improvement
10//!
11//! Reference: <https://arxiv.org/abs/2203.11171>
12
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16use super::step::{StepOutput, StepResult, TokenUsage};
17
18/// Self-Consistency configuration
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct SelfConsistencyConfig {
21    /// Number of reasoning paths to sample (default: 5)
22    /// Research shows diminishing returns after ~10-15 samples
23    pub num_samples: usize,
24
25    /// Voting method to use
26    pub voting_method: VotingMethod,
27
28    /// Temperature variance for diverse sampling
29    /// Higher values = more diverse reasoning paths
30    pub temperature_base: f64,
31
32    /// Temperature increment per sample (for diversity)
33    pub temperature_variance: f64,
34
35    /// Minimum confidence threshold for a sample to be included in voting
36    pub min_sample_confidence: f64,
37
38    /// Enable CISC (Confidence-Informed Self-Consistency)
39    /// Reduces required samples by ~40% (arXiv:2502.06233)
40    pub use_cisc: bool,
41
42    /// Early stopping if consensus reached
43    pub early_stopping: bool,
44
45    /// Consensus threshold for early stopping (e.g., 0.8 = 80% agreement)
46    pub consensus_threshold: f64,
47}
48
49impl Default for SelfConsistencyConfig {
50    fn default() -> Self {
51        Self {
52            num_samples: 5,
53            voting_method: VotingMethod::MajorityVote,
54            temperature_base: 0.7,
55            temperature_variance: 0.1,
56            min_sample_confidence: 0.5,
57            use_cisc: true, // Enable by default for cost efficiency
58            early_stopping: true,
59            consensus_threshold: 0.8,
60        }
61    }
62}
63
64impl SelfConsistencyConfig {
65    /// Create a fast config (fewer samples, early stopping)
66    pub fn fast() -> Self {
67        Self {
68            num_samples: 3,
69            early_stopping: true,
70            consensus_threshold: 0.7,
71            ..Default::default()
72        }
73    }
74
75    /// Create a thorough config (more samples, no early stopping)
76    pub fn thorough() -> Self {
77        Self {
78            num_samples: 10,
79            early_stopping: false,
80            ..Default::default()
81        }
82    }
83
84    /// Create a paranoid config (maximum samples)
85    pub fn paranoid() -> Self {
86        Self {
87            num_samples: 15,
88            early_stopping: false,
89            min_sample_confidence: 0.6,
90            ..Default::default()
91        }
92    }
93
94    /// Get temperature for a specific sample index
95    pub fn temperature_for_sample(&self, index: usize) -> f64 {
96        self.temperature_base + (index as f64 * self.temperature_variance)
97    }
98}
99
100/// Voting methods for self-consistency
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
102#[serde(rename_all = "snake_case")]
103pub enum VotingMethod {
104    /// Simple majority voting (original self-consistency)
105    #[default]
106    MajorityVote,
107
108    /// Weighted by confidence scores (CISC)
109    ConfidenceWeighted,
110
111    /// Weighted by semantic similarity clustering
112    ClusterWeighted,
113
114    /// Unanimous agreement required
115    Unanimous,
116}
117
118/// A single sampled reasoning path
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct ReasoningPath {
121    /// The final answer/conclusion extracted
122    pub answer: String,
123
124    /// The full reasoning trace
125    pub reasoning: String,
126
127    /// Confidence score for this path
128    pub confidence: f64,
129
130    /// Token usage for this sample
131    pub tokens: TokenUsage,
132
133    /// Temperature used for this sample
134    pub temperature: f64,
135
136    /// Sample index
137    pub sample_index: usize,
138}
139
140/// Result of self-consistency voting
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct ConsistencyResult {
143    /// The winning answer after voting
144    pub answer: String,
145
146    /// Aggregated confidence (voting strength)
147    pub confidence: f64,
148
149    /// Number of votes for winning answer
150    pub vote_count: usize,
151
152    /// Total number of samples
153    pub total_samples: usize,
154
155    /// Agreement ratio (votes / total)
156    pub agreement_ratio: f64,
157
158    /// All reasoning paths sampled
159    pub paths: Vec<ReasoningPath>,
160
161    /// Vote distribution (answer -> count)
162    pub vote_distribution: HashMap<String, usize>,
163
164    /// Whether early stopping was triggered
165    pub early_stopped: bool,
166
167    /// Total token usage across all samples
168    pub total_tokens: TokenUsage,
169}
170
171impl ConsistencyResult {
172    /// Check if result meets confidence threshold
173    pub fn meets_threshold(&self, threshold: f64) -> bool {
174        self.confidence >= threshold && self.agreement_ratio >= 0.5
175    }
176
177    /// Get the dissenting paths (those that disagreed with winner)
178    pub fn dissenting_paths(&self) -> Vec<&ReasoningPath> {
179        self.paths
180            .iter()
181            .filter(|p| p.answer != self.answer)
182            .collect()
183    }
184
185    /// Get reasoning diversity score (0-1, higher = more diverse)
186    pub fn diversity_score(&self) -> f64 {
187        let unique_answers = self.vote_distribution.len();
188        if self.total_samples <= 1 {
189            0.0
190        } else {
191            (unique_answers - 1) as f64 / (self.total_samples - 1) as f64
192        }
193    }
194}
195
196/// Self-Consistency Engine
197pub struct SelfConsistencyEngine {
198    config: SelfConsistencyConfig,
199}
200
201impl SelfConsistencyEngine {
202    /// Create a new self-consistency engine
203    pub fn new(config: SelfConsistencyConfig) -> Self {
204        Self { config }
205    }
206
207    /// Create with default config
208    pub fn default_engine() -> Self {
209        Self::new(SelfConsistencyConfig::default())
210    }
211
212    /// Aggregate multiple step results using self-consistency voting
213    pub fn vote(&self, results: Vec<StepResult>) -> ConsistencyResult {
214        let paths: Vec<ReasoningPath> = results
215            .into_iter()
216            .enumerate()
217            .filter_map(|(idx, result)| self.extract_path(result, idx))
218            .collect();
219
220        self.aggregate_paths(paths)
221    }
222
223    /// Extract a reasoning path from a step result
224    fn extract_path(&self, result: StepResult, index: usize) -> Option<ReasoningPath> {
225        if !result.success || result.confidence < self.config.min_sample_confidence {
226            return None;
227        }
228
229        let (answer, reasoning) = match &result.output {
230            StepOutput::Text { content } => {
231                // Extract answer from text (look for common patterns)
232                let answer = self.extract_answer_from_text(content);
233                (answer, content.clone())
234            }
235            StepOutput::Structured { data } => {
236                // Look for answer field in structured output
237                let answer = data
238                    .get("answer")
239                    .or_else(|| data.get("conclusion"))
240                    .or_else(|| data.get("result"))
241                    .and_then(|v| v.as_str())
242                    .map(|s| s.to_string())
243                    .unwrap_or_else(|| format!("{:?}", data));
244                let reasoning = serde_json::to_string_pretty(&data).unwrap_or_default();
245                (answer, reasoning)
246            }
247            StepOutput::Boolean { value, reason } => {
248                let answer = if *value { "true" } else { "false" }.to_string();
249                let reasoning = reason.clone().unwrap_or_default();
250                (answer, reasoning)
251            }
252            StepOutput::Score { value } => (format!("{:.2}", value), String::new()),
253            StepOutput::List { items } => {
254                let answer = items
255                    .iter()
256                    .map(|i| i.content.clone())
257                    .collect::<Vec<_>>()
258                    .join("; ");
259                (answer.clone(), answer)
260            }
261            StepOutput::Empty => return None,
262        };
263
264        Some(ReasoningPath {
265            answer: self.normalize_answer(&answer),
266            reasoning,
267            confidence: result.confidence,
268            tokens: result.tokens,
269            temperature: self.config.temperature_for_sample(index),
270            sample_index: index,
271        })
272    }
273
274    /// Extract answer from free-form text
275    fn extract_answer_from_text(&self, text: &str) -> String {
276        // Look for common answer patterns
277        let patterns = [
278            "the answer is",
279            "therefore,",
280            "in conclusion,",
281            "final answer:",
282            "result:",
283            "answer:",
284        ];
285
286        for pattern in patterns {
287            if let Some(pos) = text.to_lowercase().find(pattern) {
288                let start = pos + pattern.len();
289                let remainder = &text[start..];
290                // Take until end of sentence or newline
291                let end = remainder
292                    .find(['.', '\n', '!', '?'])
293                    .unwrap_or(remainder.len().min(200));
294                return remainder[..end].trim().to_string();
295            }
296        }
297
298        // Fallback: use last sentence
299        text.split(['.', '\n'])
300            .rfind(|s| !s.trim().is_empty())
301            .map(|s| s.trim().to_string())
302            .unwrap_or_else(|| text.chars().take(200).collect())
303    }
304
305    /// Normalize answer for comparison (lowercase, trim, etc.)
306    fn normalize_answer(&self, answer: &str) -> String {
307        answer
308            .to_lowercase()
309            .trim()
310            .replace([',', '.', '!', '?', '"', '\''], "")
311            .split_whitespace()
312            .collect::<Vec<_>>()
313            .join(" ")
314    }
315
316    /// Aggregate reasoning paths using configured voting method
317    fn aggregate_paths(&self, paths: Vec<ReasoningPath>) -> ConsistencyResult {
318        if paths.is_empty() {
319            return ConsistencyResult {
320                answer: String::new(),
321                confidence: 0.0,
322                vote_count: 0,
323                total_samples: 0,
324                agreement_ratio: 0.0,
325                paths: Vec::new(),
326                vote_distribution: HashMap::new(),
327                early_stopped: false,
328                total_tokens: TokenUsage::default(),
329            };
330        }
331
332        // Count votes and calculate weights
333        let mut vote_counts: HashMap<String, usize> = HashMap::new();
334        let mut vote_weights: HashMap<String, f64> = HashMap::new();
335        let mut total_tokens = TokenUsage::default();
336
337        for path in &paths {
338            *vote_counts.entry(path.answer.clone()).or_insert(0) += 1;
339
340            let weight = match self.config.voting_method {
341                VotingMethod::MajorityVote => 1.0,
342                VotingMethod::ConfidenceWeighted => path.confidence,
343                VotingMethod::ClusterWeighted => path.confidence, // Simplified
344                VotingMethod::Unanimous => 1.0,
345            };
346
347            *vote_weights.entry(path.answer.clone()).or_insert(0.0) += weight;
348            total_tokens.add(&path.tokens);
349        }
350
351        // Find winner - using safe comparison that handles NaN gracefully
352        let (winner, vote_count) = match self.config.voting_method {
353            VotingMethod::Unanimous => {
354                // All must agree
355                if vote_counts.len() == 1 {
356                    // SAFETY: We checked vote_counts.len() == 1, so there's exactly one entry
357                    // Using unwrap_or_default as a defensive fallback
358                    vote_counts.into_iter().next().unwrap_or_default()
359                } else {
360                    // No consensus - return most common with low confidence
361                    vote_counts
362                        .into_iter()
363                        .max_by_key(|(_, count)| *count)
364                        .unwrap_or_default()
365                }
366            }
367            _ => {
368                // Find by weight - use total_cmp for safe f64 comparison (handles NaN)
369                vote_weights
370                    .iter()
371                    .max_by(|a, b| a.1.total_cmp(b.1))
372                    .map(|(answer, _)| {
373                        let count = vote_counts.get(answer).copied().unwrap_or(0);
374                        (answer.clone(), count)
375                    })
376                    .unwrap_or_default()
377            }
378        };
379
380        let total_samples = paths.len();
381        let agreement_ratio = vote_count as f64 / total_samples as f64;
382
383        // Calculate aggregated confidence
384        let confidence = if self.config.use_cisc {
385            // CISC: Weight confidence by agreement
386            let winner_paths: Vec<_> = paths.iter().filter(|p| p.answer == winner).collect();
387            if winner_paths.is_empty() {
388                0.0
389            } else {
390                let avg_confidence: f64 = winner_paths.iter().map(|p| p.confidence).sum::<f64>()
391                    / winner_paths.len() as f64;
392                avg_confidence * agreement_ratio
393            }
394        } else {
395            // Simple: Just use agreement ratio
396            agreement_ratio
397        };
398
399        // Rebuild vote distribution with original counts
400        let mut final_distribution = HashMap::new();
401        for path in &paths {
402            *final_distribution.entry(path.answer.clone()).or_insert(0) += 1;
403        }
404
405        ConsistencyResult {
406            answer: winner,
407            confidence,
408            vote_count,
409            total_samples,
410            agreement_ratio,
411            paths,
412            vote_distribution: final_distribution,
413            early_stopped: false,
414            total_tokens,
415        }
416    }
417
418    /// Check if early stopping should be triggered
419    pub fn should_early_stop(&self, current_results: &[StepResult]) -> bool {
420        if !self.config.early_stopping || current_results.len() < 3 {
421            return false;
422        }
423
424        let paths: Vec<ReasoningPath> = current_results
425            .iter()
426            .enumerate()
427            .filter_map(|(idx, result)| self.extract_path(result.clone(), idx))
428            .collect();
429
430        if paths.is_empty() {
431            return false;
432        }
433
434        // Count current votes
435        let mut vote_counts: HashMap<String, usize> = HashMap::new();
436        for path in &paths {
437            *vote_counts.entry(path.answer.clone()).or_insert(0) += 1;
438        }
439
440        // Check if any answer has reached consensus threshold
441        let max_votes = vote_counts.values().max().copied().unwrap_or(0);
442        let current_ratio = max_votes as f64 / paths.len() as f64;
443
444        current_ratio >= self.config.consensus_threshold
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn test_config_defaults() {
454        let config = SelfConsistencyConfig::default();
455        assert_eq!(config.num_samples, 5);
456        assert!(config.use_cisc);
457        assert!(config.early_stopping);
458    }
459
460    #[test]
461    fn test_temperature_variance() {
462        let config = SelfConsistencyConfig::default();
463        assert!((config.temperature_for_sample(0) - 0.7).abs() < 0.01);
464        assert!((config.temperature_for_sample(1) - 0.8).abs() < 0.01);
465        assert!((config.temperature_for_sample(2) - 0.9).abs() < 0.01);
466    }
467
468    #[test]
469    fn test_majority_voting() {
470        let engine = SelfConsistencyEngine::default_engine();
471
472        let results = vec![
473            StepResult::success(
474                "test",
475                StepOutput::Text {
476                    content: "The answer is 42.".to_string(),
477                },
478                0.8,
479            ),
480            StepResult::success(
481                "test",
482                StepOutput::Text {
483                    content: "The answer is 42.".to_string(),
484                },
485                0.85,
486            ),
487            StepResult::success(
488                "test",
489                StepOutput::Text {
490                    content: "The answer is 43.".to_string(),
491                },
492                0.75,
493            ),
494        ];
495
496        let result = engine.vote(results);
497
498        assert_eq!(result.answer, "42");
499        assert_eq!(result.vote_count, 2);
500        assert_eq!(result.total_samples, 3);
501    }
502
503    #[test]
504    fn test_normalize_answer() {
505        let engine = SelfConsistencyEngine::default_engine();
506
507        assert_eq!(engine.normalize_answer("  HELLO, World!  "), "hello world");
508        assert_eq!(engine.normalize_answer("42."), "42");
509    }
510
511    #[test]
512    fn test_diversity_score() {
513        let result = ConsistencyResult {
514            answer: "42".to_string(),
515            confidence: 0.8,
516            vote_count: 2,
517            total_samples: 3,
518            agreement_ratio: 0.67,
519            paths: Vec::new(),
520            vote_distribution: HashMap::from([("42".to_string(), 2), ("43".to_string(), 1)]),
521            early_stopped: false,
522            total_tokens: TokenUsage::default(),
523        };
524
525        // 2 unique answers out of 3 samples = diversity 0.5
526        assert!((result.diversity_score() - 0.5).abs() < 0.01);
527    }
528
529    #[test]
530    fn test_early_stopping() {
531        let config = SelfConsistencyConfig {
532            consensus_threshold: 0.7,
533            early_stopping: true,
534            ..Default::default()
535        };
536        let engine = SelfConsistencyEngine::new(config);
537
538        // 3 out of 4 agree = 75% > 70% threshold
539        let results: Vec<StepResult> = (0..4)
540            .map(|i| {
541                let answer = if i < 3 { "42" } else { "43" };
542                StepResult::success(
543                    "test",
544                    StepOutput::Text {
545                        content: format!("The answer is {}.", answer),
546                    },
547                    0.8,
548                )
549            })
550            .collect();
551
552        assert!(engine.should_early_stop(&results));
553    }
554
555    #[test]
556    fn test_empty_paths_handling() {
557        let engine = SelfConsistencyEngine::default_engine();
558        let result = engine.aggregate_paths(vec![]);
559
560        assert!(result.answer.is_empty());
561        assert_eq!(result.confidence, 0.0);
562        assert_eq!(result.total_samples, 0);
563    }
564
565    #[test]
566    fn test_nan_handling_in_vote_weights() {
567        // Ensure we handle NaN values gracefully
568        let engine = SelfConsistencyEngine::new(SelfConsistencyConfig {
569            voting_method: VotingMethod::ConfidenceWeighted,
570            ..Default::default()
571        });
572
573        // This should not panic even with edge cases
574        let result = engine.aggregate_paths(vec![]);
575        assert!(result.answer.is_empty());
576    }
577}