use chrono::{DateTime, NaiveDateTime};
use polars::error::PolarsResult;
use polars::prelude::*;
use rust_decimal::prelude::FromPrimitive;
use rust_decimal::Decimal;
use crate::{Candle, Side, Signal};
pub fn extract_new_rows(updated: &DataFrame, data: &DataFrame) -> DataFrame {
updated
.join(data, ["time"], ["time"], JoinArgs::new(JoinType::Anti))
.unwrap()
}
pub fn extract_candles_from_df(df: &DataFrame) -> PolarsResult<Vec<Candle>> {
let time = df.column("time")?.datetime()?;
let high = df.column("high")?.f64()?;
let low = df.column("low")?.f64()?;
let open = df.column("open")?.f64()?;
let close = df.column("close")?.f64()?;
let volume = df.column("volume")?.f64()?;
Ok((0..time.len())
.into_iter()
.map(|i| Candle {
time: DateTime::from_timestamp_millis(time.get(i).unwrap())
.unwrap()
.naive_utc(),
high: Decimal::from_f64(high.get(i).unwrap()).unwrap(),
low: Decimal::from_f64(low.get(i).unwrap()).unwrap(),
open: Decimal::from_f64(open.get(i).unwrap()).unwrap(),
close: Decimal::from_f64(close.get(i).unwrap()).unwrap(),
volume: Decimal::from_f64(volume.get(i).unwrap()).unwrap(),
})
.collect())
}
pub fn extract_signals_from_df(df: &DataFrame, column_name: &str) -> PolarsResult<Vec<Signal>> {
Ok(df
.column(column_name)?
.i8()?
.into_iter()
.map(|value| {
if let Some(value) = value {
Signal::from(value)
} else {
Signal::Hold
}
})
.collect())
}
pub fn extract_side_from_df(df: &DataFrame, column_name: &str) -> PolarsResult<Vec<Side>> {
Ok(df
.column(column_name)?
.i8()?
.into_iter()
.map(|value| Side::from(value.unwrap()))
.collect())
}
#[derive(Debug)]
pub enum AlignmentError {
DifferentLengths,
TimestampsNotAligned,
}
pub fn check_candle_alignment(a: &DataFrame, b: &DataFrame) -> Result<(), AlignmentError> {
let market_data_index = a.column("time").unwrap().datetime().unwrap();
let historical_data_index = b.column("time").unwrap().datetime().unwrap();
if market_data_index.len() != historical_data_index.len() {
return Err(AlignmentError::DifferentLengths);
}
let index_alignment_mask: Vec<bool> = market_data_index
.iter()
.zip(historical_data_index.iter())
.map(|(a, b)| a != b)
.collect();
if index_alignment_mask.iter().any(|&x| x) {
return Err(AlignmentError::TimestampsNotAligned);
}
Ok(())
}
pub fn trim_candles(candles: &DataFrame, end_time: NaiveDateTime, length: IdxSize) -> DataFrame {
candles
.clone()
.lazy()
.filter(col("time").lt(lit(end_time)))
.tail(length)
.collect()
.unwrap()
}
#[cfg(test)]
mod tests {
use crate::utils::extract_new_rows;
use polars::prelude::*;
#[test]
fn test_extract_new_rows() {
let candles = df!(
"time" => &[1, 2, 3, 41, 51],
"open" => &[1, 2, 3, 42, 52],
"high" => &[1, 2, 3, 43, 53],
"low" => &[1, 2, 3, 44, 54],
"close" => &[1, 2, 3, 45, 55],
"volume" => &[1, 2, 3, 46, 56],
)
.unwrap();
let indicator_data = df!(
"time" => &[1, 2, 3],
"open" => &[1, 2, 3],
"high" => &[1, 2, 3],
"low" => &[1, 2, 3],
"close" => &[1, 2, 3],
"volume" => &[1, 2, 3],
)
.unwrap();
let new_rows = extract_new_rows(&candles, &indicator_data);
assert_eq!(new_rows.shape(), (2, 6));
assert_eq!(
new_rows.column("time").unwrap().i32().unwrap().get(0),
Some(41)
);
assert_eq!(
new_rows.column("time").unwrap().i32().unwrap().get(1),
Some(51)
);
assert_eq!(
new_rows.column("open").unwrap().i32().unwrap().get(0),
Some(42)
);
assert_eq!(
new_rows.column("open").unwrap().i32().unwrap().get(1),
Some(52)
);
assert_eq!(
new_rows.column("high").unwrap().i32().unwrap().get(0),
Some(43)
);
assert_eq!(
new_rows.column("high").unwrap().i32().unwrap().get(1),
Some(53)
);
assert_eq!(
new_rows.column("low").unwrap().i32().unwrap().get(0),
Some(44)
);
assert_eq!(
new_rows.column("low").unwrap().i32().unwrap().get(1),
Some(54)
);
assert_eq!(
new_rows.column("close").unwrap().i32().unwrap().get(0),
Some(45)
);
assert_eq!(
new_rows.column("close").unwrap().i32().unwrap().get(1),
Some(55)
);
assert_eq!(
new_rows.column("volume").unwrap().i32().unwrap().get(0),
Some(46)
);
assert_eq!(
new_rows.column("volume").unwrap().i32().unwrap().get(1),
Some(56)
);
}
}