use std::collections::HashMap;
use std::rc::Rc;
use super::error::ModeChoiceError;
use crate::gmns::types::AgentType;
use crate::log_main;
use crate::od::OdMatrix;
use crate::od::dense::DenseOdMatrix;
use crate::verbose::EVENT_MODE_CHOICE;
use super::utility::ModeUtility;
#[derive(Debug)]
pub struct ModeSkim {
pub time: DenseOdMatrix,
pub distance: Rc<DenseOdMatrix>,
pub cost: Rc<DenseOdMatrix>,
}
#[derive(Debug)]
pub struct MultinomialLogit {
pub utilities: Vec<ModeUtility>,
}
impl MultinomialLogit {
pub fn new(utilities: Vec<ModeUtility>) -> Self {
MultinomialLogit { utilities }
}
pub fn default_auto_bike_walk() -> Self {
MultinomialLogit {
utilities: vec![
ModeUtility::new(AgentType::Auto)
.with_asc(0.0)
.with_coeff_time(-0.03)
.build(),
ModeUtility::new(AgentType::Bike)
.with_asc(-1.0)
.with_coeff_time(-0.05)
.build(),
ModeUtility::new(AgentType::Walk)
.with_asc(-2.0)
.with_coeff_time(-0.08)
.build(),
],
}
}
pub fn split(
&self,
total_od: &dyn OdMatrix,
skims: &HashMap<AgentType, ModeSkim>,
) -> Result<HashMap<AgentType, DenseOdMatrix>, ModeChoiceError> {
let zone_ids = total_od.zone_ids().to_vec();
let n = zone_ids.len();
let num_modes = self.utilities.len();
let mut result_matrices: Vec<DenseOdMatrix> = self
.utilities
.iter()
.map(|_| DenseOdMatrix::new(zone_ids.clone()))
.collect();
let mut skim_refs: Vec<&ModeSkim> = Vec::with_capacity(num_modes);
for utility in &self.utilities {
match skims.get(&utility.agent_type) {
Some(skim) => skim_refs.push(skim),
None => {
return Err(ModeChoiceError::MissingSkim(
format!("{:?}", utility.agent_type),
));
}
}
}
let mut v_buf: Vec<f64> = vec![0.0; num_modes];
for i in 0..n {
for j in 0..n {
let total_demand = total_od.get_by_index(i, j);
if total_demand <= 0.0 {
continue;
}
let mut v_max = f64::NEG_INFINITY;
for (k, utility) in self.utilities.iter().enumerate() {
let skim = skim_refs[k];
let time = skim.time.get_by_index(i, j);
let distance = skim.distance.get_by_index(i, j);
let cost = skim.cost.get_by_index(i, j);
let v = utility.compute(time, distance, cost);
v_buf[k] = v;
if v > v_max {
v_max = v;
}
}
let exp_sum: f64 = v_buf[..num_modes].iter().map(|&v| (v - v_max).exp()).sum();
if exp_sum <= 0.0 {
continue;
}
for k in 0..num_modes {
let prob = (v_buf[k] - v_max).exp() / exp_sum;
result_matrices[k].set_by_index(i, j, total_demand * prob);
}
}
}
let mut result: HashMap<AgentType, DenseOdMatrix> = HashMap::with_capacity(num_modes);
for (k, utility) in self.utilities.iter().enumerate() {
result.insert(
utility.agent_type,
std::mem::replace(&mut result_matrices[k], DenseOdMatrix::new(vec![])),
);
}
log_main!(
EVENT_MODE_CHOICE,
"Mode choice split complete",
zones = n,
modes = self.utilities.len()
);
Ok(result)
}
}