use polars::prelude::*;
use crate::indicators::oscillators::{calculate_rsi, calculate_stochastic};
use crate::indicators::moving_averages::{calculate_ema, calculate_sma};
pub fn detect_swing_opportunities(
df: &DataFrame,
trend_ma_period: Option<usize>,
pullback_threshold: Option<f64>,
rsi_period: Option<usize>,
stoch_period: Option<usize>,
) -> PolarsResult<Series> {
let ma_period = trend_ma_period.unwrap_or(50);
let pullback_pct = pullback_threshold.unwrap_or(3.0);
let rsi_len = rsi_period.unwrap_or(14);
let stoch_len = stoch_period.unwrap_or(14);
let trend_ma = calculate_ema(df, "close", ma_period)?;
let rsi = calculate_rsi(df, rsi_len, "close")?;
let (stoch_k, _) = calculate_stochastic(df, stoch_len, 3, None)?;
let close = df.column("close")?.f64()?;
let low = df.column("low")?.f64()?;
let high = df.column("high")?.f64()?;
let ma_vals = trend_ma.f64()?;
let rsi_vals = rsi.f64()?;
let stoch_vals = stoch_k.f64()?;
let mut swing_signals = Vec::with_capacity(df.height());
let lookback = 5; let min_periods = ma_period.max(rsi_len).max(stoch_len) + lookback;
for i in 0..min_periods.min(df.height()) {
swing_signals.push(0);
}
for i in min_periods..df.height() {
let ma_val = ma_vals.get(i).unwrap_or(f64::NAN);
let close_val = close.get(i).unwrap_or(f64::NAN);
let rsi_val = rsi_vals.get(i).unwrap_or(f64::NAN);
let stoch_val = stoch_vals.get(i).unwrap_or(f64::NAN);
if ma_val.is_nan() || close_val.is_nan() || rsi_val.is_nan() || stoch_val.is_nan() {
swing_signals.push(0);
continue;
}
let trend_direction = if close_val > ma_val { 1 } else { -1 };
let mut recent_extreme = close_val;
if trend_direction > 0 {
for j in (i - lookback)..i {
let h = high.get(j).unwrap_or(f64::NAN);
if !h.is_nan() && h > recent_extreme {
recent_extreme = h;
}
}
} else {
recent_extreme = low.get(i).unwrap_or(f64::NAN);
for j in (i - lookback)..i {
let l = low.get(j).unwrap_or(f64::NAN);
if !l.is_nan() && l < recent_extreme {
recent_extreme = l;
}
}
}
let pullback = if trend_direction > 0 {
((recent_extreme - close_val) / recent_extreme * 100.0).abs()
} else {
((close_val - recent_extreme) / recent_extreme * 100.0).abs()
};
if trend_direction > 0 && pullback >= pullback_pct {
if rsi_val < 40.0 && stoch_val < 30.0 {
swing_signals.push(1); } else {
swing_signals.push(0); }
} else if trend_direction < 0 && pullback >= pullback_pct {
if rsi_val > 60.0 && stoch_val > 70.0 {
swing_signals.push(-1); } else {
swing_signals.push(0); }
} else {
swing_signals.push(0); }
}
Ok(Series::new("swing_signal", swing_signals))
}
pub fn calculate_swing_risk_level(
df: &DataFrame,
lookback_period: Option<usize>,
) -> PolarsResult<Series> {
let lookback = lookback_period.unwrap_or(20);
let high = df.column("high")?.f64()?;
let low = df.column("low")?.f64()?;
let close = df.column("close")?.f64()?;
let mut atr_values = Vec::with_capacity(df.height());
atr_values.push(high.get(0).unwrap_or(f64::NAN) - low.get(0).unwrap_or(f64::NAN));
for i in 1..df.height() {
let h = high.get(i).unwrap_or(f64::NAN);
let l = low.get(i).unwrap_or(f64::NAN);
let c_prev = close.get(i - 1).unwrap_or(f64::NAN);
if h.is_nan() || l.is_nan() || c_prev.is_nan() {
atr_values.push(f64::NAN);
continue;
}
let tr = (h - l)
.max((h - c_prev).abs())
.max((l - c_prev).abs());
atr_values.push(tr);
}
let mut risk_levels = Vec::with_capacity(df.height());
for i in 0..lookback.min(df.height()) {
risk_levels.push(2);
}
for i in lookback..df.height() {
let mut atr_sum = 0.0;
let mut atr_count = 0;
for j in (i - lookback)..i {
let atr_val = atr_values[j];
if !atr_val.is_nan() {
atr_sum += atr_val;
atr_count += 1;
}
}
let avg_atr = if atr_count > 0 { atr_sum / atr_count as f64 } else { f64::NAN };
if avg_atr.is_nan() {
risk_levels.push(2); continue;
}
let mut price_sum = 0.0;
let mut price_count = 0;
for j in (i - lookback)..i {
let c = close.get(j).unwrap_or(f64::NAN);
if !c.is_nan() {
price_sum += c;
price_count += 1;
}
}
let avg_price = if price_count > 0 { price_sum / price_count as f64 } else { f64::NAN };
if avg_price.is_nan() || avg_price == 0.0 {
risk_levels.push(2); continue;
}
let norm_atr = avg_atr / avg_price * 100.0;
if norm_atr < 1.0 {
risk_levels.push(1); } else if norm_atr < 3.0 {
risk_levels.push(2); } else {
risk_levels.push(3); }
}
Ok(Series::new("swing_risk_level", risk_levels))
}
pub fn add_swing_analysis(df: &mut DataFrame) -> PolarsResult<()> {
let swing_signal = detect_swing_opportunities(df, None, None, None, None)?;
let risk_level = calculate_swing_risk_level(df, None)?;
df.with_column(swing_signal)?;
df.with_column(risk_level)?;
Ok(())
}