use polars::prelude::*;
use crate::indicators::moving_averages::calculate_ema;
use crate::indicators::oscillators::calculate_rsi;
mod trend_strength;
pub(crate) mod swing_detection;
mod multi_timeframe;
mod mean_reversion;
mod support_resistance;
pub use trend_strength::*;
pub use swing_detection::*;
pub use multi_timeframe::*;
pub use mean_reversion::*;
pub use support_resistance::*;
pub fn add_short_term_indicators(df: &DataFrame) -> PolarsResult<DataFrame> {
let mut result = df.clone();
trend_strength::add_trend_strength_analysis(&mut result, 14)?;
swing_detection::add_swing_analysis(&mut result)?;
multi_timeframe::add_multi_timeframe_analysis(&mut result, None, None)?;
mean_reversion::add_mean_reversion_analysis(&mut result)?;
support_resistance::add_support_resistance_analysis(&mut result)?;
Ok(result)
}
pub fn generate_swing_trading_signals(df: &DataFrame) -> PolarsResult<Series> {
let required_indicators = [
"trend_strength", "trend_classification", "swing_signal",
"multi_timeframe_alignment", "mean_reversion_signal"
];
for indicator in required_indicators {
if !df.schema().contains(indicator) {
return Err(PolarsError::ComputeError(
format!("Required indicator '{}' not found", indicator).into(),
));
}
}
let trend_strength = df.column("trend_strength")?.f64()?;
let trend_class = df.column("trend_classification")?.i32()?;
let swing_signal = df.column("swing_signal")?.i32()?;
let mtf_alignment = df.column("multi_timeframe_alignment")?.i32()?;
let mean_rev_signal = df.column("mean_reversion_signal")?.i32()?;
let has_risk_reward = df.schema().contains("risk_reward_ratio");
let risk_reward = if has_risk_reward {
Some(df.column("risk_reward_ratio")?.f64()?)
} else {
None
};
let mut combined_signals = Vec::with_capacity(df.height());
for i in 0..df.height() {
let strength = trend_strength.get(i).unwrap_or(f64::NAN);
let trend = trend_class.get(i).unwrap_or(0);
let swing = swing_signal.get(i).unwrap_or(0);
let alignment = mtf_alignment.get(i).unwrap_or(0);
let mean_rev = mean_rev_signal.get(i).unwrap_or(0);
if strength.is_nan() {
combined_signals.push(0);
continue;
}
let mut bullish_count = 0;
let mut bearish_count = 0;
if trend > 0 { bullish_count += 2; }
if trend < 0 { bearish_count += 2; }
if swing > 0 { bullish_count += 1; }
if swing < 0 { bearish_count += 1; }
if alignment > 0 { bullish_count += 1; }
if alignment < 0 { bearish_count += 1; }
if mean_rev > 0 { bullish_count += 1; }
if mean_rev < 0 { bearish_count += 1; }
if let Some(rr) = &risk_reward {
let rr_val = rr.get(i).unwrap_or(f64::NAN);
if !rr_val.is_nan() {
if rr_val >= 2.0 && bullish_count > bearish_count {
bullish_count += 1; } else if rr_val <= 0.5 && bearish_count > bullish_count {
bearish_count += 1; }
}
}
if bullish_count >= 3 && bearish_count == 0 {
combined_signals.push(2); } else if bullish_count > bearish_count {
combined_signals.push(1); } else if bearish_count >= 3 && bullish_count == 0 {
combined_signals.push(-2); } else if bearish_count > bullish_count {
combined_signals.push(-1); } else {
combined_signals.push(0); }
}
Ok(Series::new("swing_trading_signal".into(), combined_signals))
}
pub fn calculate_position_sizing(
df: &DataFrame,
risk_percentage: Option<f64>,
) -> PolarsResult<Series> {
let max_risk = risk_percentage.unwrap_or(2.0) / 100.0;
if !df.schema().contains("swing_risk_level") {
return Err(PolarsError::ComputeError(
"swing_risk_level not found. Calculate swing analysis first.".into(),
));
}
let signals = match generate_swing_trading_signals(df) {
Ok(s) => s,
Err(_) => {
if !df.schema().contains("swing_signal".into()) {
return Err(PolarsError::ComputeError(
"No trading signals found for position sizing".into(),
));
}
df.column("swing_signal".into())?.as_series().clone()
}
};
let signal_vals = signals.i32()?;
let risk_level = df.column("swing_risk_level")?.i32()?;
let has_rr_ratio = df.schema().contains("risk_reward_ratio");
let rr_ratio = if has_rr_ratio {
Some(df.column("risk_reward_ratio")?.f64()?)
} else {
None
};
let mut position_sizes = Vec::with_capacity(df.height());
for i in 0..df.height() {
let signal = signal_vals.get(i).unwrap_or(0);
let risk = risk_level.get(i).unwrap_or(2);
if signal == 0 {
position_sizes.push(0.0);
continue;
}
let base_size = match risk {
1 => max_risk * 1.0, 2 => max_risk * 0.75, 3 => max_risk * 0.5, _ => max_risk * 0.75, };
let mut size = if signal.abs() == 2 {
base_size * 1.0 } else {
base_size * 0.75 };
if let Some(rr) = &rr_ratio {
let rr_val = rr.get(i).unwrap_or(f64::NAN);
if !rr_val.is_nan() {
if rr_val >= 3.0 {
size *= 1.25; } else if rr_val >= 2.0 {
size *= 1.1; } else if rr_val <= 0.5 {
size *= 0.5; }
size = size.min(max_risk);
}
}
position_sizes.push(size);
}
Ok(Series::new("position_size", position_sizes))
}