use super::{
cost_aggregation::CostAggregation, cost_error::CostError,
network::network_cost_rate::NetworkCostRate, vehicle::vehicle_cost_rate::VehicleCostRate,
};
use crate::model::{property::edge::Edge, traversal::state::state_variable::StateVar, unit::Cost};
pub fn calculate_vehicle_costs(
state_sequence: (&[StateVar], &[StateVar]),
indices: &[(String, usize)],
weights: &[f64],
rates: &[VehicleCostRate],
cost_aggregation: &CostAggregation,
) -> Result<Cost, CostError> {
let (prev_state, next_state) = state_sequence;
let costs = indices.iter().map(|(name, state_idx)| {
let prev_state_var = prev_state
.get(*state_idx)
.ok_or_else(|| CostError::StateIndexOutOfBounds(*state_idx, name.clone()))?;
let next_state_var = next_state
.get(*state_idx)
.ok_or_else(|| CostError::StateIndexOutOfBounds(*state_idx, name.clone()))?;
let delta: StateVar = *next_state_var - *prev_state_var;
let mapping = rates.get(*state_idx).ok_or_else(|| {
CostError::CostVectorOutOfBounds(*state_idx, String::from("vehicle_rates"))
})?;
let weight = weights
.get(*state_idx)
.ok_or_else(|| CostError::CostVectorOutOfBounds(*state_idx, String::from("weights")))?;
let delta_cost = mapping.map_value(delta);
let cost = delta_cost * weight;
Ok((name, cost))
});
cost_aggregation.agg_iter(costs)
}
pub fn calculate_network_traversal_costs(
state_sequence: (&[StateVar], &[StateVar]),
edge: &Edge,
indices: &[(String, usize)],
weights: &[f64],
rates: &[NetworkCostRate],
cost_aggregation: &CostAggregation,
) -> Result<Cost, CostError> {
let (prev_state, next_state) = state_sequence;
let costs = indices.iter().map(|(name, state_idx)| {
let prev_state_var = prev_state
.get(*state_idx)
.ok_or_else(|| CostError::StateIndexOutOfBounds(*state_idx, name.clone()))?;
let next_state_var = next_state
.get(*state_idx)
.ok_or_else(|| CostError::StateIndexOutOfBounds(*state_idx, name.clone()))?;
let weight = weights
.get(*state_idx)
.ok_or_else(|| CostError::CostVectorOutOfBounds(*state_idx, String::from("weights")))?;
let rate = rates.get(*state_idx).ok_or_else(|| {
CostError::CostVectorOutOfBounds(*state_idx, String::from("network_cost_rate"))
})?;
let access_cost = rate.traversal_cost(*prev_state_var, *next_state_var, edge)?;
let cost = access_cost * weight;
Ok((name, cost))
});
cost_aggregation.agg_iter(costs)
}
pub fn calculate_network_access_costs(
state_sequence: (&[StateVar], &[StateVar]),
edge_sequence: (&Edge, &Edge),
indices: &[(String, usize)],
weights: &[f64],
rates: &[NetworkCostRate],
cost_aggregation: &CostAggregation,
) -> Result<Cost, CostError> {
let (prev_state, next_state) = state_sequence;
let (prev_edge, next_edge) = edge_sequence;
let costs = indices.iter().map(|(name, idx)| match rates.get(*idx) {
None => Ok((name, Cost::ZERO)),
Some(m) => {
let prev_state_var = prev_state
.get(*idx)
.ok_or_else(|| CostError::StateIndexOutOfBounds(*idx, name.clone()))?;
let next_state_var = next_state
.get(*idx)
.ok_or_else(|| CostError::StateIndexOutOfBounds(*idx, name.clone()))?;
let access_cost =
m.access_cost(*prev_state_var, *next_state_var, prev_edge, next_edge)?;
let coefficient = weights.get(*idx).unwrap_or(&1.0);
let cost = access_cost * coefficient;
Ok((name, cost))
}
});
cost_aggregation.agg_iter(costs)
}