Skip to main content

quantwave_plugins/
custom_9.rs

1use polars::prelude::*;
2use pyo3_polars::derive::polars_expr;
3use serde::Deserialize;
4use quantwave_core::*;
5use quantwave_core::traits::Next;
6
7// 1. sdo
8#[derive(Deserialize)]
9struct SdoKwargs {
10    lookback_period: usize,
11    period: usize,
12    ema_pds: usize,
13}
14
15fn sdo_output(_: &[Field]) -> PolarsResult<Field> {
16    Ok(Field::new("sdo".into(), DataType::Float64))
17}
18
19#[polars_expr(output_type_func=sdo_output)]
20fn sdo(inputs: &[Series], kwargs: SdoKwargs) -> PolarsResult<Series> {
21    let s = inputs[0].f64()?;
22    let mut indicator = quantwave_core::SDO::new(kwargs.lookback_period, kwargs.period, kwargs.ema_pds);
23    let mut values = Vec::with_capacity(s.len());
24
25    for i in 0..s.len() {
26        let val = s.get(i).unwrap_or(f64::NAN);
27        if val.is_nan() {
28            values.push(f64::NAN);
29        } else {
30            values.push(indicator.next(val));
31        }
32    }
33
34    Ok(Float64Chunked::from_slice("sdo".into(), &values).into_series())
35}
36
37// 2. regimes_tar
38#[derive(Deserialize)]
39struct RegimesTarKwargs {
40    thresholds: Vec<f64>,
41}
42
43fn regimes_tar_output(_: &[Field]) -> PolarsResult<Field> {
44    Ok(Field::new("tar_regime".into(), DataType::UInt32))
45}
46
47#[polars_expr(output_type_func=regimes_tar_output)]
48fn regimes_tar(inputs: &[Series], kwargs: RegimesTarKwargs) -> PolarsResult<Series> {
49    let s = inputs[0].f64()?;
50    let mut model = quantwave_core::regimes::tar::TAR::multi(kwargs.thresholds.clone());
51    let mut results = Vec::with_capacity(s.len());
52
53    for i in 0..s.len() {
54        let val = s.get(i).unwrap_or(f64::NAN);
55        // Even for NaN, we pass it to the model.next to keep behavior identical
56        // to the user's closure if they didn't check for NaN explicitly.
57        let regime = model.next(val);
58        let out = match regime {
59            quantwave_core::regimes::MarketRegime::Steady => 0u32,
60            quantwave_core::regimes::MarketRegime::Crisis => 1,
61            quantwave_core::regimes::MarketRegime::Bull => 2,
62            quantwave_core::regimes::MarketRegime::Bear => 3,
63            quantwave_core::regimes::MarketRegime::Cluster(c) => 4 + (c as u32),
64        };
65        results.push(out);
66    }
67
68    Ok(UInt32Chunked::from_slice("tar_regime".into(), &results).into_series())
69}
70
71// 3. ichimoku_cloud
72#[derive(Deserialize)]
73struct IchimokuCloudKwargs {
74    p1: usize,
75    p2: usize,
76    p3: usize,
77}
78
79fn ichimoku_cloud_output(_: &[Field]) -> PolarsResult<Field> {
80    Ok(Field::new(
81        "ichimoku_output".into(),
82        DataType::Struct(vec![
83            Field::new("tenkan".into(), DataType::Float64),
84            Field::new("kijun".into(), DataType::Float64),
85            Field::new("senkou_a".into(), DataType::Float64),
86            Field::new("senkou_b".into(), DataType::Float64),
87        ]),
88    ))
89}
90
91#[polars_expr(output_type_func=ichimoku_cloud_output)]
92fn ichimoku_cloud(inputs: &[Series], kwargs: IchimokuCloudKwargs) -> PolarsResult<Series> {
93    let high = inputs[0].f64()?;
94    let low = inputs[1].f64()?;
95
96    let mut ic = quantwave_core::IchimokuCloud::new(kwargs.p1, kwargs.p2, kwargs.p3);
97    let mut t_vals = Vec::with_capacity(high.len());
98    let mut k_vals = Vec::with_capacity(high.len());
99    let mut sa_vals = Vec::with_capacity(high.len());
100    let mut sb_vals = Vec::with_capacity(high.len());
101
102    for i in 0..high.len() {
103        let h = high.get(i).unwrap_or(f64::NAN);
104        let l = low.get(i).unwrap_or(f64::NAN);
105        
106        if h.is_nan() || l.is_nan() {
107            t_vals.push(Some(f64::NAN));
108            k_vals.push(Some(f64::NAN));
109            sa_vals.push(Some(f64::NAN));
110            sb_vals.push(Some(f64::NAN));
111        } else {
112            let (t, k, sa, sb) = ic.next((h, l));
113            t_vals.push(Some(t));
114            k_vals.push(Some(k));
115            sa_vals.push(Some(sa));
116            sb_vals.push(Some(sb));
117        }
118    }
119
120    let t_series = Float64Chunked::new("tenkan".into(), t_vals).into_series();
121    let k_series = Float64Chunked::new("kijun".into(), k_vals).into_series();
122    let sa_series = Float64Chunked::new("senkou_a".into(), sa_vals).into_series();
123    let sb_series = Float64Chunked::new("senkou_b".into(), sb_vals).into_series();
124
125    let out = StructChunked::from_series(
126        "ichimoku_output".into(),
127        high.len(),
128        [t_series, k_series, sa_series, sb_series].iter(),
129    )?;
130
131    Ok(out.into_series())
132}
133
134// 4. mama
135#[derive(Deserialize)]
136struct MamaKwargs {
137    fastlimit: f64,
138    slowlimit: f64,
139}
140
141fn mama_output(_: &[Field]) -> PolarsResult<Field> {
142    Ok(Field::new(
143        "mama_result".into(),
144        DataType::Struct(vec![
145            Field::new("mama".into(), DataType::Float64),
146            Field::new("fama".into(), DataType::Float64),
147        ]),
148    ))
149}
150
151#[polars_expr(output_type_func=mama_output)]
152fn mama(inputs: &[Series], kwargs: MamaKwargs) -> PolarsResult<Series> {
153    let s = inputs[0].f64()?;
154    let mut indicator = MAMA::new(kwargs.fastlimit, kwargs.slowlimit);
155    let mut mama_vals = Vec::with_capacity(s.len());
156    let mut fama_vals = Vec::with_capacity(s.len());
157
158    for i in 0..s.len() {
159        let val = s.get(i).unwrap_or(f64::NAN);
160        if val.is_nan() {
161            mama_vals.push(Some(f64::NAN));
162            fama_vals.push(Some(f64::NAN));
163        } else {
164            let (m, f) = indicator.next(val);
165            mama_vals.push(Some(m));
166            fama_vals.push(Some(f));
167        }
168    }
169
170    let s_mama = Float64Chunked::new("mama".into(), mama_vals).into_series();
171    let s_fama = Float64Chunked::new("fama".into(), fama_vals).into_series();
172
173    let out = StructChunked::from_series(
174        "mama_result".into(),
175        s.len(),
176        [s_mama, s_fama].iter(),
177    )?;
178
179    Ok(out.into_series())
180}
181
182// 5. atr_trailing_stop
183#[derive(Deserialize)]
184struct AtrTrailingStopKwargs {
185    period: usize,
186    multiplier: f64,
187}
188
189fn atr_trailing_stop_output(_: &[Field]) -> PolarsResult<Field> {
190    Ok(Field::new(
191        "atr_ts_output".into(),
192        DataType::Struct(vec![
193            Field::new("stop".into(), DataType::Float64),
194            Field::new("direction".into(), DataType::Float64),
195        ]),
196    ))
197}
198
199#[polars_expr(output_type_func=atr_trailing_stop_output)]
200fn atr_trailing_stop(inputs: &[Series], kwargs: AtrTrailingStopKwargs) -> PolarsResult<Series> {
201    let high = inputs[0].f64()?;
202    let low = inputs[1].f64()?;
203    let close = inputs[2].f64()?;
204
205    let mut atr_ts = quantwave_core::ATRTrailingStop::new(kwargs.period, kwargs.multiplier);
206    let mut stops = Vec::with_capacity(high.len());
207    let mut directions = Vec::with_capacity(high.len());
208
209    for i in 0..high.len() {
210        let h = high.get(i).unwrap_or(f64::NAN);
211        let l = low.get(i).unwrap_or(f64::NAN);
212        let c = close.get(i).unwrap_or(f64::NAN);
213
214        if h.is_nan() || l.is_nan() || c.is_nan() {
215            stops.push(Some(f64::NAN));
216            directions.push(Some(f64::NAN));
217        } else {
218            let (stop, dir) = atr_ts.next((h, l, c));
219            stops.push(Some(stop));
220            directions.push(Some(dir as f64));
221        }
222    }
223
224    let stop_series = Float64Chunked::new("stop".into(), stops).into_series();
225    let dir_series = Float64Chunked::new("direction".into(), directions).into_series();
226
227    let out = StructChunked::from_series(
228        "atr_ts_output".into(),
229        high.len(),
230        [stop_series, dir_series].iter(),
231    )?;
232
233    Ok(out.into_series())
234}