use polars::prelude::*;
use quantwave_core::features::{self as rust_features};
use quantwave_core::traits::Next;
use crate::QuantWaveNamespace;
pub struct TaFeaturesNamespace<'a>(pub(crate) &'a LazyFrame);
impl<'a> QuantWaveNamespace<'a> {
pub fn features(self) -> TaFeaturesNamespace<'a> {
TaFeaturesNamespace(self.0)
}
}
impl<'a> TaFeaturesNamespace<'a> {
pub fn hurst(self, period: usize) -> LazyFrame {
self.0.clone().with_columns([col("close")
.map(
move |s| {
let mut extractor = rust_features::HurstFeatureExtractor::new(period);
let ca: &Float64Chunked = s.f64()?;
let mut values = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(f64::NAN);
values.push(extractor.next(val).persistence);
}
Ok(Some(Column::from(Series::new(
format!("hurst_{}", period).into(),
values,
))))
},
GetOutput::from_type(DataType::Float64),
)
.alias(&format!("hurst_{}", period))])
}
pub fn cyber_cycle(self, length: usize) -> LazyFrame {
self.0.clone().with_columns([col("close")
.map(
move |s| {
let mut extractor = rust_features::CyberCycleFeatureExtractor::new(length);
let ca: &Float64Chunked = s.f64()?;
let mut cycles = Vec::with_capacity(s.len());
let mut triggers = Vec::with_capacity(s.len());
let mut momenta = Vec::with_capacity(s.len());
let mut signals = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(f64::NAN);
let f = extractor.next(val);
cycles.push(f.cycle);
triggers.push(f.trigger);
momenta.push(f.cycle_momentum);
signals.push(f.trigger_signal);
}
let s_cycle = Series::new("cycle".into(), cycles);
let s_trigger = Series::new("trigger".into(), triggers);
let s_mom = Series::new("momentum".into(), momenta);
let s_sig = Series::new("signal".into(), signals);
let struct_series = StructChunked::from_series(
"cyber_cycle_result".into(),
s.len(),
[s_cycle, s_trigger, s_mom, s_sig].iter(),
)?;
Ok(Some(Column::from(struct_series.into_series())))
},
GetOutput::from_type(DataType::Struct(vec![
Field::new("cycle".into(), DataType::Float64),
Field::new("trigger".into(), DataType::Float64),
Field::new("momentum".into(), DataType::Float64),
Field::new("signal".into(), DataType::Float64),
])),
)
.alias("cyber_cycle")])
}
pub fn griffiths_dominant_cycle(self, lower: usize, upper: usize, length: usize) -> LazyFrame {
self.0.clone().with_columns([col("close")
.map(
move |s| {
let mut extractor =
rust_features::GriffithsDominantCycleFeatureExtractor::new(lower, upper, length);
let ca: &Float64Chunked = s.f64()?;
let mut values = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(f64::NAN);
values.push(extractor.next(val).dominant_cycle);
}
Ok(Some(Column::from(Series::new("griffiths_dc".into(), values))))
},
GetOutput::from_type(DataType::Float64),
)
.alias("griffiths_dc")])
}
pub fn regime_features(self) -> LazyFrame {
self.0.clone().with_columns([col("close")
.map(
move |s| {
let mut hmm = quantwave_core::regimes::hmm::HMM::bull_bear();
let ca = s.f64()?;
let mut labels = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(f64::NAN);
let regime = if val.is_nan() {
quantwave_core::regimes::MarketRegime::Steady
} else {
hmm.next(val)
};
let label: u32 = match regime {
quantwave_core::regimes::MarketRegime::Bull => 1,
quantwave_core::regimes::MarketRegime::Bear => 2,
quantwave_core::regimes::MarketRegime::Crisis => 3,
quantwave_core::regimes::MarketRegime::Steady => 0,
quantwave_core::regimes::MarketRegime::Cluster(c) => 4 + (c as u32),
};
labels.push(label);
}
Ok(Some(Column::from(Series::new("regime_label".into(), labels))))
},
GetOutput::from_type(DataType::UInt32),
)
.alias("regime_label")])
}
pub fn instantaneous_trendline(self) -> LazyFrame {
self.0.clone().with_columns([col("close")
.map(
move |s| {
let mut extractor = rust_features::InstantaneousTrendlineFeatureExtractor::new();
let ca: &Float64Chunked = s.f64()?;
let mut trends = Vec::with_capacity(s.len());
let mut strengths = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(f64::NAN);
let f = extractor.next(val);
trends.push(f.trend);
strengths.push(f.strength);
}
let struct_series = StructChunked::from_series(
"itl_result".into(),
s.len(),
[
Series::new("trend".into(), trends),
Series::new("strength".into(), strengths),
]
.iter(),
)?;
Ok(Some(Column::from(struct_series.into_series())))
},
GetOutput::from_type(DataType::Struct(vec![
Field::new("trend".into(), DataType::Float64),
Field::new("strength".into(), DataType::Float64),
])),
)
.alias("itl")])
}
pub fn regime_probs(self) -> LazyFrame {
self.0.clone().with_columns([col("close")
.map(
move |s| {
let mut extractor = rust_features::RegimeProbFeatureExtractor::bull_bear();
let ca = s.f64()?;
let mut bull = Vec::with_capacity(s.len());
let mut bear = Vec::with_capacity(s.len());
let mut steady = Vec::with_capacity(s.len());
let mut crisis = Vec::with_capacity(s.len());
let mut other = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(f64::NAN);
let f = extractor.next(val);
bull.push(f.probs[0]);
bear.push(f.probs[1]);
crisis.push(f.probs[2]);
steady.push(f.probs[3]);
other.push(f.probs[4]);
}
let struct_series = StructChunked::from_series(
"regime_probs_result".into(),
s.len(),
[
Series::new("prob_bull".into(), bull),
Series::new("prob_bear".into(), bear),
Series::new("prob_steady".into(), steady),
Series::new("prob_crisis".into(), crisis),
Series::new("prob_other".into(), other),
]
.iter(),
)?;
Ok(Some(Column::from(struct_series.into_series())))
},
GetOutput::from_type(DataType::Struct(vec![
Field::new("prob_bull".into(), DataType::Float64),
Field::new("prob_bear".into(), DataType::Float64),
Field::new("prob_steady".into(), DataType::Float64),
Field::new("prob_crisis".into(), DataType::Float64),
Field::new("prob_other".into(), DataType::Float64),
])),
)
.alias("regime_probs")])
}
pub fn trendflex(self, length: usize) -> LazyFrame {
self.0.clone().with_columns([col("close")
.map(
move |s| {
let mut extractor = rust_features::TrendflexFeatureExtractor::new(length);
let ca: &Float64Chunked = s.f64()?;
let mut values = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(f64::NAN);
values.push(extractor.next(val).trendflex);
}
Ok(Some(Column::from(Series::new(
format!("trendflex_{}", length).into(),
values,
))))
},
GetOutput::from_type(DataType::Float64),
)
.alias(&format!("trendflex_{}", length))])
}
pub fn ehlers_autocorrelation(self, length: usize, num_lags: usize) -> LazyFrame {
self.0.clone().with_columns([col("close")
.map(
move |s| {
let mut extractor =
rust_features::EhlersAutocorrelationFeatureExtractor::new(length, num_lags);
let ca: &Float64Chunked = s.f64()?;
let mut lags = Vec::with_capacity(s.len());
let mut max_corrs = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(f64::NAN);
let f = extractor.next(val);
lags.push(f.dominant_lag as u32);
max_corrs.push(f.max_correlation);
}
let struct_series = StructChunked::from_series(
"ehlers_autocorr_result".into(),
s.len(),
[
Series::new("dominant_lag".into(), lags),
Series::new("max_correlation".into(), max_corrs),
]
.iter(),
)?;
Ok(Some(Column::from(struct_series.into_series())))
},
GetOutput::from_type(DataType::Struct(vec![
Field::new("dominant_lag".into(), DataType::UInt32),
Field::new("max_correlation".into(), DataType::Float64),
])),
)
.alias("ehlers_autocorr")])
}
pub fn recommended_matrix(self) -> LazyFrame {
use crate::QuantWaveExt;
self.hurst(100)
.ta()
.features()
.cyber_cycle(30)
.ta()
.features()
.griffiths_dominant_cycle(6, 50, 30)
.ta()
.features()
.regime_features()
.ta()
.features()
.instantaneous_trendline()
.ta()
.features()
.regime_probs()
.ta()
.features()
.trendflex(30)
.ta()
.features()
.ehlers_autocorrelation(30, 10)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::QuantWaveExt;
#[test]
fn smoke_ta_features_surface() -> PolarsResult<()> {
let prices: Vec<f64> = (0..40)
.map(|i| 100.0 + 3.0 * (i as f64 * 0.4).sin() + (i as f64) * 0.1)
.collect();
let df = df!["close" => prices]?;
let lf = df.lazy();
let out = lf
.clone()
.ta()
.features()
.hurst(8)
.collect()?;
assert!(out.column("hurst_8").is_ok());
assert_eq!(out.column("hurst_8")?.dtype(), &DataType::Float64);
let out = out
.lazy()
.ta()
.features()
.cyber_cycle(12)
.collect()?;
let cc = out.column("cyber_cycle")?;
assert_eq!(cc.dtype().clone(), DataType::Struct(vec![
Field::new("cycle".into(), DataType::Float64),
Field::new("trigger".into(), DataType::Float64),
Field::new("momentum".into(), DataType::Float64),
Field::new("signal".into(), DataType::Float64),
]));
let ca = cc.struct_()?;
assert!(ca.field_by_name("cycle".into())?.f64()?.get(39).is_some());
let out = out
.lazy()
.ta()
.features()
.griffiths_dominant_cycle(6, 40, 25)
.collect()?;
assert!(out.column("griffiths_dc").is_ok());
assert_eq!(out.column("griffiths_dc")?.dtype(), &DataType::Float64);
let out = out
.lazy()
.ta()
.features()
.regime_features()
.collect()?;
assert!(out.column("regime_label").is_ok());
assert_eq!(out.column("regime_label")?.dtype(), &DataType::UInt32);
let out = out.lazy().ta().features().instantaneous_trendline().collect()?;
assert!(out.column("itl").is_ok());
let out = out.lazy().ta().features().regime_probs().collect()?;
assert!(out.column("regime_probs").is_ok());
let out = out.lazy().ta().features().trendflex(20).collect()?;
assert!(out.column("trendflex_20").is_ok());
let out = out
.lazy()
.ta()
.features()
.ehlers_autocorrelation(30, 10)
.collect()?;
assert!(out.column("ehlers_autocorr").is_ok());
Ok(())
}
}