Skip to main content

quantwave_plugins/
custom_4.rs

1use polars::prelude::*;
2use pyo3_polars::derive::polars_expr;
3use serde::Deserialize;
4use quantwave_core::*;
5use quantwave_core::traits::Next;
6
7#[derive(Deserialize)]
8struct RegimesNextStateProbKwargs {
9    num_states: usize,
10    steps: usize,
11}
12
13pub fn regimes_next_state_prob_output(_: &[Field]) -> PolarsResult<Field> {
14    Ok(Field::new(
15        "next_state_probs".into(),
16        DataType::List(Box::new(DataType::Float64)),
17    ))
18}
19
20#[polars_expr(output_type_func=regimes_next_state_prob_output)]
21fn regimes_next_state_prob(inputs: &[Series], kwargs: RegimesNextStateProbKwargs) -> PolarsResult<Series> {
22    let s = &inputs[0];
23    let ca = s.u32()?;
24    let states: Vec<u32> = ca.into_iter().map(|v| v.unwrap_or(0)).collect();
25    let matrix = quantwave_core::RegimeAnalytics::transition_matrix(&states, kwargs.num_states);
26    
27    let mut builders = ListPrimitiveChunkedBuilder::<Float64Type>::new(
28        "next_state_probs".into(),
29        s.len(),
30        s.len() * kwargs.num_states,
31        DataType::Float64,
32    );
33
34    for &current in &states {
35        let probs = quantwave_core::RegimeAnalytics::forecast_state(&matrix, current, kwargs.steps);
36        builders.append_slice(&probs);
37    }
38    
39    let list_ca = builders.finish();
40    Ok(list_ca.into_series())
41}
42
43#[polars_expr(output_type=UInt32)]
44fn hmm_bull_bear(inputs: &[Series]) -> PolarsResult<Series> {
45    let s = &inputs[0];
46    let ca = s.f64()?;
47    let mut hmm = quantwave_core::regimes::hmm::HMM::bull_bear();
48    let mut values = Vec::with_capacity(s.len());
49
50    for i in 0..s.len() {
51        let val = ca.get(i).unwrap_or(f64::NAN);
52        let regime = hmm.next(val);
53        let out = match regime {
54            quantwave_core::regimes::MarketRegime::Bull => 1u32,
55            quantwave_core::regimes::MarketRegime::Bear => 2,
56            _ => 0,
57        };
58        values.push(out);
59    }
60
61    Ok(Series::new("hmm_regime".into(), values))
62}
63
64#[derive(Deserialize)]
65struct AlmaKwargs {
66    period: usize,
67    offset: f64,
68    sigma: f64,
69}
70
71#[polars_expr(output_type=Float64)]
72fn alma(inputs: &[Series], kwargs: AlmaKwargs) -> PolarsResult<Series> {
73    let s = &inputs[0];
74    let ca = s.f64()?;
75    let mut alma = quantwave_core::ALMA::new(kwargs.period, kwargs.offset, kwargs.sigma);
76    let mut values = Vec::with_capacity(s.len());
77
78    for i in 0..s.len() {
79        let val = ca.get(i).unwrap_or(0.0);
80        values.push(alma.next(val));
81    }
82
83    Ok(Series::new("alma".into(), values))
84}
85
86#[polars_expr(output_type=Float64)]
87fn regimes_stability_score(inputs: &[Series]) -> PolarsResult<Series> {
88    let s = &inputs[0];
89    let ca = s.u32()?;
90    let states: Vec<u32> = ca.into_iter().map(|v| v.unwrap_or(0)).collect();
91    let score = quantwave_core::RegimeAnalytics::stability_score(&states);
92    
93    Ok(Series::new("stability_score".into(), vec![score; s.len()]))
94}
95
96#[derive(Deserialize)]
97struct GeometricPatternsKwargs {
98    swing_strength: usize,
99}
100
101pub fn geometric_patterns_output(_: &[Field]) -> PolarsResult<Field> {
102    Ok(Field::new(
103        "geometric_patterns".into(),
104        DataType::Struct(vec![
105            Field::new("flag".into(), DataType::Struct(vec![
106                Field::new("id".into(), DataType::UInt32),
107                Field::new("is_bull".into(), DataType::Boolean),
108                Field::new("pole_length".into(), DataType::Float64),
109                Field::new("pole_length_atr".into(), DataType::Float64),
110                Field::new("breakout_confirmed".into(), DataType::Boolean),
111                Field::new("breakout_price".into(), DataType::Float64),
112            ])),
113            Field::new("hs".into(), DataType::Struct(vec![
114                Field::new("id".into(), DataType::UInt32),
115                Field::new("is_bearish".into(), DataType::Boolean),
116                Field::new("height".into(), DataType::Float64),
117                Field::new("height_atr".into(), DataType::Float64),
118                Field::new("score".into(), DataType::Float64),
119                Field::new("breakout_confirmed".into(), DataType::Boolean),
120            ])),
121        ]),
122    ))
123}
124
125#[polars_expr(output_type_func=geometric_patterns_output)]
126fn geometric_patterns(inputs: &[Series], kwargs: GeometricPatternsKwargs) -> PolarsResult<Series> {
127    let highs = &inputs[0].f64()?;
128    let lows = &inputs[1].f64()?;
129    
130    let n = highs.len();
131    let mut scanner = quantwave_core::GeometricPatternScanner::new(kwargs.swing_strength);
132
133    let mut flag_ids: Vec<u32> = Vec::with_capacity(n);
134    let mut flag_is_bull: Vec<bool> = Vec::with_capacity(n);
135    let mut flag_pole_len: Vec<f64> = Vec::with_capacity(n);
136    let mut flag_pole_atr: Vec<f64> = Vec::with_capacity(n);
137    let mut flag_breakout: Vec<bool> = Vec::with_capacity(n);
138    let mut flag_bp: Vec<f64> = Vec::with_capacity(n);
139
140    let mut hs_ids: Vec<u32> = Vec::with_capacity(n);
141    let mut hs_bear: Vec<bool> = Vec::with_capacity(n);
142    let mut hs_height: Vec<f64> = Vec::with_capacity(n);
143    let mut hs_height_atr: Vec<f64> = Vec::with_capacity(n);
144    let mut hs_score: Vec<f64> = Vec::with_capacity(n);
145    let mut hs_breakout: Vec<bool> = Vec::with_capacity(n);
146
147    for i in 0..n {
148        let h = highs.get(i).unwrap_or(f64::NAN);
149        let l = lows.get(i).unwrap_or(f64::NAN);
150        let hh = if h.is_nan() || l.is_nan() { f64::NAN } else { h.max(l) };
151        let ll = if h.is_nan() || l.is_nan() { f64::NAN } else { l.min(h) };
152        let (_state, flag, hs) = scanner.next((hh, ll));
153
154        if let Some(f) = flag {
155            flag_ids.push(f.id);
156            flag_is_bull.push(f.is_bull);
157            flag_pole_len.push(f.pole_length);
158            flag_pole_atr.push(f.pole_length_atr);
159            flag_breakout.push(f.breakout_confirmed);
160            flag_bp.push(f.breakout_price);
161        } else {
162            flag_ids.push(0);
163            flag_is_bull.push(false);
164            flag_pole_len.push(f64::NAN);
165            flag_pole_atr.push(f64::NAN);
166            flag_breakout.push(false);
167            flag_bp.push(f64::NAN);
168        }
169
170        if let Some(hp) = hs {
171            hs_ids.push(hp.id);
172            hs_bear.push(hp.is_bearish);
173            hs_height.push(hp.height);
174            hs_height_atr.push(hp.height_atr);
175            hs_score.push(hp.score);
176            hs_breakout.push(hp.breakout_confirmed);
177        } else {
178            hs_ids.push(0);
179            hs_bear.push(false);
180            hs_height.push(f64::NAN);
181            hs_height_atr.push(f64::NAN);
182            hs_score.push(f64::NAN);
183            hs_breakout.push(false);
184        }
185    }
186
187    let s_fid = Series::new("id".into(), flag_ids);
188    let s_fbull = Series::new("is_bull".into(), flag_is_bull);
189    let s_fplen = Series::new("pole_length".into(), flag_pole_len);
190    let s_fpatr = Series::new("pole_length_atr".into(), flag_pole_atr);
191    let s_fbo = Series::new("breakout_confirmed".into(), flag_breakout);
192    let s_fbp = Series::new("breakout_price".into(), flag_bp);
193
194    let flag_struct = StructChunked::from_series(
195        "flag".into(),
196        n,
197        [s_fid, s_fbull, s_fplen, s_fpatr, s_fbo, s_fbp].iter(),
198    )?;
199
200    let s_hid = Series::new("id".into(), hs_ids);
201    let s_hbear = Series::new("is_bearish".into(), hs_bear);
202    let s_hh = Series::new("height".into(), hs_height);
203    let s_hhatr = Series::new("height_atr".into(), hs_height_atr);
204    let s_hsc = Series::new("score".into(), hs_score);
205    let s_hbo = Series::new("breakout_confirmed".into(), hs_breakout);
206
207    let hs_struct = StructChunked::from_series(
208        "hs".into(),
209        n,
210        [s_hid, s_hbear, s_hh, s_hhatr, s_hsc, s_hbo].iter(),
211    )?;
212
213    let combined = StructChunked::from_series(
214        "geo_patterns".into(),
215        n,
216        [flag_struct.into_series(), hs_struct.into_series()].iter(),
217    )?;
218    Ok(combined.into_series())
219}