quantwave_plugins/
custom_7.rs1use polars::prelude::*;
2use pyo3_polars::derive::polars_expr;
3use serde::Deserialize;
4use quantwave_core::*;
5use quantwave_core::traits::Next;
6
7#[derive(Deserialize)]
8pub struct AdaptiveEmaKwargs {
9 pub period: usize,
10 pub pds: usize,
11}
12
13#[derive(Deserialize)]
14pub struct RegimesTransitionMatrixKwargs {
15 pub num_states: usize,
16}
17
18#[derive(Deserialize)]
19pub struct VpnKwargs {
20 pub period: usize,
21 pub smooth_period: usize,
22}
23
24fn ms_garch_output(_: &[Field]) -> PolarsResult<Field> {
25 Ok(Field::new(
26 "ms_garch_data".into(),
27 DataType::Struct(vec![
28 Field::new("regime".into(), DataType::UInt32),
29 Field::new("estimated_vol".into(), DataType::Float64),
30 ]),
31 ))
32}
33
34#[polars_expr(output_type_func=ms_garch_output)]
35fn regimes_ms_garch(inputs: &[Series]) -> PolarsResult<Series> {
36 let s = &inputs[0];
37 let ca = s.f64()?;
38 let mut model = quantwave_core::regimes::ms_garch::MSGarch::low_high_vol();
39 let mut regimes = Vec::with_capacity(s.len());
40 let mut vols = Vec::with_capacity(s.len());
41
42 for i in 0..s.len() {
43 let ret = ca.get(i).unwrap_or(0.0);
44 let (regime, vol) = model.next(ret);
45
46 let r_val = match regime {
47 quantwave_core::regimes::MarketRegime::Steady => 0u32,
48 quantwave_core::regimes::MarketRegime::Crisis => 1,
49 quantwave_core::regimes::MarketRegime::Bull => 2,
50 quantwave_core::regimes::MarketRegime::Bear => 3,
51 quantwave_core::regimes::MarketRegime::Cluster(c) => 4 + (c as u32),
52 };
53 regimes.push(r_val);
54 vols.push(vol);
55 }
56
57 let s_regime = Series::new("regime".into(), regimes);
58 let s_vol = Series::new("estimated_vol".into(), vols);
59 let struct_series = StructChunked::from_series(
60 "ms_garch_data".into(),
61 s.len(),
62 [s_regime, s_vol].iter(),
63 )?;
64 Ok(struct_series.into_series())
65}
66
67fn adaptive_ema_output(_: &[Field]) -> PolarsResult<Field> {
68 Ok(Field::new("adaptive_ema".into(), DataType::Float64))
69}
70
71#[polars_expr(output_type_func=adaptive_ema_output)]
72fn adaptive_ema(inputs: &[Series], kwargs: AdaptiveEmaKwargs) -> PolarsResult<Series> {
73 let high = inputs[0].f64()?;
74 let low = inputs[1].f64()?;
75 let close = inputs[2].f64()?;
76
77 let mut indicator = quantwave_core::AdaptiveEMA::new(kwargs.period, kwargs.pds);
78 let mut values = Vec::with_capacity(high.len());
79
80 for i in 0..high.len() {
81 let h = high.get(i).unwrap_or(f64::NAN);
82 let l = low.get(i).unwrap_or(f64::NAN);
83 let c = close.get(i).unwrap_or(f64::NAN);
84 values.push(indicator.next((h, l, c)));
85 }
86
87 Ok(Series::new("adaptive_ema".into(), values))
88}
89
90fn regimes_transition_matrix_output(_: &[Field]) -> PolarsResult<Field> {
91 Ok(Field::new(
92 "regime_transition_matrix".into(),
93 DataType::List(Box::new(DataType::Float64)),
94 ))
95}
96
97#[polars_expr(output_type_func=regimes_transition_matrix_output)]
98fn regimes_transition_matrix(inputs: &[Series], kwargs: RegimesTransitionMatrixKwargs) -> PolarsResult<Series> {
99 let s = &inputs[0];
100 let ca = s.u32()?;
101 let states: Vec<u32> = ca.into_iter().map(|v| v.unwrap_or(0)).collect();
102 let matrix = quantwave_core::RegimeAnalytics::transition_matrix(&states, kwargs.num_states);
103
104 let mut builders = ListPrimitiveChunkedBuilder::<Float64Type>::new(
105 "transition_matrix".into(),
106 matrix.len(),
107 matrix.len() * kwargs.num_states,
108 DataType::Float64,
109 );
110 for row in matrix {
111 builders.append_slice(&row);
112 }
113
114 let list_ca = builders.finish();
115 Ok(list_ca.into_series())
116}
117
118fn vpn_output(_: &[Field]) -> PolarsResult<Field> {
119 Ok(Field::new("vpn".into(), DataType::Float64))
120}
121
122#[polars_expr(output_type_func=vpn_output)]
123fn vpn(inputs: &[Series], kwargs: VpnKwargs) -> PolarsResult<Series> {
124 let high = inputs[0].f64()?;
125 let low = inputs[1].f64()?;
126 let close = inputs[2].f64()?;
127 let volume = inputs[3].f64()?;
128
129 let mut indicator = quantwave_core::VPNIndicator::new(kwargs.period, kwargs.smooth_period);
130 let mut values = Vec::with_capacity(high.len());
131
132 for i in 0..high.len() {
133 let h = high.get(i).unwrap_or(f64::NAN);
134 let l = low.get(i).unwrap_or(f64::NAN);
135 let c = close.get(i).unwrap_or(f64::NAN);
136 let v = volume.get(i).unwrap_or(f64::NAN);
137 values.push(indicator.next((h, l, c, v)));
138 }
139
140 Ok(Series::new("vpn".into(), values))
141}
142
143fn regimes_hsmm_output(_: &[Field]) -> PolarsResult<Field> {
144 Ok(Field::new("hsmm_regime".into(), DataType::UInt32))
145}
146
147#[polars_expr(output_type_func=regimes_hsmm_output)]
148fn regimes_hsmm(inputs: &[Series]) -> PolarsResult<Series> {
149 let s = &inputs[0];
150 let ca = s.f64()?;
151 let mut model = quantwave_core::regimes::hsmm::HSMM::new(
153 vec![vec![0.0, 1.0], vec![1.0, 0.0]], vec![0.001, -0.002],
155 vec![0.01, 0.02],
156 vec![
157 quantwave_core::regimes::hsmm::DurationDistribution::Poisson { lambda: 5.0 },
158 quantwave_core::regimes::hsmm::DurationDistribution::Poisson { lambda: 2.0 },
159 ],
160 );
161 let mut values = Vec::with_capacity(s.len());
162
163 for i in 0..s.len() {
164 let val = ca.get(i).unwrap_or(f64::NAN);
165 let regime = model.next(val);
166 let out = match regime {
167 quantwave_core::regimes::MarketRegime::Steady => 0u32,
168 quantwave_core::regimes::MarketRegime::Crisis => 1,
169 _ => 2, };
171 values.push(out);
172 }
173
174 Ok(Series::new("hsmm_regime".into(), values))
175}