nash_protocol/protocol/
hooks.rs

1//! These module contains types that allow protocol requests and pipelines to call other
2//! requests and piplines via before and after hooks. These need to be explicity encoded
3//! in enum types to make the compiler happy.
4
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use tokio::sync::{Mutex, RwLock};
9
10use crate::errors::{ProtocolError, Result};
11use crate::protocol::ErrorResponse;
12
13use super::asset_nonces::{AssetNoncesRequest, AssetNoncesResponse};
14use super::cancel_all_orders::{CancelAllOrders, CancelAllOrdersResponse};
15use super::dh_fill_pool::{DhFillPoolRequest, DhFillPoolResponse};
16use super::list_markets::{ListMarketsRequest, ListMarketsResponse};
17use super::orderbook::{OrderbookRequest, OrderbookResponse};
18use super::place_order::{LimitOrderRequest, PlaceOrderResponse};
19use super::sign_all_states::{SignAllPipelineState, SignAllStates};
20use super::sign_states::{SignStatesRequest, SignStatesResponse};
21use super::traits::{NashProtocol, NashProtocolPipeline};
22use super::{ResponseOrError, State};
23
24/// An enum wrapping all the different protocol requests
25#[derive(Debug, Clone)]
26pub enum NashProtocolRequest {
27    AssetNonces(AssetNoncesRequest),
28    DhFill(
29        DhFillPoolRequest,
30        Option<Arc<Mutex<Option<tokio::sync::OwnedSemaphorePermit>>>>,
31    ),
32    LimitOrder(LimitOrderRequest),
33    Orderbook(OrderbookRequest),
34    CancelOrders(CancelAllOrders),
35    SignState(SignStatesRequest),
36    ListMarkets(ListMarketsRequest),
37}
38
39/// An enum wrapping all the different protocol responses
40#[derive(Debug)]
41pub enum NashProtocolResponse {
42    AssetNonces(AssetNoncesResponse),
43    DhFill(DhFillPoolResponse),
44    LimitOrder(PlaceOrderResponse),
45    Orderbook(OrderbookResponse),
46    CancelOrders(CancelAllOrdersResponse),
47    SignState(SignStatesResponse),
48    ListMarkets(ListMarketsResponse),
49}
50
51/// Implement NashProtocol for the enum, threading through to the base implementation
52/// for each of the captured types. This could probably be automated wiht a macro.
53#[async_trait]
54impl NashProtocol for NashProtocolRequest {
55    type Response = NashProtocolResponse;
56
57    async fn acquire_permit(
58        &self,
59        state: Arc<RwLock<State>>,
60    ) -> Option<tokio::sync::OwnedSemaphorePermit> {
61        match self {
62            Self::AssetNonces(nonces) => NashProtocol::acquire_permit(nonces, state).await,
63            Self::DhFill(dh_fill, permit) => match permit {
64                Some(permit) => permit.lock().await.take(),
65                None => NashProtocol::acquire_permit(dh_fill, state).await,
66            },
67            Self::LimitOrder(limit_order) => NashProtocol::acquire_permit(limit_order, state).await,
68            Self::Orderbook(orderbook) => NashProtocol::acquire_permit(orderbook, state).await,
69            Self::SignState(sign_state) => NashProtocol::acquire_permit(sign_state, state).await,
70            Self::CancelOrders(cancel_all) => NashProtocol::acquire_permit(cancel_all, state).await,
71            Self::ListMarkets(list_markets) => {
72                NashProtocol::acquire_permit(list_markets, state).await
73            }
74        }
75    }
76
77    async fn graphql(&self, state: Arc<RwLock<State>>) -> Result<serde_json::Value> {
78        match self {
79            Self::AssetNonces(nonces) => nonces.graphql(state).await,
80            Self::DhFill(dh_fill, _permit) => dh_fill.graphql(state).await,
81            Self::LimitOrder(limit_order) => limit_order.graphql(state).await,
82            Self::Orderbook(orderbook) => orderbook.graphql(state).await,
83            Self::SignState(sign_state) => sign_state.graphql(state).await,
84            Self::CancelOrders(cancel_all) => cancel_all.graphql(state).await,
85            Self::ListMarkets(list_markets) => list_markets.graphql(state).await,
86        }
87    }
88
89    async fn response_from_json(
90        &self,
91        response: serde_json::Value,
92        state: Arc<RwLock<State>>,
93    ) -> Result<ResponseOrError<Self::Response>> {
94        match self {
95            Self::AssetNonces(nonces) => Ok(nonces
96                .response_from_json(response, state)
97                .await?
98                .map(Box::new(|res| NashProtocolResponse::AssetNonces(res)))),
99            Self::DhFill(dh_fill, _permit) => Ok(dh_fill
100                .response_from_json(response, state)
101                .await?
102                .map(Box::new(|res| NashProtocolResponse::DhFill(res)))),
103            Self::LimitOrder(limit_order) => Ok(limit_order
104                .response_from_json(response, state)
105                .await?
106                .map(Box::new(|res| NashProtocolResponse::LimitOrder(res)))),
107            Self::Orderbook(orderbook) => Ok(orderbook
108                .response_from_json(response, state)
109                .await?
110                .map(Box::new(|res| NashProtocolResponse::Orderbook(res)))),
111            Self::SignState(sign_state) => Ok(sign_state
112                .response_from_json(response, state)
113                .await?
114                .map(Box::new(|res| NashProtocolResponse::SignState(res)))),
115            Self::CancelOrders(cancel_all) => Ok(cancel_all
116                .response_from_json(response, state)
117                .await?
118                .map(Box::new(|res| NashProtocolResponse::CancelOrders(res)))),
119            Self::ListMarkets(list_markets) => Ok(list_markets
120                .response_from_json(response, state)
121                .await?
122                .map(Box::new(|res| NashProtocolResponse::ListMarkets(res)))),
123        }
124    }
125
126    async fn process_response(
127        &self,
128        response: &Self::Response,
129        state: Arc<RwLock<State>>,
130    ) -> Result<()> {
131        match (self, response) {
132            (Self::AssetNonces(nonces), NashProtocolResponse::AssetNonces(response)) => {
133                nonces.process_response(response, state).await?
134            }
135            (Self::DhFill(dh_fill, _permit), NashProtocolResponse::DhFill(response)) => {
136                dh_fill.process_response(response, state).await?
137            }
138            (Self::SignState(sign_req), NashProtocolResponse::SignState(response)) => {
139                sign_req.process_response(response, state).await?
140            }
141            (Self::LimitOrder(limit_order), NashProtocolResponse::LimitOrder(response)) => {
142                limit_order.process_response(response, state).await?
143            }
144            (Self::CancelOrders(cancel_all), NashProtocolResponse::CancelOrders(response)) => {
145                cancel_all.process_response(response, state).await?
146            }
147            (Self::ListMarkets(list_markets), NashProtocolResponse::ListMarkets(response)) => {
148                list_markets.process_response(response, state).await?
149            }
150            _ => {
151                return Err(ProtocolError(
152                    "Attempting to process a differently typed response. This should never happen.
153                    If you are seeing this error, there is something wrong with the client
154                    implementation of the generic protocol runtime loop.",
155                ))
156            }
157        };
158        Ok(())
159    }
160
161    async fn process_error(
162        &self,
163        response: &ErrorResponse,
164        state: Arc<RwLock<State>>,
165    ) -> Result<()> {
166        match self {
167            Self::AssetNonces(nonces) => nonces.process_error(response, state).await,
168            Self::DhFill(dh_fill, _permit) => dh_fill.process_error(response, state).await,
169            Self::LimitOrder(limit_order) => limit_order.process_error(response, state).await,
170            Self::Orderbook(orderbook) => orderbook.process_error(response, state).await,
171            Self::SignState(sign_state) => sign_state.process_error(response, state).await,
172            Self::CancelOrders(cancel_all) => cancel_all.process_error(response, state).await,
173            Self::ListMarkets(list_markets) => list_markets.process_error(response, state).await,
174        }
175    }
176
177    async fn run_before(&self, state: Arc<RwLock<State>>) -> Result<Option<Vec<ProtocolHook>>> {
178        match self {
179            Self::AssetNonces(nonces) => NashProtocol::run_before(nonces, state).await,
180            Self::DhFill(dh_fill, _permit) => NashProtocol::run_before(dh_fill, state).await,
181            Self::LimitOrder(limit_order) => NashProtocol::run_before(limit_order, state).await,
182            Self::Orderbook(orderbook) => NashProtocol::run_before(orderbook, state).await,
183            Self::SignState(sign_state) => NashProtocol::run_before(sign_state, state).await,
184            Self::CancelOrders(cancel_all) => NashProtocol::run_before(cancel_all, state).await,
185            Self::ListMarkets(list_markets) => NashProtocol::run_before(list_markets, state).await,
186        }
187    }
188
189    async fn run_after(&self, state: Arc<RwLock<State>>) -> Result<Option<Vec<ProtocolHook>>> {
190        match self {
191            Self::AssetNonces(nonces) => NashProtocol::run_after(nonces, state).await,
192            Self::DhFill(dh_fill, _permit) => NashProtocol::run_after(dh_fill, state).await,
193            Self::LimitOrder(limit_order) => NashProtocol::run_after(limit_order, state).await,
194            Self::Orderbook(orderbook) => NashProtocol::run_after(orderbook, state).await,
195            Self::SignState(sign_state) => NashProtocol::run_after(sign_state, state).await,
196            Self::CancelOrders(cancel_all) => NashProtocol::run_after(cancel_all, state).await,
197            Self::ListMarkets(list_markets) => NashProtocol::run_after(list_markets, state).await,
198        }
199    }
200}
201
202/// Captures and protocol request or pipeline that can be executed in a `run_before`
203/// or `run_after` hook for the `NashProtocol` and `NashProtocolPipeline` traits.
204#[derive(Clone, Debug)]
205pub enum ProtocolHook {
206    SignAllState(SignAllStates),
207    Protocol(NashProtocolRequest),
208}
209
210/// State representation for `ProtocolHook`
211pub enum ProtocolHookState {
212    SignAllStates(<SignAllStates as NashProtocolPipeline>::PipelineState),
213    Protocol(<NashProtocolRequest as NashProtocolPipeline>::PipelineState),
214}
215
216/// Implement `NashProtocolPipeline` for `ProtocolHook` so that hooks can be run as typical pipelines
217#[async_trait]
218impl NashProtocolPipeline for ProtocolHook {
219    type PipelineState = ProtocolHookState;
220    type ActionType = NashProtocolRequest;
221
222    async fn acquire_permit(
223        &self,
224        state: Arc<RwLock<State>>,
225    ) -> Option<tokio::sync::OwnedSemaphorePermit> {
226        match self {
227            Self::SignAllState(sign_all) => {
228                NashProtocolPipeline::acquire_permit(sign_all, state).await
229            }
230            Self::Protocol(protocol) => NashProtocol::acquire_permit(protocol, state).await,
231        }
232    }
233
234    async fn init_state(&self, state: Arc<RwLock<State>>) -> Self::PipelineState {
235        match self {
236            Self::SignAllState(sign_all) => {
237                ProtocolHookState::SignAllStates(sign_all.init_state(state).await)
238            }
239            Self::Protocol(protocol) => {
240                ProtocolHookState::Protocol(protocol.init_state(state).await)
241            }
242        }
243    }
244
245    async fn next_step(
246        &self,
247        pipeline_state: &Self::PipelineState,
248        client_state: Arc<RwLock<State>>,
249    ) -> Result<Option<Self::ActionType>> {
250        match (self, pipeline_state) {
251            (Self::SignAllState(sign_all), ProtocolHookState::SignAllStates(sign_all_state)) => {
252                Ok(sign_all
253                    .next_step(sign_all_state, client_state)
254                    .await?
255                    .map(|x| NashProtocolRequest::SignState(x)))
256            }
257            (Self::Protocol(protocol), ProtocolHookState::Protocol(protocol_state)) => {
258                protocol.next_step(protocol_state, client_state).await
259            }
260            _ => Err(ProtocolError("Protocol does not align with action")),
261        }
262    }
263
264    async fn process_step(
265        &self,
266        result: <NashProtocolRequest as NashProtocol>::Response,
267        pipeline_state: &mut Self::PipelineState,
268    ) {
269        match (self, pipeline_state) {
270            (ProtocolHook::SignAllState(request), ProtocolHookState::SignAllStates(state)) => {
271                if let NashProtocolResponse::SignState(response) = result {
272                    request.process_step(response, state).await
273                }
274            }
275            (ProtocolHook::Protocol(request), ProtocolHookState::Protocol(state)) => {
276                request.process_step(result, state).await
277            }
278            _ => {}
279        }
280    }
281
282    fn output(
283        &self,
284        pipeline_state: Self::PipelineState,
285    ) -> Result<ResponseOrError<NashProtocolResponse>> {
286        match pipeline_state {
287            ProtocolHookState::SignAllStates(SignAllPipelineState {
288                previous_response: Some(response),
289                ..
290            }) => Ok(ResponseOrError::from_data(NashProtocolResponse::SignState(
291                response,
292            ))),
293            ProtocolHookState::Protocol(Some(protocol_state)) => Ok(protocol_state),
294            _ => Err(ProtocolError("Pipeline did not return state")),
295        }
296    }
297
298    async fn run_before(&self, state: Arc<RwLock<State>>) -> Result<Option<Vec<ProtocolHook>>> {
299        match self {
300            Self::Protocol(protocol) => NashProtocol::run_before(protocol, state).await,
301            Self::SignAllState(sign_all) => NashProtocolPipeline::run_before(sign_all, state).await,
302        }
303    }
304
305    async fn run_after(&self, state: Arc<RwLock<State>>) -> Result<Option<Vec<ProtocolHook>>> {
306        match self {
307            Self::Protocol(protocol) => NashProtocol::run_after(protocol, state).await,
308            Self::SignAllState(sign_all) => NashProtocolPipeline::run_after(sign_all, state).await,
309        }
310    }
311}