use crate::FLOAT_EPSILON;
pub trait IntoWeights {
fn into_weights(&self, stopped_stocks: &[bool], position_limit: f64) -> Vec<f64>;
}
impl IntoWeights for Vec<bool> {
fn into_weights(&self, stopped_stocks: &[bool], position_limit: f64) -> Vec<f64> {
calculate_target_weights(self, stopped_stocks, position_limit)
}
}
impl IntoWeights for Vec<f64> {
fn into_weights(&self, stopped_stocks: &[bool], position_limit: f64) -> Vec<f64> {
normalize_weights_finlab(self, stopped_stocks, position_limit)
}
}
pub fn normalize_weights_finlab(
weights: &[f64],
stopped_stocks: &[bool],
position_limit: f64,
) -> Vec<f64> {
let mut result = Vec::with_capacity(weights.len());
let original_abs_weight: f64 = weights.iter().map(|w| w.abs()).sum();
let remaining_abs_weight: f64 = weights
.iter()
.enumerate()
.filter(|(i, _)| !stopped_stocks.get(*i).copied().unwrap_or(false))
.map(|(_, w)| w.abs())
.sum();
if remaining_abs_weight < FLOAT_EPSILON {
return vec![0.0; weights.len()];
}
let scale_factor = if original_abs_weight > FLOAT_EPSILON {
original_abs_weight / remaining_abs_weight
} else {
1.0
};
let divisor = original_abs_weight.max(1.0);
for (i, &w) in weights.iter().enumerate() {
let stopped = stopped_stocks.get(i).copied().unwrap_or(false);
if stopped {
result.push(0.0);
} else {
let scaled = w * scale_factor;
let normalized = scaled / divisor;
let clipped = normalized.clamp(-position_limit, position_limit);
result.push(clipped);
}
}
result
}
pub fn calculate_target_weights(
signals: &[bool],
stopped_stocks: &[bool],
position_limit: f64,
) -> Vec<f64> {
let mut weights = Vec::with_capacity(signals.len());
let active_count: usize = signals
.iter()
.enumerate()
.filter(|(i, &sig)| sig && !stopped_stocks.get(*i).copied().unwrap_or(false))
.count();
if active_count == 0 {
return vec![0.0; signals.len()];
}
let weight = (1.0 / active_count as f64).min(position_limit);
for (i, &sig) in signals.iter().enumerate() {
let stopped = stopped_stocks.get(i).copied().unwrap_or(false);
if sig && !stopped {
weights.push(weight);
} else {
weights.push(0.0);
}
}
let total: f64 = weights.iter().sum();
if total > 0.0 && total < 1.0 {
for w in weights.iter_mut() {
*w /= total;
}
}
apply_position_limit(&mut weights, position_limit);
weights
}
pub fn apply_position_limit(weights: &mut [f64], limit: f64) {
for _ in 0..100 {
let mut needs_cap = false;
for w in weights.iter_mut() {
if *w > limit {
*w = limit;
needs_cap = true;
}
}
if !needs_cap {
break;
}
let total: f64 = weights.iter().sum();
if total > 0.0 {
for w in weights.iter_mut() {
*w /= total;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_weights_finlab_basic() {
let weights = vec![0.4, 0.3, 0.3];
let stopped = vec![false, false, false];
let result = normalize_weights_finlab(&weights, &stopped, 1.0);
assert_eq!(result.len(), 3);
assert!((result[0] - 0.4).abs() < 1e-10);
assert!((result[1] - 0.3).abs() < 1e-10);
assert!((result[2] - 0.3).abs() < 1e-10);
}
#[test]
fn test_normalize_weights_finlab_sum_greater_than_one() {
let weights = vec![0.6, 0.6, 0.4]; let stopped = vec![false, false, false];
let result = normalize_weights_finlab(&weights, &stopped, 1.0);
assert!((result[0] - 0.6 / 1.6).abs() < 1e-10);
assert!((result[1] - 0.6 / 1.6).abs() < 1e-10);
assert!((result[2] - 0.4 / 1.6).abs() < 1e-10);
}
#[test]
fn test_normalize_weights_finlab_sum_less_than_one() {
let weights = vec![0.2, 0.3]; let stopped = vec![false, false];
let result = normalize_weights_finlab(&weights, &stopped, 1.0);
assert!((result[0] - 0.2).abs() < 1e-10);
assert!((result[1] - 0.3).abs() < 1e-10);
}
#[test]
fn test_normalize_weights_finlab_with_position_limit() {
let weights = vec![0.8, 0.4]; let stopped = vec![false, false];
let result = normalize_weights_finlab(&weights, &stopped, 0.5);
assert!((result[0] - 0.5).abs() < 1e-10);
assert!((result[1] - 0.4 / 1.2).abs() < 1e-10);
}
#[test]
fn test_normalize_weights_finlab_with_stopped_stocks() {
let weights = vec![0.5, 0.5, 0.5]; let stopped = vec![false, true, false]; let result = normalize_weights_finlab(&weights, &stopped, 1.0);
assert!((result[0] - 0.5).abs() < 1e-10);
assert!((result[1] - 0.0).abs() < 1e-10); assert!((result[2] - 0.5).abs() < 1e-10);
}
#[test]
fn test_normalize_weights_finlab_with_stopped_rescales() {
let weights = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]; let stopped = vec![false, true, false]; let result = normalize_weights_finlab(&weights, &stopped, 1.0);
assert!((result[0] - 0.5).abs() < 1e-10);
assert!((result[1] - 0.0).abs() < 1e-10); assert!((result[2] - 0.5).abs() < 1e-10);
let total: f64 = result.iter().sum();
assert!((total - 1.0).abs() < 1e-10);
}
#[test]
fn test_normalize_weights_finlab_negative_weights() {
let weights = vec![0.5, -0.3]; let stopped = vec![false, false];
let result = normalize_weights_finlab(&weights, &stopped, 1.0);
assert!((result[0] - 0.5).abs() < 1e-10);
assert!((result[1] - (-0.3)).abs() < 1e-10);
}
#[test]
fn test_calculate_target_weights_basic() {
let signals = vec![true, true, false];
let stopped = vec![false, false, false];
let result = calculate_target_weights(&signals, &stopped, 1.0);
assert_eq!(result.len(), 3);
assert!((result[0] - 0.5).abs() < 1e-10);
assert!((result[1] - 0.5).abs() < 1e-10);
assert!((result[2] - 0.0).abs() < 1e-10);
}
#[test]
fn test_calculate_target_weights_with_stopped() {
let signals = vec![true, true, true];
let stopped = vec![false, true, false]; let result = calculate_target_weights(&signals, &stopped, 1.0);
assert!((result[0] - 0.5).abs() < 1e-10);
assert!((result[1] - 0.0).abs() < 1e-10); assert!((result[2] - 0.5).abs() < 1e-10);
}
#[test]
fn test_apply_position_limit() {
let mut weights = vec![0.6, 0.3, 0.2, 0.1];
apply_position_limit(&mut weights, 0.3);
for w in &weights {
assert!(*w <= 0.3 + 1e-10);
}
let sum: f64 = weights.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
}
}