use crate::config::BacktestConfig;
use crate::is_valid_price;
use crate::position::Position;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct TouchedExitResult {
pub stock_id: usize,
pub exit_ratio: f64,
pub is_take_profit: bool,
}
pub fn detect_stops(
positions: &HashMap<usize, Position>,
prices: &[f64],
config: &BacktestConfig,
) -> Vec<usize> {
positions
.iter()
.filter_map(|(&stock_id, pos)| {
if stock_id >= prices.len() {
return None;
}
let current_price = prices[stock_id];
let stop_entry = pos.stop_entry_price;
if stop_entry <= 0.0 {
return None;
}
let return_since_entry = (current_price - stop_entry) / stop_entry;
if config.stop_loss < 1.0 && return_since_entry <= -config.stop_loss {
return Some(stock_id);
}
if config.take_profit < f64::INFINITY && return_since_entry >= config.take_profit {
return Some(stock_id);
}
if config.trail_stop < f64::INFINITY {
let drawdown_from_entry = (pos.max_price - current_price) / stop_entry;
if drawdown_from_entry >= config.trail_stop {
return Some(stock_id);
}
}
None
})
.collect()
}
pub fn detect_stops_finlab(
positions: &HashMap<usize, Position>,
prices: &[f64],
config: &BacktestConfig,
) -> Vec<usize> {
positions
.iter()
.filter_map(|(&stock_id, pos)| {
if stock_id >= prices.len() {
return None;
}
let current_price = prices[stock_id];
let stop_entry = pos.stop_entry_price;
if stop_entry <= 0.0 || !is_valid_price(current_price) {
return None;
}
let cr = pos.cr;
let cr_at_close = cr * current_price / current_price;
let maxcr = pos.maxcr;
let is_long = pos.last_market_value >= 0.0;
if is_long {
if config.take_profit < f64::INFINITY && cr_at_close >= 1.0 + config.take_profit {
return Some(stock_id);
}
let stop_threshold = 1.0 - config.stop_loss;
let trail_threshold = if config.trail_stop < f64::INFINITY {
maxcr - config.trail_stop
} else {
f64::NEG_INFINITY
};
let min_r = stop_threshold.max(trail_threshold);
if cr_at_close < min_r {
return Some(stock_id);
}
} else {
let stop_threshold = 1.0 + config.stop_loss;
let trail_threshold = if config.trail_stop < f64::INFINITY {
maxcr + config.trail_stop
} else {
f64::INFINITY
};
let max_r = stop_threshold.min(trail_threshold);
if cr_at_close >= max_r {
return Some(stock_id);
}
let min_r = 1.0 - config.take_profit;
if config.take_profit < f64::INFINITY && cr_at_close < min_r {
return Some(stock_id);
}
}
None
})
.collect()
}
pub fn detect_touched_exit(
positions: &HashMap<usize, Position>,
open_prices: &[f64],
high_prices: &[f64],
low_prices: &[f64],
close_prices: &[f64],
_prev_prices: &[f64], config: &BacktestConfig,
) -> Vec<TouchedExitResult> {
positions
.iter()
.filter_map(|(&stock_id, pos)| {
if stock_id >= open_prices.len()
|| stock_id >= high_prices.len()
|| stock_id >= low_prices.len()
|| stock_id >= close_prices.len()
{
return None;
}
let open_price = open_prices[stock_id];
let high_price = high_prices[stock_id];
let low_price = low_prices[stock_id];
let close_price = close_prices[stock_id];
let prev_price = pos.previous_price;
if open_price.is_nan()
|| high_price.is_nan()
|| low_price.is_nan()
|| close_price.is_nan()
|| close_price <= 0.0
|| prev_price <= 0.0
{
return None;
}
let cr_new = pos.cr;
let maxcr = pos.maxcr;
if cr_new.is_nan() || cr_new <= 0.0 {
return None;
}
let r = close_price / prev_price;
let open_r = cr_new / r * (open_price / prev_price);
let high_r = cr_new / r * (high_price / prev_price);
let low_r = cr_new / r * (low_price / prev_price);
let is_long = pos.last_market_value > 0.0;
let (max_r, min_r) = if is_long {
let max_r = 1.0 + config.take_profit;
let stop_threshold = 1.0 - config.stop_loss;
let trail_threshold = if config.trail_stop < f64::INFINITY {
maxcr - config.trail_stop
} else {
f64::NEG_INFINITY
};
let min_r = stop_threshold.max(trail_threshold);
(max_r, min_r)
} else {
let stop_threshold = 1.0 + config.stop_loss;
let trail_threshold = if config.trail_stop < f64::INFINITY {
maxcr + config.trail_stop
} else {
f64::INFINITY
};
let max_r = stop_threshold.min(trail_threshold);
let min_r = 1.0 - config.take_profit;
(max_r, min_r)
};
let touch_open = open_r >= max_r || open_r <= min_r;
let touch_high = high_r >= max_r;
let touch_low = low_r <= min_r;
if touch_open {
let exit_ratio = open_r / r;
let is_take_profit = if is_long {
open_r >= max_r
} else {
open_r <= min_r
};
Some(TouchedExitResult {
stock_id,
exit_ratio,
is_take_profit,
})
} else if touch_high {
let exit_ratio = max_r / cr_new;
let is_take_profit = is_long; Some(TouchedExitResult {
stock_id,
exit_ratio,
is_take_profit,
})
} else if touch_low {
let exit_ratio = min_r / cr_new;
let is_take_profit = !is_long; Some(TouchedExitResult {
stock_id,
exit_ratio,
is_take_profit,
})
} else {
None
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_config(stop_loss: f64, take_profit: f64, trail_stop: f64) -> BacktestConfig {
BacktestConfig {
fee_ratio: 0.001425,
tax_ratio: 0.003,
stop_loss,
take_profit,
trail_stop,
position_limit: 1.0,
retain_cost_when_rebalance: true,
stop_trading_next_period: false,
finlab_mode: false,
touched_exit: false,
}
}
#[test]
fn test_detect_stops_stop_loss() {
let mut positions = HashMap::new();
let mut pos = Position::new(1000.0, 100.0);
pos.stop_entry_price = 100.0;
positions.insert(0, pos);
let prices = vec![89.0]; let config = make_config(0.10, f64::INFINITY, f64::INFINITY);
let stops = detect_stops(&positions, &prices, &config);
assert_eq!(stops, vec![0]);
}
#[test]
fn test_detect_stops_take_profit() {
let mut positions = HashMap::new();
let mut pos = Position::new(1000.0, 100.0);
pos.stop_entry_price = 100.0;
positions.insert(0, pos);
let prices = vec![121.0]; let config = make_config(1.0, 0.20, f64::INFINITY);
let stops = detect_stops(&positions, &prices, &config);
assert_eq!(stops, vec![0]);
}
#[test]
fn test_detect_stops_trailing() {
let mut positions = HashMap::new();
let mut pos = Position::new(1000.0, 100.0);
pos.stop_entry_price = 100.0;
pos.max_price = 120.0; positions.insert(0, pos);
let prices = vec![105.0]; let config = make_config(1.0, f64::INFINITY, 0.10);
let stops = detect_stops(&positions, &prices, &config);
assert_eq!(stops, vec![0]);
}
#[test]
fn test_detect_stops_no_trigger() {
let mut positions = HashMap::new();
let mut pos = Position::new(1000.0, 100.0);
pos.stop_entry_price = 100.0;
positions.insert(0, pos);
let prices = vec![105.0]; let config = make_config(0.10, 0.20, f64::INFINITY);
let stops = detect_stops(&positions, &prices, &config);
assert!(stops.is_empty());
}
#[test]
fn test_detect_stops_finlab_take_profit() {
let mut positions = HashMap::new();
let mut pos = Position::new(1000.0, 100.0);
pos.stop_entry_price = 100.0;
pos.cr = 1.21; pos.maxcr = 1.21;
pos.last_market_value = 1210.0;
positions.insert(0, pos);
let prices = vec![121.0];
let config = make_config(1.0, 0.20, f64::INFINITY);
let stops = detect_stops_finlab(&positions, &prices, &config);
assert_eq!(stops, vec![0]);
}
#[test]
fn test_detect_stops_finlab_stop_loss() {
let mut positions = HashMap::new();
let mut pos = Position::new(1000.0, 100.0);
pos.stop_entry_price = 100.0;
pos.cr = 0.89; pos.maxcr = 1.0;
pos.last_market_value = 890.0;
positions.insert(0, pos);
let prices = vec![89.0];
let config = make_config(0.10, f64::INFINITY, f64::INFINITY);
let stops = detect_stops_finlab(&positions, &prices, &config);
assert_eq!(stops, vec![0]);
}
#[test]
fn test_detect_stops_finlab_trail_stop() {
let mut positions = HashMap::new();
let mut pos = Position::new(1000.0, 100.0);
pos.stop_entry_price = 100.0;
pos.cr = 1.05; pos.maxcr = 1.20; pos.last_market_value = 1050.0;
positions.insert(0, pos);
let prices = vec![105.0];
let config = make_config(1.0, f64::INFINITY, 0.10);
let stops = detect_stops_finlab(&positions, &prices, &config);
assert_eq!(stops, vec![0]);
}
#[test]
fn test_touched_exit_result_fields() {
let result = TouchedExitResult {
stock_id: 5,
exit_ratio: 0.95,
is_take_profit: false,
};
assert_eq!(result.stock_id, 5);
assert!((result.exit_ratio - 0.95).abs() < 1e-10);
assert!(!result.is_take_profit);
}
#[test]
fn test_detect_touched_exit_open_touch() {
let mut positions = HashMap::new();
let mut pos = Position::new(1000.0, 100.0);
pos.stop_entry_price = 100.0;
pos.cr = 0.85; pos.maxcr = 1.0;
pos.last_market_value = 850.0;
pos.previous_price = 100.0;
positions.insert(0, pos);
let open_prices = vec![88.0];
let high_prices = vec![96.0];
let low_prices = vec![87.0];
let close_prices = vec![95.0];
let prev_prices = vec![100.0];
let config = make_config(0.10, f64::INFINITY, f64::INFINITY);
let results = detect_touched_exit(
&positions,
&open_prices,
&high_prices,
&low_prices,
&close_prices,
&prev_prices,
&config,
);
assert_eq!(results.len(), 1);
assert_eq!(results[0].stock_id, 0);
assert!(!results[0].is_take_profit); }
#[test]
fn test_detect_touched_exit_no_touch() {
let mut positions = HashMap::new();
let mut pos = Position::new(1000.0, 100.0);
pos.stop_entry_price = 100.0;
pos.cr = 1.05;
pos.maxcr = 1.05;
pos.last_market_value = 1050.0;
pos.previous_price = 100.0;
positions.insert(0, pos);
let open_prices = vec![102.0];
let high_prices = vec![108.0];
let low_prices = vec![98.0];
let close_prices = vec![105.0];
let prev_prices = vec![100.0];
let config = make_config(0.10, 0.20, f64::INFINITY);
let results = detect_touched_exit(
&positions,
&open_prices,
&high_prices,
&low_prices,
&close_prices,
&prev_prices,
&config,
);
assert!(results.is_empty());
}
}