use polars::prelude::*;
use quantwave_core::features::{self as rust_features};
use quantwave_core::traits::Next;
use crate::QuantWaveNamespace;
pub struct TaFeaturesNamespace<'a>(pub(crate) &'a LazyFrame);
impl<'a> QuantWaveNamespace<'a> {
pub fn features(self) -> TaFeaturesNamespace<'a> {
TaFeaturesNamespace(self.0)
}
}
impl<'a> TaFeaturesNamespace<'a> {
pub fn hurst(self, period: usize) -> LazyFrame {
self.0.clone().with_columns([col("close")
.map(
move |s| {
let mut extractor = rust_features::HurstFeatureExtractor::new(period);
let ca: &Float64Chunked = s.f64()?;
let mut values = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(f64::NAN);
values.push(extractor.next(val).persistence);
}
Ok(Some(Column::from(Series::new(
format!("hurst_{}", period).into(),
values,
))))
},
GetOutput::from_type(DataType::Float64),
)
.alias(&format!("hurst_{}", period))])
}
pub fn cyber_cycle(self, length: usize) -> LazyFrame {
self.0.clone().with_columns([col("close")
.map(
move |s| {
let mut extractor = rust_features::CyberCycleFeatureExtractor::new(length);
let ca: &Float64Chunked = s.f64()?;
let mut cycles = Vec::with_capacity(s.len());
let mut triggers = Vec::with_capacity(s.len());
let mut momenta = Vec::with_capacity(s.len());
let mut signals = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(f64::NAN);
let f = extractor.next(val);
cycles.push(f.cycle);
triggers.push(f.trigger);
momenta.push(f.cycle_momentum);
signals.push(f.trigger_signal);
}
let s_cycle = Series::new("cycle".into(), cycles);
let s_trigger = Series::new("trigger".into(), triggers);
let s_mom = Series::new("momentum".into(), momenta);
let s_sig = Series::new("signal".into(), signals);
let struct_series = StructChunked::from_series(
"cyber_cycle_result".into(),
s.len(),
[s_cycle, s_trigger, s_mom, s_sig].iter(),
)?;
Ok(Some(Column::from(struct_series.into_series())))
},
GetOutput::from_type(DataType::Struct(vec![
Field::new("cycle".into(), DataType::Float64),
Field::new("trigger".into(), DataType::Float64),
Field::new("momentum".into(), DataType::Float64),
Field::new("signal".into(), DataType::Float64),
])),
)
.alias("cyber_cycle")])
}
pub fn griffiths_dominant_cycle(self, lower: usize, upper: usize, length: usize) -> LazyFrame {
self.0.clone().with_columns([col("close")
.map(
move |s| {
let mut extractor =
rust_features::GriffithsDominantCycleFeatureExtractor::new(lower, upper, length);
let ca: &Float64Chunked = s.f64()?;
let mut values = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(f64::NAN);
values.push(extractor.next(val).dominant_cycle);
}
Ok(Some(Column::from(Series::new("griffiths_dc".into(), values))))
},
GetOutput::from_type(DataType::Float64),
)
.alias("griffiths_dc")])
}
pub fn regime_features(self) -> LazyFrame {
self.0.clone().with_columns([col("close")
.map(
move |s| {
let mut hmm = quantwave_core::regimes::hmm::HMM::bull_bear();
let ca = s.f64()?;
let mut labels = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(f64::NAN);
let regime = if val.is_nan() {
quantwave_core::regimes::MarketRegime::Steady
} else {
hmm.next(val)
};
let label: u32 = match regime {
quantwave_core::regimes::MarketRegime::Bull => 1,
quantwave_core::regimes::MarketRegime::Bear => 2,
quantwave_core::regimes::MarketRegime::Crisis => 3,
quantwave_core::regimes::MarketRegime::Steady => 0,
quantwave_core::regimes::MarketRegime::Cluster(c) => 4 + (c as u32),
};
labels.push(label);
}
Ok(Some(Column::from(Series::new("regime_label".into(), labels))))
},
GetOutput::from_type(DataType::UInt32),
)
.alias("regime_label")])
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::QuantWaveExt;
#[test]
fn smoke_ta_features_surface() -> PolarsResult<()> {
let prices: Vec<f64> = (0..40)
.map(|i| 100.0 + 3.0 * (i as f64 * 0.4).sin() + (i as f64) * 0.1)
.collect();
let df = df!["close" => prices]?;
let lf = df.lazy();
let out = lf
.clone()
.ta()
.features()
.hurst(8)
.collect()?;
assert!(out.column("hurst_8").is_ok());
assert_eq!(out.column("hurst_8")?.dtype(), &DataType::Float64);
let out = out
.lazy()
.ta()
.features()
.cyber_cycle(12)
.collect()?;
let cc = out.column("cyber_cycle")?;
assert_eq!(cc.dtype().clone(), DataType::Struct(vec![
Field::new("cycle".into(), DataType::Float64),
Field::new("trigger".into(), DataType::Float64),
Field::new("momentum".into(), DataType::Float64),
Field::new("signal".into(), DataType::Float64),
]));
let ca = cc.struct_()?;
assert!(ca.field_by_name("cycle".into())?.f64()?.get(39).is_some());
let out = out
.lazy()
.ta()
.features()
.griffiths_dominant_cycle(6, 40, 25)
.collect()?;
assert!(out.column("griffiths_dc").is_ok());
assert_eq!(out.column("griffiths_dc")?.dtype(), &DataType::Float64);
let out = out
.lazy()
.ta()
.features()
.regime_features()
.collect()?;
assert!(out.column("regime_label").is_ok());
assert_eq!(out.column("regime_label")?.dtype(), &DataType::UInt32);
assert!(out.column("hurst_8").is_ok());
assert!(out.column("cyber_cycle").is_ok());
assert!(out.column("griffiths_dc").is_ok());
assert!(out.column("regime_label").is_ok());
Ok(())
}
}