use std::collections::{BTreeMap, HashMap};
use nalgebra::{DMatrix, DVector};
use crate::{PortfolioError, Result};
pub fn get_latest_prices(prices: &DMatrix<f64>) -> Result<DVector<f64>> {
let rows = prices.nrows();
if rows == 0 {
return Err(PortfolioError::InvalidArgument(
"price matrix has no rows".into(),
));
}
Ok(prices.row(rows - 1).transpose().into_owned())
}
pub fn get_latest_prices_labeled(
prices: &DMatrix<f64>,
tickers: &[String],
) -> Result<BTreeMap<String, f64>> {
if prices.ncols() != tickers.len() {
return Err(PortfolioError::DimensionMismatch(format!(
"prices has {} columns but {} tickers were supplied",
prices.ncols(),
tickers.len()
)));
}
let latest = get_latest_prices(prices)?;
Ok(tickers
.iter()
.zip(latest.iter())
.map(|(t, v)| (t.clone(), *v))
.collect())
}
pub struct DiscreteAllocation {
pub weights: DVector<f64>,
pub prices: DVector<f64>,
pub total_portfolio_value: f64,
pub short_ratio: Option<f64>,
pub tickers: Option<Vec<String>>,
}
impl DiscreteAllocation {
pub fn new(
weights: DVector<f64>,
prices: DVector<f64>,
total_portfolio_value: f64,
) -> Result<Self> {
let n = weights.len();
if prices.len() != n {
return Err(PortfolioError::DimensionMismatch(format!(
"weights has length {} but prices has length {}",
n,
prices.len()
)));
}
if total_portfolio_value <= 0.0 {
return Err(PortfolioError::InvalidArgument(
"total_portfolio_value must be positive".into(),
));
}
for (i, &p) in prices.iter().enumerate() {
if p <= 0.0 {
return Err(PortfolioError::InvalidArgument(format!(
"price[{i}] = {p} is non-positive"
)));
}
}
Ok(Self {
weights,
prices,
total_portfolio_value,
short_ratio: None,
tickers: None,
})
}
pub fn new_labeled(
weights: BTreeMap<String, f64>,
latest_prices: BTreeMap<String, f64>,
total_portfolio_value: f64,
) -> Result<Self> {
let tickers: Vec<String> = weights
.keys()
.filter(|t| latest_prices.contains_key(*t))
.cloned()
.collect();
if tickers.is_empty() {
return Err(PortfolioError::InvalidArgument(
"weights and latest_prices share no common tickers".into(),
));
}
let w = DVector::from_iterator(tickers.len(), tickers.iter().map(|t| weights[t]));
let p = DVector::from_iterator(tickers.len(), tickers.iter().map(|t| latest_prices[t]));
let mut da = Self::new(w, p, total_portfolio_value)?;
da.tickers = Some(tickers);
Ok(da)
}
pub fn with_short_ratio(mut self, short_ratio: f64) -> Result<Self> {
if short_ratio < 0.0 {
return Err(PortfolioError::InvalidArgument(
"short_ratio must be non-negative".into(),
));
}
self.short_ratio = Some(short_ratio);
Ok(self)
}
fn effective_short_ratio(&self) -> f64 {
if let Some(r) = self.short_ratio {
return r;
}
self.weights.iter().filter(|w| **w < 0.0).map(|w| -w).sum()
}
pub fn greedy_portfolio(&self) -> Result<(HashMap<usize, i64>, f64)> {
self.greedy_portfolio_with_options(false, false)
}
pub fn greedy_portfolio_with_options(
&self,
reinvest: bool,
_verbose: bool,
) -> Result<(HashMap<usize, i64>, f64)> {
let n = self.weights.len();
let has_shorts = self.weights.iter().any(|w| *w < 0.0);
let short_ratio = self.effective_short_ratio();
let (long_budget, short_budget) = if has_shorts && short_ratio > 0.0 {
let short_val = self.total_portfolio_value * short_ratio;
let long_val = if reinvest {
self.total_portfolio_value + short_val
} else {
self.total_portfolio_value
};
(long_val, short_val)
} else {
(self.total_portfolio_value, 0.0)
};
let long_indices: Vec<usize> = (0..n).filter(|&i| self.weights[i] > 0.0).collect();
let long_total: f64 = long_indices.iter().map(|&i| self.weights[i]).sum();
let mut shares = vec![0_i64; n];
let mut leftover_long = long_budget;
if long_total > 0.0 && long_budget > 0.0 {
let (s, lo) = greedy_one_side(
&long_indices,
&self.weights,
&self.prices,
long_budget,
long_total,
false,
);
for (i, sh) in s {
shares[i] = sh;
}
leftover_long = lo;
}
let short_indices: Vec<usize> = (0..n).filter(|&i| self.weights[i] < 0.0).collect();
let short_total: f64 = short_indices.iter().map(|&i| -self.weights[i]).sum();
let mut leftover_short = short_budget;
if has_shorts && short_total > 0.0 && short_budget > 0.0 {
let (s, lo) = greedy_one_side(
&short_indices,
&self.weights,
&self.prices,
short_budget,
short_total,
true,
);
for (i, sh) in s {
shares[i] = sh;
}
leftover_short = lo;
}
let allocation: HashMap<usize, i64> = shares
.iter()
.enumerate()
.filter(|(_, &s)| s != 0)
.map(|(i, &s)| (i, s))
.collect();
Ok((allocation, leftover_long + leftover_short))
}
pub fn rounded_portfolio(&self) -> Result<(HashMap<usize, i64>, f64)> {
let n = self.weights.len();
let mut allocation = HashMap::new();
let mut spent = 0.0_f64;
for i in 0..n {
let ideal = self.weights[i] * self.total_portfolio_value / self.prices[i];
let s = ideal.round() as i64;
if s != 0 {
allocation.insert(i, s);
spent += s as f64 * self.prices[i];
}
}
let leftover = self.total_portfolio_value - spent;
Ok((allocation, leftover))
}
pub fn allocation_value(&self, allocation: &HashMap<usize, i64>) -> f64 {
allocation
.iter()
.map(|(&i, &s)| s as f64 * self.prices[i])
.sum()
}
fn require_tickers(&self) -> Result<&[String]> {
self.tickers.as_deref().ok_or_else(|| {
PortfolioError::InvalidArgument(
"this DiscreteAllocation has no ticker labels; build with new_labeled to use \
the *_labeled API"
.into(),
)
})
}
fn relabel_alloc(&self, alloc: HashMap<usize, i64>) -> Result<HashMap<String, i64>> {
let tickers = self.require_tickers()?;
Ok(alloc
.into_iter()
.map(|(i, s)| (tickers[i].clone(), s))
.collect())
}
pub fn greedy_portfolio_labeled(&self) -> Result<(HashMap<String, i64>, f64)> {
let (alloc, leftover) = self.greedy_portfolio()?;
Ok((self.relabel_alloc(alloc)?, leftover))
}
pub fn greedy_portfolio_with_options_labeled(
&self,
reinvest: bool,
verbose: bool,
) -> Result<(HashMap<String, i64>, f64)> {
let (alloc, leftover) = self.greedy_portfolio_with_options(reinvest, verbose)?;
Ok((self.relabel_alloc(alloc)?, leftover))
}
pub fn rounded_portfolio_labeled(&self) -> Result<(HashMap<String, i64>, f64)> {
let (alloc, leftover) = self.rounded_portfolio()?;
Ok((self.relabel_alloc(alloc)?, leftover))
}
pub fn allocation_value_labeled(&self, allocation: &HashMap<String, i64>) -> Result<f64> {
let tickers = self.require_tickers()?;
let mut value = 0.0;
for (t, &s) in allocation {
let i = tickers
.iter()
.position(|x| x == t)
.ok_or_else(|| PortfolioError::InvalidArgument(format!("unknown ticker {t}")))?;
value += s as f64 * self.prices[i];
}
Ok(value)
}
}
fn greedy_one_side(
indices: &[usize],
weights: &DVector<f64>,
prices: &DVector<f64>,
budget: f64,
weight_total: f64,
is_short: bool,
) -> (Vec<(usize, i64)>, f64) {
let mut ideal = vec![0.0_f64; indices.len()];
for (k, &i) in indices.iter().enumerate() {
let w = weights[i].abs() / weight_total;
ideal[k] = w * budget / prices[i];
}
let mut shares: Vec<i64> = ideal.iter().map(|&x| x.trunc() as i64).collect();
let spent: f64 = (0..indices.len())
.map(|k| shares[k] as f64 * prices[indices[k]])
.sum();
let mut remaining = budget - spent;
let mut fracs: Vec<(usize, f64)> = (0..indices.len())
.map(|k| (k, ideal[k] - ideal[k].trunc()))
.collect();
fracs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for (k, _) in &fracs {
let i = indices[*k];
let price = prices[i];
if remaining >= price - 1e-9 {
shares[*k] += 1;
remaining -= price;
}
}
let mut out = Vec::with_capacity(indices.len());
for (k, &i) in indices.iter().enumerate() {
let s = if is_short { -shares[k] } else { shares[k] };
if s != 0 {
out.push((i, s));
}
}
(out, remaining)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use nalgebra::dmatrix;
#[test]
fn get_latest_prices_returns_last_row() {
let prices = dmatrix![
100.0, 200.0;
101.0, 199.0;
102.0, 198.0
];
let latest = get_latest_prices(&prices).unwrap();
assert_eq!(latest.len(), 2);
assert_relative_eq!(latest[0], 102.0);
assert_relative_eq!(latest[1], 198.0);
}
#[test]
fn greedy_does_not_exceed_budget() {
let weights = DVector::from_vec(vec![0.6, 0.3, 0.1]);
let prices = DVector::from_vec(vec![100.0, 50.0, 25.0]);
let budget = 10_000.0;
let da = DiscreteAllocation::new(weights, prices, budget).unwrap();
let (alloc, leftover) = da.greedy_portfolio().unwrap();
assert!(
leftover >= -1e-9,
"leftover should be non-negative, got {leftover}"
);
let spent = da.allocation_value(&alloc);
assert!(
spent <= budget + 1e-9,
"spent {spent} exceeds budget {budget}"
);
}
#[test]
fn greedy_allocates_close_to_target() {
let weights = DVector::from_vec(vec![0.5, 0.5]);
let prices = DVector::from_vec(vec![100.0, 100.0]);
let budget = 10_000.0;
let da = DiscreteAllocation::new(weights, prices, budget).unwrap();
let (alloc, leftover) = da.greedy_portfolio().unwrap();
assert_eq!(*alloc.get(&0).unwrap_or(&0), 50);
assert_eq!(*alloc.get(&1).unwrap_or(&0), 50);
assert_relative_eq!(leftover, 0.0, epsilon = 1e-9);
}
#[test]
fn greedy_handles_fractional_shares() {
let weights = DVector::from_vec(vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]);
let prices = DVector::from_vec(vec![7.0, 11.0, 13.0]);
let budget = 1_000.0;
let da = DiscreteAllocation::new(weights, prices, budget).unwrap();
let (alloc, leftover) = da.greedy_portfolio().unwrap();
assert!(leftover >= -1e-9);
let spent = da.allocation_value(&alloc);
assert!(spent <= budget + 1e-9);
}
#[test]
fn dimension_mismatch_errors() {
let w = DVector::from_vec(vec![0.5, 0.5]);
let p = DVector::from_vec(vec![100.0]);
assert!(DiscreteAllocation::new(w, p, 1000.0).is_err());
}
#[test]
fn non_positive_price_errors() {
let w = DVector::from_vec(vec![1.0]);
let p = DVector::from_vec(vec![0.0]);
assert!(DiscreteAllocation::new(w, p, 1000.0).is_err());
}
#[test]
fn rounded_portfolio_allocates_nearest_share() {
let weights = DVector::from_vec(vec![0.5, 0.5]);
let prices = DVector::from_vec(vec![100.0, 100.0]);
let budget = 10_100.0; let da = DiscreteAllocation::new(weights, prices, budget).unwrap();
let (alloc, _) = da.rounded_portfolio().unwrap();
assert_eq!(*alloc.get(&0).unwrap_or(&0), 51);
assert_eq!(*alloc.get(&1).unwrap_or(&0), 51);
}
#[test]
fn greedy_handles_signed_weights_with_shorts() {
let weights = DVector::from_vec(vec![0.7, 0.6, -0.3]);
let prices = DVector::from_vec(vec![100.0, 200.0, 50.0]);
let da = DiscreteAllocation::new(weights, prices, 10_000.0).unwrap();
let (alloc, _leftover) = da.greedy_portfolio().unwrap();
assert!(alloc.get(&2).copied().unwrap_or(0) < 0);
assert!(alloc.get(&0).copied().unwrap_or(0) > 0);
assert!(alloc.get(&1).copied().unwrap_or(0) > 0);
}
#[test]
fn zero_weight_assets_are_absent_from_allocation() {
let weights = DVector::from_vec(vec![1.0, 0.0]);
let prices = DVector::from_vec(vec![50.0, 200.0]);
let da = DiscreteAllocation::new(weights, prices, 5_000.0).unwrap();
let (alloc, _) = da.greedy_portfolio().unwrap();
assert!(!alloc.contains_key(&1));
}
#[test]
fn labeled_api_aligns_weights_and_prices_by_ticker() {
let mut weights = BTreeMap::new();
weights.insert("MSFT".to_string(), 0.4);
weights.insert("AAPL".to_string(), 0.6);
let mut prices = BTreeMap::new();
prices.insert("AAPL".to_string(), 100.0);
prices.insert("MSFT".to_string(), 50.0);
let da = DiscreteAllocation::new_labeled(weights, prices, 10_000.0).unwrap();
assert_eq!(da.tickers.as_deref().unwrap(), &["AAPL", "MSFT"]);
let (shares, leftover) = da.greedy_portfolio_labeled().unwrap();
assert!(leftover >= -1e-9);
assert_eq!(*shares.get("AAPL").unwrap(), 60);
assert_eq!(*shares.get("MSFT").unwrap(), 80);
}
#[test]
fn labeled_api_intersects_ticker_sets() {
let mut weights = BTreeMap::new();
weights.insert("AAPL".to_string(), 0.5);
weights.insert("MSFT".to_string(), 0.5);
let mut prices = BTreeMap::new();
prices.insert("AAPL".to_string(), 100.0);
prices.insert("GOOG".to_string(), 200.0);
let da = DiscreteAllocation::new_labeled(weights, prices, 1_000.0).unwrap();
assert_eq!(da.tickers.as_deref().unwrap(), &["AAPL"]);
}
#[test]
fn labeled_api_errors_when_no_common_tickers() {
let mut weights = BTreeMap::new();
weights.insert("AAPL".to_string(), 1.0);
let mut prices = BTreeMap::new();
prices.insert("MSFT".to_string(), 100.0);
assert!(DiscreteAllocation::new_labeled(weights, prices, 1_000.0).is_err());
}
#[test]
fn labeled_api_errors_without_tickers() {
let weights = DVector::from_vec(vec![0.5, 0.5]);
let prices = DVector::from_vec(vec![100.0, 100.0]);
let da = DiscreteAllocation::new(weights, prices, 1_000.0).unwrap();
assert!(da.greedy_portfolio_labeled().is_err());
}
#[test]
fn get_latest_prices_labeled_returns_alphabetical_map() {
let prices = dmatrix![
100.0, 200.0;
101.0, 199.0
];
let tickers = vec!["MSFT".to_string(), "AAPL".to_string()];
let map = get_latest_prices_labeled(&prices, &tickers).unwrap();
let keys: Vec<&String> = map.keys().collect();
assert_eq!(keys, vec!["AAPL", "MSFT"]);
assert_relative_eq!(*map.get("MSFT").unwrap(), 101.0);
assert_relative_eq!(*map.get("AAPL").unwrap(), 199.0);
}
}