use std::collections::HashSet;
use log::{self, log_enabled, debug};
use num_traits::ToPrimitive;
use crate::brokers::BrokerInfo;
use crate::commissions::CommissionCalc;
use crate::core::{GenericResult, EmptyResult};
use crate::currency::Cash;
use crate::currency::converter::CurrencyConverterRc;
use crate::time;
use crate::types::{Decimal, TradeType};
use crate::util;
use super::asset_allocation::{Portfolio, AssetAllocation, Holding, StockHolding};
pub fn rebalance_portfolio(portfolio: &mut Portfolio, converter: CurrencyConverterRc) -> EmptyResult {
let portfolio_info = PortfolioInfo::new(portfolio);
calculate_restrictions(&mut portfolio.assets);
debug!("");
debug!("Calculating assets target value...");
AssetGroupRebalancer::rebalance(
&portfolio.name, &mut portfolio.assets,
portfolio.target_net_value - portfolio.min_cash_assets,
portfolio.min_trade_volume);
let target_value = calculate_result_value(&portfolio_info, converter.clone(), &mut portfolio.assets)?;
portfolio.target_cash_assets = portfolio.target_net_value - target_value;
let (interim_trade_commissions, interim_additional_commissions) =
calculate_total_commissions(portfolio, converter.clone())?;
let interim_total_commissions = interim_trade_commissions + interim_additional_commissions;
assert!(portfolio.commissions.is_zero());
portfolio.change_commission(interim_total_commissions);
distribute_cash_assets(portfolio, &portfolio_info, converter.clone())?;
let (trade_commissions, additional_commissions) = calculate_total_commissions(portfolio, converter)?;
assert_eq!(
portfolio.commissions - interim_total_commissions,
trade_commissions - interim_trade_commissions,
);
portfolio.change_commission(additional_commissions - interim_additional_commissions);
Ok(())
}
fn calculate_restrictions(assets: &mut Vec<AssetAllocation>) -> (Decimal, Option<Decimal>) {
let mut total_min_value = dec!(0);
let mut total_max_value = dec!(0);
let mut all_with_max_value = true;
for asset in assets {
let (min_value, max_value) = match &mut asset.holding {
Holding::Group(holdings) => calculate_restrictions(holdings),
Holding::Stock(_) => {
let min_value = if asset.restrict_selling.unwrap_or(false) {
asset.current_value
} else {
dec!(0)
};
let max_value = if asset.restrict_buying.unwrap_or(false) {
Some(asset.current_value)
} else {
None
};
(min_value, max_value)
},
};
asset.min_value = min_value;
asset.max_value = max_value;
if asset.expected_weight.is_zero() {
propagate_zero_weight(asset)
}
total_min_value += asset.min_value;
if let Some(max_value) = asset.max_value {
assert!(max_value >= asset.min_value);
total_max_value += max_value;
} else {
all_with_max_value = false;
}
}
let total_max_value = if all_with_max_value {
Some(total_max_value)
} else {
None
};
(total_min_value, total_max_value)
}
fn propagate_zero_weight(asset: &mut AssetAllocation) {
if asset.min_value.is_zero() {
if let Holding::Group(ref mut holdings) = asset.holding {
for holding in holdings {
assert!(holding.min_value.is_zero());
propagate_zero_weight(holding);
}
}
} else if let Some(max_value) = asset.max_value {
assert_eq!(max_value, asset.min_value);
}
asset.max_value = Some(asset.min_value);
}
struct AssetGroupRebalancer<'a> {
name: &'a str,
assets: &'a mut Vec<AssetAllocation>,
target_total_value: Decimal,
min_trade_volume: Decimal,
balance: Decimal,
}
impl<'a> AssetGroupRebalancer<'a> {
fn rebalance(
name: &str, assets: &mut Vec<AssetAllocation>, target_total_value: Decimal,
min_trade_volume: Decimal
) -> Decimal {
let mut rebalancer = AssetGroupRebalancer {
name, assets, target_total_value, min_trade_volume,
balance: dec!(0),
};
debug!("{name}:", name=name);
rebalancer.calculate_initial_target_values();
rebalancer.apply_restrictions();
rebalancer.correct_balance();
rebalancer.propagate_changes();
rebalancer.balance
}
fn calculate_initial_target_values(&mut self) {
debug!("* Initial target values:");
for asset in self.assets.iter_mut() {
asset.target_value = self.target_total_value * asset.expected_weight;
debug!(" * {name}: {current_value} -> {target_value}",
name=asset.full_name(), current_value=asset.current_value.normalize(),
target_value=asset.target_value.normalize());
}
let state = self.get_current_state();
for asset in self.assets.iter_mut() {
let mut difference = asset.target_value - asset.current_value;
if let Holding::Stock(ref holding) = asset.holding {
let trade_granularity = holding.trade_granularity();
difference = util::round(difference / trade_granularity, 0) * trade_granularity;
if difference.is_sign_negative() && -difference > asset.current_value {
difference += trade_granularity
}
}
if difference.abs() < self.min_trade_volume {
difference = dec!(0);
}
let target_value = asset.current_value + difference;
self.balance += asset.target_value - target_value;
asset.target_value = target_value;
}
self.log_state_changes("Rounding", state);
}
fn apply_restrictions(&mut self) {
let state = self.get_current_state();
let mut logged = false;
let mut log_restriction_applying = |name: &str, action: &str, value: Decimal| {
if !logged {
debug!("* Applying restrictions:");
logged = true;
}
debug!(" * {name}: {action} is blocked at {value}",
name=name, action=action, value=value.normalize());
};
for asset in self.assets.iter_mut() {
if let Some(max_value) = asset.max_value {
if asset.target_value > max_value {
if asset.restrict_buying.unwrap_or(false) && asset.target_value > asset.current_value {
log_restriction_applying(&asset.full_name(), "buying", max_value);
asset.buy_blocked = true;
}
self.balance += asset.target_value - max_value;
asset.target_value = max_value;
}
}
let min_value = asset.min_value;
if asset.target_value < min_value {
log_restriction_applying(&asset.full_name(), "selling", min_value);
asset.sell_blocked = true;
self.balance += asset.target_value - min_value;
asset.target_value = min_value;
}
}
self.log_state_changes("Restrictions applying", state);
}
fn correct_balance(&mut self) {
let state = self.get_current_state();
for trade_type in [TradeType::Sell, TradeType::Buy].iter().cloned() {
let mut correctable_assets: HashSet<usize> = (0..self.assets.len()).collect();
while match trade_type {
TradeType::Sell => self.balance.is_sign_negative(),
TradeType::Buy => self.balance.is_sign_positive(),
} {
let mut best_trade: Option<PossibleTrade> = None;
for index in correctable_assets.iter().cloned().collect::<Vec<_>>() {
let asset = &mut self.assets[index];
let expected_value = self.target_total_value * asset.expected_weight;
let possible_trade = calculate_min_trade_volume(
trade_type, asset, expected_value, self.balance, self.min_trade_volume);
match possible_trade {
Some(mut trade) => {
trade.path.push(index);
best_trade = Some(get_best_trade(trade_type, best_trade, trade));
},
None => {
correctable_assets.remove(&index);
},
};
}
let trade = match best_trade {
Some(trade) => trade,
None => break,
};
assert_eq!(trade.path.len(), 1);
let asset = &mut self.assets[*trade.path.last().unwrap()];
asset.target_value += trade.volume;
self.balance -= trade.volume;
}
}
self.log_state_changes("Balance correction", state);
}
fn propagate_changes(&mut self) {
let state = self.get_current_state();
let mut propagated = false;
for asset in self.assets.iter_mut() {
let asset_name = asset.full_name();
if let Holding::Group(ref mut holdings) = asset.holding {
let balance = AssetGroupRebalancer::rebalance(
&asset_name, holdings, asset.target_value, self.min_trade_volume);
asset.target_value -= balance;
self.balance += balance;
propagated = true;
}
}
if propagated {
debug!("{name}:", name=self.name);
self.log_state_changes("Target value change propagation", state);
}
}
fn get_current_state(&self) -> Option<AssetGroupRebalancingState> {
if !log_enabled!(log::Level::Debug) {
return None;
}
let mut state = AssetGroupRebalancingState {
target_values: Vec::new(),
balance: self.balance,
};
for asset in self.assets.iter() {
state.target_values.push(asset.target_value);
}
Some(state)
}
fn log_state_changes(&self, changes_summary: &str, prev_state: Option<AssetGroupRebalancingState>) {
let prev_state = match prev_state {
Some(state) => state,
None => return,
};
let changed = self.balance != prev_state.balance ||
self.assets.iter().enumerate().any(|item| {
let (index, asset) = item;
asset.target_value != prev_state.target_values[index]
});
if !changed {
return;
}
debug!("* {changes_summary} ({prev_balance} -> {balance}):",
changes_summary=changes_summary, prev_balance=prev_state.balance.normalize(),
balance=self.balance.normalize());
for (index, asset) in self.assets.iter().enumerate() {
let prev_target_value = prev_state.target_values[index];
if prev_target_value != asset.target_value {
debug!(" * {name}: {prev_target_value} -> {target_value}",
name=asset.full_name(), prev_target_value=prev_target_value.normalize(),
target_value=asset.target_value.normalize())
}
}
}
}
struct PortfolioInfo {
broker: BrokerInfo,
currency: String,
net_value: Decimal,
}
impl PortfolioInfo {
fn new(portfolio: &Portfolio) -> PortfolioInfo {
PortfolioInfo {
broker: portfolio.broker.clone(),
currency: portfolio.currency.clone(),
net_value: portfolio.current_net_value,
}
}
}
struct AssetGroupRebalancingState {
target_values: Vec<Decimal>,
balance: Decimal,
}
fn calculate_result_value(
portfolio: &PortfolioInfo, converter: CurrencyConverterRc,
assets: &mut [AssetAllocation],
) -> GenericResult<Decimal> {
let mut total_value = dec!(0);
for asset in assets.iter_mut() {
let name = asset.full_name();
total_value += match asset.holding {
Holding::Stock(ref mut holding) => {
assert_eq!(holding.target_shares, holding.current_shares);
change_to(portfolio, converter.clone(), &name, holding, asset.target_value)?;
asset.target_value
},
Holding::Group(ref mut holdings) => {
calculate_result_value(portfolio, converter.clone(), holdings)?
},
};
}
Ok(total_value)
}
fn distribute_cash_assets(
portfolio: &mut Portfolio, portfolio_info: &PortfolioInfo, converter: CurrencyConverterRc,
) -> EmptyResult {
debug!("");
debug!("Cash assets distribution:");
for trade_type in [TradeType::Sell, TradeType::Buy].iter().cloned() {
loop {
let free_cash_assets = portfolio.target_cash_assets - portfolio.min_cash_assets;
if !match trade_type {
TradeType::Sell => free_cash_assets.is_sign_negative(),
TradeType::Buy => free_cash_assets.is_sign_positive(),
} {
break;
}
let expected_net_value = portfolio.target_net_value - portfolio.min_cash_assets;
let trade = find_assets_for_cash_distribution(
trade_type, &portfolio.assets, expected_net_value, free_cash_assets,
portfolio.min_trade_volume);
let trade = match trade {
Some(trade) => trade,
None => break,
};
portfolio.target_cash_assets -= trade.volume;
let commission = process_trade(
portfolio_info, converter.clone(), &mut portfolio.assets, trade)?;
portfolio.change_commission(commission);
}
}
Ok(())
}
struct PossibleTrade {
path: Vec<usize>,
volume: Decimal,
result: f64,
}
fn find_assets_for_cash_distribution(
trade_type: TradeType, assets: &[AssetAllocation], expected_total_value: Decimal,
cash_assets: Decimal, min_trade_volume: Decimal
) -> Option<PossibleTrade> {
let mut best_trade: Option<PossibleTrade> = None;
for (index, asset) in assets.iter().enumerate() {
let expected_value = expected_total_value * asset.expected_weight;
let trade = match asset.holding {
Holding::Stock(_) => {
calculate_min_trade_volume(
trade_type, asset, expected_value, cash_assets, min_trade_volume)
},
Holding::Group(ref holdings) => {
let mut trade = find_assets_for_cash_distribution(
trade_type, holdings, expected_value, cash_assets, min_trade_volume);
if let Some(ref mut trade) = trade {
trade.result = calculate_trade_result(
expected_value, asset.target_value, trade.volume);
}
trade
},
};
let trade = match trade {
Some(mut trade) => {
trade.path.push(index);
trade
},
None => continue,
};
best_trade = Some(get_best_trade(trade_type, best_trade, trade));
}
best_trade
}
fn process_trade(
portfolio: &PortfolioInfo, converter: CurrencyConverterRc,
assets: &mut [AssetAllocation], mut trade: PossibleTrade,
) -> GenericResult<Decimal> {
let index = trade.path.pop().unwrap();
let asset = &mut assets[index];
let name = asset.full_name();
let target_value = asset.target_value + trade.volume;
let commission = match asset.holding {
Holding::Stock(ref mut holding) => {
assert!(trade.path.is_empty());
debug!("* {name}: {prev_target_value} -> {target_value}",
name=name, prev_target_value=asset.target_value.normalize(),
target_value=target_value.normalize());
change_to(portfolio, converter, &name, holding, target_value)?
},
Holding::Group(ref mut holdings) => {
process_trade(portfolio, converter, holdings, trade)?
},
};
asset.target_value = target_value;
Ok(commission)
}
fn calculate_min_trade_volume(
trade_type: TradeType, asset: &AssetAllocation, expected_value: Decimal,
cash_assets: Decimal, min_trade_volume: Decimal
) -> Option<PossibleTrade> {
let trade_volume = match trade_type {
TradeType::Sell => calculate_min_sell_volume(asset, min_trade_volume),
TradeType::Buy => {
match calculate_min_buy_volume(asset, min_trade_volume) {
Some(trade_volume) if trade_volume <= cash_assets => Some(trade_volume),
_ => None,
}
},
};
let trade_volume = match trade_volume {
Some(trade_volume) => match trade_type {
TradeType::Sell => -trade_volume,
TradeType::Buy => trade_volume,
},
None => return None,
};
Some(PossibleTrade {
path: Vec::new(),
volume: trade_volume,
result: calculate_trade_result(expected_value, asset.target_value, trade_volume),
})
}
fn calculate_trade_result(expected_value: Decimal, target_value: Decimal, trade_volume: Decimal) -> f64 {
let result_value = target_value + trade_volume;
if expected_value.is_zero() {
if result_value.is_zero() {
1.0
} else {
f64::MAX
}
} else {
let result_value = result_value.to_f64().unwrap();
let expected_value = expected_value.to_f64().unwrap();
result_value / expected_value
}
}
fn get_best_trade(trade_type: TradeType, best_trade: Option<PossibleTrade>, trade: PossibleTrade) -> PossibleTrade {
match best_trade {
Some(best_trade) => {
#[allow(clippy::float_cmp)]
if best_trade.result == trade.result {
if best_trade.volume <= trade.volume {
best_trade
} else {
trade
}
} else if match trade_type {
TradeType::Sell => best_trade.result > trade.result,
TradeType::Buy => best_trade.result < trade.result,
} {
best_trade
} else {
trade
}
},
None => trade,
}
}
fn change_to(
portfolio: &PortfolioInfo, converter: CurrencyConverterRc,
name: &str, holding: &mut StockHolding, target_value: Decimal,
) -> GenericResult<Decimal> {
let calculate_commission = |target_shares| -> GenericResult<Decimal> {
let mut commission_calc = CommissionCalc::new(
converter.clone(), portfolio.broker.commission_spec.clone(),
Cash::new(&portfolio.currency, portfolio.net_value))?;
calculate_target_commission(
name, holding, target_shares, &mut commission_calc,
&portfolio.currency, converter.clone())
};
let target_shares = (target_value / holding.price).normalize();
assert!(util::decimal_precision(target_shares) <= util::decimal_precision(holding.current_shares));
let paid_commission = calculate_commission(holding.target_shares)?;
let current_commission = calculate_commission(target_shares)?;
holding.target_shares = target_shares;
Ok(current_commission - paid_commission)
}
fn calculate_target_commission(
name: &str, holding: &StockHolding, target_shares: Decimal, commission_calc: &mut CommissionCalc,
currency: &str, converter: CurrencyConverterRc,
) -> GenericResult<Decimal> {
if target_shares == holding.current_shares {
return Ok(dec!(0))
}
let (trade_type, shares) = if target_shares > holding.current_shares {
(TradeType::Buy, target_shares - holding.current_shares)
} else {
(TradeType::Sell, holding.current_shares - target_shares)
};
let date = time::today_trade_conclusion_time().date;
let commission = commission_calc.add_trade(date, trade_type, shares, holding.currency_price)
.map_err(|e| format!("{}: {}", name, e))?;
converter.convert_to(date, commission, currency)
}
fn calculate_total_commissions(portfolio: &Portfolio, converter: CurrencyConverterRc) -> GenericResult<(Decimal, Decimal)> {
let mut calc = CommissionCalc::new(
converter.clone(), portfolio.broker.commission_spec.clone(),
Cash::new(&portfolio.currency, portfolio.current_net_value))?;
let trade_commissions = calculate_trade_commissions(
&portfolio.assets, &mut calc, &portfolio.currency, converter.clone())?;
let date = time::today_trade_conclusion_time().date;
let mut additional_commissions = dec!(0);
for commissions in calc.calculate()?.values() {
for commission in commissions.iter() {
additional_commissions += converter.convert_to(date, commission, &portfolio.currency)?;
}
}
Ok((trade_commissions, additional_commissions))
}
fn calculate_trade_commissions(
assets: &[AssetAllocation], calc: &mut CommissionCalc,
currency: &str, converter: CurrencyConverterRc,
) -> GenericResult<Decimal> {
let mut trade_commissions = dec!(0);
for asset in assets {
match &asset.holding {
Holding::Stock(holding) => {
trade_commissions += calculate_target_commission(
&asset.full_name(), holding, holding.target_shares, calc,
currency, converter.clone(),
)?;
},
Holding::Group(assets) => {
trade_commissions += calculate_trade_commissions(
assets, calc, currency, converter.clone())?;
},
}
}
Ok(trade_commissions)
}
fn calculate_min_sell_volume(asset: &AssetAllocation, min_trade_volume: Decimal) -> Option<Decimal> {
let trade_granularity = asset.iterative_trading_granularity(TradeType::Sell);
let trade_volume = if asset.target_value <= asset.current_value {
if asset.target_value <= asset.current_value - min_trade_volume {
trade_granularity
} else {
round_min_trade_volume(
min_trade_volume - (asset.current_value - asset.target_value),
trade_granularity
)
}
} else {
if asset.target_value - trade_granularity >= asset.current_value + min_trade_volume {
trade_granularity
} else {
asset.target_value - asset.current_value
}
};
if asset.target_value - trade_volume < asset.min_value {
return None
}
Some(trade_volume)
}
fn calculate_min_buy_volume(asset: &AssetAllocation, min_trade_volume: Decimal) -> Option<Decimal> {
let trade_granularity = asset.iterative_trading_granularity(TradeType::Buy);
let trade_volume = if asset.target_value >= asset.current_value {
if asset.target_value >= asset.current_value + min_trade_volume {
trade_granularity
} else {
round_min_trade_volume(
min_trade_volume - (asset.target_value - asset.current_value),
trade_granularity,
)
}
} else {
if asset.target_value + trade_granularity <= asset.current_value - min_trade_volume {
trade_granularity
} else {
asset.current_value - asset.target_value
}
};
if let Some(max_value) = asset.max_value {
if asset.target_value + trade_volume > max_value {
return None;
}
}
Some(trade_volume)
}
fn round_min_trade_volume(volume: Decimal, granularity: Decimal) -> Decimal {
(volume / granularity).ceil() * granularity
}