1use serde::{Deserialize, Serialize};
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct StepScore {
37 pub step_index: usize,
39 pub step_content: String,
41 pub correctness: f32,
43 pub logical_validity: f32,
45 pub relevance: f32,
47 pub issues: Vec<StepIssue>,
49 pub needs_revision: bool,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct StepIssue {
56 pub issue_type: IssueType,
57 pub description: String,
58 pub severity: Severity,
59 pub suggested_fix: Option<String>,
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63pub enum IssueType {
64 ArithmeticError,
65 LogicalFallacy,
66 MissingJustification,
67 InvalidAssumption,
68 Irrelevant,
69 SkippedStep,
70 CircularReasoning,
71 Contradiction,
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
75pub enum Severity {
76 Low, Medium, High, Critical, }
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct PrmResult {
85 pub step_scores: Vec<StepScore>,
87 pub overall_score: f32,
89 pub first_error_step: Option<usize>,
91 pub final_answer_confidence: f32,
93 pub is_sound: bool,
95 pub metrics: PrmMetrics,
97}
98
99#[derive(Debug, Clone, Default, Serialize, Deserialize)]
100pub struct PrmMetrics {
101 pub total_steps: usize,
102 pub correct_steps: usize,
103 pub avg_correctness: f32,
104 pub avg_logical_validity: f32,
105 pub avg_relevance: f32,
106 pub critical_issues: usize,
107}
108
109impl PrmResult {
110 pub fn compute(step_scores: Vec<StepScore>) -> Self {
111 if step_scores.is_empty() {
112 return Self {
113 step_scores: vec![],
114 overall_score: 0.0,
115 first_error_step: None,
116 final_answer_confidence: 0.0,
117 is_sound: false,
118 metrics: PrmMetrics::default(),
119 };
120 }
121
122 let first_error_step = step_scores
124 .iter()
125 .position(|s| s.needs_revision || s.correctness < 0.5);
126
127 let overall_score = step_scores
129 .iter()
130 .map(|s| s.correctness.max(0.01))
131 .product::<f32>();
132
133 let critical_issues = step_scores
135 .iter()
136 .flat_map(|s| s.issues.iter())
137 .filter(|i| i.severity == Severity::Critical)
138 .count();
139
140 let is_sound = critical_issues == 0 && step_scores.iter().all(|s| s.correctness >= 0.6);
141
142 let final_answer_confidence = if is_sound {
144 step_scores.last().map(|s| s.correctness).unwrap_or(0.0) * overall_score.sqrt()
145 } else {
146 overall_score * 0.5 };
148
149 let total_steps = step_scores.len();
150 let correct_steps = step_scores.iter().filter(|s| s.correctness >= 0.7).count();
151
152 let metrics = PrmMetrics {
153 total_steps,
154 correct_steps,
155 avg_correctness: step_scores.iter().map(|s| s.correctness).sum::<f32>()
156 / total_steps as f32,
157 avg_logical_validity: step_scores.iter().map(|s| s.logical_validity).sum::<f32>()
158 / total_steps as f32,
159 avg_relevance: step_scores.iter().map(|s| s.relevance).sum::<f32>()
160 / total_steps as f32,
161 critical_issues,
162 };
163
164 Self {
165 step_scores,
166 overall_score,
167 first_error_step,
168 final_answer_confidence,
169 is_sound,
170 metrics,
171 }
172 }
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct PrmConfig {
178 pub min_step_correctness: f32,
180 pub halt_on_critical: bool,
182 pub max_steps: usize,
184 pub strategy: VerificationStrategy,
186}
187
188impl Default for PrmConfig {
189 fn default() -> Self {
190 Self {
191 min_step_correctness: 0.5,
192 halt_on_critical: true,
193 max_steps: 50,
194 strategy: VerificationStrategy::Sequential,
195 }
196 }
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
200pub enum VerificationStrategy {
201 Sequential,
203 Parallel,
205 Batched { batch_size: usize },
207 FinalOnly,
209}
210
211pub struct VerificationPrompts;
213
214impl VerificationPrompts {
215 pub fn math_step(step: &str, context: &str, problem: &str) -> String {
217 format!(
218 r#"You are a mathematical reasoning verifier. Evaluate the following reasoning step.
219
220PROBLEM: {problem}
221
222PREVIOUS CONTEXT:
223{context}
224
225STEP TO VERIFY:
226{step}
227
228Evaluate this step on three dimensions (0.0-1.0):
2291. CORRECTNESS: Is the mathematical operation/statement correct?
2302. LOGICAL_VALIDITY: Does it follow logically from the previous steps?
2313. RELEVANCE: Does it contribute to solving the problem?
232
233Identify any issues:
234- Arithmetic errors
235- Invalid assumptions
236- Missing justifications
237- Logical fallacies
238
239Respond in JSON:
240{{
241 "correctness": 0.0-1.0,
242 "logical_validity": 0.0-1.0,
243 "relevance": 0.0-1.0,
244 "issues": [
245 {{
246 "issue_type": "ArithmeticError|LogicalFallacy|MissingJustification|InvalidAssumption|Irrelevant|SkippedStep|CircularReasoning|Contradiction",
247 "description": "...",
248 "severity": "Low|Medium|High|Critical",
249 "suggested_fix": "..." or null
250 }}
251 ],
252 "needs_revision": true/false
253}}"#,
254 problem = problem,
255 context = context,
256 step = step
257 )
258 }
259
260 pub fn logic_step(step: &str, context: &str, claim: &str) -> String {
262 format!(
263 r#"You are a logical reasoning verifier using formal logic principles.
264
265CLAIM BEING ANALYZED: {claim}
266
267PRIOR REASONING:
268{context}
269
270STEP TO VERIFY:
271{step}
272
273Evaluate using Toulmin model components:
274- Does it provide valid GROUNDS (evidence)?
275- Does it provide valid WARRANT (logical connection)?
276- Are there unstated but necessary BACKING assumptions?
277- What REBUTTALS might apply?
278
279Rate on three dimensions (0.0-1.0):
2801. CORRECTNESS: Is the logical step valid?
2812. LOGICAL_VALIDITY: Is the inference sound?
2823. RELEVANCE: Does it support or refute the claim?
283
284Respond in JSON:
285{{
286 "correctness": 0.0-1.0,
287 "logical_validity": 0.0-1.0,
288 "relevance": 0.0-1.0,
289 "issues": [...],
290 "needs_revision": true/false
291}}"#,
292 claim = claim,
293 context = context,
294 step = step
295 )
296 }
297}
298
299pub struct StepParser;
301
302impl StepParser {
303 pub fn parse_numbered(text: &str) -> Vec<String> {
305 let mut steps = Vec::new();
306 let mut current_step = String::new();
307
308 for line in text.lines() {
309 let trimmed = line.trim();
310
311 let is_new_step = trimmed.starts_with(|c: char| c.is_ascii_digit())
313 || trimmed.to_lowercase().starts_with("step ")
314 || trimmed.starts_with("- ")
315 || trimmed.starts_with("* ");
316
317 if is_new_step && !current_step.is_empty() {
318 steps.push(current_step.trim().to_string());
319 current_step = String::new();
320 }
321
322 if !trimmed.is_empty() {
323 if !current_step.is_empty() {
324 current_step.push(' ');
325 }
326 current_step.push_str(trimmed);
327 }
328 }
329
330 if !current_step.is_empty() {
331 steps.push(current_step.trim().to_string());
332 }
333
334 steps
335 }
336
337 pub fn parse_sentences(text: &str) -> Vec<String> {
339 let mut steps = Vec::new();
340 let mut current = String::new();
341
342 for c in text.chars() {
343 current.push(c);
344
345 if c == '.' || c == '!' || c == '?' {
347 let trimmed = current.trim().to_string();
348 if !trimmed.is_empty() && trimmed.len() > 10 {
349 steps.push(trimmed);
350 }
351 current.clear();
352 }
353 }
354
355 if !current.trim().is_empty() && current.trim().len() > 10 {
356 steps.push(current.trim().to_string());
357 }
358
359 steps
360 }
361
362 pub fn parse_auto(text: &str) -> Vec<String> {
364 let numbered = Self::parse_numbered(text);
366 if numbered.len() >= 2 {
367 return numbered;
368 }
369
370 Self::parse_sentences(text)
372 }
373}
374
375#[derive(Debug, Clone)]
377pub struct PrmReranker {
378 pub n_candidates: usize,
380 pub aggregation: ScoreAggregation,
382}
383
384#[derive(Debug, Clone, Copy, PartialEq, Eq)]
385pub enum ScoreAggregation {
386 Product,
388 Minimum,
390 WeightedAverage,
392 GeometricMean,
394}
395
396impl Default for PrmReranker {
397 fn default() -> Self {
398 Self {
399 n_candidates: 5,
400 aggregation: ScoreAggregation::Product,
401 }
402 }
403}
404
405impl PrmReranker {
406 pub fn new(n_candidates: usize) -> Self {
407 Self {
408 n_candidates,
409 ..Default::default()
410 }
411 }
412
413 pub fn aggregate_score(&self, step_scores: &[f32]) -> f32 {
415 if step_scores.is_empty() {
416 return 0.0;
417 }
418
419 match self.aggregation {
420 ScoreAggregation::Product => step_scores.iter().product(),
421 ScoreAggregation::Minimum => step_scores
422 .iter()
423 .copied()
424 .min_by(|a, b| a.partial_cmp(b).unwrap())
425 .unwrap_or(0.0),
426 ScoreAggregation::WeightedAverage => {
427 let weights: Vec<f32> = (1..=step_scores.len()).map(|i| i as f32).collect();
428 let weight_sum: f32 = weights.iter().sum();
429 step_scores
430 .iter()
431 .zip(weights.iter())
432 .map(|(s, w)| s * w)
433 .sum::<f32>()
434 / weight_sum
435 }
436 ScoreAggregation::GeometricMean => {
437 let n = step_scores.len() as f32;
438 step_scores
439 .iter()
440 .map(|s| s.max(0.001))
441 .product::<f32>()
442 .powf(1.0 / n)
443 }
444 }
445 }
446
447 pub fn rerank<T>(&self, solutions: &mut [(T, f32)])
449 where
450 T: Clone,
451 {
452 solutions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 #[test]
461 fn test_step_parser_numbered() {
462 let text = r#"
4631. First, identify the given information
4642. Next, set up the equation
4653. Solve for x
4664. Verify the answer
467"#;
468
469 let steps = StepParser::parse_numbered(text);
470 assert_eq!(steps.len(), 4);
471 assert!(steps[0].contains("identify"));
472 assert!(steps[2].contains("Solve"));
473 }
474
475 #[test]
476 fn test_prm_result_computation() {
477 let scores = vec![
478 StepScore {
479 step_index: 0,
480 step_content: "Step 1".into(),
481 correctness: 0.9,
482 logical_validity: 0.95,
483 relevance: 0.9,
484 issues: vec![],
485 needs_revision: false,
486 },
487 StepScore {
488 step_index: 1,
489 step_content: "Step 2".into(),
490 correctness: 0.85,
491 logical_validity: 0.9,
492 relevance: 0.85,
493 issues: vec![],
494 needs_revision: false,
495 },
496 StepScore {
497 step_index: 2,
498 step_content: "Step 3".into(),
499 correctness: 0.8,
500 logical_validity: 0.85,
501 relevance: 0.9,
502 issues: vec![],
503 needs_revision: false,
504 },
505 ];
506
507 let result = PrmResult::compute(scores);
508
509 assert!(result.is_sound);
510 assert!(result.first_error_step.is_none());
511 assert!(result.overall_score > 0.5);
512 assert_eq!(result.metrics.total_steps, 3);
513 assert_eq!(result.metrics.correct_steps, 3);
514 }
515
516 #[test]
517 fn test_prm_detects_errors() {
518 let scores = vec![
519 StepScore {
520 step_index: 0,
521 step_content: "Good step".into(),
522 correctness: 0.9,
523 logical_validity: 0.9,
524 relevance: 0.9,
525 issues: vec![],
526 needs_revision: false,
527 },
528 StepScore {
529 step_index: 1,
530 step_content: "Bad step".into(),
531 correctness: 0.3,
532 logical_validity: 0.4,
533 relevance: 0.5,
534 issues: vec![StepIssue {
535 issue_type: IssueType::ArithmeticError,
536 description: "2 + 2 != 5".into(),
537 severity: Severity::Critical,
538 suggested_fix: Some("2 + 2 = 4".into()),
539 }],
540 needs_revision: true,
541 },
542 ];
543
544 let result = PrmResult::compute(scores);
545
546 assert!(!result.is_sound);
547 assert_eq!(result.first_error_step, Some(1));
548 assert_eq!(result.metrics.critical_issues, 1);
549 }
550
551 #[test]
552 fn test_prm_reranker() {
553 let reranker = PrmReranker::default();
554
555 let mut solutions = vec![
556 ("Solution A", 0.7),
557 ("Solution B", 0.9),
558 ("Solution C", 0.5),
559 ];
560
561 reranker.rerank(&mut solutions);
562
563 assert_eq!(solutions[0].0, "Solution B");
564 assert_eq!(solutions[1].0, "Solution A");
565 assert_eq!(solutions[2].0, "Solution C");
566 }
567
568 #[test]
569 fn test_score_aggregation() {
570 let reranker = PrmReranker {
571 n_candidates: 5,
572 aggregation: ScoreAggregation::GeometricMean,
573 };
574
575 let scores = vec![0.9, 0.8, 0.7];
576 let agg = reranker.aggregate_score(&scores);
577
578 assert!((agg - 0.797).abs() < 0.01);
579 }
580}