pyth_lazer_client/
resilient_ws_connection.rs

1use std::time::Duration;
2
3use backoff::{backoff::Backoff, ExponentialBackoff};
4use futures_util::StreamExt;
5use pyth_lazer_protocol::api::{SubscribeRequest, SubscriptionId, UnsubscribeRequest, WsRequest};
6use tokio::{pin, select, sync::mpsc, time::Instant};
7use tracing::{error, info, warn};
8use url::Url;
9
10use crate::{
11    ws_connection::{AnyResponse, PythLazerWSConnection},
12    CHANNEL_CAPACITY,
13};
14use anyhow::{bail, Context, Result};
15
16const BACKOFF_RESET_DURATION: Duration = Duration::from_secs(10);
17
18pub struct PythLazerResilientWSConnection {
19    request_sender: mpsc::Sender<WsRequest>,
20}
21
22impl PythLazerResilientWSConnection {
23    /// Creates a new resilient WebSocket client instance
24    ///
25    /// # Arguments
26    /// * `endpoint` - The WebSocket URL of the Lazer service
27    /// * `access_token` - Access token for authentication
28    /// * `sender` - A sender to send responses back to the client
29    ///
30    /// # Returns
31    /// Returns a new client instance (not yet connected)
32    pub fn new(
33        endpoint: Url,
34        access_token: String,
35        backoff: ExponentialBackoff,
36        timeout: Duration,
37        sender: mpsc::Sender<AnyResponse>,
38    ) -> Self {
39        let (request_sender, mut request_receiver) = mpsc::channel(CHANNEL_CAPACITY);
40        let mut task =
41            PythLazerResilientWSConnectionTask::new(endpoint, access_token, backoff, timeout);
42
43        tokio::spawn(async move {
44            if let Err(e) = task.run(sender, &mut request_receiver).await {
45                error!("Resilient WebSocket connection task failed: {}", e);
46            }
47        });
48
49        Self { request_sender }
50    }
51
52    pub async fn subscribe(&mut self, request: SubscribeRequest) -> Result<()> {
53        self.request_sender
54            .send(WsRequest::Subscribe(request))
55            .await
56            .context("Failed to send subscribe request")?;
57        Ok(())
58    }
59
60    pub async fn unsubscribe(&mut self, subscription_id: SubscriptionId) -> Result<()> {
61        self.request_sender
62            .send(WsRequest::Unsubscribe(UnsubscribeRequest {
63                subscription_id,
64            }))
65            .await
66            .context("Failed to send unsubscribe request")?;
67        Ok(())
68    }
69}
70
71struct PythLazerResilientWSConnectionTask {
72    endpoint: Url,
73    access_token: String,
74    subscriptions: Vec<SubscribeRequest>,
75    backoff: ExponentialBackoff,
76    timeout: Duration,
77}
78
79impl PythLazerResilientWSConnectionTask {
80    pub fn new(
81        endpoint: Url,
82        access_token: String,
83        backoff: ExponentialBackoff,
84        timeout: Duration,
85    ) -> Self {
86        Self {
87            endpoint,
88            access_token,
89            subscriptions: Vec::new(),
90            backoff,
91            timeout,
92        }
93    }
94
95    pub async fn run(
96        &mut self,
97        response_sender: mpsc::Sender<AnyResponse>,
98        request_receiver: &mut mpsc::Receiver<WsRequest>,
99    ) -> Result<()> {
100        loop {
101            let start_time = Instant::now();
102            if let Err(e) = self.start(response_sender.clone(), request_receiver).await {
103                // If a connection was working for BACKOFF_RESET_DURATION
104                // and timeout + 1sec, it was considered successful therefore reset the backoff
105                if start_time.elapsed() > BACKOFF_RESET_DURATION
106                    && start_time.elapsed() > self.timeout + Duration::from_secs(1)
107                {
108                    self.backoff.reset();
109                }
110
111                let delay = self.backoff.next_backoff();
112                match delay {
113                    Some(d) => {
114                        info!("WebSocket connection failed: {}. Retrying in {:?}", e, d);
115                        tokio::time::sleep(d).await;
116                    }
117                    None => {
118                        bail!(
119                            "Max retries reached for WebSocket connection to {}, this should never happen, please contact developers",
120                            self.endpoint
121                        );
122                    }
123                }
124            }
125        }
126    }
127
128    pub async fn start(
129        &mut self,
130        sender: mpsc::Sender<AnyResponse>,
131        request_receiver: &mut mpsc::Receiver<WsRequest>,
132    ) -> Result<()> {
133        let mut ws_connection =
134            PythLazerWSConnection::new(self.endpoint.clone(), self.access_token.clone())?;
135        let stream = ws_connection.start().await?;
136        pin!(stream);
137
138        for subscription in self.subscriptions.clone() {
139            ws_connection
140                .send_request(WsRequest::Subscribe(subscription))
141                .await?;
142        }
143        loop {
144            let timeout_response = tokio::time::timeout(self.timeout, stream.next());
145
146            select! {
147                response = timeout_response => {
148                    match response {
149                        Ok(Some(response)) => match response {
150                            Ok(response) => {
151                                sender
152                                    .send(response)
153                                    .await
154                                    .context("Failed to send response")?;
155                            }
156                            Err(e) => {
157                                bail!("WebSocket stream error: {}", e);
158                            }
159                        },
160                        Ok(None) => {
161                            bail!("WebSocket stream ended unexpectedly");
162                        }
163                        Err(_elapsed) => {
164                            bail!("WebSocket stream timed out");
165                        }
166                    }
167                }
168                Some(request) = request_receiver.recv() => {
169                   match request {
170                        WsRequest::Subscribe(request) => {
171                            self.subscribe(&mut ws_connection, request).await?;
172                        }
173                        WsRequest::Unsubscribe(request) => {
174                            self.unsubscribe(&mut ws_connection, request).await?;
175                        }
176                   }
177                }
178            }
179        }
180    }
181
182    pub async fn subscribe(
183        &mut self,
184        ws_connection: &mut PythLazerWSConnection,
185        request: SubscribeRequest,
186    ) -> Result<()> {
187        self.subscriptions.push(request.clone());
188        ws_connection.subscribe(request).await
189    }
190
191    pub async fn unsubscribe(
192        &mut self,
193        ws_connection: &mut PythLazerWSConnection,
194        request: UnsubscribeRequest,
195    ) -> Result<()> {
196        if let Some(index) = self
197            .subscriptions
198            .iter()
199            .position(|r| r.subscription_id == request.subscription_id)
200        {
201            self.subscriptions.remove(index);
202        } else {
203            warn!(
204                "Unsubscribe called for non-existent subscription: {:?}",
205                request.subscription_id
206            );
207        }
208        ws_connection.unsubscribe(request).await
209    }
210}