use polars::prelude::*;
use pyo3_polars::derive::polars_expr;
use serde::Deserialize;
use quantwave_core::*;
use quantwave_core::traits::Next;
#[derive(Deserialize)]
struct RegimesNextStateProbKwargs {
num_states: usize,
steps: usize,
}
pub fn regimes_next_state_prob_output(_: &[Field]) -> PolarsResult<Field> {
Ok(Field::new(
"next_state_probs".into(),
DataType::List(Box::new(DataType::Float64)),
))
}
#[polars_expr(output_type_func=regimes_next_state_prob_output)]
fn regimes_next_state_prob(inputs: &[Series], kwargs: RegimesNextStateProbKwargs) -> PolarsResult<Series> {
let s = &inputs[0];
let ca = s.u32()?;
let states: Vec<u32> = ca.into_iter().map(|v| v.unwrap_or(0)).collect();
let matrix = quantwave_core::RegimeAnalytics::transition_matrix(&states, kwargs.num_states);
let mut builders = ListPrimitiveChunkedBuilder::<Float64Type>::new(
"next_state_probs".into(),
s.len(),
s.len() * kwargs.num_states,
DataType::Float64,
);
for ¤t in &states {
let probs = quantwave_core::RegimeAnalytics::forecast_state(&matrix, current, kwargs.steps);
builders.append_slice(&probs);
}
let list_ca = builders.finish();
Ok(list_ca.into_series())
}
#[polars_expr(output_type=UInt32)]
fn hmm_bull_bear(inputs: &[Series]) -> PolarsResult<Series> {
let s = &inputs[0];
let ca = s.f64()?;
let mut hmm = quantwave_core::regimes::hmm::HMM::bull_bear();
let mut values = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(f64::NAN);
let regime = hmm.next(val);
let out = match regime {
quantwave_core::regimes::MarketRegime::Bull => 1u32,
quantwave_core::regimes::MarketRegime::Bear => 2,
_ => 0,
};
values.push(out);
}
Ok(Series::new("hmm_regime".into(), values))
}
#[derive(Deserialize)]
struct AlmaKwargs {
period: usize,
offset: f64,
sigma: f64,
}
#[polars_expr(output_type=Float64)]
fn alma(inputs: &[Series], kwargs: AlmaKwargs) -> PolarsResult<Series> {
let s = &inputs[0];
let ca = s.f64()?;
let mut alma = quantwave_core::ALMA::new(kwargs.period, kwargs.offset, kwargs.sigma);
let mut values = Vec::with_capacity(s.len());
for i in 0..s.len() {
let val = ca.get(i).unwrap_or(0.0);
values.push(alma.next(val));
}
Ok(Series::new("alma".into(), values))
}
#[polars_expr(output_type=Float64)]
fn regimes_stability_score(inputs: &[Series]) -> PolarsResult<Series> {
let s = &inputs[0];
let ca = s.u32()?;
let states: Vec<u32> = ca.into_iter().map(|v| v.unwrap_or(0)).collect();
let score = quantwave_core::RegimeAnalytics::stability_score(&states);
Ok(Series::new("stability_score".into(), vec![score; s.len()]))
}
#[derive(Deserialize)]
struct GeometricPatternsKwargs {
swing_strength: usize,
}
pub fn geometric_patterns_output(_: &[Field]) -> PolarsResult<Field> {
Ok(Field::new(
"geometric_patterns".into(),
DataType::Struct(vec![
Field::new("flag".into(), DataType::Struct(vec![
Field::new("id".into(), DataType::UInt32),
Field::new("is_bull".into(), DataType::Boolean),
Field::new("pole_length".into(), DataType::Float64),
Field::new("pole_length_atr".into(), DataType::Float64),
Field::new("breakout_confirmed".into(), DataType::Boolean),
Field::new("breakout_price".into(), DataType::Float64),
])),
Field::new("hs".into(), DataType::Struct(vec![
Field::new("id".into(), DataType::UInt32),
Field::new("is_bearish".into(), DataType::Boolean),
Field::new("height".into(), DataType::Float64),
Field::new("height_atr".into(), DataType::Float64),
Field::new("score".into(), DataType::Float64),
Field::new("breakout_confirmed".into(), DataType::Boolean),
])),
]),
))
}
#[polars_expr(output_type_func=geometric_patterns_output)]
fn geometric_patterns(inputs: &[Series], kwargs: GeometricPatternsKwargs) -> PolarsResult<Series> {
let highs = &inputs[0].f64()?;
let lows = &inputs[1].f64()?;
let n = highs.len();
let mut scanner = quantwave_core::GeometricPatternScanner::new(kwargs.swing_strength);
let mut flag_ids: Vec<u32> = Vec::with_capacity(n);
let mut flag_is_bull: Vec<bool> = Vec::with_capacity(n);
let mut flag_pole_len: Vec<f64> = Vec::with_capacity(n);
let mut flag_pole_atr: Vec<f64> = Vec::with_capacity(n);
let mut flag_breakout: Vec<bool> = Vec::with_capacity(n);
let mut flag_bp: Vec<f64> = Vec::with_capacity(n);
let mut hs_ids: Vec<u32> = Vec::with_capacity(n);
let mut hs_bear: Vec<bool> = Vec::with_capacity(n);
let mut hs_height: Vec<f64> = Vec::with_capacity(n);
let mut hs_height_atr: Vec<f64> = Vec::with_capacity(n);
let mut hs_score: Vec<f64> = Vec::with_capacity(n);
let mut hs_breakout: Vec<bool> = Vec::with_capacity(n);
for i in 0..n {
let h = highs.get(i).unwrap_or(f64::NAN);
let l = lows.get(i).unwrap_or(f64::NAN);
let hh = if h.is_nan() || l.is_nan() { f64::NAN } else { h.max(l) };
let ll = if h.is_nan() || l.is_nan() { f64::NAN } else { l.min(h) };
let (_state, flag, hs) = scanner.next((hh, ll));
if let Some(f) = flag {
flag_ids.push(f.id);
flag_is_bull.push(f.is_bull);
flag_pole_len.push(f.pole_length);
flag_pole_atr.push(f.pole_length_atr);
flag_breakout.push(f.breakout_confirmed);
flag_bp.push(f.breakout_price);
} else {
flag_ids.push(0);
flag_is_bull.push(false);
flag_pole_len.push(f64::NAN);
flag_pole_atr.push(f64::NAN);
flag_breakout.push(false);
flag_bp.push(f64::NAN);
}
if let Some(hp) = hs {
hs_ids.push(hp.id);
hs_bear.push(hp.is_bearish);
hs_height.push(hp.height);
hs_height_atr.push(hp.height_atr);
hs_score.push(hp.score);
hs_breakout.push(hp.breakout_confirmed);
} else {
hs_ids.push(0);
hs_bear.push(false);
hs_height.push(f64::NAN);
hs_height_atr.push(f64::NAN);
hs_score.push(f64::NAN);
hs_breakout.push(false);
}
}
let s_fid = Series::new("id".into(), flag_ids);
let s_fbull = Series::new("is_bull".into(), flag_is_bull);
let s_fplen = Series::new("pole_length".into(), flag_pole_len);
let s_fpatr = Series::new("pole_length_atr".into(), flag_pole_atr);
let s_fbo = Series::new("breakout_confirmed".into(), flag_breakout);
let s_fbp = Series::new("breakout_price".into(), flag_bp);
let flag_struct = StructChunked::from_series(
"flag".into(),
n,
[s_fid, s_fbull, s_fplen, s_fpatr, s_fbo, s_fbp].iter(),
)?;
let s_hid = Series::new("id".into(), hs_ids);
let s_hbear = Series::new("is_bearish".into(), hs_bear);
let s_hh = Series::new("height".into(), hs_height);
let s_hhatr = Series::new("height_atr".into(), hs_height_atr);
let s_hsc = Series::new("score".into(), hs_score);
let s_hbo = Series::new("breakout_confirmed".into(), hs_breakout);
let hs_struct = StructChunked::from_series(
"hs".into(),
n,
[s_hid, s_hbear, s_hh, s_hhatr, s_hsc, s_hbo].iter(),
)?;
let combined = StructChunked::from_series(
"geo_patterns".into(),
n,
[flag_struct.into_series(), hs_struct.into_series()].iter(),
)?;
Ok(combined.into_series())
}