reasonkit/thinktool/
self_refine.rs

1//! # Self-Refine Module
2//!
3//! Implements iterative self-refinement based on Madaan et al. (2023)
4//! "Self-Refine: Iterative Refinement with Self-Feedback"
5//!
6//! ## Scientific Foundation
7//!
8//! - arXiv:2303.17651: Self-Refine achieves +20% on math, code, and reasoning
9//! - Works through GENERATE → FEEDBACK → REFINE loop
10//! - No additional training required (prompt-based)
11//!
12//! ## Usage
13//!
14//! ```rust,ignore
15//! use reasonkit::thinktool::self_refine::{SelfRefineEngine, RefineConfig};
16//!
17//! let engine = SelfRefineEngine::new(RefineConfig::default());
18//! let result = engine.refine(initial_output, problem).await?;
19//! ```
20
21use serde::{Deserialize, Serialize};
22
23/// Self-Refine configuration
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct RefineConfig {
26    /// Maximum refinement iterations
27    pub max_iterations: usize,
28    /// Stop if quality improvement is below this threshold
29    pub min_improvement_threshold: f32,
30    /// Stop if quality reaches this level
31    pub target_quality: f32,
32    /// Feedback dimensions to evaluate
33    pub feedback_dimensions: Vec<FeedbackDimension>,
34    /// Whether to preserve reasoning chain
35    pub preserve_reasoning: bool,
36}
37
38impl Default for RefineConfig {
39    fn default() -> Self {
40        Self {
41            max_iterations: 3,
42            min_improvement_threshold: 0.05,
43            target_quality: 0.90,
44            feedback_dimensions: vec![
45                FeedbackDimension::Correctness,
46                FeedbackDimension::Completeness,
47                FeedbackDimension::Clarity,
48                FeedbackDimension::Coherence,
49            ],
50            preserve_reasoning: true,
51        }
52    }
53}
54
55impl RefineConfig {
56    /// Config for BrutalHonesty (adversarial critique)
57    pub fn brutal_honesty() -> Self {
58        Self {
59            max_iterations: 5,
60            min_improvement_threshold: 0.03,
61            target_quality: 0.95,
62            feedback_dimensions: vec![
63                FeedbackDimension::Correctness,
64                FeedbackDimension::Honesty,
65                FeedbackDimension::Completeness,
66                FeedbackDimension::BiasDetection,
67                FeedbackDimension::WeaknessIdentification,
68            ],
69            preserve_reasoning: true,
70        }
71    }
72
73    /// Config for code refinement
74    pub fn code() -> Self {
75        Self {
76            max_iterations: 4,
77            feedback_dimensions: vec![
78                FeedbackDimension::Correctness,
79                FeedbackDimension::Efficiency,
80                FeedbackDimension::Readability,
81                FeedbackDimension::EdgeCases,
82            ],
83            ..Default::default()
84        }
85    }
86}
87
88/// Dimensions to evaluate for feedback
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
90pub enum FeedbackDimension {
91    /// Is the content factually correct?
92    Correctness,
93    /// Is all relevant information included?
94    Completeness,
95    /// Is it clearly written/explained?
96    Clarity,
97    /// Does it flow logically?
98    Coherence,
99    /// Is it truthful without exaggeration?
100    Honesty,
101    /// Are biases identified and addressed?
102    BiasDetection,
103    /// Are weaknesses/limitations acknowledged?
104    WeaknessIdentification,
105    /// Are edge cases handled?
106    EdgeCases,
107    /// Is it efficient (for code)?
108    Efficiency,
109    /// Is it readable (for code)?
110    Readability,
111    /// Custom dimension
112    Custom,
113}
114
115impl FeedbackDimension {
116    pub fn prompt_question(&self) -> &'static str {
117        match self {
118            Self::Correctness => "Is the content factually correct? Identify any errors.",
119            Self::Completeness => "Is all relevant information included? What's missing?",
120            Self::Clarity => "Is the explanation clear? What's confusing?",
121            Self::Coherence => "Does the reasoning flow logically? Any gaps?",
122            Self::Honesty => "Is the assessment honest without exaggeration or false modesty?",
123            Self::BiasDetection => "Are there any hidden biases or assumptions? Identify them.",
124            Self::WeaknessIdentification => "What weaknesses or limitations exist? Be specific.",
125            Self::EdgeCases => "Are edge cases and exceptions handled properly?",
126            Self::Efficiency => "Is the solution efficient? How can it be optimized?",
127            Self::Readability => "Is the code/text readable? How can it be improved?",
128            Self::Custom => "Evaluate the overall quality and suggest improvements.",
129        }
130    }
131}
132
133/// Feedback for a single dimension
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct DimensionFeedback {
136    pub dimension: FeedbackDimension,
137    /// Quality score (0.0 - 1.0)
138    pub score: f32,
139    /// Specific issues found
140    pub issues: Vec<String>,
141    /// Suggestions for improvement
142    pub suggestions: Vec<String>,
143}
144
145/// Complete feedback from one iteration
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct IterationFeedback {
148    pub iteration: usize,
149    pub dimension_feedback: Vec<DimensionFeedback>,
150    /// Overall quality score
151    pub overall_score: f32,
152    /// Combined improvement suggestions
153    pub improvement_plan: String,
154    /// Whether refinement should continue
155    pub should_continue: bool,
156}
157
158impl IterationFeedback {
159    pub fn compute_overall_score(&mut self) {
160        if self.dimension_feedback.is_empty() {
161            self.overall_score = 0.0;
162            return;
163        }
164
165        self.overall_score = self.dimension_feedback.iter().map(|f| f.score).sum::<f32>()
166            / self.dimension_feedback.len() as f32;
167    }
168
169    pub fn has_critical_issues(&self) -> bool {
170        self.dimension_feedback.iter().any(|f| f.score < 0.5)
171    }
172
173    pub fn worst_dimension(&self) -> Option<&DimensionFeedback> {
174        self.dimension_feedback.iter().min_by(|a, b| {
175            a.score
176                .partial_cmp(&b.score)
177                .unwrap_or(std::cmp::Ordering::Equal)
178        })
179    }
180}
181
182/// A single refinement iteration
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct RefineIteration {
185    pub iteration: usize,
186    /// Content before this iteration
187    pub input: String,
188    /// Content after this iteration
189    pub output: String,
190    /// Feedback that guided this iteration
191    pub feedback: IterationFeedback,
192    /// Improvement from previous iteration
193    pub quality_delta: f32,
194}
195
196/// Result of the self-refine process
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct RefineResult {
199    /// Original input
200    pub original: String,
201    /// Final refined output
202    pub refined: String,
203    /// Quality of original (0.0 - 1.0)
204    pub original_quality: f32,
205    /// Quality of refined (0.0 - 1.0)
206    pub refined_quality: f32,
207    /// Total improvement
208    pub improvement: f32,
209    /// All iterations
210    pub iterations: Vec<RefineIteration>,
211    /// Why refinement stopped
212    pub stop_reason: StopReason,
213    /// Total tokens used
214    pub total_tokens: usize,
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
218pub enum StopReason {
219    /// Reached target quality
220    TargetReached,
221    /// Max iterations hit
222    MaxIterations,
223    /// Improvement below threshold
224    DiminishingReturns,
225    /// No issues found
226    NoIssuesFound,
227    /// Error during refinement
228    Error,
229}
230
231impl RefineResult {
232    pub fn improvement_percentage(&self) -> f32 {
233        if self.original_quality > 0.0 {
234            ((self.refined_quality - self.original_quality) / self.original_quality) * 100.0
235        } else {
236            0.0
237        }
238    }
239
240    pub fn format_summary(&self) -> String {
241        format!(
242            "Self-Refine: {} iterations, {:.1}% → {:.1}% (+{:.1}%), stopped: {:?}",
243            self.iterations.len(),
244            self.original_quality * 100.0,
245            self.refined_quality * 100.0,
246            self.improvement * 100.0,
247            self.stop_reason
248        )
249    }
250}
251
252/// Prompt templates for self-refine
253pub struct RefinePrompts;
254
255impl RefinePrompts {
256    /// Generate feedback prompt
257    pub fn feedback(content: &str, problem: &str, dimensions: &[FeedbackDimension]) -> String {
258        let dimension_prompts: String = dimensions
259            .iter()
260            .enumerate()
261            .map(|(i, d)| format!("{}. {}", i + 1, d.prompt_question()))
262            .collect::<Vec<_>>()
263            .join("\n");
264
265        format!(
266            r#"You are a critical reviewer evaluating the following content.
267
268ORIGINAL PROBLEM/TASK:
269{problem}
270
271CONTENT TO REVIEW:
272{content}
273
274Evaluate on these dimensions:
275{dimension_prompts}
276
277For each dimension, provide:
2781. Score (0.0 - 1.0, where 1.0 is perfect)
2792. Specific issues found
2803. Concrete suggestions for improvement
281
282Then provide an overall improvement plan.
283
284Respond in JSON:
285{{
286    "dimensions": [
287        {{
288            "dimension": "dimension_name",
289            "score": 0.0-1.0,
290            "issues": ["issue1", "issue2"],
291            "suggestions": ["suggestion1", "suggestion2"]
292        }}
293    ],
294    "overall_score": 0.0-1.0,
295    "improvement_plan": "Detailed plan to address the issues..."
296}}"#,
297            problem = problem,
298            content = content,
299            dimension_prompts = dimension_prompts
300        )
301    }
302
303    /// Generate refinement prompt
304    pub fn refine(content: &str, problem: &str, feedback: &IterationFeedback) -> String {
305        let issues: Vec<String> = feedback
306            .dimension_feedback
307            .iter()
308            .flat_map(|f| f.issues.clone())
309            .collect();
310
311        let suggestions: Vec<String> = feedback
312            .dimension_feedback
313            .iter()
314            .flat_map(|f| f.suggestions.clone())
315            .collect();
316
317        format!(
318            r#"Refine the following content based on the feedback provided.
319
320ORIGINAL PROBLEM/TASK:
321{problem}
322
323CONTENT TO REFINE:
324{content}
325
326ISSUES IDENTIFIED:
327{issues}
328
329SUGGESTIONS FOR IMPROVEMENT:
330{suggestions}
331
332IMPROVEMENT PLAN:
333{plan}
334
335Provide the refined version that addresses ALL the issues and incorporates ALL the suggestions.
336Maintain the same format and structure, but improve the quality."#,
337            problem = problem,
338            content = content,
339            issues = issues
340                .iter()
341                .map(|i| format!("- {}", i))
342                .collect::<Vec<_>>()
343                .join("\n"),
344            suggestions = suggestions
345                .iter()
346                .map(|s| format!("- {}", s))
347                .collect::<Vec<_>>()
348                .join("\n"),
349            plan = feedback.improvement_plan
350        )
351    }
352
353    /// BrutalHonesty-specific feedback prompt
354    pub fn brutal_honesty_feedback(content: &str, claim: &str) -> String {
355        format!(
356            r#"You are the BRUTAL HONESTY reviewer. Your job is to find EVERY flaw.
357
358CLAIM/ARGUMENT BEING ANALYZED:
359{claim}
360
361CURRENT ANALYSIS:
362{content}
363
364Be RUTHLESSLY CRITICAL. Evaluate:
365
3661. HONESTY: Is the analysis truthful without exaggeration or false modesty?
367   - Are strengths overstated?
368   - Are weaknesses downplayed?
369   - Is uncertainty properly communicated?
370
3712. COMPLETENESS: What is MISSING from the analysis?
372   - What perspectives weren't considered?
373   - What evidence was overlooked?
374   - What counterarguments weren't addressed?
375
3763. BIAS DETECTION: What BIASES are present?
377   - Confirmation bias (only seeing supporting evidence)?
378   - Authority bias (over-relying on sources)?
379   - Recency bias (ignoring historical context)?
380
3814. WEAKNESS IDENTIFICATION: What are the WEAKNESSES?
382   - In the reasoning?
383   - In the evidence?
384   - In the conclusions?
385
3865. DEVIL'S ADVOCATE: Argue the OPPOSITE position
387   - What would a critic say?
388   - How could this be wrong?
389
390Score each dimension 0.0-1.0 and provide specific improvements.
391Respond in JSON format."#,
392            claim = claim,
393            content = content
394        )
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    #[test]
403    fn test_refine_config_default() {
404        let config = RefineConfig::default();
405        assert_eq!(config.max_iterations, 3);
406        assert!(!config.feedback_dimensions.is_empty());
407    }
408
409    #[test]
410    fn test_brutal_honesty_config() {
411        let config = RefineConfig::brutal_honesty();
412        assert_eq!(config.max_iterations, 5);
413        assert!(config
414            .feedback_dimensions
415            .contains(&FeedbackDimension::Honesty));
416        assert!(config
417            .feedback_dimensions
418            .contains(&FeedbackDimension::BiasDetection));
419    }
420
421    #[test]
422    fn test_iteration_feedback() {
423        let mut feedback = IterationFeedback {
424            iteration: 1,
425            dimension_feedback: vec![
426                DimensionFeedback {
427                    dimension: FeedbackDimension::Correctness,
428                    score: 0.8,
429                    issues: vec!["Minor error".into()],
430                    suggestions: vec!["Fix error".into()],
431                },
432                DimensionFeedback {
433                    dimension: FeedbackDimension::Clarity,
434                    score: 0.6,
435                    issues: vec!["Unclear section".into()],
436                    suggestions: vec!["Rewrite section".into()],
437                },
438            ],
439            overall_score: 0.0,
440            improvement_plan: "Fix issues".into(),
441            should_continue: true,
442        };
443
444        feedback.compute_overall_score();
445        assert!((feedback.overall_score - 0.7).abs() < 0.01);
446        assert!(!feedback.has_critical_issues());
447    }
448
449    #[test]
450    fn test_dimension_prompts() {
451        let q = FeedbackDimension::Honesty.prompt_question();
452        assert!(q.contains("honest"));
453
454        let q = FeedbackDimension::BiasDetection.prompt_question();
455        assert!(q.contains("bias"));
456    }
457
458    #[test]
459    fn test_refine_result_summary() {
460        let result = RefineResult {
461            original: "Original".into(),
462            refined: "Refined".into(),
463            original_quality: 0.6,
464            refined_quality: 0.85,
465            improvement: 0.25,
466            iterations: vec![],
467            stop_reason: StopReason::TargetReached,
468            total_tokens: 1000,
469        };
470
471        let summary = result.format_summary();
472        assert!(summary.contains("TargetReached"));
473        assert!(result.improvement_percentage() > 40.0);
474    }
475}