use serde::{Deserialize, Serialize};
use super::error::EvalError;
use super::types::ParameterKind;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParameterRange {
kind: ParameterKind,
min: f64,
max: f64,
step: Option<f64>,
default: f64,
}
impl ParameterRange {
pub fn new(
kind: ParameterKind,
min: f64,
max: f64,
step: Option<f64>,
default: f64,
) -> Result<Self, EvalError> {
if !min.is_finite() || !max.is_finite() || min >= max {
return Err(EvalError::InvalidRange { min, max });
}
if !default.is_finite() || default < min || default > max {
return Err(EvalError::DefaultOutOfRange { default, min, max });
}
Ok(Self {
kind,
min,
max,
step,
default,
})
}
#[must_use]
pub fn kind(&self) -> ParameterKind {
self.kind
}
#[must_use]
pub fn min(&self) -> f64 {
self.min
}
#[must_use]
pub fn max(&self) -> f64 {
self.max
}
#[must_use]
pub fn step(&self) -> Option<f64> {
self.step
}
#[must_use]
pub fn default_value(&self) -> f64 {
self.default
}
#[must_use]
pub fn step_count(&self) -> Option<usize> {
let step = self.step?;
if step <= 0.0 {
return None;
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
Some(((self.max - self.min) / step).floor() as usize + 1)
}
#[must_use]
pub fn clamp(&self, value: f64) -> f64 {
value.clamp(self.min, self.max)
}
#[must_use]
pub fn contains(&self, value: f64) -> bool {
(self.min..=self.max).contains(&value)
}
#[must_use]
pub fn quantize(&self, value: f64) -> f64 {
if let Some(step) = self.step
&& step > 0.0
{
let quantized = self.min + ((value - self.min) / step).round() * step;
return self.clamp((quantized * 100.0).round() / 100.0);
}
value
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct SearchSpace {
pub parameters: Vec<ParameterRange>,
}
impl Default for SearchSpace {
fn default() -> Self {
Self {
parameters: vec![
ParameterRange::new(ParameterKind::Temperature, 0.0, 1.0, Some(0.1), 0.7)
.expect("default Temperature range is valid"),
ParameterRange::new(ParameterKind::TopP, 0.1, 1.0, Some(0.05), 0.9)
.expect("default TopP range is valid"),
ParameterRange::new(ParameterKind::TopK, 1.0, 100.0, Some(5.0), 40.0)
.expect("default TopK range is valid"),
ParameterRange::new(ParameterKind::FrequencyPenalty, -2.0, 2.0, Some(0.2), 0.0)
.expect("default FrequencyPenalty range is valid"),
ParameterRange::new(ParameterKind::PresencePenalty, -2.0, 2.0, Some(0.2), 0.0)
.expect("default PresencePenalty range is valid"),
],
}
}
}
impl SearchSpace {
#[must_use]
pub fn range_for(&self, kind: ParameterKind) -> Option<&ParameterRange> {
self.parameters.iter().find(|r| r.kind() == kind)
}
#[must_use]
pub fn is_valid(&self) -> bool {
self.parameters.iter().all(|r| {
r.min().is_finite()
&& r.max().is_finite()
&& r.default_value().is_finite()
&& r.min() < r.max()
&& r.step().is_none_or(|s| s.is_finite() && s > 0.0)
})
}
#[must_use]
pub fn grid_size(&self) -> usize {
self.parameters
.iter()
.filter_map(ParameterRange::step_count)
.sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_range(
kind: ParameterKind,
min: f64,
max: f64,
step: Option<f64>,
default: f64,
) -> ParameterRange {
ParameterRange::new(kind, min, max, step, default).unwrap()
}
#[test]
fn new_valid_range() {
let r = make_range(ParameterKind::Temperature, 0.0, 1.0, Some(0.5), 0.5);
assert_eq!(r.kind(), ParameterKind::Temperature);
assert!((r.min() - 0.0).abs() < f64::EPSILON);
assert!((r.max() - 1.0).abs() < f64::EPSILON);
assert!((r.default_value() - 0.5).abs() < f64::EPSILON);
assert_eq!(r.step(), Some(0.5));
}
#[test]
fn new_invalid_range_min_ge_max() {
assert!(matches!(
ParameterRange::new(ParameterKind::Temperature, 1.0, 0.0, None, 0.5),
Err(EvalError::InvalidRange { .. })
));
assert!(matches!(
ParameterRange::new(ParameterKind::Temperature, 0.5, 0.5, None, 0.5),
Err(EvalError::InvalidRange { .. })
));
}
#[test]
fn new_invalid_range_nonfinite_bounds() {
assert!(matches!(
ParameterRange::new(ParameterKind::Temperature, f64::NAN, 1.0, None, 0.5),
Err(EvalError::InvalidRange { .. })
));
assert!(matches!(
ParameterRange::new(ParameterKind::Temperature, 0.0, f64::INFINITY, None, 0.5),
Err(EvalError::InvalidRange { .. })
));
}
#[test]
fn new_invalid_default_out_of_range() {
assert!(matches!(
ParameterRange::new(ParameterKind::Temperature, 0.0, 1.0, None, 2.0),
Err(EvalError::DefaultOutOfRange { .. })
));
assert!(matches!(
ParameterRange::new(ParameterKind::Temperature, 0.0, 1.0, None, -0.1),
Err(EvalError::DefaultOutOfRange { .. })
));
}
#[test]
fn step_count_with_step() {
let r = make_range(ParameterKind::Temperature, 0.0, 1.0, Some(0.5), 0.5);
assert_eq!(r.step_count(), Some(3)); }
#[test]
fn step_count_no_step() {
let r = make_range(ParameterKind::Temperature, 0.0, 1.0, None, 0.5);
assert_eq!(r.step_count(), None);
}
#[test]
fn step_count_zero_step() {
let mut r = make_range(ParameterKind::Temperature, 0.0, 1.0, None, 0.5);
r.step = Some(0.0);
assert_eq!(r.step_count(), None);
}
#[test]
fn clamp_below_min() {
let r = make_range(ParameterKind::TopP, 0.1, 1.0, Some(0.1), 0.9);
assert!((r.clamp(-1.0) - 0.1).abs() < f64::EPSILON);
}
#[test]
fn clamp_above_max() {
let r = make_range(ParameterKind::TopP, 0.1, 1.0, Some(0.1), 0.9);
assert!((r.clamp(2.0) - 1.0).abs() < f64::EPSILON);
}
#[test]
fn clamp_within_range() {
let r = make_range(ParameterKind::Temperature, 0.0, 2.0, Some(0.1), 0.7);
assert!((r.clamp(1.0) - 1.0).abs() < f64::EPSILON);
}
#[test]
fn contains_within_range() {
let r = make_range(ParameterKind::Temperature, 0.0, 2.0, Some(0.1), 0.7);
assert!(r.contains(1.0));
assert!(r.contains(0.0));
assert!(r.contains(2.0));
assert!(!r.contains(-0.1));
assert!(!r.contains(2.1));
}
#[test]
fn quantize_snaps_to_nearest_step() {
let r = make_range(ParameterKind::Temperature, 0.0, 2.0, Some(0.1), 0.7);
let q = r.quantize(0.73);
assert!((q - 0.7).abs() < 1e-10, "expected 0.7, got {q}");
}
#[test]
fn quantize_no_step_returns_value_unchanged() {
let r = make_range(ParameterKind::Temperature, 0.0, 2.0, None, 0.7);
assert!((r.quantize(1.234) - 1.234).abs() < f64::EPSILON);
}
#[test]
fn quantize_clamps_result() {
let r = make_range(ParameterKind::Temperature, 0.0, 1.0, Some(0.1), 0.5);
let q = r.quantize(100.0);
assert!(q <= 1.0, "quantize must clamp to max");
}
#[test]
fn quantize_avoids_fp_accumulation() {
let r = make_range(ParameterKind::Temperature, 0.0, 2.0, Some(0.1), 0.7);
let accumulated = 0.1_f64 * 7.0;
let q = r.quantize(accumulated);
assert!(
(q - 0.7).abs() < 1e-10,
"expected 0.7, got {q} (accumulated={accumulated})"
);
}
#[test]
fn default_search_space_has_five_parameters() {
let space = SearchSpace::default();
assert_eq!(space.parameters.len(), 5);
}
#[test]
fn default_grid_size_is_reasonable() {
let space = SearchSpace::default();
let size = space.grid_size();
assert!(size > 0);
assert!(size < 200);
}
#[test]
fn range_for_finds_temperature() {
let space = SearchSpace::default();
let range = space.range_for(ParameterKind::Temperature);
assert!(range.is_some());
assert!((range.unwrap().default_value() - 0.7).abs() < f64::EPSILON);
}
#[test]
fn range_for_missing_returns_none() {
let space = SearchSpace::default();
let range = space.range_for(ParameterKind::RetrievalTopK);
assert!(range.is_none());
}
#[test]
fn grid_size_empty_space_is_zero() {
let space = SearchSpace { parameters: vec![] };
assert_eq!(space.grid_size(), 0);
}
#[test]
fn quantize_with_nonzero_min_anchors_to_min() {
let r = make_range(ParameterKind::TopK, 1.0, 100.0, Some(5.0), 40.0);
let q = r.quantize(6.0);
assert!(
(q - 6.0).abs() < 1e-10,
"expected 6.0 (min-anchored grid), got {q}"
);
let q2 = r.quantize(3.0);
assert!((q2 - 1.0).abs() < 1e-10, "expected 1.0, got {q2}");
}
#[test]
fn quantize_negative_step_returns_unchanged() {
let mut r = make_range(ParameterKind::Temperature, 0.0, 2.0, None, 0.7);
r.step = Some(-0.1);
assert!((r.quantize(0.75) - 0.75).abs() < f64::EPSILON);
}
#[test]
fn parameter_range_is_valid_for_default() {
for r in &SearchSpace::default().parameters {
assert!(
r.min() < r.max(),
"default range {:?} has min >= max",
r.kind()
);
}
}
#[test]
fn search_space_is_valid_for_default() {
assert!(SearchSpace::default().is_valid());
}
#[test]
fn search_space_invalid_when_range_inverted() {
let mut space = SearchSpace::default();
space.parameters[0].min = 2.0;
space.parameters[0].max = 0.0;
assert!(!space.is_valid());
}
}