pub const DEFAULT_E_EXPAND: u32 = 5;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CostInputs {
pub top_k: u32,
pub expand_entities: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CostParams {
pub e_expand: u32,
}
impl Default for CostParams {
fn default() -> Self {
Self {
e_expand: DEFAULT_E_EXPAND,
}
}
}
impl CostParams {
pub fn validate(&self) -> Result<(), &'static str> {
if self.e_expand == 0 {
return Err("admission.e_expand must be >= 1");
}
Ok(())
}
}
pub fn cost_units(inputs: CostInputs, params: CostParams) -> u64 {
let multiplier = if inputs.expand_entities {
params.e_expand
} else {
1
};
(inputs.top_k as u64).saturating_mul(multiplier as u64)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_expand_is_top_k() {
let cost = cost_units(
CostInputs {
top_k: 10,
expand_entities: false,
},
CostParams::default(),
);
assert_eq!(cost, 10);
}
#[test]
fn expand_multiplies_by_e_expand() {
let cost = cost_units(
CostInputs {
top_k: 10,
expand_entities: true,
},
CostParams::default(),
);
assert_eq!(cost, 10 * DEFAULT_E_EXPAND as u64);
}
#[test]
fn term_1423_repro_pattern_is_expensive() {
let cost = cost_units(
CostInputs {
top_k: 200,
expand_entities: true,
},
CostParams::default(),
);
assert_eq!(cost, 1000);
}
#[test]
fn custom_e_expand_is_honored() {
let cost = cost_units(
CostInputs {
top_k: 10,
expand_entities: true,
},
CostParams { e_expand: 3 },
);
assert_eq!(cost, 30);
}
#[test]
fn zero_top_k_is_zero_cost() {
let cost = cost_units(
CostInputs {
top_k: 0,
expand_entities: true,
},
CostParams::default(),
);
assert_eq!(cost, 0);
}
#[test]
fn validate_rejects_zero_e_expand() {
let params = CostParams { e_expand: 0 };
assert!(params.validate().is_err());
}
#[test]
fn validate_accepts_one_or_higher() {
assert!(CostParams { e_expand: 1 }.validate().is_ok());
assert!(CostParams { e_expand: 100 }.validate().is_ok());
}
#[test]
fn saturating_arithmetic_no_panic_at_extreme_inputs() {
let cost = cost_units(
CostInputs {
top_k: u32::MAX,
expand_entities: true,
},
CostParams::default(),
);
assert_eq!(cost, u32::MAX as u64 * DEFAULT_E_EXPAND as u64);
}
}