Skip to main content

quantwave_plugins/
custom_2.rs

1use polars::prelude::*;
2use pyo3_polars::derive::polars_expr;
3use serde::Deserialize;
4use quantwave_core::*;
5use quantwave_core::traits::Next;
6
7fn u8_to_matype(matype: u8) -> talib::MaType {
8    match matype {
9        0 => talib::MaType::Sma,
10        1 => talib::MaType::Ema,
11        2 => talib::MaType::Wma,
12        3 => talib::MaType::Dema,
13        4 => talib::MaType::Tema,
14        5 => talib::MaType::Trima,
15        6 => talib::MaType::Kama,
16        7 => talib::MaType::Mama,
17        8 => talib::MaType::T3,
18        _ => talib::MaType::Sma,
19    }
20}
21
22// 1. MAVP
23#[derive(Deserialize)]
24struct MavpKwargs {
25    minperiod: usize,
26    maxperiod: usize,
27    matype: u8,
28}
29
30#[polars_expr(output_type=Float64)]
31fn mavp(inputs: &[Series], kwargs: MavpKwargs) -> PolarsResult<Series> {
32    let in1_ca = inputs[0].f64()?;
33    let in2_ca = inputs[1].f64()?;
34
35    let mut indicator = MAVP::new(kwargs.minperiod, kwargs.maxperiod, u8_to_matype(kwargs.matype));
36    let mut values = Vec::with_capacity(in1_ca.len());
37
38    for i in 0..in1_ca.len() {
39        let i1 = in1_ca.get(i).unwrap_or(f64::NAN);
40        let i2 = in2_ca.get(i).unwrap_or(f64::NAN);
41        values.push(indicator.next((i1, i2)));
42    }
43
44    Ok(Series::new("mavp".into(), values))
45}
46
47// 2. MFI
48#[derive(Deserialize)]
49struct MfiKwargs {
50    period: usize,
51}
52
53#[polars_expr(output_type=Float64)]
54fn mfi(inputs: &[Series], kwargs: MfiKwargs) -> PolarsResult<Series> {
55    let high = inputs[0].f64()?;
56    let low = inputs[1].f64()?;
57    let close = inputs[2].f64()?;
58    let volume = inputs[3].f64()?;
59
60    let mut indicator = MFI::new(kwargs.period);
61    let mut values = Vec::with_capacity(high.len());
62
63    for i in 0..high.len() {
64        let h = high.get(i).unwrap_or(f64::NAN);
65        let l = low.get(i).unwrap_or(f64::NAN);
66        let c = close.get(i).unwrap_or(f64::NAN);
67        let v = volume.get(i).unwrap_or(f64::NAN);
68        values.push(indicator.next((h, l, c, v)));
69    }
70
71    Ok(Series::new("mfi".into(), values))
72}
73
74// 3. REVERSE_EMA
75#[derive(Deserialize)]
76struct ReverseEmaKwargs {
77    alpha: f64,
78}
79
80#[polars_expr(output_type=Float64)]
81fn reverse_ema(inputs: &[Series], kwargs: ReverseEmaKwargs) -> PolarsResult<Series> {
82    let ca = inputs[0].f64()?;
83
84    let mut indicator = quantwave_core::ReverseEMA::new(kwargs.alpha);
85    let mut values = Vec::with_capacity(ca.len());
86
87    for i in 0..ca.len() {
88        let val = ca.get(i).unwrap_or(f64::NAN);
89        values.push(indicator.next(val));
90    }
91
92    Ok(Series::new("reverse_ema".into(), values))
93}
94
95// 4. VOLATILITY_CLUSTERER
96#[derive(Deserialize)]
97struct VolatilityClustererKwargs {
98    atr_period: usize,
99    window_size: usize,
100    k: usize,
101}
102
103#[polars_expr(output_type=UInt32)]
104fn volatility_clusterer(inputs: &[Series], kwargs: VolatilityClustererKwargs) -> PolarsResult<Series> {
105    let high = inputs[0].f64()?;
106    let low = inputs[1].f64()?;
107    let close = inputs[2].f64()?;
108
109    let mut clusterer = quantwave_core::regimes::volatility_clustering::VolatilityClusterer::new(
110        kwargs.atr_period,
111        kwargs.window_size,
112        kwargs.k,
113    );
114    let mut values = Vec::with_capacity(high.len());
115
116    for i in 0..high.len() {
117        let h = high.get(i).unwrap_or(f64::NAN);
118        let l = low.get(i).unwrap_or(f64::NAN);
119        let c = close.get(i).unwrap_or(f64::NAN);
120        let regime = clusterer.next((h, l, c));
121        let val = match regime {
122            quantwave_core::regimes::MarketRegime::Steady => 0u32,
123            quantwave_core::regimes::MarketRegime::Bull => 1,
124            quantwave_core::regimes::MarketRegime::Bear => 2,
125            quantwave_core::regimes::MarketRegime::Crisis => 3,
126            quantwave_core::regimes::MarketRegime::Cluster(c) => 4 + (c as u32),
127        };
128        values.push(val);
129    }
130
131    Ok(Series::new("volatility_regime".into(), values))
132}
133
134// 5. PIVOT_POINTS
135pub fn pivot_points_output(_: &[Field]) -> PolarsResult<Field> {
136    Ok(Field::new(
137        "pivot_points".into(),
138        DataType::Struct(vec![
139            Field::new("p".into(), DataType::Float64),
140            Field::new("r1".into(), DataType::Float64),
141            Field::new("s1".into(), DataType::Float64),
142            Field::new("r2".into(), DataType::Float64),
143            Field::new("s2".into(), DataType::Float64),
144        ]),
145    ))
146}
147
148#[polars_expr(output_type_func=pivot_points_output)]
149fn pivot_points(inputs: &[Series]) -> PolarsResult<Series> {
150    let high = inputs[0].f64()?;
151    let low = inputs[1].f64()?;
152    let close = inputs[2].f64()?;
153
154    let mut pivot = quantwave_core::PivotPoints::new();
155    let mut p_vals = Vec::with_capacity(high.len());
156    let mut r1_vals = Vec::with_capacity(high.len());
157    let mut s1_vals = Vec::with_capacity(high.len());
158    let mut r2_vals = Vec::with_capacity(high.len());
159    let mut s2_vals = Vec::with_capacity(high.len());
160
161    for i in 0..high.len() {
162        let h = high.get(i).unwrap_or(0.0);
163        let l = low.get(i).unwrap_or(0.0);
164        let c = close.get(i).unwrap_or(0.0);
165        let (p, r1, s1, r2, s2) = pivot.next((h, l, c));
166        p_vals.push(p);
167        r1_vals.push(r1);
168        s1_vals.push(s1);
169        r2_vals.push(r2);
170        s2_vals.push(s2);
171    }
172
173    let p_series = Series::new("p".into(), p_vals);
174    let r1_series = Series::new("r1".into(), r1_vals);
175    let s1_series = Series::new("s1".into(), s1_vals);
176    let r2_series = Series::new("r2".into(), r2_vals);
177    let s2_series = Series::new("s2".into(), s2_vals);
178
179    let out = StructChunked::from_series(
180        "pivot_output".into(),
181        high.len(),
182        [p_series, r1_series, s1_series, r2_series, s2_series].iter(),
183    )?;
184    Ok(out.into_series())
185}