Skip to main content

quantwave_plugins/
custom_7.rs

1use polars::prelude::*;
2use pyo3_polars::derive::polars_expr;
3use serde::Deserialize;
4use quantwave_core::*;
5use quantwave_core::traits::Next;
6
7#[derive(Deserialize)]
8pub struct AdaptiveEmaKwargs {
9    pub period: usize,
10    pub pds: usize,
11}
12
13#[derive(Deserialize)]
14pub struct RegimesTransitionMatrixKwargs {
15    pub num_states: usize,
16}
17
18#[derive(Deserialize)]
19pub struct VpnKwargs {
20    pub period: usize,
21    pub smooth_period: usize,
22}
23
24fn ms_garch_output(_: &[Field]) -> PolarsResult<Field> {
25    Ok(Field::new(
26        "ms_garch_data".into(),
27        DataType::Struct(vec![
28            Field::new("regime".into(), DataType::UInt32),
29            Field::new("estimated_vol".into(), DataType::Float64),
30        ]),
31    ))
32}
33
34#[polars_expr(output_type_func=ms_garch_output)]
35fn regimes_ms_garch(inputs: &[Series]) -> PolarsResult<Series> {
36    let s = &inputs[0];
37    let ca = s.f64()?;
38    let mut model = quantwave_core::regimes::ms_garch::MSGarch::low_high_vol();
39    let mut regimes = Vec::with_capacity(s.len());
40    let mut vols = Vec::with_capacity(s.len());
41
42    for i in 0..s.len() {
43        let ret = ca.get(i).unwrap_or(0.0);
44        let (regime, vol) = model.next(ret);
45        
46        let r_val = match regime {
47            quantwave_core::regimes::MarketRegime::Steady => 0u32,
48            quantwave_core::regimes::MarketRegime::Crisis => 1,
49            quantwave_core::regimes::MarketRegime::Bull => 2,
50            quantwave_core::regimes::MarketRegime::Bear => 3,
51            quantwave_core::regimes::MarketRegime::Cluster(c) => 4 + (c as u32),
52        };
53        regimes.push(r_val);
54        vols.push(vol);
55    }
56
57    let s_regime = Series::new("regime".into(), regimes);
58    let s_vol = Series::new("estimated_vol".into(), vols);
59    let struct_series = StructChunked::from_series(
60        "ms_garch_data".into(),
61        s.len(),
62        [s_regime, s_vol].iter(),
63    )?;
64    Ok(struct_series.into_series())
65}
66
67fn adaptive_ema_output(_: &[Field]) -> PolarsResult<Field> {
68    Ok(Field::new("adaptive_ema".into(), DataType::Float64))
69}
70
71#[polars_expr(output_type_func=adaptive_ema_output)]
72fn adaptive_ema(inputs: &[Series], kwargs: AdaptiveEmaKwargs) -> PolarsResult<Series> {
73    let high = inputs[0].f64()?;
74    let low = inputs[1].f64()?;
75    let close = inputs[2].f64()?;
76
77    let mut indicator = quantwave_core::AdaptiveEMA::new(kwargs.period, kwargs.pds);
78    let mut values = Vec::with_capacity(high.len());
79
80    for i in 0..high.len() {
81        let h = high.get(i).unwrap_or(f64::NAN);
82        let l = low.get(i).unwrap_or(f64::NAN);
83        let c = close.get(i).unwrap_or(f64::NAN);
84        values.push(indicator.next((h, l, c)));
85    }
86
87    Ok(Series::new("adaptive_ema".into(), values))
88}
89
90fn regimes_transition_matrix_output(_: &[Field]) -> PolarsResult<Field> {
91    Ok(Field::new(
92        "regime_transition_matrix".into(),
93        DataType::List(Box::new(DataType::Float64)),
94    ))
95}
96
97#[polars_expr(output_type_func=regimes_transition_matrix_output)]
98fn regimes_transition_matrix(inputs: &[Series], kwargs: RegimesTransitionMatrixKwargs) -> PolarsResult<Series> {
99    let s = &inputs[0];
100    let ca = s.u32()?;
101    let states: Vec<u32> = ca.into_iter().map(|v| v.unwrap_or(0)).collect();
102    let matrix = quantwave_core::RegimeAnalytics::transition_matrix(&states, kwargs.num_states);
103    
104    let mut builders = ListPrimitiveChunkedBuilder::<Float64Type>::new(
105        "transition_matrix".into(),
106        matrix.len(),
107        matrix.len() * kwargs.num_states,
108        DataType::Float64,
109    );
110    for row in matrix {
111        builders.append_slice(&row);
112    }
113    
114    let list_ca = builders.finish();
115    Ok(list_ca.into_series())
116}
117
118fn vpn_output(_: &[Field]) -> PolarsResult<Field> {
119    Ok(Field::new("vpn".into(), DataType::Float64))
120}
121
122#[polars_expr(output_type_func=vpn_output)]
123fn vpn(inputs: &[Series], kwargs: VpnKwargs) -> PolarsResult<Series> {
124    let high = inputs[0].f64()?;
125    let low = inputs[1].f64()?;
126    let close = inputs[2].f64()?;
127    let volume = inputs[3].f64()?;
128
129    let mut indicator = quantwave_core::VPNIndicator::new(kwargs.period, kwargs.smooth_period);
130    let mut values = Vec::with_capacity(high.len());
131
132    for i in 0..high.len() {
133        let h = high.get(i).unwrap_or(f64::NAN);
134        let l = low.get(i).unwrap_or(f64::NAN);
135        let c = close.get(i).unwrap_or(f64::NAN);
136        let v = volume.get(i).unwrap_or(f64::NAN);
137        values.push(indicator.next((h, l, c, v)));
138    }
139
140    Ok(Series::new("vpn".into(), values))
141}
142
143fn regimes_hsmm_output(_: &[Field]) -> PolarsResult<Field> {
144    Ok(Field::new("hsmm_regime".into(), DataType::UInt32))
145}
146
147#[polars_expr(output_type_func=regimes_hsmm_output)]
148fn regimes_hsmm(inputs: &[Series]) -> PolarsResult<Series> {
149    let s = &inputs[0];
150    let ca = s.f64()?;
151    // Default 2-state HSMM: Poisson durations (5 days Bull, 2 days Bear)
152    let mut model = quantwave_core::regimes::hsmm::HSMM::new(
153        vec![vec![0.0, 1.0], vec![1.0, 0.0]], // Always switch
154        vec![0.001, -0.002],
155        vec![0.01, 0.02],
156        vec![
157            quantwave_core::regimes::hsmm::DurationDistribution::Poisson { lambda: 5.0 },
158            quantwave_core::regimes::hsmm::DurationDistribution::Poisson { lambda: 2.0 },
159        ],
160    );
161    let mut values = Vec::with_capacity(s.len());
162
163    for i in 0..s.len() {
164        let val = ca.get(i).unwrap_or(f64::NAN);
165        let regime = model.next(val);
166        let out = match regime {
167            quantwave_core::regimes::MarketRegime::Steady => 0u32,
168            quantwave_core::regimes::MarketRegime::Crisis => 1,
169            _ => 2, // Map others
170        };
171        values.push(out);
172    }
173
174    Ok(Series::new("hsmm_regime".into(), values))
175}