quantwave-plugins 0.6.0

Polars expression plugins for quantwave
Documentation
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;
use serde::Deserialize;
use quantwave_core::*;
use quantwave_core::traits::Next;

#[derive(Deserialize)]
struct RegimesNextStateProbKwargs {
    num_states: usize,
    steps: usize,
}

pub fn regimes_next_state_prob_output(_: &[Field]) -> PolarsResult<Field> {
    Ok(Field::new(
        "next_state_probs".into(),
        DataType::List(Box::new(DataType::Float64)),
    ))
}

#[polars_expr(output_type_func=regimes_next_state_prob_output)]
fn regimes_next_state_prob(inputs: &[Series], kwargs: RegimesNextStateProbKwargs) -> PolarsResult<Series> {
    let s = &inputs[0];
    let ca = s.u32()?;
    let states: Vec<u32> = ca.into_iter().map(|v| v.unwrap_or(0)).collect();
    let matrix = quantwave_core::RegimeAnalytics::transition_matrix(&states, kwargs.num_states);
    
    let mut builders = ListPrimitiveChunkedBuilder::<Float64Type>::new(
        "next_state_probs".into(),
        s.len(),
        s.len() * kwargs.num_states,
        DataType::Float64,
    );

    for &current in &states {
        let probs = quantwave_core::RegimeAnalytics::forecast_state(&matrix, current, kwargs.steps);
        builders.append_slice(&probs);
    }
    
    let list_ca = builders.finish();
    Ok(list_ca.into_series())
}

#[polars_expr(output_type=UInt32)]
fn hmm_bull_bear(inputs: &[Series]) -> PolarsResult<Series> {
    let s = &inputs[0];
    let ca = s.f64()?;
    let mut hmm = quantwave_core::regimes::hmm::HMM::bull_bear();
    let mut values = Vec::with_capacity(s.len());

    for i in 0..s.len() {
        let val = ca.get(i).unwrap_or(f64::NAN);
        let regime = hmm.next(val);
        let out = match regime {
            quantwave_core::regimes::MarketRegime::Bull => 1u32,
            quantwave_core::regimes::MarketRegime::Bear => 2,
            _ => 0,
        };
        values.push(out);
    }

    Ok(Series::new("hmm_regime".into(), values))
}

#[derive(Deserialize)]
struct AlmaKwargs {
    period: usize,
    offset: f64,
    sigma: f64,
}

#[polars_expr(output_type=Float64)]
fn alma(inputs: &[Series], kwargs: AlmaKwargs) -> PolarsResult<Series> {
    let s = &inputs[0];
    let ca = s.f64()?;
    let mut alma = quantwave_core::ALMA::new(kwargs.period, kwargs.offset, kwargs.sigma);
    let mut values = Vec::with_capacity(s.len());

    for i in 0..s.len() {
        let val = ca.get(i).unwrap_or(0.0);
        values.push(alma.next(val));
    }

    Ok(Series::new("alma".into(), values))
}

#[polars_expr(output_type=Float64)]
fn regimes_stability_score(inputs: &[Series]) -> PolarsResult<Series> {
    let s = &inputs[0];
    let ca = s.u32()?;
    let states: Vec<u32> = ca.into_iter().map(|v| v.unwrap_or(0)).collect();
    let score = quantwave_core::RegimeAnalytics::stability_score(&states);
    
    Ok(Series::new("stability_score".into(), vec![score; s.len()]))
}

#[derive(Deserialize)]
struct GeometricPatternsKwargs {
    swing_strength: usize,
}

pub fn geometric_patterns_output(_: &[Field]) -> PolarsResult<Field> {
    Ok(Field::new(
        "geometric_patterns".into(),
        DataType::Struct(vec![
            Field::new("flag".into(), DataType::Struct(vec![
                Field::new("id".into(), DataType::UInt32),
                Field::new("is_bull".into(), DataType::Boolean),
                Field::new("pole_length".into(), DataType::Float64),
                Field::new("pole_length_atr".into(), DataType::Float64),
                Field::new("breakout_confirmed".into(), DataType::Boolean),
                Field::new("breakout_price".into(), DataType::Float64),
            ])),
            Field::new("hs".into(), DataType::Struct(vec![
                Field::new("id".into(), DataType::UInt32),
                Field::new("is_bearish".into(), DataType::Boolean),
                Field::new("height".into(), DataType::Float64),
                Field::new("height_atr".into(), DataType::Float64),
                Field::new("score".into(), DataType::Float64),
                Field::new("breakout_confirmed".into(), DataType::Boolean),
            ])),
        ]),
    ))
}

#[polars_expr(output_type_func=geometric_patterns_output)]
fn geometric_patterns(inputs: &[Series], kwargs: GeometricPatternsKwargs) -> PolarsResult<Series> {
    let highs = &inputs[0].f64()?;
    let lows = &inputs[1].f64()?;
    
    let n = highs.len();
    let mut scanner = quantwave_core::GeometricPatternScanner::new(kwargs.swing_strength);

    let mut flag_ids: Vec<u32> = Vec::with_capacity(n);
    let mut flag_is_bull: Vec<bool> = Vec::with_capacity(n);
    let mut flag_pole_len: Vec<f64> = Vec::with_capacity(n);
    let mut flag_pole_atr: Vec<f64> = Vec::with_capacity(n);
    let mut flag_breakout: Vec<bool> = Vec::with_capacity(n);
    let mut flag_bp: Vec<f64> = Vec::with_capacity(n);

    let mut hs_ids: Vec<u32> = Vec::with_capacity(n);
    let mut hs_bear: Vec<bool> = Vec::with_capacity(n);
    let mut hs_height: Vec<f64> = Vec::with_capacity(n);
    let mut hs_height_atr: Vec<f64> = Vec::with_capacity(n);
    let mut hs_score: Vec<f64> = Vec::with_capacity(n);
    let mut hs_breakout: Vec<bool> = Vec::with_capacity(n);

    for i in 0..n {
        let h = highs.get(i).unwrap_or(f64::NAN);
        let l = lows.get(i).unwrap_or(f64::NAN);
        let hh = if h.is_nan() || l.is_nan() { f64::NAN } else { h.max(l) };
        let ll = if h.is_nan() || l.is_nan() { f64::NAN } else { l.min(h) };
        let (_state, flag, hs) = scanner.next((hh, ll));

        if let Some(f) = flag {
            flag_ids.push(f.id);
            flag_is_bull.push(f.is_bull);
            flag_pole_len.push(f.pole_length);
            flag_pole_atr.push(f.pole_length_atr);
            flag_breakout.push(f.breakout_confirmed);
            flag_bp.push(f.breakout_price);
        } else {
            flag_ids.push(0);
            flag_is_bull.push(false);
            flag_pole_len.push(f64::NAN);
            flag_pole_atr.push(f64::NAN);
            flag_breakout.push(false);
            flag_bp.push(f64::NAN);
        }

        if let Some(hp) = hs {
            hs_ids.push(hp.id);
            hs_bear.push(hp.is_bearish);
            hs_height.push(hp.height);
            hs_height_atr.push(hp.height_atr);
            hs_score.push(hp.score);
            hs_breakout.push(hp.breakout_confirmed);
        } else {
            hs_ids.push(0);
            hs_bear.push(false);
            hs_height.push(f64::NAN);
            hs_height_atr.push(f64::NAN);
            hs_score.push(f64::NAN);
            hs_breakout.push(false);
        }
    }

    let s_fid = Series::new("id".into(), flag_ids);
    let s_fbull = Series::new("is_bull".into(), flag_is_bull);
    let s_fplen = Series::new("pole_length".into(), flag_pole_len);
    let s_fpatr = Series::new("pole_length_atr".into(), flag_pole_atr);
    let s_fbo = Series::new("breakout_confirmed".into(), flag_breakout);
    let s_fbp = Series::new("breakout_price".into(), flag_bp);

    let flag_struct = StructChunked::from_series(
        "flag".into(),
        n,
        [s_fid, s_fbull, s_fplen, s_fpatr, s_fbo, s_fbp].iter(),
    )?;

    let s_hid = Series::new("id".into(), hs_ids);
    let s_hbear = Series::new("is_bearish".into(), hs_bear);
    let s_hh = Series::new("height".into(), hs_height);
    let s_hhatr = Series::new("height_atr".into(), hs_height_atr);
    let s_hsc = Series::new("score".into(), hs_score);
    let s_hbo = Series::new("breakout_confirmed".into(), hs_breakout);

    let hs_struct = StructChunked::from_series(
        "hs".into(),
        n,
        [s_hid, s_hbear, s_hh, s_hhatr, s_hsc, s_hbo].iter(),
    )?;

    let combined = StructChunked::from_series(
        "geo_patterns".into(),
        n,
        [flag_struct.into_series(), hs_struct.into_series()].iter(),
    )?;
    Ok(combined.into_series())
}