use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepScore {
pub step_index: usize,
pub step_content: String,
pub correctness: f32,
pub logical_validity: f32,
pub relevance: f32,
pub issues: Vec<StepIssue>,
pub needs_revision: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepIssue {
pub issue_type: IssueType,
pub description: String,
pub severity: Severity,
pub suggested_fix: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum IssueType {
ArithmeticError,
LogicalFallacy,
MissingJustification,
InvalidAssumption,
Irrelevant,
SkippedStep,
CircularReasoning,
Contradiction,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Severity {
Low, Medium, High, Critical, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrmResult {
pub step_scores: Vec<StepScore>,
pub overall_score: f32,
pub first_error_step: Option<usize>,
pub final_answer_confidence: f32,
pub is_sound: bool,
pub metrics: PrmMetrics,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PrmMetrics {
pub total_steps: usize,
pub correct_steps: usize,
pub avg_correctness: f32,
pub avg_logical_validity: f32,
pub avg_relevance: f32,
pub critical_issues: usize,
}
impl PrmResult {
pub fn compute(step_scores: Vec<StepScore>) -> Self {
if step_scores.is_empty() {
return Self {
step_scores: vec![],
overall_score: 0.0,
first_error_step: None,
final_answer_confidence: 0.0,
is_sound: false,
metrics: PrmMetrics::default(),
};
}
let first_error_step = step_scores
.iter()
.position(|s| s.needs_revision || s.correctness < 0.5);
let overall_score = step_scores
.iter()
.map(|s| s.correctness.max(0.01))
.product::<f32>();
let critical_issues = step_scores
.iter()
.flat_map(|s| s.issues.iter())
.filter(|i| i.severity == Severity::Critical)
.count();
let is_sound = critical_issues == 0 && step_scores.iter().all(|s| s.correctness >= 0.6);
let final_answer_confidence = if is_sound {
step_scores.last().map(|s| s.correctness).unwrap_or(0.0) * overall_score.sqrt()
} else {
overall_score * 0.5 };
let total_steps = step_scores.len();
let correct_steps = step_scores.iter().filter(|s| s.correctness >= 0.7).count();
let metrics = PrmMetrics {
total_steps,
correct_steps,
avg_correctness: step_scores.iter().map(|s| s.correctness).sum::<f32>()
/ total_steps as f32,
avg_logical_validity: step_scores.iter().map(|s| s.logical_validity).sum::<f32>()
/ total_steps as f32,
avg_relevance: step_scores.iter().map(|s| s.relevance).sum::<f32>()
/ total_steps as f32,
critical_issues,
};
Self {
step_scores,
overall_score,
first_error_step,
final_answer_confidence,
is_sound,
metrics,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrmConfig {
pub min_step_correctness: f32,
pub halt_on_critical: bool,
pub max_steps: usize,
pub strategy: VerificationStrategy,
}
impl Default for PrmConfig {
fn default() -> Self {
Self {
min_step_correctness: 0.5,
halt_on_critical: true,
max_steps: 50,
strategy: VerificationStrategy::Sequential,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum VerificationStrategy {
Sequential,
Parallel,
Batched { batch_size: usize },
FinalOnly,
}
pub struct VerificationPrompts;
impl VerificationPrompts {
pub fn math_step(step: &str, context: &str, problem: &str) -> String {
format!(
r#"You are a mathematical reasoning verifier. Evaluate the following reasoning step.
PROBLEM: {problem}
PREVIOUS CONTEXT:
{context}
STEP TO VERIFY:
{step}
Evaluate this step on three dimensions (0.0-1.0):
1. CORRECTNESS: Is the mathematical operation/statement correct?
2. LOGICAL_VALIDITY: Does it follow logically from the previous steps?
3. RELEVANCE: Does it contribute to solving the problem?
Identify any issues:
- Arithmetic errors
- Invalid assumptions
- Missing justifications
- Logical fallacies
Respond in JSON:
{{
"correctness": 0.0-1.0,
"logical_validity": 0.0-1.0,
"relevance": 0.0-1.0,
"issues": [
{{
"issue_type": "ArithmeticError|LogicalFallacy|MissingJustification|InvalidAssumption|Irrelevant|SkippedStep|CircularReasoning|Contradiction",
"description": "...",
"severity": "Low|Medium|High|Critical",
"suggested_fix": "..." or null
}}
],
"needs_revision": true/false
}}"#,
problem = problem,
context = context,
step = step
)
}
pub fn logic_step(step: &str, context: &str, claim: &str) -> String {
format!(
r#"You are a logical reasoning verifier using formal logic principles.
CLAIM BEING ANALYZED: {claim}
PRIOR REASONING:
{context}
STEP TO VERIFY:
{step}
Evaluate using Toulmin model components:
- Does it provide valid GROUNDS (evidence)?
- Does it provide valid WARRANT (logical connection)?
- Are there unstated but necessary BACKING assumptions?
- What REBUTTALS might apply?
Rate on three dimensions (0.0-1.0):
1. CORRECTNESS: Is the logical step valid?
2. LOGICAL_VALIDITY: Is the inference sound?
3. RELEVANCE: Does it support or refute the claim?
Respond in JSON:
{{
"correctness": 0.0-1.0,
"logical_validity": 0.0-1.0,
"relevance": 0.0-1.0,
"issues": [...],
"needs_revision": true/false
}}"#,
claim = claim,
context = context,
step = step
)
}
}
pub struct StepParser;
impl StepParser {
pub fn parse_numbered(text: &str) -> Vec<String> {
let mut steps = Vec::new();
let mut current_step = String::new();
for line in text.lines() {
let trimmed = line.trim();
let is_new_step = trimmed.starts_with(|c: char| c.is_ascii_digit())
|| trimmed.to_lowercase().starts_with("step ")
|| trimmed.starts_with("- ")
|| trimmed.starts_with("* ");
if is_new_step && !current_step.is_empty() {
steps.push(current_step.trim().to_string());
current_step = String::new();
}
if !trimmed.is_empty() {
if !current_step.is_empty() {
current_step.push(' ');
}
current_step.push_str(trimmed);
}
}
if !current_step.is_empty() {
steps.push(current_step.trim().to_string());
}
steps
}
pub fn parse_sentences(text: &str) -> Vec<String> {
let mut steps = Vec::new();
let mut current = String::new();
for c in text.chars() {
current.push(c);
if c == '.' || c == '!' || c == '?' {
let trimmed = current.trim().to_string();
if !trimmed.is_empty() && trimmed.len() > 10 {
steps.push(trimmed);
}
current.clear();
}
}
if !current.trim().is_empty() && current.trim().len() > 10 {
steps.push(current.trim().to_string());
}
steps
}
pub fn parse_auto(text: &str) -> Vec<String> {
let numbered = Self::parse_numbered(text);
if numbered.len() >= 2 {
return numbered;
}
Self::parse_sentences(text)
}
}
#[derive(Debug, Clone)]
pub struct PrmReranker {
pub n_candidates: usize,
pub aggregation: ScoreAggregation,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScoreAggregation {
Product,
Minimum,
WeightedAverage,
GeometricMean,
}
impl Default for PrmReranker {
fn default() -> Self {
Self {
n_candidates: 5,
aggregation: ScoreAggregation::Product,
}
}
}
impl PrmReranker {
pub fn new(n_candidates: usize) -> Self {
Self {
n_candidates,
..Default::default()
}
}
pub fn aggregate_score(&self, step_scores: &[f32]) -> f32 {
if step_scores.is_empty() {
return 0.0;
}
match self.aggregation {
ScoreAggregation::Product => step_scores.iter().product(),
ScoreAggregation::Minimum => step_scores
.iter()
.copied()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(0.0),
ScoreAggregation::WeightedAverage => {
let weights: Vec<f32> = (1..=step_scores.len()).map(|i| i as f32).collect();
let weight_sum: f32 = weights.iter().sum();
step_scores
.iter()
.zip(weights.iter())
.map(|(s, w)| s * w)
.sum::<f32>()
/ weight_sum
}
ScoreAggregation::GeometricMean => {
let n = step_scores.len() as f32;
step_scores
.iter()
.map(|s| s.max(0.001))
.product::<f32>()
.powf(1.0 / n)
}
}
}
pub fn rerank<T>(&self, solutions: &mut [(T, f32)])
where
T: Clone,
{
solutions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_step_parser_numbered() {
let text = r#"
1. First, identify the given information
2. Next, set up the equation
3. Solve for x
4. Verify the answer
"#;
let steps = StepParser::parse_numbered(text);
assert_eq!(steps.len(), 4);
assert!(steps[0].contains("identify"));
assert!(steps[2].contains("Solve"));
}
#[test]
fn test_prm_result_computation() {
let scores = vec![
StepScore {
step_index: 0,
step_content: "Step 1".into(),
correctness: 0.9,
logical_validity: 0.95,
relevance: 0.9,
issues: vec![],
needs_revision: false,
},
StepScore {
step_index: 1,
step_content: "Step 2".into(),
correctness: 0.85,
logical_validity: 0.9,
relevance: 0.85,
issues: vec![],
needs_revision: false,
},
StepScore {
step_index: 2,
step_content: "Step 3".into(),
correctness: 0.8,
logical_validity: 0.85,
relevance: 0.9,
issues: vec![],
needs_revision: false,
},
];
let result = PrmResult::compute(scores);
assert!(result.is_sound);
assert!(result.first_error_step.is_none());
assert!(result.overall_score > 0.5);
assert_eq!(result.metrics.total_steps, 3);
assert_eq!(result.metrics.correct_steps, 3);
}
#[test]
fn test_prm_detects_errors() {
let scores = vec![
StepScore {
step_index: 0,
step_content: "Good step".into(),
correctness: 0.9,
logical_validity: 0.9,
relevance: 0.9,
issues: vec![],
needs_revision: false,
},
StepScore {
step_index: 1,
step_content: "Bad step".into(),
correctness: 0.3,
logical_validity: 0.4,
relevance: 0.5,
issues: vec![StepIssue {
issue_type: IssueType::ArithmeticError,
description: "2 + 2 != 5".into(),
severity: Severity::Critical,
suggested_fix: Some("2 + 2 = 4".into()),
}],
needs_revision: true,
},
];
let result = PrmResult::compute(scores);
assert!(!result.is_sound);
assert_eq!(result.first_error_step, Some(1));
assert_eq!(result.metrics.critical_issues, 1);
}
#[test]
fn test_prm_reranker() {
let reranker = PrmReranker::default();
let mut solutions = vec![
("Solution A", 0.7),
("Solution B", 0.9),
("Solution C", 0.5),
];
reranker.rerank(&mut solutions);
assert_eq!(solutions[0].0, "Solution B");
assert_eq!(solutions[1].0, "Solution A");
assert_eq!(solutions[2].0, "Solution C");
}
#[test]
fn test_score_aggregation() {
let reranker = PrmReranker {
n_candidates: 5,
aggregation: ScoreAggregation::GeometricMean,
};
let scores = vec![0.9, 0.8, 0.7];
let agg = reranker.aggregate_score(&scores);
assert!((agg - 0.797).abs() < 0.01);
}
}