use polars::prelude::*;
use crate::indicators::moving_averages::{calculate_sma, calculate_ema};
use crate::indicators::trend::calculate_adx;
pub fn calculate_trend_strength(
df: &DataFrame,
period: Option<usize>,
smooth_period: Option<usize>,
ma_period: Option<usize>,
) -> PolarsResult<Series> {
let adx_period = period.unwrap_or(14);
let smoothing = smooth_period.unwrap_or(3);
let ma_len = ma_period.unwrap_or(50);
let adx = calculate_adx(df, adx_period)?;
let adx_values = adx.f64()?;
let sma = calculate_sma(df, "close", ma_len)?;
let sma_vals = sma.f64()?;
let short_ma = calculate_sma(df, "close", ma_len / 4)?; let short_ma_vals = short_ma.f64()?;
let close = df.column("close")?.f64()?;
let mut trend_strength = Vec::with_capacity(df.height());
let min_periods = adx_period.max(ma_len).max(smoothing);
for i in 0..min_periods.min(df.height()) {
trend_strength.push(f64::NAN);
}
for i in min_periods..df.height() {
let adx_val = adx_values.get(i).unwrap_or(f64::NAN);
let sma_val = sma_vals.get(i).unwrap_or(f64::NAN);
let short_sma_val = short_ma_vals.get(i).unwrap_or(f64::NAN);
let close_val = close.get(i).unwrap_or(f64::NAN);
if adx_val.is_nan() || sma_val.is_nan() || short_sma_val.is_nan() || close_val.is_nan() {
trend_strength.push(f64::NAN);
continue;
}
let mut strength = adx_val;
let ma_alignment = if close_val > sma_val && short_sma_val > sma_val {
1.25 } else if close_val < sma_val && short_sma_val < sma_val {
1.25 } else if (close_val > sma_val && short_sma_val < sma_val) ||
(close_val < sma_val && short_sma_val > sma_val) {
0.75 } else {
1.0
};
strength *= ma_alignment;
if i >= min_periods + smoothing {
let mut sum = strength;
for j in 1..=smoothing {
sum += trend_strength[i - j];
}
strength = sum / (smoothing as f64 + 1.0);
}
trend_strength.push(strength.min(100.0));
}
Ok(Series::new("trend_strength", trend_strength))
}
pub fn classify_trend(df: &DataFrame) -> PolarsResult<Series> {
for col in ["trend_strength", "close"].iter() {
if !df.schema().contains(*col) {
return Err(PolarsError::ComputeError(
format!("Required column '{}' not found", col).into(),
));
}
}
let strength = df.column("trend_strength")?.f64()?;
let close = df.column("close")?.f64()?;
let sma_short = calculate_sma(df, "close", 20)?;
let sma_medium = calculate_sma(df, "close", 50)?;
let sma_short_vals = sma_short.f64()?;
let sma_medium_vals = sma_medium.f64()?;
let mut trend_class = Vec::with_capacity(df.height());
for i in 0..50.min(df.height()) {
trend_class.push(0);
}
for i in 50..df.height() {
let strength_val = strength.get(i).unwrap_or(f64::NAN);
let close_val = close.get(i).unwrap_or(f64::NAN);
let short_ma = sma_short_vals.get(i).unwrap_or(f64::NAN);
let medium_ma = sma_medium_vals.get(i).unwrap_or(f64::NAN);
if strength_val.is_nan() || close_val.is_nan() || short_ma.is_nan() || medium_ma.is_nan() {
trend_class.push(0);
continue;
}
let trend_direction = if close_val > medium_ma && short_ma > medium_ma {
1 } else if close_val < medium_ma && short_ma < medium_ma {
-1 } else {
0 };
if trend_direction > 0 {
if strength_val >= 30.0 {
trend_class.push(2); } else {
trend_class.push(1); }
} else if trend_direction < 0 {
if strength_val >= 30.0 {
trend_class.push(-2); } else {
trend_class.push(-1); }
} else {
trend_class.push(0); }
}
Ok(Series::new("trend_classification", trend_class))
}
pub fn add_trend_strength_analysis(df: &mut DataFrame, period: usize) -> PolarsResult<()> {
let trend_strength = calculate_trend_strength(df, Some(period), None, None)?;
df.with_column(trend_strength)?;
let trend_class = classify_trend(df)?;
df.with_column(trend_class)?;
Ok(())
}