use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningConfig {
pub lambda_step: f32,
pub lambda_trajectory: f32,
pub lambda_answer: f32,
pub step_similarity_threshold: f32,
pub hidden_sample_size: usize,
pub normalize_hidden: bool,
}
impl Default for ReasoningConfig {
fn default() -> Self {
Self {
lambda_step: 0.3,
lambda_trajectory: 0.5,
lambda_answer: 0.2,
step_similarity_threshold: 0.9,
hidden_sample_size: 256,
normalize_hidden: true,
}
}
}
impl ReasoningConfig {
pub fn math_reasoning() -> Self {
Self {
lambda_step: 0.4,
lambda_trajectory: 0.4,
lambda_answer: 0.4, step_similarity_threshold: 0.85,
hidden_sample_size: 256,
normalize_hidden: true,
}
}
pub fn code_generation() -> Self {
Self {
lambda_step: 0.2,
lambda_trajectory: 0.6, lambda_answer: 0.3,
step_similarity_threshold: 0.92,
hidden_sample_size: 512,
normalize_hidden: true,
}
}
}
#[derive(Debug, Clone)]
pub struct ReasoningStep {
pub id: String,
pub position: usize,
pub teacher_hidden: Vec<f32>,
pub student_hidden: Vec<f32>,
pub step_type: Option<StepType>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StepType {
Setup,
Calculation,
Inference,
Conclusion,
Answer,
}
impl ReasoningStep {
pub fn new(id: &str, teacher: Vec<f32>, student: Vec<f32>) -> Self {
Self {
id: id.to_string(),
position: 0,
teacher_hidden: teacher,
student_hidden: student,
step_type: None,
}
}
pub fn with_position(mut self, pos: usize) -> Self {
self.position = pos;
self
}
pub fn with_type(mut self, step_type: StepType) -> Self {
self.step_type = Some(step_type);
self
}
pub fn cosine_similarity(&self) -> f32 {
cosine_similarity(&self.teacher_hidden, &self.student_hidden)
}
pub fn mse(&self) -> f32 {
if self.teacher_hidden.len() != self.student_hidden.len() {
return f32::MAX;
}
let n = self.teacher_hidden.len() as f32;
self.teacher_hidden
.iter()
.zip(&self.student_hidden)
.map(|(t, s)| (t - s).powi(2))
.sum::<f32>()
/ n
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ReasoningMetrics {
pub chains_evaluated: usize,
pub avg_step_similarity: f64,
pub avg_trajectory_loss: f64,
pub answer_agreement_rate: f64,
pub degraded_steps: usize,
}
pub struct ChainOfThoughtLoss {
config: ReasoningConfig,
metrics: ReasoningMetrics,
}
impl ChainOfThoughtLoss {
pub fn new(config: ReasoningConfig) -> Self {
Self {
config,
metrics: ReasoningMetrics::default(),
}
}
pub fn config(&self) -> &ReasoningConfig {
&self.config
}
pub fn metrics(&self) -> &ReasoningMetrics {
&self.metrics
}
pub fn compute(
&mut self,
steps: &[ReasoningStep],
teacher_answer: Option<&str>,
student_answer: Option<&str>,
) -> f32 {
if steps.is_empty() {
return 0.0;
}
let step_loss = self.compute_step_loss(steps);
let trajectory_loss = self.compute_trajectory_loss(steps);
let answer_loss = self.compute_answer_loss(teacher_answer, student_answer);
self.update_metrics(steps, teacher_answer, student_answer);
self.config.lambda_step * step_loss
+ self.config.lambda_trajectory * trajectory_loss
+ self.config.lambda_answer * answer_loss
}
fn compute_step_loss(&self, steps: &[ReasoningStep]) -> f32 {
let mut total_loss = 0.0;
for step in steps {
let similarity = step.cosine_similarity();
let weight = match step.step_type {
Some(StepType::Conclusion) | Some(StepType::Answer) => 2.0,
Some(StepType::Calculation) => 1.5,
_ => 1.0,
};
total_loss += weight * (1.0 - similarity);
}
total_loss / steps.len() as f32
}
fn compute_trajectory_loss(&self, steps: &[ReasoningStep]) -> f32 {
if steps.len() < 2 {
return 0.0;
}
let mut total_loss = 0.0;
for i in 0..steps.len() - 1 {
let teacher_delta = vector_diff(&steps[i + 1].teacher_hidden, &steps[i].teacher_hidden);
let student_delta = vector_diff(&steps[i + 1].student_hidden, &steps[i].student_hidden);
let transition_similarity = cosine_similarity(&teacher_delta, &student_delta);
total_loss += 1.0 - transition_similarity;
}
total_loss / (steps.len() - 1) as f32
}
fn compute_answer_loss(&self, teacher: Option<&str>, student: Option<&str>) -> f32 {
match (teacher, student) {
(Some(t), Some(s)) => {
if t.trim() == s.trim() {
0.0 } else if t.contains(s.trim()) || s.contains(t.trim()) {
0.5 } else {
1.0 }
}
_ => 0.0, }
}
fn update_metrics(
&mut self,
steps: &[ReasoningStep],
teacher_answer: Option<&str>,
student_answer: Option<&str>,
) {
self.metrics.chains_evaluated += 1;
let avg_sim: f64 = steps
.iter()
.map(|s| s.cosine_similarity() as f64)
.sum::<f64>()
/ steps.len() as f64;
let n = self.metrics.chains_evaluated as f64;
self.metrics.avg_step_similarity =
(self.metrics.avg_step_similarity * (n - 1.0) + avg_sim) / n;
for step in steps {
if step.cosine_similarity() < self.config.step_similarity_threshold {
self.metrics.degraded_steps += 1;
}
}
if let (Some(t), Some(s)) = (teacher_answer, student_answer) {
let agrees = t.trim() == s.trim();
self.metrics.answer_agreement_rate = (self.metrics.answer_agreement_rate * (n - 1.0)
+ if agrees { 1.0 } else { 0.0 })
/ n;
}
}
pub fn reset_metrics(&mut self) {
self.metrics = ReasoningMetrics::default();
}
pub fn evaluate_quality(&self) -> f32 {
let sim_score = self.metrics.avg_step_similarity as f32;
let ans_score = self.metrics.answer_agreement_rate as f32;
0.6 * sim_score + 0.4 * ans_score
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
0.0
} else {
(dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
}
}
fn vector_diff(a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b).map(|(x, y)| x - y).collect()
}
#[allow(dead_code)]
fn normalize_inplace(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reasoning_config() {
let default = ReasoningConfig::default();
assert!(default.lambda_step > 0.0);
assert!(default.lambda_trajectory > 0.0);
let math = ReasoningConfig::math_reasoning();
assert!(math.lambda_answer >= default.lambda_answer);
}
#[test]
fn test_reasoning_step() {
let teacher = vec![1.0, 0.0, 0.0];
let student = vec![1.0, 0.0, 0.0];
let step = ReasoningStep::new("step1", teacher, student);
assert!((step.cosine_similarity() - 1.0).abs() < 1e-5);
assert!(step.mse() < 1e-5);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-5);
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &b).abs() < 1e-5);
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-5);
}
#[test]
fn test_cot_loss() {
let config = ReasoningConfig::default();
let mut loss_fn = ChainOfThoughtLoss::new(config);
let steps = vec![
ReasoningStep::new("step1", vec![1.0, 0.0], vec![0.9, 0.1]),
ReasoningStep::new("step2", vec![0.0, 1.0], vec![0.1, 0.9]),
];
let loss = loss_fn.compute(&steps, Some("42"), Some("42"));
assert!(loss >= 0.0);
assert!(loss < 1.0); }
#[test]
fn test_answer_loss() {
let config = ReasoningConfig::default();
let loss_fn = ChainOfThoughtLoss::new(config);
assert_eq!(loss_fn.compute_answer_loss(Some("42"), Some("42")), 0.0);
assert_eq!(loss_fn.compute_answer_loss(Some("42"), Some("24")), 1.0);
assert_eq!(loss_fn.compute_answer_loss(None, None), 0.0);
}
#[test]
fn test_trajectory_loss() {
let config = ReasoningConfig::default();
let loss_fn = ChainOfThoughtLoss::new(config);
let steps = vec![
ReasoningStep::new("s1", vec![0.0, 0.0], vec![0.0, 0.0]),
ReasoningStep::new("s2", vec![1.0, 0.0], vec![1.0, 0.0]),
];
let loss = loss_fn.compute_trajectory_loss(&steps);
assert!(loss < 0.1); }
#[test]
fn test_step_types() {
let step = ReasoningStep::new("s1", vec![1.0], vec![1.0])
.with_position(0)
.with_type(StepType::Conclusion);
assert_eq!(step.position, 0);
assert_eq!(step.step_type, Some(StepType::Conclusion));
}
}