1use std::any::type_name;
4use std::collections::HashMap;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8use async_recursion::async_recursion;
9use futures::{FutureExt, SinkExt, StreamExt};
10use futures_util::future::{select, Either};
11use rand::Rng;
12use tokio::{net::TcpStream, sync::mpsc, sync::oneshot, sync::RwLock, time::Duration};
13use tokio_native_tls::TlsStream;
14use tokio_tungstenite::tungstenite::protocol::Message;
15use tokio_tungstenite::{connect_async, stream::Stream, WebSocketStream};
16use tracing::{error, trace, trace_span, warn, Instrument};
17
18use nash_protocol::errors::{ProtocolError, Result};
19use nash_protocol::protocol::subscriptions::SubscriptionResponse;
20use nash_protocol::protocol::{ErrorResponse, NashProtocol, NashProtocolPipeline, NashProtocolSubscription, ResponseOrError, State, MAX_R_VAL_POOL_SIZE};
21use nash_protocol::types::Blockchain;
22
23use crate::http_extension::HttpClientState;
24use crate::Environment;
25
26use super::absinthe::{AbsintheEvent, AbsintheTopic, AbsintheWSRequest, AbsintheWSResponse};
27
28type WebSocket = WebSocketStream<Stream<TcpStream, TlsStream<TcpStream>>>;
29
30const HEARTBEAT_MESSAGE_ID: u64 = 0;
31pub fn spawn_heartbeat_loop(
33 period: Duration,
34 client_id: u64,
35 outgoing_sender: mpsc::UnboundedSender<(AbsintheWSRequest, Option<oneshot::Receiver<bool>>)>,
36) {
37 tokio::spawn(async move {
38 loop {
39 let heartbeat = AbsintheWSRequest::new(
40 client_id,
41 HEARTBEAT_MESSAGE_ID, AbsintheTopic::Phoenix,
43 AbsintheEvent::Heartbeat,
44 None,
45 );
46 if let Err(_ignore) = outgoing_sender.send((heartbeat, None)) {
47 break;
49 }
50 tokio::time::sleep(period).await;
51 }
52 });
53}
54
55pub fn spawn_sender_loop(
58 timeout_duration: Duration,
59 mut websocket: WebSocket,
60 mut ws_outgoing_receiver: mpsc::UnboundedReceiver<(
61 AbsintheWSRequest,
62 Option<oneshot::Receiver<bool>>,
63 )>,
64 mut ws_disconnect_receiver: mpsc::UnboundedReceiver<()>,
65 message_broker_link: mpsc::UnboundedSender<BrokerAction>,
66) {
67 tokio::spawn(async move {
68 while ws_disconnect_receiver.recv().now_or_never().is_none() {
71 let next_outgoing = ws_outgoing_receiver.recv();
73 let next_incoming = tokio::time::timeout(timeout_duration, websocket.next());
74 tokio::pin!(next_outgoing);
75 tokio::pin!(next_incoming);
76 match select(next_outgoing, next_incoming).await {
77 Either::Left((outgoing, _)) => {
78 if let Some((request, _ready_rx)) = outgoing {
79 if let Ok(request_raw) = serde_json::to_string(&request) {
80 match websocket.send(Message::Text(request_raw)).await {
82 Ok(_) => {
83 trace!(id = ?request.message_id(), "SEND");
84 }
85 Err(e) => {
86 error!(error = %e, "SEND channel error");
87 let error = ProtocolError("failed to send message on WS connection, likely disconnected");
88 let _ =
89 message_broker_link.send(BrokerAction::Message(Err(error)));
90 break;
91 }
92 }
93 } else {
94 error!(request = ?request, "SEND invalid request");
95 }
96 } else {
97 error!("SEND channel error");
98 let error =
99 ProtocolError("outgoing channel died or errored, likely disconnected");
100 let _ = message_broker_link.send(BrokerAction::Message(Err(error)));
101 break;
102 }
103 }
104 Either::Right((incoming, _)) => {
105 if let Ok(incoming) = incoming {
106 if let Some(Ok(message)) = incoming {
107 let raw_response = message.into_text().map_err(|e| {
108 ProtocolError::coerce_static_from_str(e.to_string().as_str())
109 });
110 let response: Result<AbsintheWSResponse> = raw_response.and_then(|r| {
111 serde_json::from_str(&r).map_err(|e| {
112 ProtocolError::coerce_static_from_str(e.to_string().as_str())
113 })
114 });
115 match response {
116 Ok(response) => {
117 trace!(id = ?response.message_id(), "RECV success");
118 let _ = message_broker_link
119 .send(BrokerAction::Message(Ok(response)));
120 }
121 Err(e) => {
122 error!(error = %e, "RECV invalid response message");
123 let _ = message_broker_link.send(BrokerAction::Message(Err(e)));
124 break;
125 }
126 }
127 } else {
128 error!("RECV channel error");
129 let error = ProtocolError(
130 "incoming channel died or errored, likely disconnected",
131 );
132 let _ = message_broker_link.send(BrokerAction::Message(Err(error)));
133 break;
134 }
135 } else {
136 error!("RECV timed out");
137 let error = ProtocolError("incoming WS timed out, likely disconnected");
138 let _ = message_broker_link.send(BrokerAction::Message(Err(error)));
139 break;
140 }
141 }
142 };
143 }
144 error!("DISCONNECT");
145 let error = ProtocolError("Disconnected.");
146 message_broker_link
147 .send(BrokerAction::Message(Err(error)))
148 .ok();
149 });
150}
151
152fn global_subscription_loop<T: NashProtocolSubscription + Send + Sync + 'static>(
153 mut callback_channel: mpsc::UnboundedReceiver<Result<AbsintheWSResponse>>,
154 user_callback_sender: mpsc::UnboundedSender<
155 Result<ResponseOrError<<T as NashProtocolSubscription>::SubscriptionResponse>>,
156 >,
157 global_subscription_sender: mpsc::UnboundedSender<
158 Result<ResponseOrError<SubscriptionResponse>>,
159 >,
160 request: T,
161 state: Arc<RwLock<State>>,
162) {
163 tokio::spawn(async move {
164 loop {
165 let response = callback_channel.recv().await;
166 match response {
168 Some(Ok(response)) => {
169 if let Ok(json_payload) = response.subscription_json_payload() {
171 let output = match request
173 .subscription_response_from_json(json_payload.clone(), state.clone())
174 .await
175 {
176 Ok(response) => {
177 match response {
178 ResponseOrError::Error(err_resp) => {
179 Ok(ResponseOrError::Error(err_resp))
180 }
181 response => {
182 let sub_response = response.response().unwrap();
184 match request
185 .process_subscription_response(
186 sub_response,
187 state.clone(),
188 )
189 .await
190 {
191 Ok(_) => Ok(response),
192 Err(e) => Err(e),
193 }
194 }
195 }
196 }
197 Err(e) => Err(e),
198 };
199 if let Err(_e) = user_callback_sender.send(output) {
201 }
204
205 if let Err(_e) = global_subscription_sender.send(
207 request
208 .wrap_response_as_any_subscription(json_payload, state.clone())
209 .await,
210 ) {
211 break;
212 }
213 } else {
214 break;
216 }
217 }
218 Some(Err(e)) => {
219 let _ = global_subscription_sender.send(Err(e));
221 break;
224 }
225 None => {
226 let _ = global_subscription_sender
227 .send(Err(ProtocolError("channel returned None. dead?")));
228 break;
229 }
230 }
231 }
232 });
233}
234
235pub enum BrokerAction {
237 RegisterRequest(
238 u64,
239 oneshot::Sender<Result<AbsintheWSResponse>>,
240 oneshot::Sender<bool>,
241 ),
242 RegisterSubscription(String, mpsc::UnboundedSender<Result<AbsintheWSResponse>>),
243 Message(Result<AbsintheWSResponse>),
244}
245
246struct MessageBroker {
247 link: mpsc::UnboundedSender<BrokerAction>,
248}
249
250impl MessageBroker {
251 pub fn new() -> Self {
252 let (link, mut internal_receiver) = mpsc::unbounded_channel();
253 tokio::spawn(async move {
254 let mut request_map = HashMap::new();
255 let mut subscription_map = HashMap::new();
256 loop {
257 if let Some(next_incoming) = internal_receiver.recv().await {
258 match next_incoming {
259 BrokerAction::RegisterRequest(id, channel, _ready_tx) => {
261 trace!(%id, "BROKER request");
262 request_map.insert(id, channel);
263 }
265 BrokerAction::RegisterSubscription(id, channel) => {
266 trace!(%id, "BROKER subscription");
267 subscription_map.insert(id, channel);
268 }
269 BrokerAction::Message(Ok(response)) => {
271 if let Some(id) = response.subscription_id() {
273 if let Some(channel) = subscription_map.get_mut(&id) {
274 if let Err(_ignore) = channel.send(Ok(response)) {
277 break;
279 }
280 }
281 }
282 else if let Some(id) = response.message_id() {
284 if let Some(channel) = request_map.remove(&id) {
285 if channel.send(Ok(response)).is_ok() {
286 trace!(id, "BROKER response");
287 } else {
288 break;
290 }
291 } else {
292 if id != HEARTBEAT_MESSAGE_ID {
293 warn!(id, ?response, "BROKER response without return channel");
294 }
295 }
296 } else {
297 warn!(?response, "BROKER response without id");
298 }
299 }
300 BrokerAction::Message(Err(e)) => {
301 error!(error = %e, "BROKER channel error");
302 for (_id, channel) in subscription_map.drain() {
304 let _ = channel.send(Err(e.clone()));
305 }
306 for (_id, channel) in request_map.drain() {
307 let _ = channel.send(Err(e.clone()));
308 }
309 break;
311 }
312 }
313 } else {
314 break;
315 }
316 }
317 });
318 Self { link }
319 }
320}
321
322pub struct WsClientState {
323 ws_outgoing_sender: mpsc::UnboundedSender<(AbsintheWSRequest, Option<oneshot::Receiver<bool>>)>,
324 ws_disconnect_sender: mpsc::UnboundedSender<()>,
325 global_subscription_sender:
326 mpsc::UnboundedSender<Result<ResponseOrError<SubscriptionResponse>>>,
327 next_message_id: Arc<AtomicU64>,
328 message_broker: MessageBroker,
329
330 client_id: u64,
331 timeout: Duration,
332}
333
334impl WsClientState {
335 fn incr_message_id(&self) -> u64 {
336 self.next_message_id.fetch_add(1, Ordering::SeqCst)
337 }
338}
339
340pub struct InnerClient {
341 pub(crate) ws_state: WsClientState,
342 pub(crate) http_state: HttpClientState,
343 pub state: Arc<RwLock<State>>,
344}
345
346impl InnerClient {
347 pub async fn setup(
348 mut state: State,
349 client_id: u64,
350 env: Environment,
351 timeout: Duration,
352 affiliate_code: Option<String>,
353 turn_off_sign_states: bool,
354 ) -> Result<(
355 Self,
356 mpsc::UnboundedReceiver<Result<ResponseOrError<SubscriptionResponse>>>,
357 )> {
358 state.affiliate_code = affiliate_code;
359 state.dont_sign_states = turn_off_sign_states;
360 let (ws_state, global_subscription_receiver) =
361 Self::setup_ws(&mut state, client_id, env, timeout).await?;
362 let http_state = Self::setup_http(&mut state, env, timeout).await?;
363 let client = InnerClient {
364 ws_state,
365 http_state,
366 state: Arc::new(RwLock::new(state)),
367 };
368 Ok((client, global_subscription_receiver))
369 }
370 pub(crate) async fn setup_ws(
372 state: &mut State,
373 client_id: u64,
374 env: Environment,
375 timeout: Duration,
376 ) -> Result<(
377 WsClientState,
378 mpsc::UnboundedReceiver<Result<ResponseOrError<SubscriptionResponse>>>,
379 )> {
380 let version = "2.0.0";
381 let domain = env.url();
382 let conn_path = match &state.signer {
384 Some(signer) => format!(
385 "wss://{}/api/socket/websocket?token={}&vsn={}",
386 domain, signer.api_keys.session_id, version
387 ),
388 None => format!("wss://{}/api/socket/websocket?vsn={}", domain, version),
389 };
390
391 let (socket, _response) = connect_async(&conn_path).await.map_err(|error| {
393 ProtocolError::coerce_static_from_str(&format!("Could not connect to WS: {}", error))
394 })?;
395
396 let (ws_outgoing_sender, ws_outgoing_receiver) = mpsc::unbounded_channel();
398 let (ws_disconnect_sender, ws_disconnect_receiver) = mpsc::unbounded_channel();
399 let (global_subscription_sender, global_subscription_receiver) = mpsc::unbounded_channel();
400
401 let message_broker = MessageBroker::new();
402
403 spawn_sender_loop(
405 timeout,
406 socket,
407 ws_outgoing_receiver,
408 ws_disconnect_receiver,
409 message_broker.link.clone(),
410 );
411
412 let message_id = 1;
414 ws_outgoing_sender
415 .send((AbsintheWSRequest::init_msg(client_id, message_id), None))
416 .map_err(|_| ProtocolError("Could not initialize connection with Nash"))?;
417
418 spawn_heartbeat_loop(timeout, client_id, ws_outgoing_sender.clone());
420
421 let client_state = WsClientState {
422 ws_outgoing_sender,
423 ws_disconnect_sender,
424 global_subscription_sender,
425 message_broker,
426 next_message_id: Arc::new(AtomicU64::new(message_id + 1)),
427 client_id,
428 timeout,
429 };
430 Ok((client_state, global_subscription_receiver))
431 }
432
433 async fn request(
435 &self,
436 request: serde_json::Value,
437 ) -> Result<oneshot::Receiver<Result<AbsintheWSResponse>>> {
438 let message_id = self.ws_state.incr_message_id();
439 let graphql_msg = AbsintheWSRequest::new(
440 self.ws_state.client_id,
441 message_id,
442 AbsintheTopic::Control,
443 AbsintheEvent::Doc,
444 Some(request),
445 );
446 let (for_broker, callback_channel) = oneshot::channel();
448 let (ready_tx, ready_rx) = oneshot::channel();
449 let broker_link = self.ws_state.message_broker.link.clone();
450 trace!(id = %message_id, "attached id");
452 broker_link
453 .send(BrokerAction::RegisterRequest(
454 message_id, for_broker, ready_tx,
455 ))
456 .map_err(|_| ProtocolError("Could not register request with broker"))?;
457 self.ws_state
459 .ws_outgoing_sender
460 .send((graphql_msg, Some(ready_rx)))
461 .map_err(|_| ProtocolError("Request failed to send over channel"))?;
462 Ok(callback_channel)
464 }
465
466 async fn execute_protocol<T: NashProtocol>(
470 &self,
471 request: T,
472 ) -> Result<ResponseOrError<T::Response>> {
473 let query = request.graphql(self.state.clone()).await?;
474 let ws_response = tokio::time::timeout(self.ws_state.timeout, self.request(query).await?)
475 .await
476 .map_err(|_| ProtocolError("Request timeout"))?
477 .map_err(|_| ProtocolError("Failed to receive response from return channel"))??;
478 let json_payload = ws_response.json_payload()?;
479 let protocol_response = request
480 .response_from_json(json_payload, self.state.clone())
481 .await?;
482 match protocol_response {
483 ResponseOrError::Response(ref response) => {
484 request
485 .process_response(&response.data, self.state.clone())
486 .await?;
487 }
488 ResponseOrError::Error(ref error_response) => {
489 request
490 .process_error(error_response, self.state.clone())
491 .await?;
492 }
493 }
494 Ok(protocol_response)
495 }
496
497 #[async_recursion]
498 pub async fn run<T: NashProtocolPipeline + Clone>(
499 &self,
500 request: T,
501 ) -> Result<ResponseOrError<<T::ActionType as NashProtocol>::Response>> {
502 async {
503 let response = {
504 if let Some(_permit) = request.acquire_permit(self.state.clone()).await {
505 self.run_helper(request).await
506 } else {
507 self.run_helper(request).await
508 }
509 };
510 if let Err(ref e) = response {
511 error!(error = %e, "request error");
512 }
513 response
514 }
515 .instrument(trace_span!(
516 "RUN (ws)",
517 request = type_name::<T>(),
518 id = %rand::thread_rng().gen::<u32>()))
519 .await
520 }
521
522 async fn run_helper<T: NashProtocolPipeline + Clone>(
524 &self,
525 request: T,
526 ) -> Result<ResponseOrError<<T::ActionType as NashProtocol>::Response>> {
527 let before_actions = request.run_before(self.state.clone()).await?;
529 if let Some(actions) = before_actions {
530 for action in actions {
531 self.run(action).await?;
532 }
533 }
534 let mut protocol_state = request.init_state(self.state.clone()).await;
536 loop {
538 if let Some(protocol_request) = request
539 .next_step(&protocol_state, self.state.clone())
540 .await?
541 {
542 let protocol_response = self.execute_protocol(protocol_request).await?;
543 if protocol_response.is_error() {
545 Self::manage_client_error(
546 self.state.clone(),
547 protocol_response.error().unwrap(),
548 )
549 .await;
550
551 return Ok(ResponseOrError::Error(
552 protocol_response
553 .consume_error()
554 .expect("Destructure error after check. Impossible to fail."),
555 ));
556 }
557 request
559 .process_step(
560 protocol_response
561 .consume_response()
562 .expect("Destructure response after check. Impossible to fail."),
563 &mut protocol_state,
564 )
565 .await;
566 } else {
567 break;
569 }
570 }
571 let after_actions = request.run_after(self.state.clone()).await?;
573 if let Some(actions) = after_actions {
575 for action in actions {
576 self.run(action).await?;
577 }
578 }
579 request.output(protocol_state)
581 }
582
583 pub async fn subscribe_protocol<T: NashProtocolSubscription + Send + Sync + 'static>(
585 &self,
586 request: T,
587 ) -> Result<
588 mpsc::UnboundedReceiver<
589 Result<ResponseOrError<<T as NashProtocolSubscription>::SubscriptionResponse>>,
590 >,
591 > {
592 let query = request.graphql(self.state.clone()).await?;
593 let subscription_response = self
595 .request(query)
596 .await?
597 .await
598 .map_err(|_| ProtocolError("Could not get subscription response"))??;
599 let (for_broker, callback_channel) = mpsc::unbounded_channel();
601 let broker_link = self.ws_state.message_broker.link.clone();
602 let subscription_id = subscription_response
605 .subscription_setup_id()
606 .ok_or(ProtocolError("Response does not include subscription id"))?;
607 broker_link
608 .send(BrokerAction::RegisterSubscription(
609 subscription_id,
610 for_broker,
611 ))
612 .map_err(|_| ProtocolError("Could not register subscription with broker"))?;
613
614 let (user_callback_sender, user_callback_receiver) = mpsc::unbounded_channel();
615
616 global_subscription_loop(
617 callback_channel,
618 user_callback_sender,
619 self.ws_state.global_subscription_sender.clone(),
620 request.clone(),
621 self.state.clone(),
622 );
623
624 Ok(user_callback_receiver)
625 }
626
627 pub async fn disconnect(&self) {
628 self.ws_state.ws_disconnect_sender.send(()).ok();
629 }
630
631 pub async fn manage_client_error(_state: Arc<RwLock<State>>, response: &ErrorResponse) {
632 error!(?response, "client error response");
633 }
634}
635
636pub struct Client {
637 pub inner: Arc<InnerClient>,
638 pub(crate) global_subscription_receiver:
639 mpsc::UnboundedReceiver<Result<ResponseOrError<SubscriptionResponse>>>,
640}
641
642impl Client {
643 pub async fn from_keys_path(
646 keys_path: Option<&str>,
647 affiliate_code: Option<String>,
648 turn_off_sign_states: bool,
649 client_id: u64,
650 env: Environment,
651 timeout: Duration,
652 ) -> Result<Self> {
653 let state = State::from_keys_path(keys_path)?;
654 Self::setup(
655 state,
656 affiliate_code,
657 turn_off_sign_states,
658 client_id,
659 env,
660 timeout,
661 )
662 .await
663 }
664
665 pub async fn from_keys(
667 secret: &str,
668 session: &str,
669 affiliate_code: Option<String>,
670 turn_off_sign_states: bool,
671 client_id: u64,
672 env: Environment,
673 timeout: Duration,
674 ) -> Result<Self> {
675 let state = State::from_keys(secret, session)?;
676 Self::setup(
677 state,
678 affiliate_code,
679 turn_off_sign_states,
680 client_id,
681 env,
682 timeout,
683 )
684 .await
685 }
686
687 async fn setup(
688 state: State,
689 affiliate_code: Option<String>,
690 turn_off_sign_states: bool,
691 client_id: u64,
692 env: Environment,
693 timeout: Duration,
694 ) -> Result<Self> {
695 let (inner, global_subscription_receiver) = InnerClient::setup(
696 state,
697 client_id,
698 env,
699 timeout,
700 affiliate_code,
701 turn_off_sign_states,
702 )
703 .await?;
704 let client = Self {
705 inner: Arc::new(inner),
706 global_subscription_receiver,
707 };
708 client.run(nash_protocol::protocol::list_markets::ListMarketsRequest).await?;
710 client.run(nash_protocol::protocol::dh_fill_pool::DhFillPoolRequest::new(Blockchain::NEO, MAX_R_VAL_POOL_SIZE)?).await?;
711 client.run(nash_protocol::protocol::dh_fill_pool::DhFillPoolRequest::new(Blockchain::Bitcoin, MAX_R_VAL_POOL_SIZE)?).await?;
713 Ok(client)
714 }
715
716 #[inline]
721 pub async fn run<T: NashProtocolPipeline + Clone>(
722 &self,
723 request: T,
724 ) -> Result<ResponseOrError<<T::ActionType as NashProtocol>::Response>> {
725 self.inner.run(request).await
726 }
727
728 #[inline]
730 pub async fn subscribe_protocol<T: NashProtocolSubscription + Send + Sync + 'static>(
731 &self,
732 request: T,
733 ) -> Result<
734 mpsc::UnboundedReceiver<
735 Result<ResponseOrError<<T as NashProtocolSubscription>::SubscriptionResponse>>,
736 >,
737 > {
738 self.inner.subscribe_protocol(request).await
739 }
740
741 #[inline]
743 pub async fn disconnect(&self) {
744 self.inner.disconnect().await;
745 }
746
747 pub async fn turn_off_sign_states(&self) {
748 let mut state = self.inner.state.write().await;
749 state.dont_sign_states = true;
750 }
751
752 pub fn start_background_sign_states_loop(&self, interval: Duration) {
753 let weak_inner = Arc::downgrade(&self.inner);
754 tokio::spawn(async move {
755 while let Some(inner) = weak_inner.upgrade() {
756 let tick_start = tokio::time::Instant::now();
757 let remaining_orders = inner.state.read().await.get_remaining_orders();
758 if remaining_orders < 10 {
759 trace!(%remaining_orders, "sign_all_states triggered");
760 let request = inner
761 .run(nash_protocol::protocol::sign_all_states::SignAllStates::new())
762 .await;
763 if let Err(e) = request {
764 error!(error = %e, "sign_all_states errored");
765 }
766 }
767 tokio::time::sleep_until(tick_start + interval).await;
768 }
769 });
770 }
771
772 pub fn start_background_fill_pool_loop(
773 &self,
774 interval: Duration,
775 chains: Option<Vec<Blockchain>>,
776 ) {
777 let weak_inner = Arc::downgrade(&self.inner);
778 tokio::spawn(async move {
779 while let Some(inner) = weak_inner.upgrade() {
780 let tick_start = tokio::time::Instant::now();
781 let fill_pool_schedules = inner
782 .state
783 .read()
784 .await
785 .acquire_fill_pool_schedules(chains.as_ref(), None)
786 .await;
787 match fill_pool_schedules {
788 Ok(fill_pool_schedules) => {
789 for (request, permit) in fill_pool_schedules {
790 let response = inner.run_http_with_permit(request, permit).await;
791 if let Err(e) = response {
792 error!(error = %e, "request errored");
793 }
794 }
795 }
796 Err(e) => error!(%e, "getting fill pool schedules errored"),
797 }
798 tokio::time::sleep_until(tick_start + interval).await;
799 }
800 });
801 }
802}