titan_rust_client/
connection.rs

1//! WebSocket connection management with auto-reconnect and stream resumption.
2
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use titan_api_codec::codec::ws::v1::ClientCodec;
10use titan_api_codec::codec::Codec;
11use titan_api_types::ws::v1::{
12    ClientRequest, RequestData, ResponseError, ResponseSuccess, ServerMessage, StreamData,
13    SwapQuoteRequest,
14};
15use tokio::net::TcpStream;
16use tokio::sync::{mpsc, oneshot, RwLock};
17use tokio_tungstenite::tungstenite::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
19
20use crate::config::TitanConfig;
21use crate::error::TitanClientError;
22use crate::state::ConnectionState;
23
24type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
25type ResponseResult = Result<ResponseSuccess, ResponseError>;
26type PendingRequestsMap = Arc<RwLock<HashMap<u32, oneshot::Sender<ResponseResult>>>>;
27
28/// Initial backoff delay in milliseconds.
29const INITIAL_BACKOFF_MS: u64 = 100;
30
31/// Information needed to resume a stream after reconnection.
32#[derive(Clone)]
33pub struct ResumableStream {
34    /// The original request used to create the stream.
35    pub request: SwapQuoteRequest,
36    /// Channel to send stream data to.
37    pub sender: mpsc::Sender<StreamData>,
38}
39
40type ResumableStreamsMap = Arc<RwLock<HashMap<u32, ResumableStream>>>;
41
42/// Internal message for sending requests through the connection
43pub struct PendingRequest {
44    pub request: ClientRequest,
45    pub response_tx: oneshot::Sender<ResponseResult>,
46}
47
48/// Manages a WebSocket connection to the Titan API with auto-reconnect.
49pub struct Connection {
50    #[allow(dead_code)]
51    config: TitanConfig,
52    request_id: AtomicU32,
53    sender: mpsc::Sender<PendingRequest>,
54    state_tx: tokio::sync::watch::Sender<ConnectionState>,
55    #[allow(dead_code)]
56    pending_requests: PendingRequestsMap,
57    resumable_streams: ResumableStreamsMap,
58}
59
60impl Connection {
61    /// Create a new connection with the given config.
62    ///
63    /// Connects eagerly and auto-reconnects on disconnection.
64    #[tracing::instrument(skip_all)]
65    pub async fn connect(config: TitanConfig) -> Result<Self, TitanClientError> {
66        let (state_tx, _state_rx) = tokio::sync::watch::channel(ConnectionState::Disconnected {
67            reason: "Connecting...".to_string(),
68        });
69
70        let pending_requests: PendingRequestsMap = Arc::new(RwLock::new(HashMap::new()));
71        let resumable_streams: ResumableStreamsMap = Arc::new(RwLock::new(HashMap::new()));
72
73        // Connect to WebSocket
74        let ws_stream = Self::establish_connection(&config).await?;
75
76        // Create channel for sending requests
77        let (sender, receiver) = mpsc::channel::<PendingRequest>(32);
78
79        // Spawn background task with reconnection support
80        let pending_clone = pending_requests.clone();
81        let streams_clone = resumable_streams.clone();
82        let state_tx_clone = state_tx.clone();
83        let config_clone = config.clone();
84
85        tokio::spawn(Self::run_connection_loop_with_reconnect(
86            ws_stream,
87            receiver,
88            pending_clone,
89            streams_clone,
90            state_tx_clone,
91            config_clone,
92        ));
93
94        state_tx.send_replace(ConnectionState::Connected);
95
96        Ok(Self {
97            config,
98            request_id: AtomicU32::new(1),
99            sender,
100            state_tx,
101            pending_requests,
102            resumable_streams,
103        })
104    }
105
106    /// Establish WebSocket connection with authentication.
107    async fn establish_connection(config: &TitanConfig) -> Result<WsStream, TitanClientError> {
108        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
109
110        // Build URL with auth token as query param
111        // Only add trailing slash if URL has no path (e.g., ws://host:port -> ws://host:port/)
112        let url = if config.url.contains("/ws") || config.url.ends_with('/') {
113            // URL already has a path, just append query param
114            format!("{}?auth={}", config.url, config.token)
115        } else {
116            // URL has no path, add trailing slash first
117            format!("{}/?auth={}", config.url, config.token)
118        };
119
120        let mut request = url.into_client_request().map_err(|e| {
121            TitanClientError::Unexpected(anyhow::anyhow!("Failed to build request: {}", e))
122        })?;
123
124        // Add the Titan subprotocol header
125        request.headers_mut().insert(
126            "Sec-WebSocket-Protocol",
127            titan_api_types::ws::v1::WEBSOCKET_SUBPROTO_BASE
128                .parse()
129                .unwrap(),
130        );
131
132        // connect_async handles TLS automatically when rustls-tls-native-roots is enabled
133        let (ws_stream, _response) = tokio_tungstenite::connect_async(request)
134            .await
135            .map_err(TitanClientError::WebSocket)?;
136
137        Ok(ws_stream)
138    }
139
140    /// Connection loop with automatic reconnection and stream resumption.
141    async fn run_connection_loop_with_reconnect(
142        initial_ws_stream: WsStream,
143        mut request_rx: mpsc::Receiver<PendingRequest>,
144        pending_requests: PendingRequestsMap,
145        resumable_streams: ResumableStreamsMap,
146        state_tx: tokio::sync::watch::Sender<ConnectionState>,
147        config: TitanConfig,
148    ) {
149        let mut ws_stream = initial_ws_stream;
150        let mut reconnect_attempt: u32 = 0;
151        let mut request_id_counter: u32 = 1;
152
153        loop {
154            // Run the connection loop until disconnection
155            let disconnect_reason = Self::run_single_connection(
156                &mut ws_stream,
157                &mut request_rx,
158                &pending_requests,
159                &resumable_streams,
160                &state_tx,
161                &mut request_id_counter,
162            )
163            .await;
164
165            // Check if request channel is closed (client dropped)
166            if request_rx.is_closed() {
167                tracing::info!("Request channel closed, shutting down connection");
168                break;
169            }
170
171            // Start reconnection attempts
172            reconnect_attempt += 1;
173
174            // Check max attempts
175            if let Some(max) = config.max_reconnect_attempts {
176                if reconnect_attempt > max {
177                    tracing::error!("Max reconnect attempts ({}) reached, giving up", max);
178                    let _ = state_tx.send(ConnectionState::Disconnected {
179                        reason: format!(
180                            "Max reconnect attempts reached. Last error: {}",
181                            disconnect_reason
182                        ),
183                    });
184                    break;
185                }
186            }
187
188            // Calculate backoff delay with exponential increase
189            let backoff_ms = calculate_backoff(reconnect_attempt, config.max_reconnect_delay_ms);
190
191            tracing::info!(
192                attempt = reconnect_attempt,
193                backoff_ms,
194                "Reconnecting after disconnection: {}",
195                disconnect_reason
196            );
197
198            let _ = state_tx.send(ConnectionState::Reconnecting {
199                attempt: reconnect_attempt,
200            });
201
202            // Wait before reconnecting
203            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
204
205            // Attempt to reconnect
206            match Self::establish_connection(&config).await {
207                Ok(new_stream) => {
208                    ws_stream = new_stream;
209                    reconnect_attempt = 0;
210                    let _ = state_tx.send(ConnectionState::Connected);
211                    tracing::info!("Reconnected successfully");
212
213                    // Resume streams after reconnection
214                    Self::resume_streams(
215                        &mut ws_stream,
216                        &resumable_streams,
217                        &mut request_id_counter,
218                    )
219                    .await;
220                }
221                Err(e) => {
222                    tracing::warn!("Reconnection failed: {}", e);
223                    continue;
224                }
225            }
226        }
227
228        // Final cleanup
229        Self::cleanup_pending_requests(&pending_requests).await;
230    }
231
232    /// Resume all active streams after reconnection.
233    async fn resume_streams(
234        ws_stream: &mut WsStream,
235        resumable_streams: &ResumableStreamsMap,
236        request_id_counter: &mut u32,
237    ) {
238        let streams_to_resume: Vec<(u32, ResumableStream)> = {
239            let streams = resumable_streams.read().await;
240            streams.iter().map(|(k, v)| (*k, v.clone())).collect()
241        };
242
243        if streams_to_resume.is_empty() {
244            return;
245        }
246
247        tracing::info!(
248            "Resuming {} streams after reconnection",
249            streams_to_resume.len()
250        );
251
252        let codec = ClientCodec::Uncompressed;
253        let mut encoder = codec.encoder();
254        let mut decoder = codec.decoder();
255
256        for (old_stream_id, resumable) in streams_to_resume {
257            let request_id = *request_id_counter;
258            *request_id_counter += 1;
259
260            let request = ClientRequest {
261                id: request_id,
262                data: RequestData::NewSwapQuoteStream(resumable.request.clone()),
263            };
264
265            // Encode and send the request
266            let encoded = match encoder.encode_mut(&request) {
267                Ok(data) => data.to_vec(),
268                Err(e) => {
269                    tracing::error!("Failed to encode stream resume request: {}", e);
270                    continue;
271                }
272            };
273
274            if let Err(e) = ws_stream.send(Message::Binary(encoded.into())).await {
275                tracing::error!("Failed to send stream resume request: {}", e);
276                continue;
277            }
278
279            // Wait for response to get new stream ID
280            match ws_stream.next().await {
281                Some(Ok(Message::Binary(data))) => {
282                    match decoder.decode_mut(data) {
283                        Ok(ServerMessage::Response(response)) => {
284                            if let Some(stream_info) = response.stream {
285                                let new_stream_id = stream_info.id;
286
287                                // Update the stream mapping
288                                let mut streams = resumable_streams.write().await;
289                                if let Some(stream) = streams.remove(&old_stream_id) {
290                                    streams.insert(new_stream_id, stream);
291                                    tracing::info!(
292                                        old_id = old_stream_id,
293                                        new_id = new_stream_id,
294                                        "Stream resumed with new ID"
295                                    );
296                                }
297                            }
298                        }
299                        Ok(ServerMessage::Error(error)) => {
300                            tracing::error!(
301                                "Failed to resume stream {}: {}",
302                                old_stream_id,
303                                error.message
304                            );
305                            // Remove the failed stream
306                            let mut streams = resumable_streams.write().await;
307                            streams.remove(&old_stream_id);
308                        }
309                        Ok(_) => {
310                            tracing::warn!("Unexpected response type during stream resumption");
311                        }
312                        Err(e) => {
313                            tracing::error!("Failed to decode stream resume response: {}", e);
314                        }
315                    }
316                }
317                Some(Ok(_)) => {
318                    tracing::warn!("Unexpected message type during stream resumption");
319                }
320                Some(Err(e)) => {
321                    tracing::error!("WebSocket error during stream resumption: {}", e);
322                    break;
323                }
324                None => {
325                    tracing::error!("Connection closed during stream resumption");
326                    break;
327                }
328            }
329        }
330    }
331
332    /// Run a single connection until disconnection.
333    async fn run_single_connection(
334        ws_stream: &mut WsStream,
335        request_rx: &mut mpsc::Receiver<PendingRequest>,
336        pending_requests: &PendingRequestsMap,
337        resumable_streams: &ResumableStreamsMap,
338        state_tx: &tokio::sync::watch::Sender<ConnectionState>,
339        request_id_counter: &mut u32,
340    ) -> String {
341        let codec = ClientCodec::Uncompressed;
342        let mut encoder = codec.encoder();
343        let mut decoder = codec.decoder();
344
345        let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
346
347        loop {
348            tokio::select! {
349                Some(pending_req) = request_rx.recv() => {
350                    let request_id = pending_req.request.id;
351                    *request_id_counter = request_id.max(*request_id_counter) + 1;
352
353                    {
354                        let mut pending_map = pending_requests.write().await;
355                        pending_map.insert(request_id, pending_req.response_tx);
356                    }
357
358                    match encoder.encode_mut(&pending_req.request) {
359                        Ok(data) => {
360                            if let Err(e) = ws_sink.send(Message::Binary(data.to_vec().into())).await {
361                                tracing::error!("Failed to send WebSocket message: {}", e);
362                                let mut pending_map = pending_requests.write().await;
363                                if let Some(tx) = pending_map.remove(&request_id) {
364                                    let _ = tx.send(Err(ResponseError {
365                                        request_id,
366                                        code: 0,
367                                        message: format!("Send failed: {}", e),
368                                    }));
369                                }
370                            }
371                        }
372                        Err(e) => {
373                            tracing::error!("Failed to encode request: {}", e);
374                            let mut pending_map = pending_requests.write().await;
375                            if let Some(tx) = pending_map.remove(&request_id) {
376                                let _ = tx.send(Err(ResponseError {
377                                    request_id,
378                                    code: 0,
379                                    message: format!("Encode failed: {}", e),
380                                }));
381                            }
382                        }
383                    }
384                }
385
386                Some(msg_result) = ws_stream_rx.next() => {
387                    match msg_result {
388                        Ok(Message::Binary(data)) => {
389                            match decoder.decode_mut(data) {
390                                Ok(server_msg) => {
391                                    Self::handle_server_message(
392                                        server_msg,
393                                        pending_requests,
394                                        resumable_streams,
395                                    ).await;
396                                }
397                                Err(e) => {
398                                    tracing::error!("Failed to decode server message: {}", e);
399                                }
400                            }
401                        }
402                        Ok(Message::Close(frame)) => {
403                            let reason = frame
404                                .map(|f| f.reason.to_string())
405                                .unwrap_or_else(|| "Server closed connection".to_string());
406                            tracing::warn!("WebSocket closed: {}", reason);
407                            let _ = state_tx.send(ConnectionState::Disconnected {
408                                reason: reason.clone(),
409                            });
410                            return reason;
411                        }
412                        Ok(Message::Ping(data)) => {
413                            let _ = ws_sink.send(Message::Pong(data)).await;
414                        }
415                        Ok(_) => {}
416                        Err(e) => {
417                            let reason = format!("WebSocket error: {}", e);
418                            let error_str = e.to_string();
419                            if error_str.contains("Connection reset without closing handshake") {
420                                tracing::info!("{}", reason);
421                            } else {
422                                tracing::error!("{}", reason);
423                            }
424                            let _ = state_tx.send(ConnectionState::Disconnected {
425                                reason: reason.clone(),
426                            });
427                            return reason;
428                        }
429                    }
430                }
431
432                else => {
433                    return "Channel closed".to_string();
434                }
435            }
436        }
437    }
438
439    /// Handle a message received from the server.
440    async fn handle_server_message(
441        msg: ServerMessage,
442        pending_requests: &PendingRequestsMap,
443        resumable_streams: &ResumableStreamsMap,
444    ) {
445        match msg {
446            ServerMessage::Response(response) => {
447                let mut pending = pending_requests.write().await;
448                if let Some(tx) = pending.remove(&response.request_id) {
449                    let _ = tx.send(Ok(response));
450                }
451            }
452            ServerMessage::Error(error) => {
453                let mut pending = pending_requests.write().await;
454                if let Some(tx) = pending.remove(&error.request_id) {
455                    let _ = tx.send(Err(error));
456                }
457            }
458            ServerMessage::StreamData(data) => {
459                let streams = resumable_streams.read().await;
460                if let Some(stream) = streams.get(&data.id) {
461                    let _ = stream.sender.send(data).await;
462                }
463            }
464            ServerMessage::StreamEnd(end) => {
465                let mut streams = resumable_streams.write().await;
466                streams.remove(&end.id);
467            }
468            ServerMessage::Other(_) => {
469                tracing::warn!("Received unknown server message type");
470            }
471        }
472    }
473
474    /// Cleanup pending requests on final shutdown.
475    async fn cleanup_pending_requests(pending_requests: &PendingRequestsMap) {
476        let mut pending_map = pending_requests.write().await;
477        for (request_id, tx) in pending_map.drain() {
478            let _ = tx.send(Err(ResponseError {
479                request_id,
480                code: 0,
481                message: "Connection closed".to_string(),
482            }));
483        }
484    }
485
486    /// Send a request and wait for response.
487    #[tracing::instrument(skip_all)]
488    pub async fn send_request(
489        &self,
490        data: RequestData,
491    ) -> Result<ResponseSuccess, TitanClientError> {
492        let request_id = self.request_id.fetch_add(1, Ordering::SeqCst);
493        let request = ClientRequest {
494            id: request_id,
495            data,
496        };
497
498        let (response_tx, response_rx) = oneshot::channel();
499
500        self.sender
501            .send(PendingRequest {
502                request,
503                response_tx,
504            })
505            .await
506            .map_err(|_| TitanClientError::Unexpected(anyhow::anyhow!("Connection closed")))?;
507
508        let response = response_rx.await.map_err(|_| {
509            TitanClientError::Unexpected(anyhow::anyhow!("Response channel closed"))
510        })?;
511
512        response.map_err(|e| TitanClientError::ServerError {
513            code: e.code,
514            message: e.message,
515        })
516    }
517
518    /// Register a resumable stream.
519    pub async fn register_stream(
520        &self,
521        stream_id: u32,
522        request: SwapQuoteRequest,
523        sender: mpsc::Sender<StreamData>,
524    ) {
525        let mut streams = self.resumable_streams.write().await;
526        streams.insert(stream_id, ResumableStream { request, sender });
527    }
528
529    /// Unregister a stream.
530    pub async fn unregister_stream(&self, stream_id: u32) {
531        let mut streams = self.resumable_streams.write().await;
532        streams.remove(&stream_id);
533    }
534
535    /// Get a receiver for connection state changes.
536    pub fn state_receiver(&self) -> tokio::sync::watch::Receiver<ConnectionState> {
537        self.state_tx.subscribe()
538    }
539
540    /// Get the current connection state.
541    pub fn state(&self) -> ConnectionState {
542        self.state_tx.borrow().clone()
543    }
544
545    /// Get all active stream IDs.
546    pub async fn active_stream_ids(&self) -> Vec<u32> {
547        let streams = self.resumable_streams.read().await;
548        streams.keys().copied().collect()
549    }
550
551    /// Stop all active streams gracefully.
552    ///
553    /// Sends StopStream for each active stream and clears the stream map.
554    #[tracing::instrument(skip_all)]
555    pub async fn stop_all_streams(&self) {
556        use titan_api_types::ws::v1::StopStreamRequest;
557
558        let stream_ids = self.active_stream_ids().await;
559
560        if stream_ids.is_empty() {
561            return;
562        }
563
564        tracing::info!("Stopping {} active streams", stream_ids.len());
565
566        for stream_id in stream_ids {
567            // Send stop request (fire and forget)
568            let _ = self
569                .send_request(RequestData::StopStream(StopStreamRequest { id: stream_id }))
570                .await;
571        }
572
573        // Clear all streams
574        let mut streams = self.resumable_streams.write().await;
575        streams.clear();
576    }
577
578    /// Graceful shutdown: stop all streams and signal connection loop to exit.
579    #[tracing::instrument(skip_all)]
580    pub async fn shutdown(&self) {
581        // Stop all streams first
582        self.stop_all_streams().await;
583
584        // Update state
585        let _ = self.state_tx.send(ConnectionState::Disconnected {
586            reason: "Client shutdown".to_string(),
587        });
588
589        // The connection loop will exit when it detects the sender is closed
590        // (which happens when Connection is dropped)
591    }
592}
593
594/// Calculate exponential backoff.
595fn calculate_backoff(attempt: u32, max_delay_ms: u64) -> u64 {
596    let base_delay = INITIAL_BACKOFF_MS * 2u64.saturating_pow(attempt.saturating_sub(1));
597    base_delay.min(max_delay_ms)
598}