use crate::indicators::{
moving_averages::{calculate_ema, calculate_sma},
oscillators::{calculate_macd, calculate_rsi},
volatility::{
calculate_atr, calculate_bb_b, calculate_bollinger_bands, calculate_gk_volatility,
},
};
use crate::util::dataframe_utils::ensure_f64_column;
use crate::util::time_utils::create_cyclical_time_features;
use polars::prelude::*;
pub fn add_technical_indicators(df: &mut DataFrame) -> PolarsResult<DataFrame> {
let numeric_columns = ["open", "high", "low", "close", "volume"];
for col_name in numeric_columns {
if !df.schema().contains(col_name) {
continue;
}
ensure_f64_column(df, col_name)?;
}
let sma20 = calculate_sma(df, "close", 20)?.with_name("sma_20".into());
let sma50 = calculate_sma(df, "close", 50)?.with_name("sma_50".into());
let ema20 = calculate_ema(df, "close", 20)?.with_name("ema_20".into());
let rsi = calculate_rsi(df, 14, "close")?.with_name("rsi_14".into());
let (macd, macd_signal) = calculate_macd(df, 12, 26, 9, "close")?;
let macd = macd.with_name("macd".into());
let macd_signal = macd_signal.with_name("macd_signal".into());
let (bb_middle, bb_upper, bb_lower) = calculate_bollinger_bands(df, 20, 2.0, "close")?;
let bb_middle = bb_middle.with_name("bb_middle".into());
let bb_upper = bb_upper.with_name("bb_upper".into());
let bb_lower = bb_lower.with_name("bb_lower".into());
let bb_b = calculate_bb_b(df, 20, 2.0, "close")?.with_name("bb_b".into());
let atr = calculate_atr(df, 14)?.with_name("atr_14".into());
let gk_vol = calculate_gk_volatility(df, 10)?.with_name("gk_volatility".into());
let close = df.column("close")?.f64()?;
let prev_close = close.shift(1);
let returns = ((close.clone() - prev_close.clone()) / prev_close.clone())
.with_name("returns".into())
.into_series();
let high = df.column("high")?.f64()?;
let low = df.column("low")?.f64()?;
let price_range = ((high.clone() - low.clone()) / close.clone())
.with_name("price_range".into())
.into_series();
let close_lag_5 = close.shift(5).with_name("close_lag_5".into());
let close_lag_15 = close.shift(15).with_name("close_lag_15".into());
let close_lag_30 = close.shift(30).with_name("close_lag_30".into());
let close_lag_5_clone = close_lag_5.clone();
let returns_5min = ((close.clone() - close_lag_5_clone.clone()) / close_lag_5_clone)
.with_name("returns_5min".into())
.into_series();
let mut vol_15min = Vec::with_capacity(df.height());
for i in 0..df.height() {
if i < 15 {
vol_15min.push(0.0);
continue;
}
let mut returns = Vec::with_capacity(15);
for j in (i - 15)..i {
let current_opt = close.get(j);
let previous_opt = if j > 0 { close.get(j - 1) } else { None };
if let (Some(current), Some(previous)) = (current_opt, previous_opt) {
if previous != 0.0 {
returns.push((current - previous) / previous);
}
}
}
if returns.is_empty() {
vol_15min.push(0.0);
continue;
}
let mean = returns.iter().sum::<f64>() / returns.len() as f64;
let variance =
returns.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / returns.len() as f64;
vol_15min.push(variance.sqrt());
}
let volatility_15min = Series::new("volatility_15min".into(), vol_15min);
let mut time_features = Vec::new();
if df.schema().contains("time") {
time_features = create_cyclical_time_features(df, "time", "%Y-%m-%d %H:%M:%S UTC")?;
}
let mut features_to_add = vec![
sma20,
sma50,
ema20,
rsi,
macd,
macd_signal,
bb_middle,
bb_upper,
bb_lower,
bb_b,
atr,
gk_vol,
returns,
price_range,
close_lag_5.into_series(),
close_lag_15.into_series(),
close_lag_30.into_series(),
returns_5min,
volatility_15min,
];
features_to_add.extend(time_features);
for feature in features_to_add {
df.with_column(feature)?;
}
Ok(df.clone())
}