use crate::config::{CompassConfigurationError, ConfigJsonExtensions};
use crate::model::cost::{CostModelConfig, CostModelError};
use crate::model::{
cost::{network::NetworkCostRate, CostAggregation, CostModel, VehicleCostRate},
state::StateModel,
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
pub struct CostModelService {
pub vehicle_rates: Arc<HashMap<String, VehicleCostRate>>,
pub network_rates: Arc<HashMap<String, NetworkCostRate>>,
pub weights: Arc<HashMap<String, f64>>,
pub cost_aggregation: CostAggregation,
pub ignore_unknown_weights: bool,
}
impl CostModelService {
pub fn build(
&self,
query: &serde_json::Value,
state_model: Arc<StateModel>,
) -> Result<CostModel, CompassConfigurationError> {
let weights: Arc<HashMap<String, f64>> = query
.get_config_serde_optional::<HashMap<String, f64>>(&"weights", &"cost_model")?
.map(|query_weights| {
let mut merged_weights = self.weights.as_ref().clone();
for (k, v) in query_weights.iter() {
merged_weights.insert(k.clone(), *v);
}
Arc::new(merged_weights)
})
.unwrap_or(self.weights.clone());
let state_indices = state_model.to_vec();
let query_state_indices = state_indices
.iter()
.filter(|(n, _idx)| weights.contains_key(n))
.map(|(n, idx)| (n.clone(), idx))
.collect::<Vec<_>>();
if weights.len() != query_state_indices.len() && !self.ignore_unknown_weights {
let names_lookup: HashSet<&String> =
query_state_indices.iter().map(|(n, _)| n).collect();
let extras = weights
.clone()
.keys()
.filter(|n| !names_lookup.contains(n))
.cloned()
.collect::<Vec<_>>()
.join(",");
let msg = format!("unknown weights in query: [{extras}]");
return Err(CompassConfigurationError::UserConfigurationError(msg));
}
let vehicle_rates = query
.get_config_serde_optional::<HashMap<String, VehicleCostRate>>(
&"vehicle_rates",
&"cost_model",
)
.map(|opt_rates| match opt_rates {
Some(rates) => {
let mut merged_rates = self.vehicle_rates.as_ref().clone();
for (k, v) in rates.iter() {
merged_rates.insert(k.clone(), v.clone());
}
Arc::new(merged_rates)
}
None => self.vehicle_rates.clone(),
})?;
let cost_aggregation: CostAggregation = query
.get_config_serde_optional(&"cost_aggregation", &"cost_model")?
.unwrap_or(self.cost_aggregation.to_owned());
let model = CostModel::new(
weights,
vehicle_rates,
self.network_rates.clone(),
cost_aggregation,
state_model,
)
.map_err(|e| {
CompassConfigurationError::UserConfigurationError(format!(
"failed to build cost model: {e}"
))
})?;
Ok(model)
}
}
impl TryFrom<&CostModelConfig> for CostModelService {
fn try_from(value: &CostModelConfig) -> Result<Self, Self::Error> {
let network_rates = value.get_network_rates()?;
let service = CostModelService {
vehicle_rates: Arc::new(value.vehicle_rates.clone().unwrap_or_default()),
network_rates: Arc::new(network_rates),
weights: Arc::new(value.weights.clone().unwrap_or_default()),
cost_aggregation: value.cost_aggregation.unwrap_or_default(),
ignore_unknown_weights: value.ignore_unknown_user_provided_weights.unwrap_or(true),
};
Ok(service)
}
type Error = CostModelError;
}