pyth_lazer_client/
resilient_ws_connection.rs1use std::time::Duration;
2
3use backoff::{backoff::Backoff, ExponentialBackoff};
4use futures_util::StreamExt;
5use pyth_lazer_protocol::api::{SubscribeRequest, SubscriptionId, UnsubscribeRequest, WsRequest};
6use tokio::{pin, select, sync::mpsc, time::Instant};
7use tracing::{error, info, warn};
8use url::Url;
9
10use crate::{
11 ws_connection::{AnyResponse, PythLazerWSConnection},
12 CHANNEL_CAPACITY,
13};
14use anyhow::{bail, Context, Result};
15
16const BACKOFF_RESET_DURATION: Duration = Duration::from_secs(10);
17
18pub struct PythLazerResilientWSConnection {
19 request_sender: mpsc::Sender<WsRequest>,
20}
21
22impl PythLazerResilientWSConnection {
23 pub fn new(
33 endpoint: Url,
34 access_token: String,
35 backoff: ExponentialBackoff,
36 timeout: Duration,
37 sender: mpsc::Sender<AnyResponse>,
38 ) -> Self {
39 let (request_sender, mut request_receiver) = mpsc::channel(CHANNEL_CAPACITY);
40 let mut task =
41 PythLazerResilientWSConnectionTask::new(endpoint, access_token, backoff, timeout);
42
43 tokio::spawn(async move {
44 if let Err(e) = task.run(sender, &mut request_receiver).await {
45 error!("Resilient WebSocket connection task failed: {}", e);
46 }
47 });
48
49 Self { request_sender }
50 }
51
52 pub async fn subscribe(&mut self, request: SubscribeRequest) -> Result<()> {
53 self.request_sender
54 .send(WsRequest::Subscribe(request))
55 .await
56 .context("Failed to send subscribe request")?;
57 Ok(())
58 }
59
60 pub async fn unsubscribe(&mut self, subscription_id: SubscriptionId) -> Result<()> {
61 self.request_sender
62 .send(WsRequest::Unsubscribe(UnsubscribeRequest {
63 subscription_id,
64 }))
65 .await
66 .context("Failed to send unsubscribe request")?;
67 Ok(())
68 }
69}
70
71struct PythLazerResilientWSConnectionTask {
72 endpoint: Url,
73 access_token: String,
74 subscriptions: Vec<SubscribeRequest>,
75 backoff: ExponentialBackoff,
76 timeout: Duration,
77}
78
79impl PythLazerResilientWSConnectionTask {
80 pub fn new(
81 endpoint: Url,
82 access_token: String,
83 backoff: ExponentialBackoff,
84 timeout: Duration,
85 ) -> Self {
86 Self {
87 endpoint,
88 access_token,
89 subscriptions: Vec::new(),
90 backoff,
91 timeout,
92 }
93 }
94
95 pub async fn run(
96 &mut self,
97 response_sender: mpsc::Sender<AnyResponse>,
98 request_receiver: &mut mpsc::Receiver<WsRequest>,
99 ) -> Result<()> {
100 loop {
101 let start_time = Instant::now();
102 if let Err(e) = self.start(response_sender.clone(), request_receiver).await {
103 if start_time.elapsed() > BACKOFF_RESET_DURATION
106 && start_time.elapsed() > self.timeout + Duration::from_secs(1)
107 {
108 self.backoff.reset();
109 }
110
111 let delay = self.backoff.next_backoff();
112 match delay {
113 Some(d) => {
114 info!("WebSocket connection failed: {}. Retrying in {:?}", e, d);
115 tokio::time::sleep(d).await;
116 }
117 None => {
118 bail!(
119 "Max retries reached for WebSocket connection to {}, this should never happen, please contact developers",
120 self.endpoint
121 );
122 }
123 }
124 }
125 }
126 }
127
128 pub async fn start(
129 &mut self,
130 sender: mpsc::Sender<AnyResponse>,
131 request_receiver: &mut mpsc::Receiver<WsRequest>,
132 ) -> Result<()> {
133 let mut ws_connection =
134 PythLazerWSConnection::new(self.endpoint.clone(), self.access_token.clone())?;
135 let stream = ws_connection.start().await?;
136 pin!(stream);
137
138 for subscription in self.subscriptions.clone() {
139 ws_connection
140 .send_request(WsRequest::Subscribe(subscription))
141 .await?;
142 }
143 loop {
144 let timeout_response = tokio::time::timeout(self.timeout, stream.next());
145
146 select! {
147 response = timeout_response => {
148 match response {
149 Ok(Some(response)) => match response {
150 Ok(response) => {
151 sender
152 .send(response)
153 .await
154 .context("Failed to send response")?;
155 }
156 Err(e) => {
157 bail!("WebSocket stream error: {}", e);
158 }
159 },
160 Ok(None) => {
161 bail!("WebSocket stream ended unexpectedly");
162 }
163 Err(_elapsed) => {
164 bail!("WebSocket stream timed out");
165 }
166 }
167 }
168 Some(request) = request_receiver.recv() => {
169 match request {
170 WsRequest::Subscribe(request) => {
171 self.subscribe(&mut ws_connection, request).await?;
172 }
173 WsRequest::Unsubscribe(request) => {
174 self.unsubscribe(&mut ws_connection, request).await?;
175 }
176 }
177 }
178 }
179 }
180 }
181
182 pub async fn subscribe(
183 &mut self,
184 ws_connection: &mut PythLazerWSConnection,
185 request: SubscribeRequest,
186 ) -> Result<()> {
187 self.subscriptions.push(request.clone());
188 ws_connection.subscribe(request).await
189 }
190
191 pub async fn unsubscribe(
192 &mut self,
193 ws_connection: &mut PythLazerWSConnection,
194 request: UnsubscribeRequest,
195 ) -> Result<()> {
196 if let Some(index) = self
197 .subscriptions
198 .iter()
199 .position(|r| r.subscription_id == request.subscription_id)
200 {
201 self.subscriptions.remove(index);
202 } else {
203 warn!(
204 "Unsubscribe called for non-existent subscription: {:?}",
205 request.subscription_id
206 );
207 }
208 ws_connection.unsubscribe(request).await
209 }
210}