use super::{
cost_ops, network::NetworkCostRate, CostAggregation, CostFeature, TraversalCost,
VehicleCostRate,
};
use crate::model::cost::CostConstraint;
use crate::model::cost::CostModelError;
use crate::model::state::StateModel;
use crate::model::state::StateVariable;
use crate::model::traversal::EdgeTraversalContext;
use indexmap::IndexMap;
use itertools::Itertools;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
pub struct CostModel {
features: IndexMap<String, CostFeature>,
weights_mapping: Arc<HashMap<String, f64>>,
vehicle_rate_mapping: Arc<HashMap<String, VehicleCostRate>>,
network_rate_mapping: Arc<HashMap<String, NetworkCostRate>>,
cost_aggregation: CostAggregation,
cost_constraint: CostConstraint,
}
impl CostModel {
pub fn new(
weights_mapping: Arc<HashMap<String, f64>>,
vehicle_rate_mapping: Arc<HashMap<String, VehicleCostRate>>,
network_rate_mapping: Arc<HashMap<String, NetworkCostRate>>,
cost_aggregation: CostAggregation,
state_model: Arc<StateModel>,
cost_constraint: CostConstraint,
) -> Result<CostModel, CostModelError> {
let ignored_weights = weights_mapping
.keys()
.filter(|k| !state_model.contains_key(k))
.collect_vec();
if !ignored_weights.is_empty() {
return Err(CostModelError::InvalidWeightNames(
ignored_weights.iter().map(|k| k.to_string()).collect(),
state_model.keys().cloned().collect_vec(),
));
}
let mut features = IndexMap::new();
let mut total_weight = 0.0;
for (name, config) in state_model.iter() {
let w_opt = weights_mapping.get(name);
let v_opt = vehicle_rate_mapping.get(name);
let n_opt = network_rate_mapping.get(name);
let feature =
CostFeature::new(name.clone(), w_opt, v_opt, n_opt, config.is_accumulator());
total_weight += feature.weight;
features.insert(name.clone(), feature);
}
if total_weight == 0.0 {
return Err(CostModelError::InvalidCostVariables(vec![]));
}
Ok(CostModel {
features,
weights_mapping,
vehicle_rate_mapping,
network_rate_mapping,
cost_aggregation,
cost_constraint,
})
}
pub fn traversal_cost(
&self,
ctx: &EdgeTraversalContext,
previous_state: &[StateVariable],
current_state: &[StateVariable],
state_model: &StateModel,
) -> Result<TraversalCost, CostModelError> {
let mut result = TraversalCost::empty();
for (name, feature) in self.features.iter() {
let v_cost = if feature.is_accumulator {
let current_cost =
feature
.vehicle_cost_rate
.compute_cost(name, current_state, state_model)?;
let previous_cost =
feature
.vehicle_cost_rate
.compute_cost(name, previous_state, state_model)?;
current_cost - previous_cost
} else {
feature
.vehicle_cost_rate
.compute_cost(name, current_state, state_model)?
};
let n_cost = if feature.is_accumulator {
let current_network_cost =
feature
.network_cost_rate
.network_cost(ctx, current_state, state_model)?;
let previous_network_cost =
feature
.network_cost_rate
.network_cost(ctx, previous_state, state_model)?;
current_network_cost - previous_network_cost
} else {
feature
.network_cost_rate
.network_cost(ctx, current_state, state_model)?
};
let cost = v_cost + n_cost;
result.insert(name, cost, feature.weight, self.cost_constraint);
}
Ok(result)
}
pub fn estimate_cost(
&self,
state: &[StateVariable],
state_model: &StateModel,
) -> Result<TraversalCost, CostModelError> {
let mut result = TraversalCost::empty();
for (name, feature) in self.features.iter() {
let v_cost = feature
.vehicle_cost_rate
.compute_cost(name, state, state_model)?;
result.insert(name, v_cost, feature.weight, self.cost_constraint);
}
Ok(result)
}
pub fn serialize_cost_info(&self) -> Result<serde_json::Value, CostModelError> {
let mut result = serde_json::Map::with_capacity(self.features.len());
for (index, (name, feature)) in self.features.iter().enumerate() {
let desc = cost_ops::describe_cost_feature_configuration(
name,
self.weights_mapping.clone(),
self.vehicle_rate_mapping.clone(),
self.network_rate_mapping.clone(),
);
result.insert(
name.clone(),
json![{
Self::WEIGHT: json![feature.weight],
Self::VEHICLE_RATE: json![feature.vehicle_cost_rate],
Self::NETWORK_RATE: json![feature.network_cost_rate.rate_type()],
Self::INDEX: json![index],
Self::DESCRIPTION: json![desc],
}],
);
}
result.insert(
Self::COST_AGGREGATION.to_string(),
json![self.cost_aggregation],
);
Ok(json![result])
}
const INDEX: &'static str = "index";
const VEHICLE_RATE: &'static str = "vehicle_rate";
const NETWORK_RATE: &'static str = "network_rate";
const WEIGHT: &'static str = "weight";
const COST_AGGREGATION: &'static str = "cost_aggregation";
const DESCRIPTION: &'static str = "description";
}
#[cfg(test)]
mod test {
use super::*;
use crate::algorithm::search::{Direction, SearchTree};
use crate::model::label::Label;
use crate::model::network::{Edge, EdgeId, EdgeListId, Vertex, VertexId};
use crate::model::state::StateVariableConfig;
use crate::model::unit::{AsF64, Cost, DistanceUnit, TimeUnit};
use crate::util::geo::InternalCoord;
use geo::coord;
use std::collections::HashMap;
use std::sync::Arc;
use uom::si::f64::*;
use uom::si::length::meter;
use uom::si::time::second;
fn create_vertex(id: VertexId) -> Vertex {
Vertex {
vertex_id: id,
coordinate: InternalCoord(coord! {x: 0.0, y: 0.0}),
}
}
fn create_edge(id: EdgeId, src: VertexId, dst: VertexId) -> Edge {
Edge {
edge_list_id: EdgeListId(0),
edge_id: id,
src_vertex_id: src,
dst_vertex_id: dst,
distance: Length::new::<meter>(100.0),
}
}
fn create_test_tree() -> SearchTree {
SearchTree::new_stateful(Direction::Forward)
}
#[test]
fn test_cost_model_new_valid_weights() {
let features = vec![(
"distance".to_string(),
StateVariableConfig::Distance {
initial: Length::new::<meter>(0.0),
accumulator: true,
output_unit: Some(DistanceUnit::Meters),
},
)];
let state_model = Arc::new(StateModel::new(features));
let mut weights = HashMap::new();
weights.insert("distance".to_string(), 1.0);
let weights = Arc::new(weights);
let mut vehicle_rates = HashMap::new();
vehicle_rates.insert(
"distance".to_string(),
VehicleCostRate::Distance {
factor: 1.0,
unit: DistanceUnit::Meters,
},
);
let vehicle_rates = Arc::new(vehicle_rates);
let network_rates = Arc::new(HashMap::new());
let cost_aggregation = CostAggregation::Sum;
let result = CostModel::new(
weights,
vehicle_rates,
network_rates,
cost_aggregation,
state_model,
CostConstraint::StrictlyPositive,
);
assert!(result.is_ok());
}
#[test]
fn test_cost_model_new_invalid_weight_names() {
let features = vec![(
"distance".to_string(),
StateVariableConfig::Distance {
initial: Length::new::<meter>(0.0),
accumulator: true,
output_unit: Some(DistanceUnit::Meters),
},
)];
let state_model = Arc::new(StateModel::new(features));
let mut weights = HashMap::new();
weights.insert("invalid_feature".to_string(), 1.0);
let weights = Arc::new(weights);
let vehicle_rates = Arc::new(HashMap::new());
let network_rates = Arc::new(HashMap::new());
let cost_aggregation = CostAggregation::Sum;
let result = CostModel::new(
weights,
vehicle_rates,
network_rates,
cost_aggregation,
state_model,
CostConstraint::StrictlyPositive,
);
assert!(matches!(
result,
Err(CostModelError::InvalidWeightNames(_, _))
));
}
#[test]
fn test_cost_model_new_zero_total_weight() {
let features = vec![(
"distance".to_string(),
StateVariableConfig::Distance {
initial: Length::new::<meter>(0.0),
accumulator: true,
output_unit: Some(DistanceUnit::Meters),
},
)];
let state_model = Arc::new(StateModel::new(features));
let mut weights = HashMap::new();
weights.insert("distance".to_string(), 0.0);
let weights = Arc::new(weights);
let vehicle_rates = Arc::new(HashMap::new());
let network_rates = Arc::new(HashMap::new());
let cost_aggregation = CostAggregation::Sum;
let result = CostModel::new(
weights,
vehicle_rates,
network_rates,
cost_aggregation,
state_model,
CostConstraint::StrictlyPositive,
);
assert!(matches!(
result,
Err(CostModelError::InvalidCostVariables(_))
));
}
#[test]
fn test_traversal_cost_accumulator_computes_delta() {
let features = vec![(
"distance".to_string(),
StateVariableConfig::Distance {
initial: Length::new::<meter>(0.0),
accumulator: true,
output_unit: Some(DistanceUnit::Meters),
},
)];
let state_model = Arc::new(StateModel::new(features));
let mut weights = HashMap::new();
weights.insert("distance".to_string(), 1.0);
let weights = Arc::new(weights);
let mut vehicle_rates = HashMap::new();
vehicle_rates.insert(
"distance".to_string(),
VehicleCostRate::Distance {
factor: 1.0,
unit: DistanceUnit::Meters,
},
);
let vehicle_rates = Arc::new(vehicle_rates);
let network_rates = Arc::new(HashMap::new());
let cost_aggregation = CostAggregation::Sum;
let cost_model = CostModel::new(
weights,
vehicle_rates,
network_rates,
cost_aggregation,
state_model.clone(),
CostConstraint::StrictlyPositive,
)
.expect("Failed to create cost model");
let previous_state = vec![StateVariable(100.0)];
let current_state = vec![StateVariable(150.0)];
let v1 = create_vertex(VertexId(0));
let l = Label::Vertex(v1.vertex_id);
let v2 = create_vertex(VertexId(1));
let e = create_edge(EdgeId(0), VertexId(0), VertexId(1));
let tree = create_test_tree();
let ctx = EdgeTraversalContext::new(&l, &v1, &e, &v2, &tree);
let result = cost_model
.traversal_cost(&ctx, &previous_state, ¤t_state, &state_model)
.expect("Failed to compute traversal cost");
assert!(result.edge_cost.as_f64() > 0.0);
assert_eq!(result.edge_cost, result.objective_cost);
}
#[test]
fn test_traversal_cost_non_accumulator_uses_current_value() {
let features = vec![(
"speed".to_string(),
StateVariableConfig::Speed {
initial: Velocity::default(),
accumulator: false, output_unit: None,
},
)];
let state_model = Arc::new(StateModel::new(features));
let mut weights = HashMap::new();
weights.insert("speed".to_string(), 2.0);
let weights = Arc::new(weights);
let mut vehicle_rates = HashMap::new();
vehicle_rates.insert("speed".to_string(), VehicleCostRate::Raw);
let vehicle_rates = Arc::new(vehicle_rates);
let network_rates = Arc::new(HashMap::new());
let cost_aggregation = CostAggregation::Sum;
let cost_model = CostModel::new(
weights,
vehicle_rates,
network_rates,
cost_aggregation,
state_model.clone(),
CostConstraint::StrictlyPositive,
)
.expect("Failed to create cost model");
let previous_state = vec![StateVariable(30.0)];
let current_state = vec![StateVariable(25.0)];
let v1 = create_vertex(VertexId(0));
let l = Label::Vertex(v1.vertex_id);
let v2 = create_vertex(VertexId(1));
let e = create_edge(EdgeId(0), VertexId(0), VertexId(1));
let tree = create_test_tree();
let ctx = EdgeTraversalContext::new(&l, &v1, &e, &v2, &tree);
let result = cost_model
.traversal_cost(&ctx, &previous_state, ¤t_state, &state_model)
.expect("Failed to compute traversal cost");
assert_eq!(result.edge_cost, Cost::new(25.0));
assert_eq!(result.objective_cost, Cost::new(50.0));
}
#[test]
fn test_traversal_cost_multiple_features_mixed() {
let features = vec![
(
"distance".to_string(),
StateVariableConfig::Distance {
initial: Length::new::<meter>(0.0),
accumulator: true,
output_unit: Some(DistanceUnit::Meters),
},
),
(
"time".to_string(),
StateVariableConfig::Time {
initial: Time::new::<second>(0.0),
accumulator: true,
output_unit: Some(TimeUnit::Seconds),
},
),
(
"speed".to_string(),
StateVariableConfig::Speed {
initial: Velocity::default(),
accumulator: false,
output_unit: None,
},
),
];
let state_model = Arc::new(StateModel::new(features));
let mut weights = HashMap::new();
weights.insert("distance".to_string(), 1.0);
weights.insert("time".to_string(), 2.0);
weights.insert("speed".to_string(), 0.5);
let weights = Arc::new(weights);
let mut vehicle_rates = HashMap::new();
vehicle_rates.insert(
"distance".to_string(),
VehicleCostRate::Distance {
factor: 1.0,
unit: DistanceUnit::Meters,
},
);
vehicle_rates.insert(
"time".to_string(),
VehicleCostRate::Time {
factor: 1.0,
unit: TimeUnit::Seconds,
},
);
vehicle_rates.insert("speed".to_string(), VehicleCostRate::Raw);
let vehicle_rates = Arc::new(vehicle_rates);
let network_rates = Arc::new(HashMap::new());
let cost_aggregation = CostAggregation::Sum;
let cost_model = CostModel::new(
weights,
vehicle_rates,
network_rates,
cost_aggregation,
state_model.clone(),
CostConstraint::StrictlyPositive,
)
.expect("Failed to create cost model");
let previous_state = vec![
StateVariable(100.0),
StateVariable(50.0),
StateVariable(60.0),
];
let current_state = vec![
StateVariable(200.0),
StateVariable(80.0),
StateVariable(45.0),
];
let v1 = create_vertex(VertexId(0));
let l = Label::Vertex(v1.vertex_id);
let v2 = create_vertex(VertexId(1));
let e = create_edge(EdgeId(0), VertexId(0), VertexId(1));
let tree = create_test_tree();
let ctx = EdgeTraversalContext::new(&l, &v1, &e, &v2, &tree);
let result = cost_model
.traversal_cost(&ctx, &previous_state, ¤t_state, &state_model)
.expect("Failed to compute traversal cost");
assert!(result.edge_cost.as_f64() > 0.0);
assert!(result.objective_cost.as_f64() > 0.0);
assert_ne!(result.edge_cost, result.objective_cost);
}
#[test]
fn test_traversal_cost_with_network_rates() {
let features = vec![(
"distance".to_string(),
StateVariableConfig::Distance {
initial: Length::new::<meter>(0.0),
accumulator: true,
output_unit: Some(DistanceUnit::Meters),
},
)];
let state_model = Arc::new(StateModel::new(features));
let mut weights = HashMap::new();
weights.insert("distance".to_string(), 1.0);
let weights = Arc::new(weights);
let mut vehicle_rates = HashMap::new();
vehicle_rates.insert(
"distance".to_string(),
VehicleCostRate::Distance {
factor: 1.0,
unit: DistanceUnit::Meters,
},
);
let vehicle_rates = Arc::new(vehicle_rates);
let mut edge_costs = HashMap::new();
edge_costs.insert(EdgeId(0), Cost::new(10.0));
let mut network_rates = HashMap::new();
network_rates.insert(
"distance".to_string(),
NetworkCostRate::EdgeLookup { lookup: edge_costs },
);
let network_rates = Arc::new(network_rates);
let cost_aggregation = CostAggregation::Sum;
let cost_model = CostModel::new(
weights,
vehicle_rates,
network_rates,
cost_aggregation,
state_model.clone(),
CostConstraint::StrictlyPositive,
)
.expect("Failed to create cost model");
let previous_state = vec![StateVariable(100.0)];
let current_state = vec![StateVariable(150.0)];
let v1 = create_vertex(VertexId(0));
let l = Label::Vertex(v1.vertex_id);
let v2 = create_vertex(VertexId(1));
let e = create_edge(EdgeId(0), VertexId(0), VertexId(1));
let tree = create_test_tree();
let ctx = EdgeTraversalContext::new(&l, &v1, &e, &v2, &tree);
let result = cost_model
.traversal_cost(&ctx, &previous_state, ¤t_state, &state_model)
.expect("Failed to compute traversal cost");
assert!(result.edge_cost.as_f64() > 0.0);
assert_eq!(result.edge_cost, result.objective_cost);
}
#[test]
fn test_estimate_cost() {
let features = vec![(
"distance".to_string(),
StateVariableConfig::Distance {
initial: Length::new::<meter>(0.0),
accumulator: true,
output_unit: Some(DistanceUnit::Meters),
},
)];
let state_model = Arc::new(StateModel::new(features));
let mut weights = HashMap::new();
weights.insert("distance".to_string(), 3.0);
let weights = Arc::new(weights);
let mut vehicle_rates = HashMap::new();
vehicle_rates.insert(
"distance".to_string(),
VehicleCostRate::Distance {
factor: 1.0,
unit: DistanceUnit::Meters,
},
);
let vehicle_rates = Arc::new(vehicle_rates);
let network_rates = Arc::new(HashMap::new());
let cost_aggregation = CostAggregation::Sum;
let cost_model = CostModel::new(
weights,
vehicle_rates,
network_rates,
cost_aggregation,
state_model.clone(),
CostConstraint::StrictlyPositive,
)
.expect("Failed to create cost model");
let state = vec![StateVariable(100.0)];
let result = cost_model
.estimate_cost(&state, &state_model)
.expect("Failed to estimate cost");
assert_eq!(
result.objective_cost.as_f64(),
result.edge_cost.as_f64() * 3.0
);
}
#[test]
fn test_zero_weight_feature_has_no_objective_cost() {
let features = vec![(
"distance".to_string(),
StateVariableConfig::Distance {
initial: Length::new::<meter>(0.0),
accumulator: true,
output_unit: Some(DistanceUnit::Meters),
},
)];
let state_model = Arc::new(StateModel::new(features));
let mut weights = HashMap::new();
weights.insert("distance".to_string(), 0.0);
let weights = Arc::new(weights);
let mut vehicle_rates = HashMap::new();
vehicle_rates.insert(
"distance".to_string(),
VehicleCostRate::Distance {
factor: 1.0,
unit: DistanceUnit::Meters,
},
);
let vehicle_rates = Arc::new(vehicle_rates);
let network_rates = Arc::new(HashMap::new());
let cost_aggregation = CostAggregation::Sum;
let result = CostModel::new(
weights,
vehicle_rates,
network_rates,
cost_aggregation,
state_model.clone(),
CostConstraint::StrictlyPositive,
);
assert!(result.is_err());
}
#[test]
fn test_accumulator_network_cost_delta() {
let features = vec![(
"distance".to_string(),
StateVariableConfig::Distance {
initial: Length::new::<meter>(0.0),
accumulator: true,
output_unit: Some(DistanceUnit::Meters),
},
)];
let state_model = Arc::new(StateModel::new(features));
let mut weights = HashMap::new();
weights.insert("distance".to_string(), 1.0);
let weights = Arc::new(weights);
let mut vehicle_rates = HashMap::new();
vehicle_rates.insert("distance".to_string(), VehicleCostRate::Zero);
let vehicle_rates = Arc::new(vehicle_rates);
let mut vertex_costs = HashMap::new();
vertex_costs.insert(VertexId(0), Cost::new(5.0));
vertex_costs.insert(VertexId(1), Cost::new(15.0));
let mut network_rates = HashMap::new();
network_rates.insert(
"distance".to_string(),
NetworkCostRate::VertexLookup {
lookup: vertex_costs,
},
);
let network_rates = Arc::new(network_rates);
let cost_aggregation = CostAggregation::Sum;
let cost_model = CostModel::new(
weights,
vehicle_rates,
network_rates,
cost_aggregation,
state_model.clone(),
CostConstraint::StrictlyPositive,
)
.expect("Failed to create cost model");
let previous_state = vec![StateVariable(100.0)];
let current_state = vec![StateVariable(150.0)];
let v1 = create_vertex(VertexId(0));
let l = Label::Vertex(v1.vertex_id);
let v2 = create_vertex(VertexId(1));
let e = create_edge(EdgeId(0), VertexId(0), VertexId(1));
let tree = create_test_tree();
let ctx = EdgeTraversalContext::new(&l, &v1, &e, &v2, &tree);
let result = cost_model
.traversal_cost(&ctx, &previous_state, ¤t_state, &state_model)
.expect("Failed to compute traversal cost");
assert!(result.edge_cost.as_f64().abs() < 1e-6);
}
#[test]
fn test_accumulator_vs_non_accumulator_difference() {
let features_acc = vec![(
"test_feature".to_string(),
StateVariableConfig::Distance {
initial: Length::new::<meter>(0.0),
accumulator: true,
output_unit: Some(DistanceUnit::Meters),
},
)];
let state_model_acc = Arc::new(StateModel::new(features_acc));
let mut weights = HashMap::new();
weights.insert("test_feature".to_string(), 1.0);
let weights_acc = Arc::new(weights.clone());
let mut vehicle_rates = HashMap::new();
vehicle_rates.insert("test_feature".to_string(), VehicleCostRate::Raw);
let vehicle_rates_acc = Arc::new(vehicle_rates.clone());
let cost_model_acc = CostModel::new(
weights_acc,
vehicle_rates_acc,
Arc::new(HashMap::new()),
CostAggregation::Sum,
state_model_acc.clone(),
CostConstraint::StrictlyPositive,
)
.expect("Failed to create accumulator cost model");
let features_non_acc = vec![(
"test_feature".to_string(),
StateVariableConfig::Distance {
initial: Length::new::<meter>(0.0),
accumulator: false, output_unit: Some(DistanceUnit::Meters),
},
)];
let state_model_non_acc = Arc::new(StateModel::new(features_non_acc));
let weights_non_acc = Arc::new(weights);
let vehicle_rates_non_acc = Arc::new(vehicle_rates);
let cost_model_non_acc = CostModel::new(
weights_non_acc,
vehicle_rates_non_acc,
Arc::new(HashMap::new()),
CostAggregation::Sum,
state_model_non_acc.clone(),
CostConstraint::StrictlyPositive,
)
.expect("Failed to create non-accumulator cost model");
let previous_state = vec![StateVariable(100.0)];
let current_state = vec![StateVariable(150.0)];
let v1 = create_vertex(VertexId(0));
let l = Label::Vertex(v1.vertex_id);
let v2 = create_vertex(VertexId(1));
let e = create_edge(EdgeId(0), VertexId(0), VertexId(1));
let tree = create_test_tree();
let ctx = EdgeTraversalContext::new(&l, &v1, &e, &v2, &tree);
let result_acc = cost_model_acc
.traversal_cost(&ctx, &previous_state, ¤t_state, &state_model_acc)
.expect("Failed to compute accumulator cost");
let result_non_acc = cost_model_non_acc
.traversal_cost(&ctx, &previous_state, ¤t_state, &state_model_non_acc)
.expect("Failed to compute non-accumulator cost");
let expected_delta = current_state[0].0 - previous_state[0].0;
assert_eq!(result_acc.edge_cost, Cost::new(expected_delta));
assert_eq!(result_non_acc.edge_cost, Cost::new(current_state[0].0));
assert_ne!(result_acc.edge_cost, result_non_acc.edge_cost);
assert_eq!(
result_non_acc.edge_cost.as_f64(),
result_acc.edge_cost.as_f64() * 3.0
);
}
}