use std::sync::Arc;
use chrono::Duration;
use itertools::iproduct;
use rand::seq::SliceRandom;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use serde::Serialize;
use serde_with::{DurationSeconds, serde_as};
use crate::{
agent::{Agent, AgentIdentifier, GridAxis, news::NewsPhase},
data::{
domain::{CandleDirection, Price, Quantity, TradeId},
event::{EconomicCalendarId, MarketId, Ohlcv, OhlcvId},
view::StreamView,
},
error::{AgentError, ChapatyResult},
gym::trading::{
action::{Action, Actions, OpenCmd},
observation::Observation,
types::TradeType,
},
};
#[serde_as]
#[derive(Debug, Clone, Copy, Serialize)]
pub struct NewsBreakout {
#[serde(skip)]
economic_cal_id: EconomicCalendarId,
#[serde(skip)]
ohlcv_id: OhlcvId,
#[serde_as(as = "DurationSeconds<i64>")]
earliest_entry: Duration,
#[serde_as(as = "DurationSeconds<i64>")]
latest_entry: Duration,
stop_loss_risk_factor: f64,
risk_reward_ratio: f64,
#[serde(skip)]
phase: NewsPhase,
#[serde(skip)]
trade_counter: i64,
}
impl NewsBreakout {
pub fn baseline(economic_cal_id: EconomicCalendarId, ohlcv_id: OhlcvId) -> Self {
Self {
economic_cal_id,
ohlcv_id,
earliest_entry: Duration::seconds(480),
latest_entry: Duration::seconds(3000),
stop_loss_risk_factor: 0.89,
risk_reward_ratio: 0.726,
phase: NewsPhase::default(),
trade_counter: 0,
}
}
pub fn economic_calendar_id(&self) -> EconomicCalendarId {
self.economic_cal_id
}
pub fn ohlcv_id(&self) -> OhlcvId {
self.ohlcv_id
}
pub fn earliest_entry(&self) -> Duration {
self.earliest_entry
}
pub fn latest_entry(&self) -> Duration {
self.latest_entry
}
pub fn stop_loss_risk_factor(&self) -> f64 {
self.stop_loss_risk_factor
}
pub fn risk_reward_ratio(&self) -> f64 {
self.risk_reward_ratio
}
pub fn with_calendar_id(self, economic_cal_id: EconomicCalendarId) -> Self {
Self {
economic_cal_id,
..self
}
}
pub fn with_ohlcv_id(self, ohlcv_id: OhlcvId) -> Self {
Self { ohlcv_id, ..self }
}
pub fn with_earliest_entry_candle(self, duration: Duration) -> Self {
Self {
earliest_entry: duration,
..self
}
}
pub fn with_latest_entry_candle(self, duration: Duration) -> Self {
Self {
latest_entry: duration,
..self
}
}
pub fn with_stop_loss_risk_factor(self, factor: f64) -> Self {
Self {
stop_loss_risk_factor: factor,
..self
}
}
pub fn with_risk_reward_ratio(self, ratio: f64) -> ChapatyResult<Self> {
if ratio <= 0.0 {
return Err(
AgentError::InvalidInput("risk_reward_ratio must be > 0.0".to_string()).into(),
);
}
Ok(Self {
risk_reward_ratio: ratio,
..self
})
}
}
impl NewsBreakout {
fn stop_loss_target(&self, news_candle: &Ohlcv) -> Option<StopLossTarget> {
let open = news_candle.open.0;
let close = news_candle.close.0;
let body_size = (open - close).abs();
match news_candle.direction() {
CandleDirection::Bearish => {
let price = close + body_size * self.stop_loss_risk_factor;
Some(StopLossTarget {
stop_loss_price: Price(price),
trade_type: TradeType::Short,
})
}
CandleDirection::Bullish => {
let price = close - body_size * self.stop_loss_risk_factor;
Some(StopLossTarget {
stop_loss_price: Price(price),
trade_type: TradeType::Long,
})
}
CandleDirection::Doji => None,
}
}
}
impl Agent for NewsBreakout {
fn act(&mut self, obs: Observation) -> ChapatyResult<Actions> {
let economic_cal_id = self.economic_cal_id;
let ohlcv_id = self.ohlcv_id;
let current_time = obs.market_view.current_timestamp();
if obs.states.any_active_trade_for_agent(&self.identifier()) {
return Ok(Actions::no_op());
}
if let NewsPhase::AwaitingNews = self.phase
&& let Some(news_event) = obs.market_view.economic_news().last_event(&economic_cal_id)
{
let news_candle_candidate = obs
.market_view
.ohlcv()
.last_event(&ohlcv_id)
.filter(|candle| candle.open_timestamp == news_event.timestamp)
.copied();
self.phase = NewsPhase::PostNews {
news_time: news_event.timestamp,
news_candle: news_candle_candidate,
};
}
let (news_time, candle) = if let NewsPhase::PostNews {
news_time,
news_candle: Some(candle),
} = self.phase
{
(news_time, candle)
} else {
self.phase = NewsPhase::AwaitingNews;
return Ok(Actions::no_op());
};
let time_since_news = current_time - news_time;
if time_since_news < self.earliest_entry {
return Ok(Actions::no_op());
}
if time_since_news > self.latest_entry {
self.phase = NewsPhase::AwaitingNews;
return Ok(Actions::no_op());
}
let entry_price = obs.market_view.try_resolved_close_price(&ohlcv_id.symbol)?;
let breakout_up = entry_price.0 > candle.high.0;
let breakout_down = entry_price.0 < candle.low.0;
if !breakout_up && !breakout_down {
return Ok(Actions::no_op()); }
let sl_target = match self.stop_loss_target(&candle) {
Some(tp) => tp,
None => {
self.phase = NewsPhase::AwaitingNews;
return Ok(Actions::no_op());
}
};
self.trade_counter += 1;
let trade_id = TradeId(self.trade_counter);
let quantity = Quantity(1.0);
let cmd = OpenCmd {
agent_id: self.identifier(),
trade_id,
trade_type: sl_target.trade_type,
quantity,
entry_price: None,
stop_loss: Some(sl_target.stop_loss_price),
take_profit: Some(sl_target.take_profit_price(entry_price, self.risk_reward_ratio)),
};
self.phase = NewsPhase::AwaitingNews;
let market_id: MarketId = ohlcv_id.into();
Ok(Actions::from((market_id, Action::Open(cmd))))
}
fn identifier(&self) -> AgentIdentifier {
AgentIdentifier::Named(Arc::new("NewsBreakout".to_string()))
}
fn reset(&mut self) {
self.phase = NewsPhase::AwaitingNews;
self.trade_counter = 0;
}
}
struct StopLossTarget {
stop_loss_price: Price,
trade_type: TradeType,
}
impl StopLossTarget {
fn take_profit_price(&self, entry_price: Price, risk_reward_ratio: f64) -> Price {
let sl = self.stop_loss_price.0;
let entry = entry_price.0;
let sl = match self.trade_type {
TradeType::Long => entry + (entry - sl) / risk_reward_ratio,
TradeType::Short => entry - (sl - entry) / risk_reward_ratio,
};
Price(sl)
}
}
pub struct NewsBreakoutGrid {
cal_id: EconomicCalendarId,
market_id: OhlcvId,
earliest_entry: (Duration, Duration),
latest_entry: (Duration, Duration),
stop_loss_risk_factor: GridAxis,
risk_reward_ratio: GridAxis,
}
impl NewsBreakoutGrid {
pub fn baseline(cal_id: EconomicCalendarId, market_id: OhlcvId) -> ChapatyResult<Self> {
Ok(Self {
cal_id,
market_id,
earliest_entry: (Duration::minutes(1), Duration::minutes(6)),
latest_entry: (Duration::minutes(20), Duration::minutes(28)),
stop_loss_risk_factor: GridAxis::new("0.5", "1.5", "0.01")?,
risk_reward_ratio: GridAxis::new("0.1", "2.6", "0.01")?,
})
}
pub fn with_earliest_entry_range(self, start: Duration, end: Duration) -> Self {
Self {
earliest_entry: (start, end),
..self
}
}
pub fn with_latest_entry_range(self, start: Duration, end: Duration) -> Self {
Self {
latest_entry: (start, end),
..self
}
}
pub fn with_stop_loss_risk_factor(self, axis: GridAxis) -> Self {
Self {
stop_loss_risk_factor: axis,
..self
}
}
pub fn with_risk_reward_ratio(self, axis: GridAxis) -> Self {
Self {
risk_reward_ratio: axis,
..self
}
}
pub fn build(self) -> (usize, impl ParallelIterator<Item = (usize, NewsBreakout)>) {
let (start_earliest, end_earliest) = self.earliest_entry;
let (start_latest, end_latest) = self.latest_entry;
let stop_loss_risk_factors = self.stop_loss_risk_factor.generate();
let risk_reward_ratios = self.risk_reward_ratio.generate();
let earliest_entries = (start_earliest.num_minutes()..end_earliest.num_minutes())
.map(Duration::minutes)
.collect::<Vec<_>>();
let latest_entries = (start_latest.num_minutes()..end_latest.num_minutes())
.map(Duration::minutes)
.collect::<Vec<_>>();
let mut args = iproduct!(
risk_reward_ratios,
stop_loss_risk_factors,
latest_entries,
earliest_entries
)
.filter(|(_, _, latest, earliest)| earliest < latest)
.enumerate()
.map(|(uid, (rrr, slrf, latest, earliest))| NewsBreakoutArgs {
uid,
rrr,
slrf,
latest,
earliest,
})
.collect::<Vec<_>>();
let mut rng = rand::rng();
args.shuffle(&mut rng);
let total_combinations = args.len();
let cal_id = self.cal_id;
let market_id = self.market_id;
let iterator = args.into_par_iter().map(move |arg| {
(
arg.uid,
NewsBreakout::baseline(cal_id, market_id)
.with_earliest_entry_candle(arg.earliest)
.with_latest_entry_candle(arg.latest)
.with_stop_loss_risk_factor(arg.slrf)
.with_risk_reward_ratio(arg.rrr)
.expect("Valid grid parameters"),
)
});
(total_combinations, iterator)
}
}
#[derive(Debug, Clone, Copy)]
struct NewsBreakoutArgs {
uid: usize,
rrr: f64,
slrf: f64,
latest: Duration,
earliest: Duration,
}