use crate::{HashSet, PortfolioOutcome, ProductOutcome};
use fts_core::models::{DemandCurve, DemandGroup, Map, ProductGroup};
use std::hash::Hash;
#[cfg(feature = "clarabel")]
pub mod clarabel;
#[cfg(feature = "osqp")]
pub mod osqp;
pub(crate) fn prepare<
DemandId: Clone + Eq + Hash,
PortfolioId: Clone + Eq + Hash,
ProductId: Clone + Eq + Hash + Ord,
>(
mut demand_curves: Map<DemandId, DemandCurve>,
mut portfolios: Map<PortfolioId, (DemandGroup<DemandId>, ProductGroup<ProductId>)>,
) -> (
Map<DemandId, DemandCurve>,
Map<PortfolioId, (DemandGroup<DemandId>, ProductGroup<ProductId>)>,
Map<PortfolioId, PortfolioOutcome>,
Map<ProductId, ProductOutcome>,
) {
let mut product_outcomes = Map::<ProductId, ProductOutcome>::default();
let mut portfolio_outcomes = Map::<PortfolioId, PortfolioOutcome>::default();
portfolio_outcomes.reserve_exact(portfolios.len());
let mut demands_in_use = HashSet::default();
for (portfolio_id, (demand_group, product_group)) in portfolios.iter_mut() {
demand_group
.retain(|demand_id, weight| *weight != 0.0 && demand_curves.contains_key(demand_id));
demand_group.sort_unstable_by(|a, _, b, _| {
demand_curves
.get_index_of(a)
.cmp(&demand_curves.get_index_of(b))
});
demands_in_use.extend(demand_group.keys());
product_group.retain(|_, &mut weight| weight != 0.0);
product_group.sort_unstable_keys();
portfolio_outcomes.entry(portfolio_id.clone()).or_default();
for product_id in product_group.keys() {
product_outcomes.entry(product_id.clone()).or_default();
}
}
product_outcomes.sort_unstable_keys();
demand_curves.retain(|demand_id, _| demands_in_use.contains(demand_id));
(
demand_curves,
portfolios,
portfolio_outcomes,
product_outcomes,
)
}
pub(crate) fn finalize<
'a,
'b,
DemandId: Clone + Eq + Hash,
PortfolioId: Clone + Eq + Hash,
ProductId: Clone + Eq + Hash + Ord,
>(
mut primal: impl Iterator<Item = &'a f64>,
dual: impl Iterator<Item = &'b f64>,
portfolios: &Map<PortfolioId, (DemandGroup<DemandId>, ProductGroup<ProductId>)>,
portfolio_outcomes: &mut Map<PortfolioId, PortfolioOutcome>,
product_outcomes: &mut Map<ProductId, ProductOutcome>,
) {
for (product_outcome, &price) in product_outcomes.values_mut().zip(dual) {
product_outcome.price = price;
}
for (portfolio_id, (demand_group, product_group)) in portfolios.iter() {
let rate = if demand_group.len() == 0 || product_group.len() == 0 {
0.0
} else {
*primal.next().unwrap()
};
let portfolio_outcome = portfolio_outcomes.get_mut(portfolio_id).unwrap();
portfolio_outcome.rate = rate;
if product_group.len() > 0 {
portfolio_outcome.price = 0.0;
for (product_id, weight) in product_group.iter() {
let product_outcome = product_outcomes.get_mut(product_id).unwrap();
portfolio_outcome.price += product_outcome.price * weight;
product_outcome.rate += (weight * rate).abs();
}
}
}
for product_outcome in product_outcomes.values_mut() {
product_outcome.rate *= 0.5;
}
}