1use polars::prelude::*;
2use pyo3_polars::derive::polars_expr;
3use serde::Deserialize;
4
5use quantwave_core::traits::Next;
6use quantwave_core::SAREXT;
7use quantwave_core::KeltnerChannels;
8use quantwave_core::RegimeAnalytics;
9use quantwave_core::MarketStructure;
10use quantwave_core::regimes::hmm_gas::HMMGAS;
11use quantwave_core::regimes::MarketRegime;
12use quantwave_core::Bias;
13
14#[derive(Deserialize)]
15pub struct SarExtKwargs {
16 startvalue: f64,
17 offsetonreverse: f64,
18 accelerationinitlong: f64,
19 accelerationlong: f64,
20 accelerationmaxlong: f64,
21 accelerationinitshort: f64,
22 accelerationshort: f64,
23 accelerationmaxshort: f64,
24}
25
26#[polars_expr(output_type=Float64)]
27fn sarext(inputs: &[Series], kwargs: SarExtKwargs) -> PolarsResult<Series> {
28 let high = inputs[0].f64()?;
29 let low = inputs[1].f64()?;
30
31 let mut indicator = SAREXT::new(
32 kwargs.startvalue,
33 kwargs.offsetonreverse,
34 kwargs.accelerationinitlong,
35 kwargs.accelerationlong,
36 kwargs.accelerationmaxlong,
37 kwargs.accelerationinitshort,
38 kwargs.accelerationshort,
39 kwargs.accelerationmaxshort,
40 );
41
42 let mut out_vec = Vec::with_capacity(high.len());
43
44 for (h, l) in high.into_iter().zip(low.into_iter()) {
45 match (h, l) {
46 (Some(hv), Some(lv)) if !hv.is_nan() && !lv.is_nan() => {
47 out_vec.push(Some(indicator.next((hv, lv))));
48 }
49 (Some(_), Some(_)) => out_vec.push(Some(f64::NAN)),
50 _ => out_vec.push(None),
51 }
52 }
53
54 let out = Float64Chunked::new("sarext".into(), out_vec);
55 Ok(out.into_series())
56}
57
58#[derive(Deserialize)]
59pub struct KeltnerChannelsKwargs {
60 ema_period: usize,
61 atr_period: usize,
62 multiplier: f64,
63}
64
65pub fn keltner_channels_output(_: &[Field]) -> PolarsResult<Field> {
66 Ok(Field::new(
67 "keltner_channels".into(),
68 DataType::Struct(vec![
69 Field::new("upper".into(), DataType::Float64),
70 Field::new("middle".into(), DataType::Float64),
71 Field::new("lower".into(), DataType::Float64),
72 ]),
73 ))
74}
75
76#[polars_expr(output_type_func=keltner_channels_output)]
77fn keltner_channels(inputs: &[Series], kwargs: KeltnerChannelsKwargs) -> PolarsResult<Series> {
78 let high = inputs[0].f64()?;
79 let low = inputs[1].f64()?;
80 let close = inputs[2].f64()?;
81
82 let mut kc = KeltnerChannels::new(
83 kwargs.ema_period, kwargs.atr_period, kwargs.multiplier,
84 );
85
86 let mut uppers = Vec::with_capacity(high.len());
87 let mut middles = Vec::with_capacity(high.len());
88 let mut lowers = Vec::with_capacity(high.len());
89
90 for i in 0..high.len() {
91 let h_opt = high.get(i);
92 let l_opt = low.get(i);
93 let c_opt = close.get(i);
94
95 match (h_opt, l_opt, c_opt) {
96 (Some(hv), Some(lv), Some(cv)) => {
97 let (upper, middle, lower) = kc.next((hv, lv, cv));
98 uppers.push(Some(upper));
99 middles.push(Some(middle));
100 lowers.push(Some(lower));
101 }
102 _ => {
103 uppers.push(None);
104 middles.push(None);
105 lowers.push(None);
106 }
107 }
108 }
109
110 let upper_series = Float64Chunked::new("upper".into(), uppers).into_series();
111 let middle_series = Float64Chunked::new("middle".into(), middles).into_series();
112 let lower_series = Float64Chunked::new("lower".into(), lowers).into_series();
113
114 let out = StructChunked::from_series(
115 "keltner_channels".into(),
116 high.len(),
117 [upper_series, middle_series, lower_series].iter(),
118 )?;
119
120 Ok(out.into_series())
121}
122
123#[derive(Deserialize)]
124pub struct RegimesDurationStatsKwargs {
125 num_states: usize,
126}
127
128pub fn regimes_duration_stats_output(_: &[Field]) -> PolarsResult<Field> {
129 Ok(Field::new(
130 "duration_stats".into(),
131 DataType::Struct(vec![
132 Field::new("regime_id".into(), DataType::UInt32),
133 Field::new("mean_duration".into(), DataType::Float64),
134 Field::new("median_duration".into(), DataType::Float64),
135 Field::new("std_duration".into(), DataType::Float64),
136 Field::new("max_duration".into(), DataType::UInt32),
137 Field::new("total_observations".into(), DataType::UInt32),
138 ]),
139 ))
140}
141
142#[polars_expr(output_type_func=regimes_duration_stats_output)]
143fn regimes_duration_stats(inputs: &[Series], kwargs: RegimesDurationStatsKwargs) -> PolarsResult<Series> {
144 let s = inputs[0].u32()?;
145 let states: Vec<u32> = s.into_iter().map(|v| v.unwrap_or(0)).collect();
146 let stats = RegimeAnalytics::duration_stats(&states, kwargs.num_states);
147
148 let mut regime_ids = Vec::with_capacity(stats.len());
149 let mut means = Vec::with_capacity(stats.len());
150 let mut medians = Vec::with_capacity(stats.len());
151 let mut stds = Vec::with_capacity(stats.len());
152 let mut maxes = Vec::with_capacity(stats.len());
153 let mut totals = Vec::with_capacity(stats.len());
154
155 for stat in stats {
156 regime_ids.push(Some(stat.regime_id));
157 means.push(Some(stat.mean_duration));
158 medians.push(Some(stat.median_duration));
159 stds.push(Some(stat.std_duration));
160 maxes.push(Some(stat.max_duration as u32));
161 totals.push(Some(stat.total_observations as u32));
162 }
163
164 let s_id = UInt32Chunked::new("regime_id".into(), regime_ids).into_series();
165 let s_mean = Float64Chunked::new("mean_duration".into(), means).into_series();
166 let s_median = Float64Chunked::new("median_duration".into(), medians).into_series();
167 let s_std = Float64Chunked::new("std_duration".into(), stds).into_series();
168 let s_max = UInt32Chunked::new("max_duration".into(), maxes).into_series();
169 let s_total = UInt32Chunked::new("total_observations".into(), totals).into_series();
170
171 let out = StructChunked::from_series(
172 "duration_stats".into(),
173 s_id.len(),
174 [s_id, s_mean, s_median, s_std, s_max, s_total].iter(),
175 )?;
176
177 Ok(out.into_series())
178}
179
180#[derive(Deserialize)]
181pub struct MarketStructureKwargs {
182 swing_strength: usize,
183}
184
185pub fn market_structure_output(_: &[Field]) -> PolarsResult<Field> {
186 Ok(Field::new(
187 "market_structure_result".into(),
188 DataType::Struct(vec![
189 Field::new("bias".into(), DataType::UInt32),
190 Field::new("last_high_price".into(), DataType::Float64),
191 Field::new("last_high_bar".into(), DataType::UInt64),
192 Field::new("last_low_price".into(), DataType::Float64),
193 Field::new("last_low_bar".into(), DataType::UInt64),
194 Field::new("has_flip".into(), DataType::Boolean),
195 Field::new("flip_bearish".into(), DataType::Boolean),
196 Field::new("flip_price".into(), DataType::Float64),
197 Field::new("flip_bar".into(), DataType::UInt64),
198 Field::new("flip_strength".into(), DataType::UInt32),
199 Field::new("swing_depth".into(), DataType::UInt32),
200 Field::new("bar_index".into(), DataType::UInt64),
201 ]),
202 ))
203}
204
205#[polars_expr(output_type_func=market_structure_output)]
206fn market_structure(inputs: &[Series], kwargs: MarketStructureKwargs) -> PolarsResult<Series> {
207 let highs = inputs[0].f64()?;
208 let lows = inputs[1].f64()?;
209
210 let mut ms = MarketStructure::new(kwargs.swing_strength);
211 let n = highs.len();
212
213 let mut bias_vals = Vec::with_capacity(n);
214 let mut lh_p = Vec::with_capacity(n);
215 let mut lh_b = Vec::with_capacity(n);
216 let mut ll_p = Vec::with_capacity(n);
217 let mut ll_b = Vec::with_capacity(n);
218 let mut has_f = Vec::with_capacity(n);
219 let mut f_bear = Vec::with_capacity(n);
220 let mut f_p = Vec::with_capacity(n);
221 let mut f_ba = Vec::with_capacity(n);
222 let mut f_str = Vec::with_capacity(n);
223 let mut depths = Vec::with_capacity(n);
224 let mut bars = Vec::with_capacity(n);
225
226 for i in 0..n {
227 let h = highs.get(i).unwrap_or(f64::NAN);
228 let l = lows.get(i).unwrap_or(f64::NAN);
229
230 let hh = if h.is_nan() || l.is_nan() { f64::NAN } else { h.max(l) };
231 let ll = if h.is_nan() || l.is_nan() { f64::NAN } else { l.min(h) };
232
233 let state = ms.next((hh, ll));
234
235 let b = match state.bias {
236 Bias::Neutral => 0u32,
237 Bias::Bullish => 1,
238 Bias::Bearish => 2,
239 };
240 bias_vals.push(Some(b));
241
242 match &state.last_swing_high {
243 Some(sh) => { lh_p.push(Some(sh.price)); lh_b.push(Some(sh.bar as u64)); }
244 None => { lh_p.push(Some(f64::NAN)); lh_b.push(Some(0)); }
245 }
246 match &state.last_swing_low {
247 Some(sl) => { ll_p.push(Some(sl.price)); ll_b.push(Some(sl.bar as u64)); }
248 None => { ll_p.push(Some(f64::NAN)); ll_b.push(Some(0)); }
249 }
250
251 if let Some(f) = &state.current_flip {
252 has_f.push(Some(true));
253 f_bear.push(Some(f.is_bearish));
254 f_p.push(Some(f.price));
255 f_ba.push(Some(f.bar as u64));
256 f_str.push(Some(f.structure_strength));
257 } else {
258 has_f.push(Some(false));
259 f_bear.push(Some(false));
260 f_p.push(Some(f64::NAN));
261 f_ba.push(Some(0));
262 f_str.push(Some(0));
263 }
264
265 depths.push(Some(state.swing_depth_used as u32));
266 bars.push(Some(state.bar_index as u64));
267 }
268
269 let s_bias = UInt32Chunked::new("bias".into(), bias_vals).into_series();
270 let s_lhp = Float64Chunked::new("last_high_price".into(), lh_p).into_series();
271 let s_lhb = UInt64Chunked::new("last_high_bar".into(), lh_b).into_series();
272 let s_llp = Float64Chunked::new("last_low_price".into(), ll_p).into_series();
273 let s_llb = UInt64Chunked::new("last_low_bar".into(), ll_b).into_series();
274 let s_hasf = BooleanChunked::new("has_flip".into(), has_f).into_series();
275 let s_fb = BooleanChunked::new("flip_bearish".into(), f_bear).into_series();
276 let s_fp = Float64Chunked::new("flip_price".into(), f_p).into_series();
277 let s_fba = UInt64Chunked::new("flip_bar".into(), f_ba).into_series();
278 let s_fstr = UInt32Chunked::new("flip_strength".into(), f_str).into_series();
279 let s_dep = UInt32Chunked::new("swing_depth".into(), depths).into_series();
280 let s_bar = UInt64Chunked::new("bar_index".into(), bars).into_series();
281
282 let out = StructChunked::from_series(
283 "market_structure_result".into(),
284 n,
285 [
286 s_bias, s_lhp, s_lhb, s_llp, s_llb, s_hasf, s_fb, s_fp, s_fba, s_fstr, s_dep, s_bar,
287 ].iter(),
288 )?;
289
290 Ok(out.into_series())
291}
292
293#[polars_expr(output_type=UInt32)]
294fn regimes_hmm_gas(inputs: &[Series]) -> PolarsResult<Series> {
295 let s = inputs[0].f64()?;
296
297 let mut model = HMMGAS::new(
298 [0.1, 0.05, 0.9], [0.1, 0.05, 0.9], [0.001, -0.002],
301 [0.01, 0.02],
302 );
303
304 let mut values = Vec::with_capacity(s.len());
305
306 for i in 0..s.len() {
307 let val = s.get(i).unwrap_or(f64::NAN);
308 let regime = model.next(val);
309 let out = match regime {
310 MarketRegime::Steady => 0u32,
311 MarketRegime::Crisis => 1,
312 _ => 2,
313 };
314 values.push(Some(out));
315 }
316
317 let out = UInt32Chunked::new("hmm_gas_regime".into(), values);
318 Ok(out.into_series())
319}