use std::sync::Arc;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use serde::Serialize;
use crate::{
agent::{
Agent, AgentIdentifier,
news::{
breakout::{NewsBreakout, NewsBreakoutGrid},
fade::{NewsFade, NewsFadeGrid},
},
},
error::ChapatyResult,
gym::trading::{
action::{Action, Actions, MarketCloseCmd},
observation::Observation,
},
};
#[derive(Debug, Clone, Copy, Serialize)]
pub struct NewsHybrid {
pub breakout: NewsBreakout,
pub fade: NewsFade,
}
impl Agent for NewsHybrid {
fn act(&mut self, obs: Observation) -> ChapatyResult<Actions> {
let fade_actions = self.fade.act(obs.clone())?;
let breakout_actions = self.breakout.act(obs.clone())?;
let any_breakout_signal =
breakout_actions.any_open_action(&self.breakout.ohlcv_id().into());
let any_fade_signal = fade_actions.any_open_action(&self.fade.ohlcv_id().into());
if any_breakout_signal {
let fade_agent_id = self.fade.identifier();
if let Some((market_id, state)) = obs.states.find_active_trade_for_agent(&fade_agent_id)
{
let close_cmd = MarketCloseCmd {
agent_id: fade_agent_id,
trade_id: state.trade_id(),
quantity: Some(state.quantity()),
};
return Ok(breakout_actions.with_action(market_id, Action::MarketClose(close_cmd)));
} else {
return Ok(breakout_actions);
}
}
if any_fade_signal {
let breakout_id = self.breakout.identifier();
if obs
.states
.find_active_trade_for_agent(&breakout_id)
.is_some()
{
return Ok(Actions::no_op());
} else {
return Ok(fade_actions);
}
}
Ok(Actions::no_op())
}
fn identifier(&self) -> AgentIdentifier {
AgentIdentifier::Named(Arc::new("NewsHybrid".to_string()))
}
fn reset(&mut self) {
self.breakout.reset();
self.fade.reset();
}
}
pub struct NewsHybridGrid {
pub fade: NewsFadeGrid,
pub breakout: NewsBreakoutGrid,
}
impl NewsHybridGrid {
pub fn build(self) -> (usize, impl ParallelIterator<Item = (usize, NewsHybrid)>) {
let (len_breakout, iter_breakout) = self.breakout.build();
let (len_fade, iter_fade) = self.fade.build();
let total_combinations = len_breakout * len_fade;
let fade_agents = iter_fade.map(|(_, agent)| agent).collect::<Vec<_>>();
let fade_arc = Arc::new(fade_agents);
let iterator = iter_breakout.flat_map(move |(b_uid, breakout)| {
let fade_ref = fade_arc.clone();
let len_fade = fade_ref.len();
(0..len_fade).into_par_iter().map(move |f_uid| {
let fade = fade_ref[f_uid];
let hybrid_uid = (b_uid * len_fade) + f_uid;
(hybrid_uid, NewsHybrid { breakout, fade })
})
});
(total_combinations, iterator)
}
}