pub const DOSE_LEVELS: [f32; 6] = [0.5, 1.0, 2.0, 3.0, 4.0, 6.0];
#[derive(Debug, Clone)]
#[must_use]
pub struct SteeringCalibration {
pub source_baseline: f32,
pub target_baseline: f32,
pub recommended_target: f32,
pub attention_ratio: f32,
pub layer: usize,
pub n_source_samples: usize,
pub n_target_samples: usize,
}
impl SteeringCalibration {
pub fn new(
source_baseline: f32,
target_baseline: f32,
layer: usize,
n_source_samples: usize,
n_target_samples: usize,
) -> Self {
let attention_ratio = if target_baseline > 1e-10 {
source_baseline / target_baseline
} else {
0.0
};
Self {
source_baseline,
target_baseline,
recommended_target: source_baseline,
attention_ratio,
layer,
n_source_samples,
n_target_samples,
}
}
pub const fn with_target(mut self, target: f32) -> Self {
self.recommended_target = target;
self
}
#[must_use]
pub fn scale_factor_for_target(&self, target: f32) -> f32 {
if self.target_baseline > 1e-10 {
target / self.target_baseline
} else {
1.0
}
}
#[must_use]
pub fn scale_factor_to_source(&self) -> f32 {
self.scale_factor_for_target(self.source_baseline)
}
#[must_use]
pub fn dose_levels_absolute(&self) -> Vec<(f32, f32)> {
DOSE_LEVELS
.iter()
.map(|&scale| (scale, self.target_baseline * scale))
.collect()
}
}
#[derive(Debug, Clone)]
pub struct DoseResponsePoint {
pub scale_factor: f32,
pub attention_level: f32,
pub kl_divergence: f32,
}
#[derive(Debug, Clone)]
pub struct DoseResponseCurve {
pub sample_id: String,
pub condition: String,
pub layer: usize,
pub baseline_attention: f32,
pub points: Vec<DoseResponsePoint>,
}
impl DoseResponseCurve {
#[must_use]
pub const fn new(
sample_id: String,
condition: String,
layer: usize,
baseline_attention: f32,
) -> Self {
Self {
sample_id,
condition,
layer,
baseline_attention,
points: Vec::new(),
}
}
pub fn add_point(&mut self, scale_factor: f32, attention_level: f32, kl_divergence: f32) {
self.points.push(DoseResponsePoint {
scale_factor,
attention_level,
kl_divergence,
});
}
#[allow(clippy::indexing_slicing)] #[must_use]
pub fn scale_for_target(&self, target: f32) -> Option<f32> {
for window in self.points.windows(2) {
let (p1, p2) = (&window[0], &window[1]);
if (p1.attention_level <= target && target <= p2.attention_level)
|| (p2.attention_level <= target && target <= p1.attention_level)
{
let denom = p2.attention_level - p1.attention_level;
if denom.abs() < 1e-10 {
continue;
}
let t = (target - p1.attention_level) / denom;
return Some(t.mul_add(p2.scale_factor - p1.scale_factor, p1.scale_factor));
}
}
None
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::float_cmp)]
mod tests {
use super::*;
#[test]
fn calibration_new() {
let cal = SteeringCalibration::new(0.09, 0.025, 16, 10, 10);
assert!((cal.source_baseline - 0.09).abs() < 1e-6);
assert!((cal.target_baseline - 0.025).abs() < 1e-6);
assert!((cal.attention_ratio - 3.6).abs() < 1e-6);
assert_eq!(cal.layer, 16);
}
#[test]
fn scale_factor_calculation() {
let cal = SteeringCalibration::new(0.09, 0.025, 16, 10, 10);
let scale = cal.scale_factor_to_source();
assert!((scale - 3.6).abs() < 1e-6);
let scale = cal.scale_factor_for_target(0.05);
assert!((scale - 2.0).abs() < 1e-6);
}
#[test]
fn dose_levels_count() {
assert_eq!(DOSE_LEVELS.len(), 6);
assert!((DOSE_LEVELS[0] - 0.5).abs() < 1e-6);
assert!((DOSE_LEVELS[1] - 1.0).abs() < 1e-6);
assert!((DOSE_LEVELS[5] - 6.0).abs() < 1e-6);
}
#[test]
fn dose_response_curve_interpolation() {
let mut curve =
DoseResponseCurve::new("test_sample".to_string(), "rust".to_string(), 16, 0.025);
curve.add_point(1.0, 0.025, 0.0);
curve.add_point(2.0, 0.050, 0.01);
curve.add_point(4.0, 0.100, 0.05);
let scale = curve.scale_for_target(0.075);
assert!(scale.is_some());
let scale = scale.unwrap();
assert!(scale > 2.0 && scale < 4.0);
}
#[test]
fn dose_response_out_of_range() {
let mut curve = DoseResponseCurve::new("test".to_string(), "rust".to_string(), 16, 0.025);
curve.add_point(1.0, 0.025, 0.0);
curve.add_point(2.0, 0.050, 0.01);
assert!(curve.scale_for_target(0.200).is_none());
}
#[test]
fn calibration_with_custom_target() {
let cal = SteeringCalibration::new(0.09, 0.025, 16, 10, 10).with_target(0.05);
assert!((cal.recommended_target - 0.05).abs() < 1e-6);
}
#[test]
fn dose_levels_absolute() {
let cal = SteeringCalibration::new(0.09, 0.025, 16, 10, 10);
let levels = cal.dose_levels_absolute();
assert_eq!(levels.len(), 6);
assert!((levels.first().unwrap().1 - 0.0125).abs() < 1e-6);
}
}