use std::collections::{HashMap, VecDeque};
use arrow::array::{Float64Array, Int32Array, StringViewArray};
use crate::config::BacktestConfig;
use crate::position::{Position, PositionSnapshot};
use crate::tracker::{BacktestResult, TradeRecord, NoopSymbolTracker, SymbolTracker, TradeTracker};
use crate::{is_valid_price, FLOAT_EPSILON};
pub struct Portfolio {
pub cash: f64,
pub positions: HashMap<String, Position>,
}
impl Portfolio {
pub fn new() -> Self {
Self {
cash: 1.0,
positions: HashMap::new(),
}
}
pub fn balance(&self) -> f64 {
self.cash + self.positions.values().map(|p| p.last_market_value).sum::<f64>()
}
}
impl Default for Portfolio {
fn default() -> Self {
Self::new()
}
}
pub struct LongFormatArrowInput<'a> {
pub dates: &'a Int32Array,
pub symbols: &'a StringViewArray,
pub prices: &'a Float64Array,
pub weights: &'a Float64Array,
pub open_prices: Option<&'a Float64Array>,
pub high_prices: Option<&'a Float64Array>,
pub low_prices: Option<&'a Float64Array>,
pub factor: Option<&'a Float64Array>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ResampleFreq {
Daily,
Weekly,
WeeklyOn(u8),
Monthly,
MonthStart,
Quarterly,
QuarterStart,
Yearly,
PositionChange,
}
impl ResampleFreq {
pub fn from_str(s: Option<&str>) -> Self {
match s {
Some("M") | Some("ME") => Self::Monthly,
Some("MS") => Self::MonthStart,
Some("W") | Some("W-FRI") => Self::Weekly,
Some("W-MON") => Self::WeeklyOn(0),
Some("W-TUE") => Self::WeeklyOn(1),
Some("W-WED") => Self::WeeklyOn(2),
Some("W-THU") => Self::WeeklyOn(3),
Some("W-SAT") => Self::WeeklyOn(5),
Some("W-SUN") => Self::WeeklyOn(6),
Some("D") => Self::Daily,
Some("Q") | Some("QE") => Self::Quarterly,
Some("QS") => Self::QuarterStart,
Some("Y") | Some("YE") | Some("A") => Self::Yearly,
None => Self::PositionChange,
_ => Self::Daily,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct ResampleOffset {
pub days: i32,
}
impl ResampleOffset {
pub fn new(days: i32) -> Self {
Self { days: days.max(0) } }
pub fn from_str(s: Option<&str>) -> Option<Self> {
let s = s?.trim();
if s.is_empty() {
return None;
}
if s.ends_with('d') || s.ends_with('D') {
let num_str = &s[..s.len() - 1];
if let Ok(days) = num_str.parse::<i32>() {
if days >= 0 {
return Some(Self { days });
}
}
}
if s.ends_with('W') || s.ends_with('w') {
let num_str = &s[..s.len() - 1];
if let Ok(weeks) = num_str.parse::<i32>() {
if weeks >= 0 {
return Some(Self { days: weeks * 7 });
}
}
}
None
}
}
#[derive(Debug, Clone, Copy)]
struct DelayedRebalance {
target_date: i32,
}
pub struct OhlcGetters<FO, FH, FL>
where
FO: Fn(usize) -> f64,
FH: Fn(usize) -> f64,
FL: Fn(usize) -> f64,
{
pub get_open: FO,
pub get_high: FH,
pub get_low: FL,
}
fn backtest_impl<'a, FD, FS, FP, FW, FF, T, FO, FH, FL>(
n_rows: usize,
get_date: FD,
get_symbol: FS,
get_price: FP,
get_weight: FW,
get_factor: FF,
resample: ResampleFreq,
offset: Option<ResampleOffset>,
config: &BacktestConfig,
tracker: &mut T,
ohlc: Option<OhlcGetters<FO, FH, FL>>,
) -> (Vec<i32>, Vec<f64>)
where
FD: Fn(usize) -> i32,
FS: Fn(usize) -> &'a str,
FP: Fn(usize) -> f64,
FW: Fn(usize) -> f64,
FF: Fn(usize) -> f64,
FO: Fn(usize) -> f64,
FH: Fn(usize) -> f64,
FL: Fn(usize) -> f64,
T: TradeTracker<Key = String, Date = i32, Record = TradeRecord>,
{
if n_rows == 0 {
return (vec![], vec![]);
}
let mut portfolio = Portfolio::new();
let mut dates: Vec<i32> = Vec::new();
let mut creturn: Vec<f64> = Vec::new();
let mut stopped_stocks: HashMap<String, bool> = HashMap::new();
let mut pending_weights: Option<HashMap<String, f64>> = None;
let mut pending_signal_date: Option<i32> = None;
let mut pending_stop_exits: Vec<String> = Vec::new();
let mut active_weights: HashMap<String, f64> = HashMap::new();
let mut has_first_signal = false;
let mut position_changed = false;
let mut current_date: Option<i32> = None;
let mut today_prices: HashMap<&str, f64> = HashMap::new();
let mut today_weights: HashMap<&str, f64> = HashMap::new();
let mut today_factor: HashMap<&str, f64> = HashMap::new();
let mut delayed_rebalances: VecDeque<DelayedRebalance> = VecDeque::new();
let offset_days = offset.map(|o| o.days).unwrap_or(0);
let mut today_open: HashMap<&str, f64> = HashMap::new();
let mut today_high: HashMap<&str, f64> = HashMap::new();
let mut today_low: HashMap<&str, f64> = HashMap::new();
let touched_exit_enabled = config.touched_exit && ohlc.is_some();
for i in 0..n_rows {
let date = get_date(i);
let symbol = get_symbol(i);
let price = get_price(i);
let weight = get_weight(i);
let date_changed = current_date.map_or(true, |d| d != date);
if date_changed && current_date.is_some() {
let prev_date = current_date.unwrap();
update_positions(&mut portfolio, &today_prices);
for (sym, _pos) in portfolio.positions.iter() {
if let Some(&price) = today_prices.get(sym.as_str()) {
tracker.record_price(sym, price, price);
}
}
if touched_exit_enabled {
let ohlc_refs = Some((&today_open, &today_high, &today_low));
let stops = detect_stops_unified(&portfolio, &today_prices, ohlc_refs, config);
execute_stops_impl(
&mut portfolio,
stops,
&mut stopped_stocks,
&today_prices,
&today_factor,
config,
prev_date,
Some(prev_date), tracker,
);
pending_stop_exits.clear();
} else {
let pending_results: Vec<StopResult> = pending_stop_exits
.drain(..)
.map(|sym| StopResult { symbol: sym, exit_ratio: 1.0 })
.collect();
execute_stops_impl(
&mut portfolio,
pending_results,
&mut stopped_stocks,
&today_prices,
&today_factor,
config,
prev_date,
None, tracker,
);
let new_stops = detect_stops_unified(&portfolio, &today_prices, None, config);
for stop in new_stops {
pending_stop_exits.push(stop.symbol);
}
}
update_previous_prices(&mut portfolio, &today_prices);
let delayed_triggered = false;
if delayed_triggered {
} else if let Some(target_weights) = pending_weights.take() {
let sig_date = pending_signal_date.take().unwrap_or(prev_date);
execute_rebalance_impl(
&mut portfolio,
&target_weights,
&today_prices,
&today_factor,
&stopped_stocks,
config,
prev_date,
sig_date,
tracker,
);
stopped_stocks.clear();
}
let normalized = normalize_weights(
&today_weights,
&stopped_stocks,
config.position_limit,
);
let is_month_end = is_month_end_i32(prev_date, date);
let is_week_end = is_week_end_i32(prev_date, date);
let is_quarter_end = is_quarter_end_i32(prev_date, date);
let is_year_end = is_year_end_i32(prev_date, date);
let prev_weekday = weekday_of_i32(prev_date);
let set_first_signal_at_boundary = offset_days == 0;
let has_signals = !normalized.is_empty();
match resample {
ResampleFreq::Daily => {
if has_signals && set_first_signal_at_boundary {
has_first_signal = true;
}
if has_first_signal {
active_weights = normalized.clone();
}
}
ResampleFreq::Weekly => {
if is_week_end {
active_weights = normalized.clone();
if has_signals && set_first_signal_at_boundary {
has_first_signal = true;
}
}
}
ResampleFreq::WeeklyOn(weekday) => {
if prev_weekday == weekday {
active_weights = normalized.clone();
if has_signals && set_first_signal_at_boundary {
has_first_signal = true;
}
}
}
ResampleFreq::Monthly | ResampleFreq::MonthStart => {
if is_month_end {
active_weights = normalized.clone();
if has_signals && set_first_signal_at_boundary {
has_first_signal = true;
}
}
}
ResampleFreq::Quarterly | ResampleFreq::QuarterStart => {
if is_quarter_end {
active_weights = normalized.clone();
if has_signals && set_first_signal_at_boundary {
has_first_signal = true;
}
}
}
ResampleFreq::Yearly => {
if is_year_end {
active_weights = normalized.clone();
if has_signals && set_first_signal_at_boundary {
has_first_signal = true;
}
}
}
ResampleFreq::PositionChange => {
if has_signals && set_first_signal_at_boundary {
has_first_signal = true;
}
position_changed = weights_differ(&active_weights, &normalized);
active_weights = normalized;
}
}
let should_rebalance = match resample {
ResampleFreq::Monthly | ResampleFreq::MonthStart => is_month_end,
ResampleFreq::Weekly => is_week_end,
ResampleFreq::WeeklyOn(weekday) => prev_weekday == weekday,
ResampleFreq::Quarterly | ResampleFreq::QuarterStart => is_quarter_end,
ResampleFreq::Yearly => is_year_end,
ResampleFreq::Daily => true,
ResampleFreq::PositionChange => position_changed,
};
if should_rebalance {
if offset_days > 0 {
let boundaries = get_all_period_boundaries(prev_date, date, resample);
for boundary in boundaries {
delayed_rebalances.push_back(DelayedRebalance {
target_date: boundary + offset_days,
});
}
} else if has_first_signal {
pending_weights = Some(active_weights.clone());
pending_signal_date = Some(prev_date);
}
}
let mut any_triggered = false;
while let Some(delayed) = delayed_rebalances.front() {
if date > delayed.target_date {
let _delayed = delayed_rebalances.pop_front().unwrap();
any_triggered = true;
} else {
break;
}
}
if any_triggered {
let signal_weights = normalize_weights(
&today_weights,
&stopped_stocks,
config.position_limit,
);
if !has_first_signal && !signal_weights.is_empty() {
has_first_signal = true;
}
pending_weights = Some(signal_weights);
pending_signal_date = Some(prev_date);
let _ = delayed_triggered; }
if has_first_signal {
dates.push(prev_date);
creturn.push(portfolio.balance());
}
today_prices.clear();
today_weights.clear();
today_factor.clear();
if touched_exit_enabled {
today_open.clear();
today_high.clear();
today_low.clear();
}
}
current_date = Some(date);
if is_valid_price(price) {
today_prices.insert(symbol, price);
}
if !weight.is_nan() && weight.abs() > FLOAT_EPSILON {
today_weights.insert(symbol, weight);
}
let factor = get_factor(i);
if factor.is_finite() && factor > 0.0 {
today_factor.insert(symbol, factor);
}
if let Some(ref ohlc_getters) = ohlc {
today_open.insert(symbol, (ohlc_getters.get_open)(i));
today_high.insert(symbol, (ohlc_getters.get_high)(i));
today_low.insert(symbol, (ohlc_getters.get_low)(i));
}
}
if let Some(last_date) = current_date {
if !today_prices.is_empty() {
update_positions(&mut portfolio, &today_prices);
for (sym, _pos) in portfolio.positions.iter() {
if let Some(&price) = today_prices.get(sym.as_str()) {
tracker.record_price(sym, price, price);
}
}
if touched_exit_enabled {
let ohlc_refs = Some((&today_open, &today_high, &today_low));
let stops = detect_stops_unified(&portfolio, &today_prices, ohlc_refs, config);
execute_stops_impl(
&mut portfolio,
stops,
&mut stopped_stocks,
&today_prices,
&today_factor,
config,
last_date,
Some(last_date),
tracker,
);
} else {
execute_pending_stops_impl(
&mut portfolio,
&mut pending_stop_exits,
&mut stopped_stocks,
&today_prices,
&today_factor,
config,
last_date,
tracker,
);
}
update_previous_prices(&mut portfolio, &today_prices);
if let Some(target_weights) = pending_weights.take() {
let sig_date = pending_signal_date.take().unwrap_or(last_date);
execute_rebalance_impl(
&mut portfolio,
&target_weights,
&today_prices,
&today_factor,
&stopped_stocks,
config,
last_date,
sig_date,
tracker,
);
}
if !has_first_signal && !today_weights.is_empty() {
has_first_signal = true;
}
if has_first_signal {
dates.push(last_date);
creturn.push(portfolio.balance());
}
}
}
(dates, creturn)
}
pub fn backtest_with_accessor<'a, FD, FS, FP, FW, FF, FO, FH, FL>(
n_rows: usize,
get_date: FD,
get_symbol: FS,
get_price: FP,
get_weight: FW,
get_factor: FF,
ohlc_accessors: Option<(FO, FH, FL)>,
resample: ResampleFreq,
offset: Option<ResampleOffset>,
config: &BacktestConfig,
) -> BacktestResult
where
FD: Fn(usize) -> i32,
FS: Fn(usize) -> &'a str,
FP: Fn(usize) -> f64,
FW: Fn(usize) -> f64,
FF: Fn(usize) -> f64,
FO: Fn(usize) -> f64,
FH: Fn(usize) -> f64,
FL: Fn(usize) -> f64,
{
let mut tracker = NoopSymbolTracker::default();
let (dates, creturn) = if config.touched_exit && ohlc_accessors.is_some() {
let (get_open, get_high, get_low) = ohlc_accessors.unwrap();
let ohlc = Some(OhlcGetters {
get_open,
get_high,
get_low,
});
backtest_impl(
n_rows,
get_date,
get_symbol,
get_price,
get_weight,
get_factor,
resample,
offset,
config,
&mut tracker,
ohlc,
)
} else {
let ohlc: Option<OhlcGetters<fn(usize) -> f64, fn(usize) -> f64, fn(usize) -> f64>> = None;
backtest_impl(
n_rows,
get_date,
get_symbol,
get_price,
get_weight,
get_factor,
resample,
offset,
config,
&mut tracker,
ohlc,
)
};
BacktestResult {
dates,
creturn,
trades: vec![],
}
}
pub fn backtest_long_arrow(
input: &LongFormatArrowInput,
resample: ResampleFreq,
offset: Option<ResampleOffset>,
config: &BacktestConfig,
) -> BacktestResult {
let ohlc_accessors = if input.open_prices.is_some()
&& input.high_prices.is_some()
&& input.low_prices.is_some()
{
let open_arr = input.open_prices.unwrap();
let high_arr = input.high_prices.unwrap();
let low_arr = input.low_prices.unwrap();
Some((
move |i: usize| open_arr.value(i),
move |i: usize| high_arr.value(i),
move |i: usize| low_arr.value(i),
))
} else {
None
};
let get_factor: Box<dyn Fn(usize) -> f64> = if let Some(factor_arr) = input.factor {
Box::new(move |i: usize| factor_arr.value(i))
} else {
Box::new(|_: usize| 1.0)
};
backtest_with_accessor(
input.dates.len(),
|i| input.dates.value(i),
|i| input.symbols.value(i),
|i| input.prices.value(i),
|i| input.weights.value(i),
|i| get_factor(i),
ohlc_accessors,
resample,
offset,
config,
)
}
pub fn backtest_long_slice(
dates: &[i32],
symbols: &[&str],
prices: &[f64],
weights: &[f64],
factor: Option<&[f64]>,
open_prices: Option<&[f64]>,
high_prices: Option<&[f64]>,
low_prices: Option<&[f64]>,
resample: ResampleFreq,
offset: Option<ResampleOffset>,
config: &BacktestConfig,
) -> BacktestResult {
let ohlc_accessors =
if open_prices.is_some() && high_prices.is_some() && low_prices.is_some() {
let open = open_prices.unwrap();
let high = high_prices.unwrap();
let low = low_prices.unwrap();
Some((
move |i: usize| open[i],
move |i: usize| high[i],
move |i: usize| low[i],
))
} else {
None
};
let get_factor: Box<dyn Fn(usize) -> f64> = if let Some(f) = factor {
Box::new(move |i: usize| f[i])
} else {
Box::new(|_: usize| 1.0)
};
backtest_with_accessor(
dates.len(),
|i| dates[i],
|i| symbols[i],
|i| prices[i],
|i| weights[i],
|i| get_factor(i),
ohlc_accessors,
resample,
offset,
config,
)
}
pub fn backtest_with_report_long_arrow(
input: &LongFormatArrowInput,
resample: ResampleFreq,
offset: Option<ResampleOffset>,
config: &BacktestConfig,
) -> BacktestResult {
let mut tracker = SymbolTracker::new();
let get_factor: Box<dyn Fn(usize) -> f64> = if let Some(factor_arr) = input.factor {
Box::new(move |i: usize| factor_arr.value(i))
} else {
Box::new(|_: usize| 1.0)
};
let (dates, creturn) = if config.touched_exit
&& input.open_prices.is_some()
&& input.high_prices.is_some()
&& input.low_prices.is_some()
{
let open_arr = input.open_prices.unwrap();
let high_arr = input.high_prices.unwrap();
let low_arr = input.low_prices.unwrap();
let ohlc = Some(OhlcGetters {
get_open: |i: usize| open_arr.value(i),
get_high: |i: usize| high_arr.value(i),
get_low: |i: usize| low_arr.value(i),
});
backtest_impl(
input.dates.len(),
|i| input.dates.value(i),
|i| input.symbols.value(i),
|i| input.prices.value(i),
|i| input.weights.value(i),
|i| get_factor(i),
resample,
offset,
config,
&mut tracker,
ohlc,
)
} else {
let ohlc: Option<OhlcGetters<fn(usize) -> f64, fn(usize) -> f64, fn(usize) -> f64>> = None;
backtest_impl(
input.dates.len(),
|i| input.dates.value(i),
|i| input.symbols.value(i),
|i| input.prices.value(i),
|i| input.weights.value(i),
|i| get_factor(i),
resample,
offset,
config,
&mut tracker,
ohlc,
)
};
let trades = tracker.finalize(config.fee_ratio, config.tax_ratio);
BacktestResult { dates, creturn, trades }
}
fn is_month_end_i32(prev_days: i32, next_days: i32) -> bool {
let prev_ym = days_to_year_month(prev_days);
let next_ym = days_to_year_month(next_days);
prev_ym != next_ym
}
fn is_week_end_i32(prev_days: i32, next_days: i32) -> bool {
let prev_week = (prev_days + 3) / 7; let next_week = (next_days + 3) / 7;
prev_week != next_week
}
fn is_quarter_end_i32(prev_days: i32, next_days: i32) -> bool {
let (prev_year, prev_month) = days_to_year_month(prev_days);
let (next_year, next_month) = days_to_year_month(next_days);
let prev_quarter = (prev_year, (prev_month - 1) / 3);
let next_quarter = (next_year, (next_month - 1) / 3);
prev_quarter != next_quarter
}
fn is_year_end_i32(prev_days: i32, next_days: i32) -> bool {
let (prev_year, _) = days_to_year_month(prev_days);
let (next_year, _) = days_to_year_month(next_days);
prev_year != next_year
}
#[inline]
fn weekday_of_i32(days: i32) -> u8 {
((days.rem_euclid(7) + 3) % 7) as u8
}
fn get_all_period_boundaries(
prev_date: i32,
date: i32,
resample: ResampleFreq,
) -> Vec<i32> {
let mut boundaries = Vec::new();
match resample {
ResampleFreq::Weekly | ResampleFreq::WeeklyOn(_) => {
let weekday = weekday_of_i32(prev_date);
let first_sunday = if weekday == 6 {
prev_date } else {
prev_date + (6 - weekday as i32)
};
let mut sunday = first_sunday;
while sunday < date {
boundaries.push(sunday);
sunday += 7;
}
}
ResampleFreq::Monthly | ResampleFreq::MonthStart => {
let (mut year, mut month) = days_to_year_month(prev_date);
loop {
let month_end = last_day_of_month_i32(ymd_to_days(year, month, 1));
if month_end >= prev_date && month_end < date {
boundaries.push(month_end);
}
if month == 12 {
month = 1;
year += 1;
} else {
month += 1;
}
if ymd_to_days(year, month, 1) >= date {
break;
}
}
}
ResampleFreq::Quarterly | ResampleFreq::QuarterStart => {
let (mut year, month) = days_to_year_month(prev_date);
let mut qtr = ((month - 1) / 3) + 1; loop {
let qtr_end_month = qtr * 3;
let qtr_end = last_day_of_quarter_i32(ymd_to_days(year, qtr_end_month, 1));
if qtr_end >= prev_date && qtr_end < date {
boundaries.push(qtr_end);
}
if qtr == 4 {
qtr = 1;
year += 1;
} else {
qtr += 1;
}
if ymd_to_days(year, qtr * 3, 1) >= date {
break;
}
}
}
ResampleFreq::Yearly => {
let (mut year, _) = days_to_year_month(prev_date);
loop {
let year_end = last_day_of_year_i32(ymd_to_days(year, 12, 31));
if year_end >= prev_date && year_end < date {
boundaries.push(year_end);
}
year += 1;
if ymd_to_days(year, 1, 1) >= date {
break;
}
}
}
_ => {
}
}
boundaries
}
#[inline]
fn days_to_year_month(days: i32) -> (i32, u32) {
let z = days + 719468; let era = if z >= 0 { z / 146097 } else { (z - 146096) / 146097 };
let doe = (z - era * 146097) as u32; let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; let y = yoe as i32 + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); let mp = (5 * doy + 2) / 153; let m = if mp < 10 { mp + 3 } else { mp - 9 }; let year = if m <= 2 { y + 1 } else { y };
(year, m)
}
#[inline]
fn ymd_to_days(year: i32, month: u32, day: u32) -> i32 {
let y = if month <= 2 { year - 1 } else { year };
let m = if month <= 2 { month + 12 } else { month };
let era = if y >= 0 { y / 400 } else { (y - 399) / 400 };
let yoe = (y - era * 400) as u32;
let doy = (153 * (m - 3) + 2) / 5 + day - 1;
let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
era * 146097 + doe as i32 - 719468
}
#[inline]
fn days_in_month(year: i32, month: u32) -> u32 {
match month {
1 | 3 | 5 | 7 | 8 | 10 | 12 => 31,
4 | 6 | 9 | 11 => 30,
2 => {
if (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0) {
29
} else {
28
}
}
_ => 30, }
}
#[inline]
fn last_day_of_month_i32(days: i32) -> i32 {
let (year, month) = days_to_year_month(days);
let last_day = days_in_month(year, month);
ymd_to_days(year, month, last_day)
}
#[inline]
fn last_day_of_quarter_i32(days: i32) -> i32 {
let (year, month) = days_to_year_month(days);
let quarter_end_month = ((month - 1) / 3 + 1) * 3;
let last_day = days_in_month(year, quarter_end_month);
ymd_to_days(year, quarter_end_month, last_day)
}
#[inline]
fn last_day_of_year_i32(days: i32) -> i32 {
let (year, _) = days_to_year_month(days);
ymd_to_days(year, 12, 31)
}
fn update_positions(portfolio: &mut Portfolio, prices: &HashMap<&str, f64>) {
for (sym, pos) in portfolio.positions.iter_mut() {
if let Some(&curr_price) = prices.get(sym.as_str()) {
pos.update_with_return(curr_price);
}
}
}
fn update_previous_prices(portfolio: &mut Portfolio, prices: &HashMap<&str, f64>) {
for (sym, pos) in portfolio.positions.iter_mut() {
if let Some(&curr_price) = prices.get(sym.as_str()) {
if is_valid_price(curr_price) {
pos.previous_price = curr_price;
}
}
}
}
fn execute_pending_stops_impl<T>(
portfolio: &mut Portfolio,
pending_stops: &mut Vec<String>,
stopped_stocks: &mut HashMap<String, bool>,
today_prices: &HashMap<&str, f64>,
today_factor: &HashMap<&str, f64>,
config: &BacktestConfig,
current_date: i32,
tracker: &mut T,
)
where
T: TradeTracker<Key = String, Date = i32, Record = TradeRecord>,
{
for sym in pending_stops.drain(..) {
if let Some(pos) = portfolio.positions.remove(&sym) {
let exit_price = today_prices.get(sym.as_str()).copied().unwrap_or(pos.previous_price);
let exit_factor = today_factor.get(sym.as_str()).copied().unwrap_or(1.0);
tracker.close_trade(
&sym,
current_date,
None, exit_price,
exit_factor,
config.fee_ratio,
config.tax_ratio,
);
let sell_value =
pos.last_market_value - pos.last_market_value.abs() * (config.fee_ratio + config.tax_ratio);
portfolio.cash += sell_value;
if config.stop_trading_next_period {
stopped_stocks.insert(sym, true);
}
}
}
}
#[derive(Debug, Clone)]
pub struct StopResult {
pub symbol: String,
pub exit_ratio: f64,
}
fn detect_stops_unified(
portfolio: &Portfolio,
close_prices: &HashMap<&str, f64>,
ohlc: Option<(&HashMap<&str, f64>, &HashMap<&str, f64>, &HashMap<&str, f64>)>,
config: &BacktestConfig,
) -> Vec<StopResult> {
let mut results = Vec::new();
for (sym, pos) in portfolio.positions.iter() {
let sym_str = sym.as_str();
let close_price = close_prices.get(sym_str).copied().unwrap_or(0.0);
if !is_valid_price(close_price) {
continue;
}
let is_long = pos.last_market_value >= 0.0;
let cr = pos.cr;
let maxcr = pos.maxcr;
let (min_r, max_r) = if is_long {
let stop_threshold = 1.0 - config.stop_loss;
let trail_threshold = if config.trail_stop < f64::INFINITY {
maxcr - config.trail_stop
} else {
f64::NEG_INFINITY
};
(stop_threshold.max(trail_threshold), 1.0 + config.take_profit)
} else {
let stop_threshold = 1.0 + config.stop_loss;
let trail_threshold = if config.trail_stop < f64::INFINITY {
maxcr + config.trail_stop
} else {
f64::INFINITY
};
(1.0 - config.take_profit, stop_threshold.min(trail_threshold))
};
if let Some((open_prices, high_prices, low_prices)) = ohlc {
let open_price = open_prices.get(sym_str).copied().unwrap_or(f64::NAN);
let high_price = high_prices.get(sym_str).copied().unwrap_or(f64::NAN);
let low_price = low_prices.get(sym_str).copied().unwrap_or(f64::NAN);
let prev_price = pos.previous_price;
if open_price.is_nan() || high_price.is_nan() || low_price.is_nan()
|| prev_price <= 0.0 || cr.is_nan() || cr <= 0.0
{
continue;
}
let r = close_price / prev_price;
if r.is_nan() || r <= 0.0 {
continue;
}
let open_r = cr / r * (open_price / prev_price);
let high_r = cr / r * (high_price / prev_price);
let low_r = cr / r * (low_price / prev_price);
let touch_open = open_r >= max_r || open_r <= min_r;
let touch_high = high_r >= max_r;
let touch_low = low_r <= min_r;
if touch_open {
results.push(StopResult { symbol: sym.clone(), exit_ratio: open_r / r });
} else if touch_high {
results.push(StopResult { symbol: sym.clone(), exit_ratio: max_r / cr });
} else if touch_low {
results.push(StopResult { symbol: sym.clone(), exit_ratio: min_r / cr });
}
} else {
let cr_at_close = cr * close_price / close_price;
if is_long {
if config.take_profit < f64::INFINITY && cr_at_close >= max_r {
results.push(StopResult { symbol: sym.clone(), exit_ratio: 1.0 });
continue;
}
if cr_at_close < min_r {
results.push(StopResult { symbol: sym.clone(), exit_ratio: 1.0 });
}
} else {
if cr_at_close >= max_r {
results.push(StopResult { symbol: sym.clone(), exit_ratio: 1.0 });
continue;
}
if config.take_profit < f64::INFINITY && cr_at_close < min_r {
results.push(StopResult { symbol: sym.clone(), exit_ratio: 1.0 });
}
}
}
}
results
}
fn execute_stops_impl<T>(
portfolio: &mut Portfolio,
stops: Vec<StopResult>,
stopped_stocks: &mut HashMap<String, bool>,
today_prices: &HashMap<&str, f64>,
today_factor: &HashMap<&str, f64>,
config: &BacktestConfig,
current_date: i32,
exit_sig_date: Option<i32>, tracker: &mut T,
)
where
T: TradeTracker<Key = String, Date = i32, Record = TradeRecord>,
{
for stop in stops {
if let Some(pos) = portfolio.positions.remove(&stop.symbol) {
let exit_value = pos.last_market_value * stop.exit_ratio;
let close_price = today_prices.get(stop.symbol.as_str()).copied().unwrap_or(pos.previous_price);
let exit_price = close_price * stop.exit_ratio;
let exit_factor = today_factor.get(stop.symbol.as_str()).copied().unwrap_or(1.0);
tracker.close_trade(
&stop.symbol,
current_date,
exit_sig_date,
exit_price,
exit_factor,
config.fee_ratio,
config.tax_ratio,
);
let sell_value = exit_value - exit_value.abs() * (config.fee_ratio + config.tax_ratio);
portfolio.cash += sell_value;
if config.stop_trading_next_period {
stopped_stocks.insert(stop.symbol, true);
}
}
}
}
fn execute_rebalance_impl<T>(
portfolio: &mut Portfolio,
target_weights: &HashMap<String, f64>,
today_prices: &HashMap<&str, f64>,
today_factor: &HashMap<&str, f64>,
stopped_stocks: &HashMap<String, bool>,
config: &BacktestConfig,
current_date: i32,
signal_date: i32,
tracker: &mut T,
)
where
T: TradeTracker<Key = String, Date = i32, Record = TradeRecord>,
{
for (_sym, pos) in portfolio.positions.iter_mut() {
pos.value = pos.last_market_value;
}
let open_positions: Vec<String> = portfolio.positions.keys().cloned().collect();
for sym in &open_positions {
let exit_price = today_prices.get(sym.as_str()).copied().unwrap_or(f64::NAN);
let exit_factor = today_factor.get(sym.as_str()).copied().unwrap_or(1.0);
tracker.close_trade(
sym,
current_date,
Some(signal_date),
exit_price,
exit_factor,
config.fee_ratio,
config.tax_ratio,
);
}
let balance = portfolio.balance();
let (effective_weights, total_target_weight) = if config.stop_trading_next_period {
let original_sum: f64 = target_weights.values().map(|w| w.abs()).sum();
let filtered: HashMap<String, f64> = target_weights
.iter()
.filter(|(sym, _)| !stopped_stocks.get(*sym).copied().unwrap_or(false))
.map(|(k, &v)| (k.clone(), v))
.collect();
let remaining_sum: f64 = filtered.values().map(|w| w.abs()).sum();
if remaining_sum > 0.0 && remaining_sum < original_sum {
let scale_factor = original_sum / remaining_sum;
let scaled: HashMap<String, f64> = filtered
.into_iter()
.map(|(k, v)| (k, v * scale_factor))
.collect();
let new_sum: f64 = scaled.values().map(|w| w.abs()).sum();
(scaled, new_sum)
} else {
(filtered, remaining_sum)
}
} else {
(target_weights.clone(), target_weights.values().map(|w| w.abs()).sum())
};
if total_target_weight == 0.0 || balance <= 0.0 {
let all_positions: Vec<String> = portfolio.positions.keys().cloned().collect();
for sym in all_positions {
if let Some(pos) = portfolio.positions.remove(&sym) {
let sell_value = pos.value - pos.value.abs() * (config.fee_ratio + config.tax_ratio);
portfolio.cash += sell_value;
}
}
return;
}
let ratio = balance / total_target_weight.max(1.0);
let old_snapshots: HashMap<String, PositionSnapshot> = portfolio
.positions
.iter()
.map(|(k, v)| (k.clone(), PositionSnapshot::from(v)))
.collect();
portfolio.positions.clear();
let mut cash = portfolio.cash;
for (sym, &target_weight) in &effective_weights {
let price_opt = today_prices.get(sym.as_str()).copied();
let price_valid = price_opt.map_or(false, |p| is_valid_price(p));
let target_value = target_weight * ratio;
let snapshot = old_snapshots.get(sym);
let current_value = snapshot.map(|s| s.cost_basis).unwrap_or(0.0);
if !price_valid {
if target_weight.abs() < FLOAT_EPSILON {
if let Some(snap) = snapshot {
if snap.market_value.abs() > FLOAT_EPSILON {
let sell_fee = snap.market_value.abs() * (config.fee_ratio + config.tax_ratio);
cash += snap.market_value - sell_fee;
}
}
continue;
}
if target_value.abs() > FLOAT_EPSILON {
let amount = target_value - current_value;
let is_buy = amount > 0.0;
let is_entry =
(target_value >= 0.0 && amount > 0.0) || (target_value <= 0.0 && amount < 0.0);
let cost = if is_entry {
amount.abs() * config.fee_ratio
} else {
amount.abs() * (config.fee_ratio + config.tax_ratio)
};
let new_value = if is_buy {
cash -= amount;
current_value + amount - cost
} else {
let sell_amount = amount.abs();
cash += sell_amount - cost;
current_value - sell_amount
};
if new_value.abs() > FLOAT_EPSILON {
let entry_factor = today_factor.get(sym.as_str()).copied().unwrap_or(1.0);
tracker.open_trade(sym.clone(), current_date, signal_date, f64::NAN, target_weight, entry_factor);
portfolio.positions.insert(
sym.clone(),
Position::new_with_nan_price(new_value),
);
}
}
continue;
}
let price = price_opt.unwrap();
if target_weight.abs() < FLOAT_EPSILON {
if current_value.abs() > FLOAT_EPSILON {
let sell_fee = current_value.abs() * (config.fee_ratio + config.tax_ratio);
cash += current_value - sell_fee;
}
continue;
}
let amount = target_value - current_value;
let is_buy = amount > 0.0;
let is_entry = (target_value >= 0.0 && amount > 0.0) || (target_value <= 0.0 && amount < 0.0);
let cost = if is_entry {
amount.abs() * config.fee_ratio
} else {
amount.abs() * (config.fee_ratio + config.tax_ratio)
};
let new_position_value = if is_buy {
cash -= amount;
current_value + amount - cost
} else {
let sell_amount = amount.abs();
cash += sell_amount - cost;
current_value - sell_amount
};
if new_position_value.abs() > FLOAT_EPSILON {
let entry_factor = today_factor.get(sym.as_str()).copied().unwrap_or(1.0);
tracker.open_trade(sym.clone(), current_date, signal_date, price, target_weight, entry_factor);
let is_continuing = current_value.abs() > FLOAT_EPSILON && current_value * target_weight > 0.0;
let new_pos = if config.retain_cost_when_rebalance && is_continuing {
let snap = snapshot.unwrap(); Position::new_from_snapshot(new_position_value, price, snap)
} else {
Position::new(new_position_value, price)
};
portfolio.positions.insert(sym.clone(), new_pos);
}
}
for (sym, snap) in old_snapshots.iter() {
if !effective_weights.contains_key(sym) && snap.cost_basis.abs() > FLOAT_EPSILON {
let sell_fee = snap.cost_basis.abs() * (config.fee_ratio + config.tax_ratio);
cash += snap.cost_basis - sell_fee;
}
}
portfolio.cash = cash;
}
fn weights_differ(a: &HashMap<String, f64>, b: &HashMap<String, f64>) -> bool {
if a.len() != b.len() {
return true;
}
for (sym, &weight_a) in a.iter() {
match b.get(sym) {
Some(&weight_b) => {
if (weight_a - weight_b).abs() > FLOAT_EPSILON {
return true;
}
}
None => return true, }
}
false
}
fn normalize_weights(
weights: &HashMap<&str, f64>,
stopped_stocks: &HashMap<String, bool>,
position_limit: f64,
) -> HashMap<String, f64> {
let filtered: Vec<(&str, f64)> = weights
.iter()
.filter(|(sym, w)| {
let is_stopped = stopped_stocks.get::<str>(*sym).copied().unwrap_or(false);
w.abs() > FLOAT_EPSILON && !is_stopped
})
.map(|(&sym, &w)| (sym, w))
.collect();
if filtered.is_empty() {
return HashMap::new();
}
let abs_sum: f64 = filtered.iter().map(|(_, w)| w.abs()).sum();
let divisor = abs_sum.max(1.0);
filtered
.into_iter()
.map(|(sym, w)| {
let norm_w = w / divisor;
let clipped = norm_w.clamp(-position_limit, position_limit);
(sym.to_string(), clipped)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::StringViewBuilder;
fn date_to_days(s: &str) -> i32 {
let parts: Vec<&str> = s.split('-').collect();
let year: i32 = parts[0].parse().unwrap();
let month: u32 = parts[1].parse().unwrap();
let day: u32 = parts[2].parse().unwrap();
let days_per_year = 365;
let mut days = (year - 1970) * days_per_year;
days += ((year - 1970 + 1) / 4) as i32; let days_per_month = [0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334];
days += days_per_month[(month - 1) as usize] as i32;
if month > 2 && year % 4 == 0 {
days += 1;
}
days += day as i32 - 1;
days
}
fn make_symbols(strs: Vec<&str>) -> StringViewArray {
let mut builder = StringViewBuilder::new();
for s in strs {
builder.append_value(s);
}
builder.finish()
}
fn make_input<'a>(
dates: &'a Int32Array,
symbols: &'a StringViewArray,
prices: &'a Float64Array,
weights: &'a Float64Array,
) -> LongFormatArrowInput<'a> {
LongFormatArrowInput {
dates,
symbols,
prices,
weights,
open_prices: None,
high_prices: None,
low_prices: None,
factor: None,
}
}
#[test]
fn test_backtest_empty() {
let dates = Int32Array::from(Vec::<i32>::new());
let symbols = make_symbols(vec![]);
let prices = Float64Array::from(Vec::<f64>::new());
let weights = Float64Array::from(Vec::<f64>::new());
let input = make_input(&dates, &symbols, &prices, &weights);
let result = backtest_long_arrow(&input, ResampleFreq::Daily, None, &BacktestConfig::default());
assert!(result.creturn.is_empty());
}
#[test]
fn test_backtest_single_stock() {
let dates = Int32Array::from(vec![
date_to_days("2024-01-01"),
date_to_days("2024-01-02"),
date_to_days("2024-01-03"),
date_to_days("2024-01-04"),
]);
let symbols = make_symbols(vec!["AAPL", "AAPL", "AAPL", "AAPL"]);
let prices = Float64Array::from(vec![100.0, 100.0, 110.0, 121.0]);
let weights = Float64Array::from(vec![1.0, 1.0, 1.0, 1.0]);
let input = make_input(&dates, &symbols, &prices, &weights);
let config = BacktestConfig {
fee_ratio: 0.0,
tax_ratio: 0.0,
finlab_mode: true,
..Default::default()
};
let result = backtest_long_arrow(&input, ResampleFreq::Daily, None, &config);
assert_eq!(result.creturn.len(), 4);
assert!((result.creturn[0] - 1.0).abs() < FLOAT_EPSILON, "Day 0: {}", result.creturn[0]);
assert!((result.creturn[1] - 1.0).abs() < FLOAT_EPSILON, "Day 1: {}", result.creturn[1]);
assert!((result.creturn[2] - 1.1).abs() < FLOAT_EPSILON, "Day 2: {}", result.creturn[2]);
assert!((result.creturn[3] - 1.21).abs() < FLOAT_EPSILON, "Day 3: {}", result.creturn[3]);
}
#[test]
fn test_backtest_two_stocks() {
let d1 = date_to_days("2024-01-01");
let d2 = date_to_days("2024-01-02");
let d3 = date_to_days("2024-01-03");
let dates = Int32Array::from(vec![d1, d1, d2, d2, d3, d3]);
let symbols = make_symbols(vec!["AAPL", "GOOG", "AAPL", "GOOG", "AAPL", "GOOG"]);
let prices = Float64Array::from(vec![100.0, 100.0, 100.0, 100.0, 110.0, 90.0]);
let weights = Float64Array::from(vec![0.5, 0.5, 0.0, 0.0, 0.0, 0.0]);
let input = make_input(&dates, &symbols, &prices, &weights);
let config = BacktestConfig {
fee_ratio: 0.0,
tax_ratio: 0.0,
finlab_mode: true,
..Default::default()
};
let result = backtest_long_arrow(&input, ResampleFreq::Daily, None, &config);
assert_eq!(result.creturn.len(), 3);
assert!((result.creturn[0] - 1.0).abs() < FLOAT_EPSILON);
assert!((result.creturn[1] - 1.0).abs() < FLOAT_EPSILON);
assert!((result.creturn[2] - 1.0).abs() < FLOAT_EPSILON);
}
#[test]
fn test_backtest_with_fees() {
let dates = Int32Array::from(vec![
date_to_days("2024-01-01"),
date_to_days("2024-01-02"),
date_to_days("2024-01-03"),
]);
let symbols = make_symbols(vec!["AAPL", "AAPL", "AAPL"]);
let prices = Float64Array::from(vec![100.0, 100.0, 100.0]);
let weights = Float64Array::from(vec![1.0, 0.0, 0.0]);
let input = make_input(&dates, &symbols, &prices, &weights);
let config = BacktestConfig {
fee_ratio: 0.01,
tax_ratio: 0.0,
finlab_mode: true,
..Default::default()
};
let result = backtest_long_arrow(&input, ResampleFreq::Daily, None, &config);
assert!((result.creturn[1] - 0.99).abs() < 1e-6, "Day 1: {}", result.creturn[1]);
}
#[test]
fn test_monthly_rebalance() {
let dates = Int32Array::from(vec![
date_to_days("2024-01-30"),
date_to_days("2024-01-31"),
date_to_days("2024-02-01"),
date_to_days("2024-02-02"),
]);
let symbols = make_symbols(vec!["AAPL", "AAPL", "AAPL", "AAPL"]);
let prices = Float64Array::from(vec![100.0, 100.0, 100.0, 110.0]);
let weights = Float64Array::from(vec![0.0, 1.0, 0.0, 0.0]);
let input = make_input(&dates, &symbols, &prices, &weights);
let config = BacktestConfig {
fee_ratio: 0.0,
tax_ratio: 0.0,
finlab_mode: true,
..Default::default()
};
let result = backtest_long_arrow(&input, ResampleFreq::Monthly, None, &config);
assert_eq!(result.creturn.len(), 3, "Expected 3 days from signal day onward");
assert!((result.creturn[0] - 1.0).abs() < FLOAT_EPSILON, "Jan 31 (signal): {}", result.creturn[0]);
assert!((result.creturn[1] - 1.0).abs() < FLOAT_EPSILON, "Feb 1 (entry): {}", result.creturn[1]);
assert!((result.creturn[2] - 1.1).abs() < FLOAT_EPSILON, "Feb 2 (+10%): {}", result.creturn[2]);
}
#[test]
fn test_slice_interface() {
let dates = [
date_to_days("2024-01-01"),
date_to_days("2024-01-02"),
date_to_days("2024-01-03"),
];
let symbols = ["AAPL", "AAPL", "AAPL"];
let prices = [100.0, 100.0, 110.0];
let weights = [1.0, 0.0, 0.0];
let config = BacktestConfig {
fee_ratio: 0.0,
tax_ratio: 0.0,
finlab_mode: true,
..Default::default()
};
let result = backtest_long_slice(&dates, &symbols, &prices, &weights, None, None, None, None, ResampleFreq::Daily, None, &config);
assert_eq!(result.creturn.len(), 3);
assert!((result.creturn[2] - 1.1).abs() < FLOAT_EPSILON);
}
}