use polars::{
prelude::{EWMOptions, LazyFrame, RollingOptionsFixedWindow, SortMultipleOptions, col, lit},
series::ops::NullBehavior,
};
use serde::{Deserialize, Serialize};
use crate::{error::ChapatyResult, transport::schema::CanonicalCol};
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct EmaWindow(pub u16);
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct SmaWindow(pub u16);
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct RsiWindow(pub u16);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BatchOhlcvIndicator {
Ema(EmaWindow),
Sma(SmaWindow),
Rsi(RsiWindow),
}
impl BatchOhlcvIndicator {
pub fn pre_compute(&self, lf: LazyFrame) -> ChapatyResult<LazyFrame> {
match self {
BatchOhlcvIndicator::Ema(ema) => ema.pre_compute_ema(lf),
BatchOhlcvIndicator::Sma(sma) => sma.pre_compute_sma(lf),
BatchOhlcvIndicator::Rsi(rsi) => rsi.pre_compute_rsi(lf),
}
}
}
pub trait WithBatchIndicators: Sized {
type BatchIndicator: Clone;
fn with_indicator(self, kind: Self::BatchIndicator) -> Self;
fn with_indicators(self, kinds: &[Self::BatchIndicator]) -> Self {
kinds
.iter()
.fold(self, |acc, kind| acc.with_indicator(kind.clone()))
}
}
impl EmaWindow {
fn pre_compute_ema(&self, lf: LazyFrame) -> ChapatyResult<LazyFrame> {
let window = self.0;
let alpha = 2.0 / (window as f64 + 1.0);
let options = EWMOptions {
alpha,
adjust: false,
bias: false,
min_periods: window as usize,
ignore_nulls: true,
};
Ok(lf
.sort(
[CanonicalCol::Timestamp],
SortMultipleOptions::default().with_maintain_order(false),
)
.select([
col(CanonicalCol::Timestamp).alias(CanonicalCol::Timestamp),
col(CanonicalCol::Close)
.ewm_mean(options)
.alias(CanonicalCol::Price),
])
.drop_nulls(None))
}
}
impl SmaWindow {
fn pre_compute_sma(&self, lf: LazyFrame) -> ChapatyResult<LazyFrame> {
let window = self.0;
let options = RollingOptionsFixedWindow {
window_size: window as usize,
min_periods: window as usize, weights: None, center: false, fn_params: None,
};
Ok(lf
.sort(
[CanonicalCol::Timestamp],
SortMultipleOptions::default().with_maintain_order(false),
)
.select([
col(CanonicalCol::Timestamp).alias(CanonicalCol::Timestamp),
col(CanonicalCol::Close)
.rolling_mean(options)
.alias(CanonicalCol::Price),
])
.drop_nulls(None))
}
}
impl RsiWindow {
fn pre_compute_rsi(&self, lf: LazyFrame) -> ChapatyResult<LazyFrame> {
let window = self.0;
let alpha = 1.0 / (window as f64);
let options = EWMOptions {
alpha,
adjust: false,
bias: false,
min_periods: window as usize,
ignore_nulls: true,
};
let rsi_expr = {
let delta = col(CanonicalCol::Close).diff(lit(1), NullBehavior::Ignore);
let gain = delta.clone().clip(lit(0), lit(f64::MAX));
let loss = delta.clip(lit(f64::MIN), lit(0)).abs();
let avg_gain = gain.ewm_mean(options);
let avg_loss = loss.ewm_mean(options);
let rs = avg_gain / avg_loss;
lit(100.0) - (lit(100.0) / (lit(1.0) + rs))
};
Ok(lf
.sort(
[CanonicalCol::Timestamp],
SortMultipleOptions::default().with_maintain_order(false),
)
.select([
col(CanonicalCol::Timestamp).alias(CanonicalCol::Timestamp),
rsi_expr.alias(CanonicalCol::Price),
])
.drop_nulls(None))
}
}