use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreFunction {
pub tc_weight: f64,
pub sc_weight: f64,
pub rw_weight: f64,
pub sc_target: f64,
}
impl Default for ScoreFunction {
fn default() -> Self {
Self {
tc_weight: 1.0,
sc_weight: 1.0,
rw_weight: 0.0,
sc_target: 20.0,
}
}
}
impl ScoreFunction {
pub fn new(tc_weight: f64, sc_weight: f64, rw_weight: f64, sc_target: f64) -> Self {
Self {
tc_weight,
sc_weight,
rw_weight,
sc_target,
}
}
pub fn time_optimized() -> Self {
Self {
tc_weight: 1.0,
sc_weight: 0.0,
rw_weight: 0.0,
sc_target: f64::INFINITY,
}
}
pub fn space_optimized(sc_target: f64) -> Self {
Self {
tc_weight: 0.0,
sc_weight: 1.0,
rw_weight: 0.0,
sc_target,
}
}
pub fn evaluate(&self, tc: f64, sc: f64, rw: f64) -> f64 {
let tc_term = self.tc_weight * 2_f64.powf(tc);
let rw_term = self.rw_weight * 2_f64.powf(rw);
let sc_penalty = (2_f64.powf(sc) - 2_f64.powf(self.sc_target)).max(0.0);
let sc_term = self.sc_weight * sc_penalty;
tc_term + rw_term + sc_term
}
#[inline]
pub fn exceeds_target(&self, sc: f64) -> bool {
sc > self.sc_target
}
pub fn with_sc_target(mut self, sc_target: f64) -> Self {
self.sc_target = sc_target;
self
}
pub fn with_tc_weight(mut self, tc_weight: f64) -> Self {
self.tc_weight = tc_weight;
self
}
pub fn with_sc_weight(mut self, sc_weight: f64) -> Self {
self.sc_weight = sc_weight;
self
}
pub fn with_rw_weight(mut self, rw_weight: f64) -> Self {
self.rw_weight = rw_weight;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_score() {
let score = ScoreFunction::default();
assert_eq!(score.tc_weight, 1.0);
assert_eq!(score.sc_weight, 1.0);
assert_eq!(score.rw_weight, 0.0);
assert_eq!(score.sc_target, 20.0);
}
#[test]
fn test_new_score() {
let score = ScoreFunction::new(2.0, 3.0, 0.5, 15.0);
assert_eq!(score.tc_weight, 2.0);
assert_eq!(score.sc_weight, 3.0);
assert_eq!(score.rw_weight, 0.5);
assert_eq!(score.sc_target, 15.0);
}
#[test]
fn test_evaluate_below_target() {
let score = ScoreFunction::default();
let result = score.evaluate(10.0, 5.0, 8.0);
assert!((result - 1024.0).abs() < 1e-10);
}
#[test]
fn test_evaluate_above_target() {
let score = ScoreFunction::new(1.0, 1.0, 0.0, 10.0);
let result = score.evaluate(10.0, 12.0, 0.0);
let expected = 1024.0 + 3072.0;
assert!((result - expected).abs() < 1e-10);
}
#[test]
fn test_evaluate_with_rw_weight() {
let score = ScoreFunction::new(1.0, 0.0, 1.0, 100.0);
let result = score.evaluate(10.0, 5.0, 8.0);
let expected = 1024.0 + 256.0;
assert!((result - expected).abs() < 1e-10);
}
#[test]
fn test_time_optimized() {
let score = ScoreFunction::time_optimized();
assert_eq!(score.tc_weight, 1.0);
assert_eq!(score.sc_weight, 0.0);
assert_eq!(score.rw_weight, 0.0);
assert!(score.sc_target.is_infinite());
let result = score.evaluate(10.0, 100.0, 0.0);
assert!((result - 1024.0).abs() < 1e-10);
}
#[test]
fn test_space_optimized() {
let score = ScoreFunction::space_optimized(10.0);
assert_eq!(score.tc_weight, 0.0);
assert_eq!(score.sc_weight, 1.0);
assert_eq!(score.rw_weight, 0.0);
assert_eq!(score.sc_target, 10.0);
let result_below = score.evaluate(10.0, 5.0, 8.0);
assert!((result_below - 0.0).abs() < 1e-10);
let result_above = score.evaluate(10.0, 12.0, 8.0);
assert!((result_above - 3072.0).abs() < 1e-10);
}
#[test]
fn test_exceeds_target() {
let score = ScoreFunction::default();
assert!(!score.exceeds_target(10.0));
assert!(!score.exceeds_target(20.0));
assert!(score.exceeds_target(21.0));
}
#[test]
fn test_score_serialization() {
let score = ScoreFunction::new(1.0, 2.0, 0.5, 15.0);
let json = serde_json::to_string(&score).unwrap();
let decoded: ScoreFunction = serde_json::from_str(&json).unwrap();
assert!((score.tc_weight - decoded.tc_weight).abs() < 1e-10);
assert!((score.sc_weight - decoded.sc_weight).abs() < 1e-10);
assert!((score.rw_weight - decoded.rw_weight).abs() < 1e-10);
assert!((score.sc_target - decoded.sc_target).abs() < 1e-10);
}
#[test]
fn test_with_sc_target_builder() {
let score = ScoreFunction::default().with_sc_target(25.0);
assert_eq!(score.sc_target, 25.0);
assert_eq!(score.tc_weight, 1.0); }
#[test]
fn test_with_tc_weight_builder() {
let score = ScoreFunction::default().with_tc_weight(2.5);
assert_eq!(score.tc_weight, 2.5);
assert_eq!(score.sc_weight, 1.0); }
#[test]
fn test_with_sc_weight_builder() {
let score = ScoreFunction::default().with_sc_weight(3.0);
assert_eq!(score.sc_weight, 3.0);
assert_eq!(score.tc_weight, 1.0); }
#[test]
fn test_with_rw_weight_builder() {
let score = ScoreFunction::default().with_rw_weight(0.5);
assert_eq!(score.rw_weight, 0.5);
assert_eq!(score.tc_weight, 1.0); }
#[test]
fn test_builder_chaining() {
let score = ScoreFunction::default()
.with_tc_weight(2.0)
.with_sc_weight(3.0)
.with_rw_weight(0.5)
.with_sc_target(15.0);
assert_eq!(score.tc_weight, 2.0);
assert_eq!(score.sc_weight, 3.0);
assert_eq!(score.rw_weight, 0.5);
assert_eq!(score.sc_target, 15.0);
}
}