1use polars::prelude::*;
2use pyo3_polars::derive::polars_expr;
3use serde::Deserialize;
4use quantwave_core::*;
5use quantwave_core::traits::Next;
6
7#[derive(Deserialize)]
8struct RegimesNextStateProbKwargs {
9 num_states: usize,
10 steps: usize,
11}
12
13pub fn regimes_next_state_prob_output(_: &[Field]) -> PolarsResult<Field> {
14 Ok(Field::new(
15 "next_state_probs".into(),
16 DataType::List(Box::new(DataType::Float64)),
17 ))
18}
19
20#[polars_expr(output_type_func=regimes_next_state_prob_output)]
21fn regimes_next_state_prob(inputs: &[Series], kwargs: RegimesNextStateProbKwargs) -> PolarsResult<Series> {
22 let s = &inputs[0];
23 let ca = s.u32()?;
24 let states: Vec<u32> = ca.into_iter().map(|v| v.unwrap_or(0)).collect();
25 let matrix = quantwave_core::RegimeAnalytics::transition_matrix(&states, kwargs.num_states);
26
27 let mut builders = ListPrimitiveChunkedBuilder::<Float64Type>::new(
28 "next_state_probs".into(),
29 s.len(),
30 s.len() * kwargs.num_states,
31 DataType::Float64,
32 );
33
34 for ¤t in &states {
35 let probs = quantwave_core::RegimeAnalytics::forecast_state(&matrix, current, kwargs.steps);
36 builders.append_slice(&probs);
37 }
38
39 let list_ca = builders.finish();
40 Ok(list_ca.into_series())
41}
42
43#[polars_expr(output_type=UInt32)]
44fn hmm_bull_bear(inputs: &[Series]) -> PolarsResult<Series> {
45 let s = &inputs[0];
46 let ca = s.f64()?;
47 let mut hmm = quantwave_core::regimes::hmm::HMM::bull_bear();
48 let mut values = Vec::with_capacity(s.len());
49
50 for i in 0..s.len() {
51 let val = ca.get(i).unwrap_or(f64::NAN);
52 let regime = hmm.next(val);
53 let out = match regime {
54 quantwave_core::regimes::MarketRegime::Bull => 1u32,
55 quantwave_core::regimes::MarketRegime::Bear => 2,
56 _ => 0,
57 };
58 values.push(out);
59 }
60
61 Ok(Series::new("hmm_regime".into(), values))
62}
63
64#[derive(Deserialize)]
65struct AlmaKwargs {
66 period: usize,
67 offset: f64,
68 sigma: f64,
69}
70
71#[polars_expr(output_type=Float64)]
72fn alma(inputs: &[Series], kwargs: AlmaKwargs) -> PolarsResult<Series> {
73 let s = &inputs[0];
74 let ca = s.f64()?;
75 let mut alma = quantwave_core::ALMA::new(kwargs.period, kwargs.offset, kwargs.sigma);
76 let mut values = Vec::with_capacity(s.len());
77
78 for i in 0..s.len() {
79 let val = ca.get(i).unwrap_or(0.0);
80 values.push(alma.next(val));
81 }
82
83 Ok(Series::new("alma".into(), values))
84}
85
86#[polars_expr(output_type=Float64)]
87fn regimes_stability_score(inputs: &[Series]) -> PolarsResult<Series> {
88 let s = &inputs[0];
89 let ca = s.u32()?;
90 let states: Vec<u32> = ca.into_iter().map(|v| v.unwrap_or(0)).collect();
91 let score = quantwave_core::RegimeAnalytics::stability_score(&states);
92
93 Ok(Series::new("stability_score".into(), vec![score; s.len()]))
94}
95
96#[derive(Deserialize)]
97struct GeometricPatternsKwargs {
98 swing_strength: usize,
99}
100
101pub fn geometric_patterns_output(_: &[Field]) -> PolarsResult<Field> {
102 Ok(Field::new(
103 "geometric_patterns".into(),
104 DataType::Struct(vec![
105 Field::new("flag".into(), DataType::Struct(vec![
106 Field::new("id".into(), DataType::UInt32),
107 Field::new("is_bull".into(), DataType::Boolean),
108 Field::new("pole_length".into(), DataType::Float64),
109 Field::new("pole_length_atr".into(), DataType::Float64),
110 Field::new("breakout_confirmed".into(), DataType::Boolean),
111 Field::new("breakout_price".into(), DataType::Float64),
112 ])),
113 Field::new("hs".into(), DataType::Struct(vec![
114 Field::new("id".into(), DataType::UInt32),
115 Field::new("is_bearish".into(), DataType::Boolean),
116 Field::new("height".into(), DataType::Float64),
117 Field::new("height_atr".into(), DataType::Float64),
118 Field::new("score".into(), DataType::Float64),
119 Field::new("breakout_confirmed".into(), DataType::Boolean),
120 ])),
121 ]),
122 ))
123}
124
125#[polars_expr(output_type_func=geometric_patterns_output)]
126fn geometric_patterns(inputs: &[Series], kwargs: GeometricPatternsKwargs) -> PolarsResult<Series> {
127 let highs = &inputs[0].f64()?;
128 let lows = &inputs[1].f64()?;
129
130 let n = highs.len();
131 let mut scanner = quantwave_core::GeometricPatternScanner::new(kwargs.swing_strength);
132
133 let mut flag_ids: Vec<u32> = Vec::with_capacity(n);
134 let mut flag_is_bull: Vec<bool> = Vec::with_capacity(n);
135 let mut flag_pole_len: Vec<f64> = Vec::with_capacity(n);
136 let mut flag_pole_atr: Vec<f64> = Vec::with_capacity(n);
137 let mut flag_breakout: Vec<bool> = Vec::with_capacity(n);
138 let mut flag_bp: Vec<f64> = Vec::with_capacity(n);
139
140 let mut hs_ids: Vec<u32> = Vec::with_capacity(n);
141 let mut hs_bear: Vec<bool> = Vec::with_capacity(n);
142 let mut hs_height: Vec<f64> = Vec::with_capacity(n);
143 let mut hs_height_atr: Vec<f64> = Vec::with_capacity(n);
144 let mut hs_score: Vec<f64> = Vec::with_capacity(n);
145 let mut hs_breakout: Vec<bool> = Vec::with_capacity(n);
146
147 for i in 0..n {
148 let h = highs.get(i).unwrap_or(f64::NAN);
149 let l = lows.get(i).unwrap_or(f64::NAN);
150 let hh = if h.is_nan() || l.is_nan() { f64::NAN } else { h.max(l) };
151 let ll = if h.is_nan() || l.is_nan() { f64::NAN } else { l.min(h) };
152 let (_state, flag, hs) = scanner.next((hh, ll));
153
154 if let Some(f) = flag {
155 flag_ids.push(f.id);
156 flag_is_bull.push(f.is_bull);
157 flag_pole_len.push(f.pole_length);
158 flag_pole_atr.push(f.pole_length_atr);
159 flag_breakout.push(f.breakout_confirmed);
160 flag_bp.push(f.breakout_price);
161 } else {
162 flag_ids.push(0);
163 flag_is_bull.push(false);
164 flag_pole_len.push(f64::NAN);
165 flag_pole_atr.push(f64::NAN);
166 flag_breakout.push(false);
167 flag_bp.push(f64::NAN);
168 }
169
170 if let Some(hp) = hs {
171 hs_ids.push(hp.id);
172 hs_bear.push(hp.is_bearish);
173 hs_height.push(hp.height);
174 hs_height_atr.push(hp.height_atr);
175 hs_score.push(hp.score);
176 hs_breakout.push(hp.breakout_confirmed);
177 } else {
178 hs_ids.push(0);
179 hs_bear.push(false);
180 hs_height.push(f64::NAN);
181 hs_height_atr.push(f64::NAN);
182 hs_score.push(f64::NAN);
183 hs_breakout.push(false);
184 }
185 }
186
187 let s_fid = Series::new("id".into(), flag_ids);
188 let s_fbull = Series::new("is_bull".into(), flag_is_bull);
189 let s_fplen = Series::new("pole_length".into(), flag_pole_len);
190 let s_fpatr = Series::new("pole_length_atr".into(), flag_pole_atr);
191 let s_fbo = Series::new("breakout_confirmed".into(), flag_breakout);
192 let s_fbp = Series::new("breakout_price".into(), flag_bp);
193
194 let flag_struct = StructChunked::from_series(
195 "flag".into(),
196 n,
197 [s_fid, s_fbull, s_fplen, s_fpatr, s_fbo, s_fbp].iter(),
198 )?;
199
200 let s_hid = Series::new("id".into(), hs_ids);
201 let s_hbear = Series::new("is_bearish".into(), hs_bear);
202 let s_hh = Series::new("height".into(), hs_height);
203 let s_hhatr = Series::new("height_atr".into(), hs_height_atr);
204 let s_hsc = Series::new("score".into(), hs_score);
205 let s_hbo = Series::new("breakout_confirmed".into(), hs_breakout);
206
207 let hs_struct = StructChunked::from_series(
208 "hs".into(),
209 n,
210 [s_hid, s_hbear, s_hh, s_hhatr, s_hsc, s_hbo].iter(),
211 )?;
212
213 let combined = StructChunked::from_series(
214 "geo_patterns".into(),
215 n,
216 [flag_struct.into_series(), hs_struct.into_series()].iter(),
217 )?;
218 Ok(combined.into_series())
219}