reasonkit/thinktool/
prm.rs

1//! # Process Reward Model (PRM) for Step-by-Step Verification
2//!
3//! Implements step-level verification based on Math-Shepherd research
4//! achieving +6.2% GSM8K improvement through granular reasoning validation.
5//!
6//! ## Scientific Foundation
7//!
8//! Based on:
9//! - Math-Shepherd (Wang et al., 2024): Process reward models for math reasoning
10//! - Let's Verify Step by Step (Lightman et al., 2023): Step-level human verification
11//!
12//! ## Key Concepts
13//!
14//! - **Outcome Reward Model (ORM)**: Scores only final answer correctness
15//! - **Process Reward Model (PRM)**: Scores each reasoning step independently
16//!
17//! PRM advantages:
18//! 1. Better credit assignment - identifies WHERE reasoning went wrong
19//! 2. More training signal - learns from partial success
20//! 3. Improved calibration - confidence per step
21//!
22//! ## Usage
23//!
24//! ```rust,ignore
25//! use reasonkit::thinktool::prm::{ProcessRewardModel, StepScore};
26//!
27//! let prm = ProcessRewardModel::new();
28//! let steps = vec!["Step 1: Given x + 2 = 5", "Step 2: x = 5 - 2 = 3"];
29//! let scores = prm.score_steps(&steps).await?;
30//! ```
31
32use serde::{Deserialize, Serialize};
33
34/// Individual step score from PRM
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct StepScore {
37    /// Step index (0-based)
38    pub step_index: usize,
39    /// Step content
40    pub step_content: String,
41    /// Correctness probability (0.0 - 1.0)
42    pub correctness: f32,
43    /// Logical validity score
44    pub logical_validity: f32,
45    /// Relevance to problem score
46    pub relevance: f32,
47    /// Identified issues (if any)
48    pub issues: Vec<StepIssue>,
49    /// Whether this step should be revised
50    pub needs_revision: bool,
51}
52
53/// Issue identified in a reasoning step
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct StepIssue {
56    pub issue_type: IssueType,
57    pub description: String,
58    pub severity: Severity,
59    pub suggested_fix: Option<String>,
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63pub enum IssueType {
64    ArithmeticError,
65    LogicalFallacy,
66    MissingJustification,
67    InvalidAssumption,
68    Irrelevant,
69    SkippedStep,
70    CircularReasoning,
71    Contradiction,
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
75pub enum Severity {
76    Low,      // Minor issue, doesn't affect correctness
77    Medium,   // Could lead to errors downstream
78    High,     // Likely causes incorrect final answer
79    Critical, // Definitely invalidates the reasoning
80}
81
82/// Result of PRM evaluation
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct PrmResult {
85    /// Scores for each step
86    pub step_scores: Vec<StepScore>,
87    /// Overall process score (product of step scores)
88    pub overall_score: f32,
89    /// First problematic step index (if any)
90    pub first_error_step: Option<usize>,
91    /// Confidence in the final answer
92    pub final_answer_confidence: f32,
93    /// Whether the reasoning chain is sound
94    pub is_sound: bool,
95    /// Aggregated metrics
96    pub metrics: PrmMetrics,
97}
98
99#[derive(Debug, Clone, Default, Serialize, Deserialize)]
100pub struct PrmMetrics {
101    pub total_steps: usize,
102    pub correct_steps: usize,
103    pub avg_correctness: f32,
104    pub avg_logical_validity: f32,
105    pub avg_relevance: f32,
106    pub critical_issues: usize,
107}
108
109impl PrmResult {
110    pub fn compute(step_scores: Vec<StepScore>) -> Self {
111        if step_scores.is_empty() {
112            return Self {
113                step_scores: vec![],
114                overall_score: 0.0,
115                first_error_step: None,
116                final_answer_confidence: 0.0,
117                is_sound: false,
118                metrics: PrmMetrics::default(),
119            };
120        }
121
122        // Find first problematic step
123        let first_error_step = step_scores
124            .iter()
125            .position(|s| s.needs_revision || s.correctness < 0.5);
126
127        // Overall score = product of correctness (with floor)
128        let overall_score = step_scores
129            .iter()
130            .map(|s| s.correctness.max(0.01))
131            .product::<f32>();
132
133        // Is sound if no critical issues and all steps >= 0.6 correctness
134        let critical_issues = step_scores
135            .iter()
136            .flat_map(|s| s.issues.iter())
137            .filter(|i| i.severity == Severity::Critical)
138            .count();
139
140        let is_sound = critical_issues == 0 && step_scores.iter().all(|s| s.correctness >= 0.6);
141
142        // Final answer confidence considers path dependency
143        let final_answer_confidence = if is_sound {
144            step_scores.last().map(|s| s.correctness).unwrap_or(0.0) * overall_score.sqrt()
145        } else {
146            overall_score * 0.5 // Penalize unsound chains
147        };
148
149        let total_steps = step_scores.len();
150        let correct_steps = step_scores.iter().filter(|s| s.correctness >= 0.7).count();
151
152        let metrics = PrmMetrics {
153            total_steps,
154            correct_steps,
155            avg_correctness: step_scores.iter().map(|s| s.correctness).sum::<f32>()
156                / total_steps as f32,
157            avg_logical_validity: step_scores.iter().map(|s| s.logical_validity).sum::<f32>()
158                / total_steps as f32,
159            avg_relevance: step_scores.iter().map(|s| s.relevance).sum::<f32>()
160                / total_steps as f32,
161            critical_issues,
162        };
163
164        Self {
165            step_scores,
166            overall_score,
167            first_error_step,
168            final_answer_confidence,
169            is_sound,
170            metrics,
171        }
172    }
173}
174
175/// Process Reward Model configuration
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct PrmConfig {
178    /// Minimum step correctness to continue
179    pub min_step_correctness: f32,
180    /// Whether to halt on critical issues
181    pub halt_on_critical: bool,
182    /// Maximum steps to evaluate
183    pub max_steps: usize,
184    /// Verification strategy
185    pub strategy: VerificationStrategy,
186}
187
188impl Default for PrmConfig {
189    fn default() -> Self {
190        Self {
191            min_step_correctness: 0.5,
192            halt_on_critical: true,
193            max_steps: 50,
194            strategy: VerificationStrategy::Sequential,
195        }
196    }
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
200pub enum VerificationStrategy {
201    /// Verify steps one by one
202    Sequential,
203    /// Verify all steps in parallel
204    Parallel,
205    /// Verify in batches
206    Batched { batch_size: usize },
207    /// Only verify final step (ORM fallback)
208    FinalOnly,
209}
210
211/// Verification prompt templates for different step types
212pub struct VerificationPrompts;
213
214impl VerificationPrompts {
215    /// Generate verification prompt for a math step
216    pub fn math_step(step: &str, context: &str, problem: &str) -> String {
217        format!(
218            r#"You are a mathematical reasoning verifier. Evaluate the following reasoning step.
219
220PROBLEM: {problem}
221
222PREVIOUS CONTEXT:
223{context}
224
225STEP TO VERIFY:
226{step}
227
228Evaluate this step on three dimensions (0.0-1.0):
2291. CORRECTNESS: Is the mathematical operation/statement correct?
2302. LOGICAL_VALIDITY: Does it follow logically from the previous steps?
2313. RELEVANCE: Does it contribute to solving the problem?
232
233Identify any issues:
234- Arithmetic errors
235- Invalid assumptions
236- Missing justifications
237- Logical fallacies
238
239Respond in JSON:
240{{
241    "correctness": 0.0-1.0,
242    "logical_validity": 0.0-1.0,
243    "relevance": 0.0-1.0,
244    "issues": [
245        {{
246            "issue_type": "ArithmeticError|LogicalFallacy|MissingJustification|InvalidAssumption|Irrelevant|SkippedStep|CircularReasoning|Contradiction",
247            "description": "...",
248            "severity": "Low|Medium|High|Critical",
249            "suggested_fix": "..." or null
250        }}
251    ],
252    "needs_revision": true/false
253}}"#,
254            problem = problem,
255            context = context,
256            step = step
257        )
258    }
259
260    /// Generate verification prompt for a logical reasoning step
261    pub fn logic_step(step: &str, context: &str, claim: &str) -> String {
262        format!(
263            r#"You are a logical reasoning verifier using formal logic principles.
264
265CLAIM BEING ANALYZED: {claim}
266
267PRIOR REASONING:
268{context}
269
270STEP TO VERIFY:
271{step}
272
273Evaluate using Toulmin model components:
274- Does it provide valid GROUNDS (evidence)?
275- Does it provide valid WARRANT (logical connection)?
276- Are there unstated but necessary BACKING assumptions?
277- What REBUTTALS might apply?
278
279Rate on three dimensions (0.0-1.0):
2801. CORRECTNESS: Is the logical step valid?
2812. LOGICAL_VALIDITY: Is the inference sound?
2823. RELEVANCE: Does it support or refute the claim?
283
284Respond in JSON:
285{{
286    "correctness": 0.0-1.0,
287    "logical_validity": 0.0-1.0,
288    "relevance": 0.0-1.0,
289    "issues": [...],
290    "needs_revision": true/false
291}}"#,
292            claim = claim,
293            context = context,
294            step = step
295        )
296    }
297}
298
299/// Step parser to extract reasoning steps from LLM output
300pub struct StepParser;
301
302impl StepParser {
303    /// Parse numbered steps (1. 2. 3. or Step 1: Step 2:)
304    pub fn parse_numbered(text: &str) -> Vec<String> {
305        let mut steps = Vec::new();
306        let mut current_step = String::new();
307
308        for line in text.lines() {
309            let trimmed = line.trim();
310
311            // Check for step markers
312            let is_new_step = trimmed.starts_with(|c: char| c.is_ascii_digit())
313                || trimmed.to_lowercase().starts_with("step ")
314                || trimmed.starts_with("- ")
315                || trimmed.starts_with("* ");
316
317            if is_new_step && !current_step.is_empty() {
318                steps.push(current_step.trim().to_string());
319                current_step = String::new();
320            }
321
322            if !trimmed.is_empty() {
323                if !current_step.is_empty() {
324                    current_step.push(' ');
325                }
326                current_step.push_str(trimmed);
327            }
328        }
329
330        if !current_step.is_empty() {
331            steps.push(current_step.trim().to_string());
332        }
333
334        steps
335    }
336
337    /// Parse steps by sentence boundaries
338    pub fn parse_sentences(text: &str) -> Vec<String> {
339        let mut steps = Vec::new();
340        let mut current = String::new();
341
342        for c in text.chars() {
343            current.push(c);
344
345            // Sentence end markers
346            if c == '.' || c == '!' || c == '?' {
347                let trimmed = current.trim().to_string();
348                if !trimmed.is_empty() && trimmed.len() > 10 {
349                    steps.push(trimmed);
350                }
351                current.clear();
352            }
353        }
354
355        if !current.trim().is_empty() && current.trim().len() > 10 {
356            steps.push(current.trim().to_string());
357        }
358
359        steps
360    }
361
362    /// Smart parsing that detects format
363    pub fn parse_auto(text: &str) -> Vec<String> {
364        // First try numbered
365        let numbered = Self::parse_numbered(text);
366        if numbered.len() >= 2 {
367            return numbered;
368        }
369
370        // Fall back to sentences
371        Self::parse_sentences(text)
372    }
373}
374
375/// Best-of-N with PRM reranking
376#[derive(Debug, Clone)]
377pub struct PrmReranker {
378    /// Number of candidate solutions to generate
379    pub n_candidates: usize,
380    /// How to aggregate step scores
381    pub aggregation: ScoreAggregation,
382}
383
384#[derive(Debug, Clone, Copy, PartialEq, Eq)]
385pub enum ScoreAggregation {
386    /// Product of step scores
387    Product,
388    /// Minimum step score
389    Minimum,
390    /// Weighted average (later steps count more)
391    WeightedAverage,
392    /// Geometric mean
393    GeometricMean,
394}
395
396impl Default for PrmReranker {
397    fn default() -> Self {
398        Self {
399            n_candidates: 5,
400            aggregation: ScoreAggregation::Product,
401        }
402    }
403}
404
405impl PrmReranker {
406    pub fn new(n_candidates: usize) -> Self {
407        Self {
408            n_candidates,
409            ..Default::default()
410        }
411    }
412
413    /// Calculate aggregate score for a reasoning chain
414    pub fn aggregate_score(&self, step_scores: &[f32]) -> f32 {
415        if step_scores.is_empty() {
416            return 0.0;
417        }
418
419        match self.aggregation {
420            ScoreAggregation::Product => step_scores.iter().product(),
421            ScoreAggregation::Minimum => step_scores
422                .iter()
423                .copied()
424                .min_by(|a, b| a.partial_cmp(b).unwrap())
425                .unwrap_or(0.0),
426            ScoreAggregation::WeightedAverage => {
427                let weights: Vec<f32> = (1..=step_scores.len()).map(|i| i as f32).collect();
428                let weight_sum: f32 = weights.iter().sum();
429                step_scores
430                    .iter()
431                    .zip(weights.iter())
432                    .map(|(s, w)| s * w)
433                    .sum::<f32>()
434                    / weight_sum
435            }
436            ScoreAggregation::GeometricMean => {
437                let n = step_scores.len() as f32;
438                step_scores
439                    .iter()
440                    .map(|s| s.max(0.001))
441                    .product::<f32>()
442                    .powf(1.0 / n)
443            }
444        }
445    }
446
447    /// Rerank solutions by PRM score
448    pub fn rerank<T>(&self, solutions: &mut [(T, f32)])
449    where
450        T: Clone,
451    {
452        solutions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    #[test]
461    fn test_step_parser_numbered() {
462        let text = r#"
4631. First, identify the given information
4642. Next, set up the equation
4653. Solve for x
4664. Verify the answer
467"#;
468
469        let steps = StepParser::parse_numbered(text);
470        assert_eq!(steps.len(), 4);
471        assert!(steps[0].contains("identify"));
472        assert!(steps[2].contains("Solve"));
473    }
474
475    #[test]
476    fn test_prm_result_computation() {
477        let scores = vec![
478            StepScore {
479                step_index: 0,
480                step_content: "Step 1".into(),
481                correctness: 0.9,
482                logical_validity: 0.95,
483                relevance: 0.9,
484                issues: vec![],
485                needs_revision: false,
486            },
487            StepScore {
488                step_index: 1,
489                step_content: "Step 2".into(),
490                correctness: 0.85,
491                logical_validity: 0.9,
492                relevance: 0.85,
493                issues: vec![],
494                needs_revision: false,
495            },
496            StepScore {
497                step_index: 2,
498                step_content: "Step 3".into(),
499                correctness: 0.8,
500                logical_validity: 0.85,
501                relevance: 0.9,
502                issues: vec![],
503                needs_revision: false,
504            },
505        ];
506
507        let result = PrmResult::compute(scores);
508
509        assert!(result.is_sound);
510        assert!(result.first_error_step.is_none());
511        assert!(result.overall_score > 0.5);
512        assert_eq!(result.metrics.total_steps, 3);
513        assert_eq!(result.metrics.correct_steps, 3);
514    }
515
516    #[test]
517    fn test_prm_detects_errors() {
518        let scores = vec![
519            StepScore {
520                step_index: 0,
521                step_content: "Good step".into(),
522                correctness: 0.9,
523                logical_validity: 0.9,
524                relevance: 0.9,
525                issues: vec![],
526                needs_revision: false,
527            },
528            StepScore {
529                step_index: 1,
530                step_content: "Bad step".into(),
531                correctness: 0.3,
532                logical_validity: 0.4,
533                relevance: 0.5,
534                issues: vec![StepIssue {
535                    issue_type: IssueType::ArithmeticError,
536                    description: "2 + 2 != 5".into(),
537                    severity: Severity::Critical,
538                    suggested_fix: Some("2 + 2 = 4".into()),
539                }],
540                needs_revision: true,
541            },
542        ];
543
544        let result = PrmResult::compute(scores);
545
546        assert!(!result.is_sound);
547        assert_eq!(result.first_error_step, Some(1));
548        assert_eq!(result.metrics.critical_issues, 1);
549    }
550
551    #[test]
552    fn test_prm_reranker() {
553        let reranker = PrmReranker::default();
554
555        let mut solutions = vec![
556            ("Solution A", 0.7),
557            ("Solution B", 0.9),
558            ("Solution C", 0.5),
559        ];
560
561        reranker.rerank(&mut solutions);
562
563        assert_eq!(solutions[0].0, "Solution B");
564        assert_eq!(solutions[1].0, "Solution A");
565        assert_eq!(solutions[2].0, "Solution C");
566    }
567
568    #[test]
569    fn test_score_aggregation() {
570        let reranker = PrmReranker {
571            n_candidates: 5,
572            aggregation: ScoreAggregation::GeometricMean,
573        };
574
575        let scores = vec![0.9, 0.8, 0.7];
576        let agg = reranker.aggregate_score(&scores);
577
578        assert!((agg - 0.797).abs() < 0.01);
579    }
580}