use rand::Rng;
pub fn mnl_select<R: Rng>(costs: &[f64], theta: f64, rng: &mut R) -> usize {
if costs.len() <= 1 {
return 0;
}
let max_util = costs
.iter()
.copied()
.map(|c| -theta * c)
.fold(f64::NEG_INFINITY, f64::max);
let exp_utils: Vec<f64> = costs
.iter()
.map(|&c| (-theta * c - max_util).exp())
.collect();
let sum: f64 = exp_utils.iter().sum();
if sum <= 0.0 || !sum.is_finite() {
return 0; }
let r: f64 = rng.gen::<f64>() * sum;
let mut cumulative = 0.0;
for (i, &u) in exp_utils.iter().enumerate() {
cumulative += u;
if r <= cumulative {
return i;
}
}
costs.len() - 1
}
pub fn mnl_probabilities(costs: &[f64], theta: f64) -> Vec<f64> {
if costs.is_empty() {
return Vec::new();
}
if costs.len() == 1 {
return vec![1.0];
}
let max_util = costs
.iter()
.copied()
.map(|c| -theta * c)
.fold(f64::NEG_INFINITY, f64::max);
let exp_utils: Vec<f64> = costs
.iter()
.map(|&c| (-theta * c - max_util).exp())
.collect();
let sum: f64 = exp_utils.iter().sum();
if sum <= 0.0 || !sum.is_finite() {
let mut probs = vec![0.0; costs.len()];
probs[0] = 1.0;
return probs;
}
exp_utils.iter().map(|&u| u / sum).collect()
}
pub fn mnl_logsum(costs: &[f64], theta: f64) -> f64 {
if costs.is_empty() || theta <= 0.0 {
return f64::INFINITY;
}
let max_util = costs
.iter()
.copied()
.map(|c| -theta * c)
.fold(f64::NEG_INFINITY, f64::max);
let sum_exp: f64 = costs.iter().map(|&c| (-theta * c - max_util).exp()).sum();
if sum_exp <= 0.0 || !sum_exp.is_finite() {
return *costs
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(&f64::INFINITY);
}
-(max_util + sum_exp.ln()) / theta
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
#[test]
fn single_cost_returns_zero() {
let mut rng = StdRng::seed_from_u64(42);
assert_eq!(mnl_select(&[10.0], 1.0, &mut rng), 0);
}
#[test]
fn empty_costs_returns_zero() {
let mut rng = StdRng::seed_from_u64(42);
assert_eq!(mnl_select(&[], 1.0, &mut rng), 0);
}
#[test]
fn high_theta_favors_cheapest() {
let costs = [10.0, 20.0, 30.0];
let mut rng = StdRng::seed_from_u64(42);
let mut count_cheapest = 0;
for _ in 0..1000 {
if mnl_select(&costs, 100.0, &mut rng) == 0 {
count_cheapest += 1;
}
}
assert!(
count_cheapest > 990,
"Expected nearly all cheapest, got {count_cheapest}/1000"
);
}
#[test]
fn low_theta_spreads_selection() {
let costs = [10.0, 10.5, 11.0];
let mut rng = StdRng::seed_from_u64(42);
let mut counts = [0usize; 3];
let trials = 3000;
for _ in 0..trials {
counts[mnl_select(&costs, 0.1, &mut rng)] += 1;
}
for &c in &counts {
assert!(
c > trials / 10,
"Expected spread selection, got counts {:?}",
counts
);
}
}
#[test]
fn probabilities_sum_to_one() {
let costs = [10.0, 12.0, 15.0];
let probs = mnl_probabilities(&costs, 1.0);
assert_eq!(probs.len(), 3);
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
}
#[test]
fn probabilities_decrease_with_cost() {
let costs = [5.0, 10.0, 20.0];
let probs = mnl_probabilities(&costs, 1.0);
assert!(probs[0] > probs[1]);
assert!(probs[1] > probs[2]);
}
#[test]
fn logsum_bounded_by_min_cost() {
let costs = [10.0, 15.0, 20.0];
let ls = mnl_logsum(&costs, 1.0);
assert!(ls <= 10.0 + 1e-6, "logsum={ls} should be <= 10.0");
}
#[test]
fn logsum_decreases_with_more_alternatives() {
let costs_2 = [10.0, 15.0];
let costs_3 = [10.0, 15.0, 12.0];
let ls_2 = mnl_logsum(&costs_2, 1.0);
let ls_3 = mnl_logsum(&costs_3, 1.0);
assert!(
ls_3 <= ls_2 + 1e-6,
"More alternatives should decrease logsum: {ls_3} vs {ls_2}"
);
}
}