use std::sync::Arc;
use chrono::{DateTime, Duration, Utc};
use itertools::iproduct;
use rand::seq::SliceRandom;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use serde::Serialize;
use serde_with::{DurationSeconds, serde_as};
use crate::{
agent::{Agent, AgentIdentifier, GridAxis, news::NewsPhase},
data::{
domain::{CandleDirection, Price, Quantity, TradeId},
event::{EconomicCalendarId, Ohlcv, OhlcvId},
view::StreamView,
},
error::{AgentError, ChapatyResult},
gym::trading::{
action::{Action, Actions, OpenCmd},
observation::Observation,
types::TradeType,
},
};
#[serde_as]
#[derive(Debug, Clone, Copy, Serialize)]
pub struct NewsFade {
#[serde(skip)]
economic_cal_id: EconomicCalendarId,
#[serde(skip)]
ohlcv_id: OhlcvId,
#[serde_as(as = "DurationSeconds<i64>")]
wait_duration: Duration,
take_profit_risk_factor: f64,
risk_reward_ratio: f64,
#[serde(skip)]
phase: NewsPhase,
#[serde(skip)]
trade_counter: i64,
#[serde(skip)]
last_processed_news: Option<DateTime<Utc>>,
}
impl NewsFade {
pub fn baseline(economic_cal_id: EconomicCalendarId, ohlcv_id: OhlcvId) -> Self {
Self {
economic_cal_id,
ohlcv_id,
wait_duration: Duration::seconds(420),
take_profit_risk_factor: 1.27,
risk_reward_ratio: 0.276,
phase: NewsPhase::default(),
trade_counter: 0,
last_processed_news: None,
}
}
pub fn economic_calendar_id(&self) -> EconomicCalendarId {
self.economic_cal_id
}
pub fn ohlcv_id(&self) -> OhlcvId {
self.ohlcv_id
}
pub fn wait_duration(&self) -> Duration {
self.wait_duration
}
pub fn take_profit_risk_factor(&self) -> f64 {
self.take_profit_risk_factor
}
pub fn risk_reward_ratio(&self) -> f64 {
self.risk_reward_ratio
}
pub fn with_calendar_id(self, economic_cal_id: EconomicCalendarId) -> Self {
Self {
economic_cal_id,
..self
}
}
pub fn with_ohlcv_id(self, ohlcv_id: OhlcvId) -> Self {
Self { ohlcv_id, ..self }
}
pub fn with_candles_after_news(self, duration: Duration) -> Self {
Self {
wait_duration: duration,
..self
}
}
pub fn with_take_profit_risk_factor(self, factor: f64) -> Self {
Self {
take_profit_risk_factor: factor,
..self
}
}
pub fn with_risk_reward_ratio(self, ratio: f64) -> ChapatyResult<Self> {
if ratio <= 0.0 {
return Err(
AgentError::InvalidInput("risk_reward_ratio must be > 0.0".to_string()).into(),
);
}
Ok(Self {
risk_reward_ratio: ratio,
..self
})
}
}
impl NewsFade {
fn take_profit_target(&self, news_candle: &Ohlcv) -> Option<TakeProfitTarget> {
let open = news_candle.open.0;
let close = news_candle.close.0;
let body_size = (open - close).abs();
match news_candle.direction() {
CandleDirection::Bearish => {
let price = close + body_size * self.take_profit_risk_factor;
Some(TakeProfitTarget {
take_profit_price: Price(price),
trade_type: TradeType::Long,
})
}
CandleDirection::Bullish => {
let price = close - body_size * self.take_profit_risk_factor;
Some(TakeProfitTarget {
take_profit_price: Price(price),
trade_type: TradeType::Short,
})
}
CandleDirection::Doji => None,
}
}
}
impl Agent for NewsFade {
fn act(&mut self, obs: Observation) -> ChapatyResult<Actions> {
let current_time = obs.market_view.current_timestamp();
if obs.states.any_active_trade_for_agent(&self.identifier()) {
return Ok(Actions::no_op());
}
if let NewsPhase::AwaitingNews = self.phase
&& let Some(news_event) = obs
.market_view
.economic_news()
.last_event(&self.economic_cal_id)
{
if Some(news_event.timestamp) == self.last_processed_news {
return Ok(Actions::no_op());
}
let news_candle = obs
.market_view
.ohlcv()
.last_event(&self.ohlcv_id)
.filter(|candle| candle.open_timestamp == news_event.timestamp)
.copied();
self.phase = NewsPhase::PostNews {
news_time: news_event.timestamp,
news_candle,
};
}
let (news_time, candle) = if let NewsPhase::PostNews {
news_time,
news_candle: Some(candle),
} = self.phase
{
(news_time, candle)
} else {
self.phase = NewsPhase::AwaitingNews;
return Ok(Actions::no_op());
};
if current_time < news_time + self.wait_duration {
return Ok(Actions::no_op());
}
let tp_target = match self.take_profit_target(&candle) {
Some(tp) => tp,
None => {
self.last_processed_news = Some(news_time);
self.phase = NewsPhase::AwaitingNews;
return Ok(Actions::no_op());
}
};
self.trade_counter += 1;
let trade_id = TradeId(self.trade_counter);
let quantity = Quantity(1.0);
let estimated_entry = obs
.market_view
.try_resolved_close_price(&self.ohlcv_id.symbol)?;
let cmd = OpenCmd {
agent_id: self.identifier(),
trade_id,
trade_type: tp_target.trade_type,
quantity,
entry_price: None,
stop_loss: Some(tp_target.stop_loss_price(estimated_entry, self.risk_reward_ratio)),
take_profit: Some(tp_target.take_profit_price),
};
self.last_processed_news = Some(news_time);
self.phase = NewsPhase::AwaitingNews;
Ok(Actions::from((self.ohlcv_id.into(), Action::Open(cmd))))
}
fn identifier(&self) -> AgentIdentifier {
AgentIdentifier::Named(Arc::new("NewsFade".to_string()))
}
fn reset(&mut self) {
self.phase = NewsPhase::AwaitingNews;
self.trade_counter = 0;
self.last_processed_news = None;
}
}
struct TakeProfitTarget {
take_profit_price: Price,
trade_type: TradeType,
}
impl TakeProfitTarget {
fn stop_loss_price(&self, entry_price: Price, risk_reward_ratio: f64) -> Price {
let tp = self.take_profit_price.0;
let entry = entry_price.0;
let sl = match self.trade_type {
TradeType::Long => entry - (tp - entry) * risk_reward_ratio,
TradeType::Short => entry + (entry - tp) * risk_reward_ratio,
};
Price(sl)
}
}
pub struct NewsFadeGrid {
cal_id: EconomicCalendarId,
ohlcv_id: OhlcvId,
wait_duration: (Duration, Duration),
tp_risk_factor: GridAxis,
risk_reward: GridAxis,
}
impl NewsFadeGrid {
pub fn baseline(cal_id: EconomicCalendarId, ohlcv_id: OhlcvId) -> ChapatyResult<Self> {
Ok(Self {
cal_id,
ohlcv_id,
wait_duration: (Duration::minutes(5), Duration::minutes(30)),
tp_risk_factor: GridAxis::new("0.5", "3.0", "0.01")?,
risk_reward: GridAxis::new("0.1", "1.0", "0.01")?,
})
}
pub fn with_candles_after_news(self, start: Duration, end: Duration) -> Self {
Self {
wait_duration: (start, end),
..self
}
}
pub fn with_take_profit_risk_factor(self, axis: GridAxis) -> Self {
Self {
tp_risk_factor: axis,
..self
}
}
pub fn with_risk_reward_ratio(self, axis: GridAxis) -> Self {
Self {
risk_reward: axis,
..self
}
}
pub fn build(self) -> (usize, impl ParallelIterator<Item = (usize, NewsFade)>) {
let (start_wait, end_wait) = self.wait_duration;
let candles_after_news = (start_wait.num_minutes()..end_wait.num_minutes())
.map(Duration::minutes)
.collect::<Vec<_>>();
let take_profit_factors = self.tp_risk_factor.generate();
let risk_rewards = self.risk_reward.generate();
let mut args = iproduct!(risk_rewards, candles_after_news, take_profit_factors)
.enumerate()
.map(|(uid, (rrr, wait, tprf))| NewsFadeArgs {
uid,
rrr,
wait,
tprf,
})
.collect::<Vec<_>>();
let mut rng = rand::rng();
args.shuffle(&mut rng);
let total_combinations = args.len();
let cal_id = self.cal_id;
let ohlcv_id = self.ohlcv_id;
let iterator = args.into_par_iter().map(move |arg| {
(
arg.uid,
NewsFade::baseline(cal_id, ohlcv_id)
.with_candles_after_news(arg.wait)
.with_take_profit_risk_factor(arg.tprf)
.with_risk_reward_ratio(arg.rrr)
.expect("Valid grid parameters"),
)
});
(total_combinations, iterator)
}
}
#[derive(Debug, Clone, Copy)]
struct NewsFadeArgs {
uid: usize,
rrr: f64,
tprf: f64,
wait: Duration,
}