1use crate::key::OkxKey as Key;
2
3use super::types::{
4 request::HttpRequest,
5 response::{FullHttpResponse, HttpResponse},
6};
7use exc_core::{retry::RetryPolicy, ExchangeError};
8use futures::{
9 future::{ready, BoxFuture},
10 FutureExt, TryFutureExt,
11};
12use http::{Request, Response};
13use hyper::Body;
14use tower::{retry::Retry, Layer, Service, ServiceBuilder};
15
16pub struct OkxHttpApiLayer<F> {
18 testing: bool,
19 aws: bool,
20 key: Option<Key>,
21 retry_policy: RetryPolicy<HttpRequest, HttpResponse, F>,
22}
23
24impl<F> OkxHttpApiLayer<F> {
25 pub fn private(&mut self, key: Key) -> &mut Self {
27 self.key = Some(key);
28 self
29 }
30
31 pub fn testing(&mut self, enable: bool) -> &mut Self {
33 self.testing = enable;
34 self
35 }
36
37 pub fn aws(&mut self, enable: bool) -> &mut Self {
39 self.aws = enable;
40 self
41 }
42
43 pub fn retry<F2>(
45 self,
46 policy: RetryPolicy<HttpRequest, HttpResponse, F2>,
47 ) -> OkxHttpApiLayer<F2>
48 where
49 F2: Clone,
50 {
51 OkxHttpApiLayer {
52 aws: self.aws,
53 retry_policy: policy,
54 key: self.key,
55 testing: self.testing,
56 }
57 }
58
59 pub fn host(&self) -> &'static str {
61 match (self.testing, self.aws) {
62 (true, _) => "https://www.okx.com",
63 (false, true) => "https://aws.okx.com",
64 (false, false) => "https://www.okx.com",
65 }
66 }
67
68 pub fn retry_on<F2>(self, f: F2) -> OkxHttpApiLayer<F2>
70 where
71 F2: Fn(&ExchangeError) -> bool,
72 F2: Send + 'static + Clone,
73 {
74 self.retry(RetryPolicy::default().retry_on(f))
75 }
76
77 pub fn retry_on_error(self) -> OkxHttpApiLayer<fn(&ExchangeError) -> bool> {
79 self.retry(RetryPolicy::default().retry_on(|_| true))
80 }
81}
82
83impl Default for OkxHttpApiLayer<fn(&ExchangeError) -> bool> {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl OkxHttpApiLayer<fn(&ExchangeError) -> bool> {
90 pub fn new() -> Self {
92 Self {
93 aws: false,
94 retry_policy: RetryPolicy::never(),
95 key: None,
96 testing: false,
97 }
98 }
99}
100
101impl<S, F> Layer<S> for OkxHttpApiLayer<F>
102where
103 S: Service<Request<Body>, Response = Response<Body>> + Clone,
104 S::Future: Send + 'static,
105 S::Error: 'static,
106 ExchangeError: From<S::Error>,
107 F: Fn(&ExchangeError) -> bool,
108 F: Send + 'static + Clone,
109{
110 type Service = Retry<RetryPolicy<HttpRequest, HttpResponse, F>, OkxHttpApi<S>>;
111
112 fn layer(&self, inner: S) -> Self::Service {
113 let svc = OkxHttpApi {
114 host: self.host().to_string(),
115 http: inner,
116 key: self.key.clone(),
117 testing: self.testing,
118 };
119 ServiceBuilder::default()
120 .retry(self.retry_policy.clone())
121 .service(svc)
122 }
123}
124
125#[derive(Clone)]
127pub struct OkxHttpApi<S> {
128 host: String,
129 key: Option<Key>,
130 http: S,
131 testing: bool,
132}
133
134const TESTING_HEADER: &str = "x-simulated-trading";
135
136impl<S> Service<HttpRequest> for OkxHttpApi<S>
137where
138 S: Service<Request<Body>, Response = Response<Body>>,
139 S::Future: Send + 'static,
140 S::Error: 'static,
141 ExchangeError: From<S::Error>,
142{
143 type Response = HttpResponse;
144 type Error = ExchangeError;
145 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
146
147 fn poll_ready(
148 &mut self,
149 cx: &mut std::task::Context<'_>,
150 ) -> std::task::Poll<Result<(), Self::Error>> {
151 self.http.poll_ready(cx).map_err(ExchangeError::from)
152 }
153
154 fn call(&mut self, req: HttpRequest) -> Self::Future {
155 let req = match req {
156 HttpRequest::Get(get) => serde_qs::to_string(&get)
157 .map_err(|err| ExchangeError::Other(err.into()))
158 .and_then(|q| {
159 let uri = format!("{}{}?{}", self.host, get.uri(), q);
160 Request::get(uri)
161 .body(Body::empty())
162 .map_err(|err| ExchangeError::Other(err.into()))
163 }),
164 HttpRequest::PrivateGet(get) => {
165 if let Some(key) = self.key.as_ref() {
166 get.to_request(&self.host, key)
167 } else {
168 Err(ExchangeError::KeyError(anyhow::anyhow!(
169 "key has not been set"
170 )))
171 }
172 }
173 };
174 match req {
175 Ok(mut req) => {
176 if self.testing {
177 req.headers_mut()
178 .insert(TESTING_HEADER, http::HeaderValue::from_static("1"));
179 }
180 self.http
181 .call(req)
182 .map_err(ExchangeError::from)
183 .and_then(|resp| {
184 trace!("http response; status: {:?}", resp.status());
185 hyper::body::to_bytes(resp.into_body())
186 .map_err(|err| ExchangeError::Other(err.into()))
187 })
188 .and_then(|bytes| {
189 tracing::trace!(?bytes, "http response;");
190 let resp = serde_json::from_slice::<FullHttpResponse>(&bytes)
191 .map_err(|err| ExchangeError::Other(err.into()));
192
193 futures::future::ready(resp)
194 })
195 .and_then(|resp| ready(resp.try_into()))
196 .boxed()
197 }
198 Err(err) => ready(Err(err)).boxed(),
199 }
200 }
201}