Skip to main content

quantwave_plugins/
custom_5.rs

1use polars::prelude::*;
2use pyo3_polars::derive::polars_expr;
3use serde::Deserialize;
4use quantwave_core::*;
5use quantwave_core::traits::Next;
6
7#[derive(Deserialize)]
8pub struct TaVarKwargs {
9    pub period: usize,
10    pub nbdev: f64,
11}
12
13#[polars_expr(output_type=Float64)]
14fn ta_var(inputs: &[Series], kwargs: TaVarKwargs) -> PolarsResult<Series> {
15    let s = &inputs[0];
16    let ca = s.f64()?;
17    let mut indicator = TaVAR::new(kwargs.period, kwargs.nbdev);
18    let mut values = Vec::with_capacity(s.len());
19    for i in 0..s.len() {
20        let val = ca.get(i).unwrap_or(f64::NAN);
21        values.push(indicator.next(val));
22    }
23    Ok(Series::new("ta_var".into(), values))
24}
25
26#[derive(Deserialize)]
27pub struct VfiKwargs {
28    pub period: usize,
29    pub coef: f64,
30    pub vcoef: f64,
31    pub smoothing_period: usize,
32}
33
34#[polars_expr(output_type=Float64)]
35fn vfi(inputs: &[Series], kwargs: VfiKwargs) -> PolarsResult<Series> {
36    let high = inputs[0].f64()?;
37    let low = inputs[1].f64()?;
38    let close = inputs[2].f64()?;
39    let volume = inputs[3].f64()?;
40
41    let mut indicator = quantwave_core::Vfi::new(
42        kwargs.period,
43        kwargs.coef,
44        kwargs.vcoef,
45        kwargs.smoothing_period,
46    );
47    let mut values = Vec::with_capacity(high.len());
48
49    for i in 0..high.len() {
50        let h = high.get(i).unwrap_or(f64::NAN);
51        let l = low.get(i).unwrap_or(f64::NAN);
52        let c = close.get(i).unwrap_or(f64::NAN);
53        let v = volume.get(i).unwrap_or(f64::NAN);
54        values.push(indicator.next((h, l, c, v)));
55    }
56
57    Ok(Series::new("vfi".into(), values))
58}
59
60#[derive(Deserialize)]
61pub struct WaveTrendKwargs {
62    pub n1: usize,
63    pub n2: usize,
64    pub n3: usize,
65}
66
67fn wavetrend_output(_input_fields: &[Field]) -> PolarsResult<Field> {
68    Ok(Field::new(
69        "wavetrend".into(),
70        DataType::Struct(vec![
71            Field::new("wt1".into(), DataType::Float64),
72            Field::new("wt2".into(), DataType::Float64),
73        ]),
74    ))
75}
76
77#[polars_expr(output_type_func=wavetrend_output)]
78fn wavetrend(inputs: &[Series], kwargs: WaveTrendKwargs) -> PolarsResult<Series> {
79    let high = inputs[0].f64()?;
80    let low = inputs[1].f64()?;
81    let close = inputs[2].f64()?;
82
83    let mut wt = quantwave_core::WaveTrend::new(kwargs.n1, kwargs.n2, kwargs.n3);
84    let mut wt1_vals = Vec::with_capacity(high.len());
85    let mut wt2_vals = Vec::with_capacity(high.len());
86
87    for i in 0..high.len() {
88        let h = high.get(i).unwrap_or(0.0);
89        let l = low.get(i).unwrap_or(0.0);
90        let c = close.get(i).unwrap_or(0.0);
91        let (wt1, wt2) = wt.next((h, l, c));
92        wt1_vals.push(wt1);
93        wt2_vals.push(wt2);
94    }
95
96    let wt1_series = Series::new("wt1".into(), wt1_vals);
97    let wt2_series = Series::new("wt2".into(), wt2_vals);
98
99    let out = StructChunked::from_series(
100        "wavetrend".into(),
101        high.len(),
102        [wt1_series, wt2_series].iter(),
103    )?;
104    Ok(out.into_series())
105}
106
107#[derive(Deserialize)]
108pub struct RegimesEnsembleKwargs {
109    pub weights: Vec<f64>,
110}
111
112#[polars_expr(output_type=UInt32)]
113fn regimes_ensemble(inputs: &[Series], kwargs: RegimesEnsembleKwargs) -> PolarsResult<Series> {
114    let n_rows = inputs[0].len();
115    let n_dims = inputs.len();
116    let ensemble = quantwave_core::regimes::ensemble::RegimeEnsemble::new(kwargs.weights);
117
118    let mut cas = Vec::with_capacity(n_dims);
119    for s in inputs {
120        cas.push(s.u32()?);
121    }
122
123    let mut results = Vec::with_capacity(n_rows);
124    for i in 0..n_rows {
125        let mut row_regimes = Vec::with_capacity(n_dims);
126        for ca in &cas {
127            let val = ca.get(i).unwrap_or(0);
128            let regime = match val {
129                0 => quantwave_core::regimes::MarketRegime::Steady,
130                1 => quantwave_core::regimes::MarketRegime::Crisis,
131                2 => quantwave_core::regimes::MarketRegime::Bull,
132                3 => quantwave_core::regimes::MarketRegime::Bear,
133                v if v >= 4 => quantwave_core::regimes::MarketRegime::Cluster((v - 4) as u8),
134                _ => quantwave_core::regimes::MarketRegime::Steady,
135            };
136            row_regimes.push(regime);
137        }
138
139        let consensus = ensemble.vote(&row_regimes);
140        let out = match consensus {
141            quantwave_core::regimes::MarketRegime::Steady => 0u32,
142            quantwave_core::regimes::MarketRegime::Crisis => 1,
143            quantwave_core::regimes::MarketRegime::Bull => 2,
144            quantwave_core::regimes::MarketRegime::Bear => 3,
145            quantwave_core::regimes::MarketRegime::Cluster(c) => 4 + (c as u32),
146        };
147        results.push(out);
148    }
149
150    Ok(Series::new("ensemble_regime".into(), results))
151}
152
153#[derive(Deserialize)]
154pub struct TaStddevKwargs {
155    pub period: usize,
156    pub nbdev: f64,
157}
158
159#[polars_expr(output_type=Float64)]
160fn ta_stddev(inputs: &[Series], kwargs: TaStddevKwargs) -> PolarsResult<Series> {
161    let s = &inputs[0];
162    let ca = s.f64()?;
163    let mut indicator = TaSTDDEV::new(kwargs.period, kwargs.nbdev);
164    let mut values = Vec::with_capacity(s.len());
165    for i in 0..s.len() {
166        let val = ca.get(i).unwrap_or(f64::NAN);
167        values.push(indicator.next(val));
168    }
169    Ok(Series::new("ta_stddev".into(), values))
170}