1use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16use super::step::{StepOutput, StepResult, TokenUsage};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct SelfConsistencyConfig {
21 pub num_samples: usize,
24
25 pub voting_method: VotingMethod,
27
28 pub temperature_base: f64,
31
32 pub temperature_variance: f64,
34
35 pub min_sample_confidence: f64,
37
38 pub use_cisc: bool,
41
42 pub early_stopping: bool,
44
45 pub consensus_threshold: f64,
47}
48
49impl Default for SelfConsistencyConfig {
50 fn default() -> Self {
51 Self {
52 num_samples: 5,
53 voting_method: VotingMethod::MajorityVote,
54 temperature_base: 0.7,
55 temperature_variance: 0.1,
56 min_sample_confidence: 0.5,
57 use_cisc: true, early_stopping: true,
59 consensus_threshold: 0.8,
60 }
61 }
62}
63
64impl SelfConsistencyConfig {
65 pub fn fast() -> Self {
67 Self {
68 num_samples: 3,
69 early_stopping: true,
70 consensus_threshold: 0.7,
71 ..Default::default()
72 }
73 }
74
75 pub fn thorough() -> Self {
77 Self {
78 num_samples: 10,
79 early_stopping: false,
80 ..Default::default()
81 }
82 }
83
84 pub fn paranoid() -> Self {
86 Self {
87 num_samples: 15,
88 early_stopping: false,
89 min_sample_confidence: 0.6,
90 ..Default::default()
91 }
92 }
93
94 pub fn temperature_for_sample(&self, index: usize) -> f64 {
96 self.temperature_base + (index as f64 * self.temperature_variance)
97 }
98}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
102#[serde(rename_all = "snake_case")]
103pub enum VotingMethod {
104 #[default]
106 MajorityVote,
107
108 ConfidenceWeighted,
110
111 ClusterWeighted,
113
114 Unanimous,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct ReasoningPath {
121 pub answer: String,
123
124 pub reasoning: String,
126
127 pub confidence: f64,
129
130 pub tokens: TokenUsage,
132
133 pub temperature: f64,
135
136 pub sample_index: usize,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct ConsistencyResult {
143 pub answer: String,
145
146 pub confidence: f64,
148
149 pub vote_count: usize,
151
152 pub total_samples: usize,
154
155 pub agreement_ratio: f64,
157
158 pub paths: Vec<ReasoningPath>,
160
161 pub vote_distribution: HashMap<String, usize>,
163
164 pub early_stopped: bool,
166
167 pub total_tokens: TokenUsage,
169}
170
171impl ConsistencyResult {
172 pub fn meets_threshold(&self, threshold: f64) -> bool {
174 self.confidence >= threshold && self.agreement_ratio >= 0.5
175 }
176
177 pub fn dissenting_paths(&self) -> Vec<&ReasoningPath> {
179 self.paths
180 .iter()
181 .filter(|p| p.answer != self.answer)
182 .collect()
183 }
184
185 pub fn diversity_score(&self) -> f64 {
187 let unique_answers = self.vote_distribution.len();
188 if self.total_samples <= 1 {
189 0.0
190 } else {
191 (unique_answers - 1) as f64 / (self.total_samples - 1) as f64
192 }
193 }
194}
195
196pub struct SelfConsistencyEngine {
198 config: SelfConsistencyConfig,
199}
200
201impl SelfConsistencyEngine {
202 pub fn new(config: SelfConsistencyConfig) -> Self {
204 Self { config }
205 }
206
207 pub fn default_engine() -> Self {
209 Self::new(SelfConsistencyConfig::default())
210 }
211
212 pub fn vote(&self, results: Vec<StepResult>) -> ConsistencyResult {
214 let paths: Vec<ReasoningPath> = results
215 .into_iter()
216 .enumerate()
217 .filter_map(|(idx, result)| self.extract_path(result, idx))
218 .collect();
219
220 self.aggregate_paths(paths)
221 }
222
223 fn extract_path(&self, result: StepResult, index: usize) -> Option<ReasoningPath> {
225 if !result.success || result.confidence < self.config.min_sample_confidence {
226 return None;
227 }
228
229 let (answer, reasoning) = match &result.output {
230 StepOutput::Text { content } => {
231 let answer = self.extract_answer_from_text(content);
233 (answer, content.clone())
234 }
235 StepOutput::Structured { data } => {
236 let answer = data
238 .get("answer")
239 .or_else(|| data.get("conclusion"))
240 .or_else(|| data.get("result"))
241 .and_then(|v| v.as_str())
242 .map(|s| s.to_string())
243 .unwrap_or_else(|| format!("{:?}", data));
244 let reasoning = serde_json::to_string_pretty(&data).unwrap_or_default();
245 (answer, reasoning)
246 }
247 StepOutput::Boolean { value, reason } => {
248 let answer = if *value { "true" } else { "false" }.to_string();
249 let reasoning = reason.clone().unwrap_or_default();
250 (answer, reasoning)
251 }
252 StepOutput::Score { value } => (format!("{:.2}", value), String::new()),
253 StepOutput::List { items } => {
254 let answer = items
255 .iter()
256 .map(|i| i.content.clone())
257 .collect::<Vec<_>>()
258 .join("; ");
259 (answer.clone(), answer)
260 }
261 StepOutput::Empty => return None,
262 };
263
264 Some(ReasoningPath {
265 answer: self.normalize_answer(&answer),
266 reasoning,
267 confidence: result.confidence,
268 tokens: result.tokens,
269 temperature: self.config.temperature_for_sample(index),
270 sample_index: index,
271 })
272 }
273
274 fn extract_answer_from_text(&self, text: &str) -> String {
276 let patterns = [
278 "the answer is",
279 "therefore,",
280 "in conclusion,",
281 "final answer:",
282 "result:",
283 "answer:",
284 ];
285
286 for pattern in patterns {
287 if let Some(pos) = text.to_lowercase().find(pattern) {
288 let start = pos + pattern.len();
289 let remainder = &text[start..];
290 let end = remainder
292 .find(['.', '\n', '!', '?'])
293 .unwrap_or(remainder.len().min(200));
294 return remainder[..end].trim().to_string();
295 }
296 }
297
298 text.split(['.', '\n'])
300 .rfind(|s| !s.trim().is_empty())
301 .map(|s| s.trim().to_string())
302 .unwrap_or_else(|| text.chars().take(200).collect())
303 }
304
305 fn normalize_answer(&self, answer: &str) -> String {
307 answer
308 .to_lowercase()
309 .trim()
310 .replace([',', '.', '!', '?', '"', '\''], "")
311 .split_whitespace()
312 .collect::<Vec<_>>()
313 .join(" ")
314 }
315
316 fn aggregate_paths(&self, paths: Vec<ReasoningPath>) -> ConsistencyResult {
318 if paths.is_empty() {
319 return ConsistencyResult {
320 answer: String::new(),
321 confidence: 0.0,
322 vote_count: 0,
323 total_samples: 0,
324 agreement_ratio: 0.0,
325 paths: Vec::new(),
326 vote_distribution: HashMap::new(),
327 early_stopped: false,
328 total_tokens: TokenUsage::default(),
329 };
330 }
331
332 let mut vote_counts: HashMap<String, usize> = HashMap::new();
334 let mut vote_weights: HashMap<String, f64> = HashMap::new();
335 let mut total_tokens = TokenUsage::default();
336
337 for path in &paths {
338 *vote_counts.entry(path.answer.clone()).or_insert(0) += 1;
339
340 let weight = match self.config.voting_method {
341 VotingMethod::MajorityVote => 1.0,
342 VotingMethod::ConfidenceWeighted => path.confidence,
343 VotingMethod::ClusterWeighted => path.confidence, VotingMethod::Unanimous => 1.0,
345 };
346
347 *vote_weights.entry(path.answer.clone()).or_insert(0.0) += weight;
348 total_tokens.add(&path.tokens);
349 }
350
351 let (winner, vote_count) = match self.config.voting_method {
353 VotingMethod::Unanimous => {
354 if vote_counts.len() == 1 {
356 vote_counts.into_iter().next().unwrap_or_default()
359 } else {
360 vote_counts
362 .into_iter()
363 .max_by_key(|(_, count)| *count)
364 .unwrap_or_default()
365 }
366 }
367 _ => {
368 vote_weights
370 .iter()
371 .max_by(|a, b| a.1.total_cmp(b.1))
372 .map(|(answer, _)| {
373 let count = vote_counts.get(answer).copied().unwrap_or(0);
374 (answer.clone(), count)
375 })
376 .unwrap_or_default()
377 }
378 };
379
380 let total_samples = paths.len();
381 let agreement_ratio = vote_count as f64 / total_samples as f64;
382
383 let confidence = if self.config.use_cisc {
385 let winner_paths: Vec<_> = paths.iter().filter(|p| p.answer == winner).collect();
387 if winner_paths.is_empty() {
388 0.0
389 } else {
390 let avg_confidence: f64 = winner_paths.iter().map(|p| p.confidence).sum::<f64>()
391 / winner_paths.len() as f64;
392 avg_confidence * agreement_ratio
393 }
394 } else {
395 agreement_ratio
397 };
398
399 let mut final_distribution = HashMap::new();
401 for path in &paths {
402 *final_distribution.entry(path.answer.clone()).or_insert(0) += 1;
403 }
404
405 ConsistencyResult {
406 answer: winner,
407 confidence,
408 vote_count,
409 total_samples,
410 agreement_ratio,
411 paths,
412 vote_distribution: final_distribution,
413 early_stopped: false,
414 total_tokens,
415 }
416 }
417
418 pub fn should_early_stop(&self, current_results: &[StepResult]) -> bool {
420 if !self.config.early_stopping || current_results.len() < 3 {
421 return false;
422 }
423
424 let paths: Vec<ReasoningPath> = current_results
425 .iter()
426 .enumerate()
427 .filter_map(|(idx, result)| self.extract_path(result.clone(), idx))
428 .collect();
429
430 if paths.is_empty() {
431 return false;
432 }
433
434 let mut vote_counts: HashMap<String, usize> = HashMap::new();
436 for path in &paths {
437 *vote_counts.entry(path.answer.clone()).or_insert(0) += 1;
438 }
439
440 let max_votes = vote_counts.values().max().copied().unwrap_or(0);
442 let current_ratio = max_votes as f64 / paths.len() as f64;
443
444 current_ratio >= self.config.consensus_threshold
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_config_defaults() {
454 let config = SelfConsistencyConfig::default();
455 assert_eq!(config.num_samples, 5);
456 assert!(config.use_cisc);
457 assert!(config.early_stopping);
458 }
459
460 #[test]
461 fn test_temperature_variance() {
462 let config = SelfConsistencyConfig::default();
463 assert!((config.temperature_for_sample(0) - 0.7).abs() < 0.01);
464 assert!((config.temperature_for_sample(1) - 0.8).abs() < 0.01);
465 assert!((config.temperature_for_sample(2) - 0.9).abs() < 0.01);
466 }
467
468 #[test]
469 fn test_majority_voting() {
470 let engine = SelfConsistencyEngine::default_engine();
471
472 let results = vec![
473 StepResult::success(
474 "test",
475 StepOutput::Text {
476 content: "The answer is 42.".to_string(),
477 },
478 0.8,
479 ),
480 StepResult::success(
481 "test",
482 StepOutput::Text {
483 content: "The answer is 42.".to_string(),
484 },
485 0.85,
486 ),
487 StepResult::success(
488 "test",
489 StepOutput::Text {
490 content: "The answer is 43.".to_string(),
491 },
492 0.75,
493 ),
494 ];
495
496 let result = engine.vote(results);
497
498 assert_eq!(result.answer, "42");
499 assert_eq!(result.vote_count, 2);
500 assert_eq!(result.total_samples, 3);
501 }
502
503 #[test]
504 fn test_normalize_answer() {
505 let engine = SelfConsistencyEngine::default_engine();
506
507 assert_eq!(engine.normalize_answer(" HELLO, World! "), "hello world");
508 assert_eq!(engine.normalize_answer("42."), "42");
509 }
510
511 #[test]
512 fn test_diversity_score() {
513 let result = ConsistencyResult {
514 answer: "42".to_string(),
515 confidence: 0.8,
516 vote_count: 2,
517 total_samples: 3,
518 agreement_ratio: 0.67,
519 paths: Vec::new(),
520 vote_distribution: HashMap::from([("42".to_string(), 2), ("43".to_string(), 1)]),
521 early_stopped: false,
522 total_tokens: TokenUsage::default(),
523 };
524
525 assert!((result.diversity_score() - 0.5).abs() < 0.01);
527 }
528
529 #[test]
530 fn test_early_stopping() {
531 let config = SelfConsistencyConfig {
532 consensus_threshold: 0.7,
533 early_stopping: true,
534 ..Default::default()
535 };
536 let engine = SelfConsistencyEngine::new(config);
537
538 let results: Vec<StepResult> = (0..4)
540 .map(|i| {
541 let answer = if i < 3 { "42" } else { "43" };
542 StepResult::success(
543 "test",
544 StepOutput::Text {
545 content: format!("The answer is {}.", answer),
546 },
547 0.8,
548 )
549 })
550 .collect();
551
552 assert!(engine.should_early_stop(&results));
553 }
554
555 #[test]
556 fn test_empty_paths_handling() {
557 let engine = SelfConsistencyEngine::default_engine();
558 let result = engine.aggregate_paths(vec![]);
559
560 assert!(result.answer.is_empty());
561 assert_eq!(result.confidence, 0.0);
562 assert_eq!(result.total_samples, 0);
563 }
564
565 #[test]
566 fn test_nan_handling_in_vote_weights() {
567 let engine = SelfConsistencyEngine::new(SelfConsistencyConfig {
569 voting_method: VotingMethod::ConfidenceWeighted,
570 ..Default::default()
571 });
572
573 let result = engine.aggregate_paths(vec![]);
575 assert!(result.answer.is_empty());
576 }
577}