use super::signals::{
BehavioralSignal, ContextBudgetSignal, MessageContentSignal, StructuralSignal, VisionSignal,
};
use super::types::ScoringWeights;
pub fn sigmoid(x: f64, center: f64, k: f64) -> f64 {
let z = k * (x - center);
if z > 500.0 {
1.0
} else if z < -500.0 {
0.0
} else {
1.0 / (1.0 + (-z).exp())
}
}
pub fn lerp(a: f64, b: f64, t: f64) -> f64 {
a + (b - a) * t.clamp(0.0, 1.0)
}
pub fn compute_score(
structural: &StructuralSignal,
behavioral: &BehavioralSignal,
budget: &ContextBudgetSignal,
vision: Option<&VisionSignal>,
message: Option<&MessageContentSignal>,
weights: &ScoringWeights,
) -> f64 {
let s_raw = structural.normalized();
let b_raw = behavioral.normalized();
let c_raw = budget.normalized();
let v_raw = vision.map(|v| v.normalized()).unwrap_or(0.0);
let m_raw = message.map(|m| m.normalized()).unwrap_or(0.0);
let s_sharp = sigmoid(s_raw, 0.5, 4.0);
let b_sharp = sigmoid(b_raw, 0.5, 4.0);
let c_sharp = sigmoid(c_raw, 0.5, 4.0);
let v_sharp = sigmoid(v_raw, 0.3, 8.0);
let m_sharp = sigmoid(m_raw, 0.5, 4.0);
let raw = weights.structural * s_sharp
+ weights.behavioral * b_sharp
+ weights.context_budget * c_sharp
+ weights.vision * v_sharp
+ weights.message * m_sharp;
let total = weights.structural
+ weights.behavioral
+ weights.context_budget
+ weights.vision
+ weights.message;
if total > 0.0 {
(raw / total).clamp(0.0, 1.0)
} else {
0.5
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sigmoid_center() {
let val = sigmoid(0.5, 0.5, 4.0);
assert!((val - 0.5).abs() < 1e-6);
}
#[test]
fn sigmoid_high() {
let val = sigmoid(1.0, 0.5, 4.0);
assert!(val > 0.8, "sigmoid(1.0) = {}", val);
}
#[test]
fn sigmoid_low() {
let val = sigmoid(0.0, 0.5, 4.0);
assert!(val < 0.2, "sigmoid(0.0) = {}", val);
}
#[test]
fn sigmoid_overflow_guard() {
assert!((sigmoid(1e6, 0.5, 1.0) - 1.0).abs() < 1e-6);
assert!(sigmoid(-1e6, 0.5, 1.0).abs() < 1e-6);
}
#[test]
fn lerp_basic() {
assert!((lerp(0.0, 1.0, 0.5) - 0.5).abs() < 1e-6);
assert!((lerp(0.0, 10.0, 0.3) - 3.0).abs() < 1e-6);
}
#[test]
fn lerp_clamped() {
assert!((lerp(0.0, 1.0, 1.5) - 1.0).abs() < 1e-6);
assert!((lerp(0.0, 1.0, -0.5) - 0.0).abs() < 1e-6);
}
#[test]
fn compute_score_default_weights() {
let structural = StructuralSignal::default();
let behavioral = BehavioralSignal::default();
let budget = ContextBudgetSignal::default();
let weights = ScoringWeights::default();
let score = compute_score(&structural, &behavioral, &budget, None, None, &weights);
assert!(
(0.0..=1.0).contains(&score),
"score out of range: {}",
score
);
}
#[test]
fn compute_score_clamped() {
let structural = StructuralSignal {
message_count: 100,
tool_call_count: 50,
tool_result_count: 50,
estimated_tokens: 500_000,
user_message_count: 20,
};
let behavioral = BehavioralSignal::default();
let budget = ContextBudgetSignal {
estimated_tokens: 500_000,
accumulated_cost: 10.0,
budget_limit: Some(5.0),
context_upgrade_threshold: Some(50_000),
};
let weights = ScoringWeights::default();
let score = compute_score(&structural, &behavioral, &budget, None, None, &weights);
assert!(
(0.0..=1.0).contains(&score),
"score out of range: {}",
score
);
}
#[test]
fn compute_score_vision_increases_score() {
let structural = StructuralSignal::default();
let behavioral = BehavioralSignal::default();
let budget = ContextBudgetSignal::default();
let weights = ScoringWeights::default();
let without_vision = compute_score(&structural, &behavioral, &budget, None, None, &weights);
let vision = VisionSignal {
recent_image_count: 2,
has_image_in_latest_turn: true,
image_producing_tools: vec!["browse".to_string()],
};
let with_vision = compute_score(
&structural,
&behavioral,
&budget,
Some(&vision),
None,
&weights,
);
assert!(
with_vision > without_vision,
"vision should increase score: {} vs {}",
with_vision,
without_vision
);
}
#[test]
fn compute_score_vision_zero_weight_no_effect() {
let mut weights = ScoringWeights::default();
weights.vision = 0.0;
let structural = StructuralSignal::default();
let behavioral = BehavioralSignal::default();
let budget = ContextBudgetSignal::default();
let vision = VisionSignal {
recent_image_count: 5,
has_image_in_latest_turn: true,
image_producing_tools: vec![],
};
let with = compute_score(
&structural,
&behavioral,
&budget,
Some(&vision),
None,
&weights,
);
let without = compute_score(&structural, &behavioral, &budget, None, None, &weights);
assert!(
(with - without).abs() < 1e-6,
"vision=0.0 should have no effect: {} vs {}",
with,
without
);
}
#[test]
fn compute_score_message_increases_score() {
let structural = StructuralSignal::default();
let behavioral = BehavioralSignal::default();
let budget = ContextBudgetSignal::default();
let weights = ScoringWeights::default();
let without_message =
compute_score(&structural, &behavioral, &budget, None, None, &weights);
let msg = MessageContentSignal::from_text(&format!(
"Debug this:\n```rust\nfn main() {{ panic!() }}\n```\n{}",
"x".repeat(300)
));
let with_message = compute_score(
&structural,
&behavioral,
&budget,
None,
Some(&msg),
&weights,
);
assert!(
with_message > without_message,
"message signal should increase score: {} vs {}",
with_message,
without_message
);
}
}