1use std::collections::HashMap;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use titan_api_codec::codec::ws::v1::ClientCodec;
10use titan_api_codec::codec::Codec;
11use titan_api_types::ws::v1::{
12 ClientRequest, RequestData, ResponseError, ResponseSuccess, ServerMessage, StreamData,
13 SwapQuoteRequest,
14};
15use tokio::net::TcpStream;
16use tokio::sync::{mpsc, oneshot, RwLock};
17use tokio_tungstenite::tungstenite::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
19
20use crate::config::TitanConfig;
21use crate::error::TitanClientError;
22use crate::state::ConnectionState;
23
24type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
25type ResponseResult = Result<ResponseSuccess, ResponseError>;
26type PendingRequestsMap = Arc<RwLock<HashMap<u32, oneshot::Sender<ResponseResult>>>>;
27
28const INITIAL_BACKOFF_MS: u64 = 100;
30
31#[derive(Clone)]
33pub struct ResumableStream {
34 pub request: SwapQuoteRequest,
36 pub sender: mpsc::Sender<StreamData>,
38}
39
40type ResumableStreamsMap = Arc<RwLock<HashMap<u32, ResumableStream>>>;
41
42pub struct PendingRequest {
44 pub request: ClientRequest,
45 pub response_tx: oneshot::Sender<ResponseResult>,
46}
47
48pub struct Connection {
50 #[allow(dead_code)]
51 config: TitanConfig,
52 request_id: AtomicU32,
53 sender: mpsc::Sender<PendingRequest>,
54 state_tx: tokio::sync::watch::Sender<ConnectionState>,
55 #[allow(dead_code)]
56 pending_requests: PendingRequestsMap,
57 resumable_streams: ResumableStreamsMap,
58}
59
60impl Connection {
61 #[tracing::instrument(skip_all)]
65 pub async fn connect(config: TitanConfig) -> Result<Self, TitanClientError> {
66 let (state_tx, _state_rx) = tokio::sync::watch::channel(ConnectionState::Disconnected {
67 reason: "Connecting...".to_string(),
68 });
69
70 let pending_requests: PendingRequestsMap = Arc::new(RwLock::new(HashMap::new()));
71 let resumable_streams: ResumableStreamsMap = Arc::new(RwLock::new(HashMap::new()));
72
73 let ws_stream = Self::establish_connection(&config).await?;
75
76 let (sender, receiver) = mpsc::channel::<PendingRequest>(32);
78
79 let pending_clone = pending_requests.clone();
81 let streams_clone = resumable_streams.clone();
82 let state_tx_clone = state_tx.clone();
83 let config_clone = config.clone();
84
85 tokio::spawn(Self::run_connection_loop_with_reconnect(
86 ws_stream,
87 receiver,
88 pending_clone,
89 streams_clone,
90 state_tx_clone,
91 config_clone,
92 ));
93
94 state_tx.send_replace(ConnectionState::Connected);
95
96 Ok(Self {
97 config,
98 request_id: AtomicU32::new(1),
99 sender,
100 state_tx,
101 pending_requests,
102 resumable_streams,
103 })
104 }
105
106 async fn establish_connection(config: &TitanConfig) -> Result<WsStream, TitanClientError> {
108 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
109 use tokio_tungstenite::Connector;
110
111 let url = if config.url.contains("/ws") || config.url.ends_with('/') {
114 format!("{}?auth={}", config.url, config.token)
116 } else {
117 format!("{}/?auth={}", config.url, config.token)
119 };
120
121 let mut request = url.into_client_request().map_err(|e| {
122 TitanClientError::Unexpected(anyhow::anyhow!("Failed to build request: {}", e))
123 })?;
124
125 request.headers_mut().insert(
127 "Sec-WebSocket-Protocol",
128 titan_api_types::ws::v1::WEBSOCKET_SUBPROTO_BASE
129 .parse()
130 .unwrap(),
131 );
132
133 let (ws_stream, _response) = if config.danger_accept_invalid_certs {
134 let tls_config = crate::tls::build_dangerous_tls_config();
135 let connector = Connector::Rustls(Arc::new(tls_config));
136 tokio_tungstenite::connect_async_tls_with_config(request, None, false, Some(connector))
137 .await
138 .map_err(TitanClientError::WebSocket)?
139 } else {
140 tokio_tungstenite::connect_async(request)
141 .await
142 .map_err(TitanClientError::WebSocket)?
143 };
144
145 Ok(ws_stream)
146 }
147
148 async fn run_connection_loop_with_reconnect(
150 initial_ws_stream: WsStream,
151 mut request_rx: mpsc::Receiver<PendingRequest>,
152 pending_requests: PendingRequestsMap,
153 resumable_streams: ResumableStreamsMap,
154 state_tx: tokio::sync::watch::Sender<ConnectionState>,
155 config: TitanConfig,
156 ) {
157 let mut ws_stream = initial_ws_stream;
158 let mut reconnect_attempt: u32 = 0;
159 let mut request_id_counter: u32 = 1;
160
161 loop {
162 let disconnect_reason = Self::run_single_connection(
164 &mut ws_stream,
165 &mut request_rx,
166 &pending_requests,
167 &resumable_streams,
168 &state_tx,
169 &mut request_id_counter,
170 )
171 .await;
172
173 if request_rx.is_closed() {
175 tracing::info!("Request channel closed, shutting down connection");
176 break;
177 }
178
179 reconnect_attempt += 1;
181
182 if let Some(max) = config.max_reconnect_attempts {
184 if reconnect_attempt > max {
185 tracing::error!("Max reconnect attempts ({}) reached, giving up", max);
186 let _ = state_tx.send(ConnectionState::Disconnected {
187 reason: format!(
188 "Max reconnect attempts reached. Last error: {}",
189 disconnect_reason
190 ),
191 });
192 break;
193 }
194 }
195
196 let backoff_ms = calculate_backoff(reconnect_attempt, config.max_reconnect_delay_ms);
198
199 tracing::info!(
200 attempt = reconnect_attempt,
201 backoff_ms,
202 "Reconnecting after disconnection: {}",
203 disconnect_reason
204 );
205
206 let _ = state_tx.send(ConnectionState::Reconnecting {
207 attempt: reconnect_attempt,
208 });
209
210 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
212
213 match Self::establish_connection(&config).await {
215 Ok(new_stream) => {
216 ws_stream = new_stream;
217 reconnect_attempt = 0;
218 let _ = state_tx.send(ConnectionState::Connected);
219 tracing::info!("Reconnected successfully");
220
221 Self::resume_streams(
223 &mut ws_stream,
224 &resumable_streams,
225 &mut request_id_counter,
226 )
227 .await;
228 }
229 Err(e) => {
230 tracing::warn!("Reconnection failed: {}", e);
231 continue;
232 }
233 }
234 }
235
236 Self::cleanup_pending_requests(&pending_requests).await;
238 }
239
240 async fn resume_streams(
242 ws_stream: &mut WsStream,
243 resumable_streams: &ResumableStreamsMap,
244 request_id_counter: &mut u32,
245 ) {
246 let streams_to_resume: Vec<(u32, ResumableStream)> = {
247 let streams = resumable_streams.read().await;
248 streams.iter().map(|(k, v)| (*k, v.clone())).collect()
249 };
250
251 if streams_to_resume.is_empty() {
252 return;
253 }
254
255 tracing::info!(
256 "Resuming {} streams after reconnection",
257 streams_to_resume.len()
258 );
259
260 let codec = ClientCodec::Uncompressed;
261 let mut encoder = codec.encoder();
262 let mut decoder = codec.decoder();
263
264 for (old_stream_id, resumable) in streams_to_resume {
265 let request_id = *request_id_counter;
266 *request_id_counter += 1;
267
268 let request = ClientRequest {
269 id: request_id,
270 data: RequestData::NewSwapQuoteStream(resumable.request.clone()),
271 };
272
273 let encoded = match encoder.encode_mut(&request) {
275 Ok(data) => data.to_vec(),
276 Err(e) => {
277 tracing::error!("Failed to encode stream resume request: {}", e);
278 continue;
279 }
280 };
281
282 if let Err(e) = ws_stream.send(Message::Binary(encoded.into())).await {
283 tracing::error!("Failed to send stream resume request: {}", e);
284 continue;
285 }
286
287 match ws_stream.next().await {
289 Some(Ok(Message::Binary(data))) => {
290 match decoder.decode_mut(data) {
291 Ok(ServerMessage::Response(response)) => {
292 if let Some(stream_info) = response.stream {
293 let new_stream_id = stream_info.id;
294
295 let mut streams = resumable_streams.write().await;
297 if let Some(stream) = streams.remove(&old_stream_id) {
298 streams.insert(new_stream_id, stream);
299 tracing::info!(
300 old_id = old_stream_id,
301 new_id = new_stream_id,
302 "Stream resumed with new ID"
303 );
304 }
305 }
306 }
307 Ok(ServerMessage::Error(error)) => {
308 tracing::error!(
309 "Failed to resume stream {}: {}",
310 old_stream_id,
311 error.message
312 );
313 let mut streams = resumable_streams.write().await;
315 streams.remove(&old_stream_id);
316 }
317 Ok(_) => {
318 tracing::warn!("Unexpected response type during stream resumption");
319 }
320 Err(e) => {
321 tracing::error!("Failed to decode stream resume response: {}", e);
322 }
323 }
324 }
325 Some(Ok(_)) => {
326 tracing::warn!("Unexpected message type during stream resumption");
327 }
328 Some(Err(e)) => {
329 tracing::error!("WebSocket error during stream resumption: {}", e);
330 break;
331 }
332 None => {
333 tracing::error!("Connection closed during stream resumption");
334 break;
335 }
336 }
337 }
338 }
339
340 async fn run_single_connection(
342 ws_stream: &mut WsStream,
343 request_rx: &mut mpsc::Receiver<PendingRequest>,
344 pending_requests: &PendingRequestsMap,
345 resumable_streams: &ResumableStreamsMap,
346 state_tx: &tokio::sync::watch::Sender<ConnectionState>,
347 request_id_counter: &mut u32,
348 ) -> String {
349 let codec = ClientCodec::Uncompressed;
350 let mut encoder = codec.encoder();
351 let mut decoder = codec.decoder();
352
353 let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
354
355 loop {
356 tokio::select! {
357 Some(pending_req) = request_rx.recv() => {
358 let request_id = pending_req.request.id;
359 *request_id_counter = request_id.max(*request_id_counter) + 1;
360
361 {
362 let mut pending_map = pending_requests.write().await;
363 pending_map.insert(request_id, pending_req.response_tx);
364 }
365
366 match encoder.encode_mut(&pending_req.request) {
367 Ok(data) => {
368 if let Err(e) = ws_sink.send(Message::Binary(data.to_vec().into())).await {
369 tracing::error!("Failed to send WebSocket message: {}", e);
370 let mut pending_map = pending_requests.write().await;
371 if let Some(tx) = pending_map.remove(&request_id) {
372 let _ = tx.send(Err(ResponseError {
373 request_id,
374 code: 0,
375 message: format!("Send failed: {}", e),
376 }));
377 }
378 }
379 }
380 Err(e) => {
381 tracing::error!("Failed to encode request: {}", e);
382 let mut pending_map = pending_requests.write().await;
383 if let Some(tx) = pending_map.remove(&request_id) {
384 let _ = tx.send(Err(ResponseError {
385 request_id,
386 code: 0,
387 message: format!("Encode failed: {}", e),
388 }));
389 }
390 }
391 }
392 }
393
394 Some(msg_result) = ws_stream_rx.next() => {
395 match msg_result {
396 Ok(Message::Binary(data)) => {
397 match decoder.decode_mut(data) {
398 Ok(server_msg) => {
399 Self::handle_server_message(
400 server_msg,
401 pending_requests,
402 resumable_streams,
403 ).await;
404 }
405 Err(e) => {
406 tracing::error!("Failed to decode server message: {}", e);
407 }
408 }
409 }
410 Ok(Message::Close(frame)) => {
411 let reason = frame
412 .map(|f| f.reason.to_string())
413 .unwrap_or_else(|| "Server closed connection".to_string());
414 tracing::warn!("WebSocket closed: {}", reason);
415 let _ = state_tx.send(ConnectionState::Disconnected {
416 reason: reason.clone(),
417 });
418 return reason;
419 }
420 Ok(Message::Ping(data)) => {
421 let _ = ws_sink.send(Message::Pong(data)).await;
422 }
423 Ok(_) => {}
424 Err(e) => {
425 let reason = format!("WebSocket error: {}", e);
426 let error_str = e.to_string();
427 if error_str.contains("Connection reset without closing handshake") {
428 tracing::info!("{}", reason);
429 } else {
430 tracing::error!("{}", reason);
431 }
432 let _ = state_tx.send(ConnectionState::Disconnected {
433 reason: reason.clone(),
434 });
435 return reason;
436 }
437 }
438 }
439
440 else => {
441 return "Channel closed".to_string();
442 }
443 }
444 }
445 }
446
447 async fn handle_server_message(
449 msg: ServerMessage,
450 pending_requests: &PendingRequestsMap,
451 resumable_streams: &ResumableStreamsMap,
452 ) {
453 match msg {
454 ServerMessage::Response(response) => {
455 let mut pending = pending_requests.write().await;
456 if let Some(tx) = pending.remove(&response.request_id) {
457 let _ = tx.send(Ok(response));
458 }
459 }
460 ServerMessage::Error(error) => {
461 let mut pending = pending_requests.write().await;
462 if let Some(tx) = pending.remove(&error.request_id) {
463 let _ = tx.send(Err(error));
464 }
465 }
466 ServerMessage::StreamData(data) => {
467 let streams = resumable_streams.read().await;
468 if let Some(stream) = streams.get(&data.id) {
469 let _ = stream.sender.send(data).await;
470 }
471 }
472 ServerMessage::StreamEnd(end) => {
473 let mut streams = resumable_streams.write().await;
474 streams.remove(&end.id);
475 }
476 ServerMessage::Other(_) => {
477 tracing::warn!("Received unknown server message type");
478 }
479 }
480 }
481
482 async fn cleanup_pending_requests(pending_requests: &PendingRequestsMap) {
484 let mut pending_map = pending_requests.write().await;
485 for (request_id, tx) in pending_map.drain() {
486 let _ = tx.send(Err(ResponseError {
487 request_id,
488 code: 0,
489 message: "Connection closed".to_string(),
490 }));
491 }
492 }
493
494 #[tracing::instrument(skip_all)]
496 pub async fn send_request(
497 &self,
498 data: RequestData,
499 ) -> Result<ResponseSuccess, TitanClientError> {
500 let request_id = self.request_id.fetch_add(1, Ordering::SeqCst);
501 let request = ClientRequest {
502 id: request_id,
503 data,
504 };
505
506 let (response_tx, response_rx) = oneshot::channel();
507
508 self.sender
509 .send(PendingRequest {
510 request,
511 response_tx,
512 })
513 .await
514 .map_err(|_| TitanClientError::Unexpected(anyhow::anyhow!("Connection closed")))?;
515
516 let response = response_rx.await.map_err(|_| {
517 TitanClientError::Unexpected(anyhow::anyhow!("Response channel closed"))
518 })?;
519
520 response.map_err(|e| TitanClientError::ServerError {
521 code: e.code,
522 message: e.message,
523 })
524 }
525
526 pub async fn register_stream(
528 &self,
529 stream_id: u32,
530 request: SwapQuoteRequest,
531 sender: mpsc::Sender<StreamData>,
532 ) {
533 let mut streams = self.resumable_streams.write().await;
534 streams.insert(stream_id, ResumableStream { request, sender });
535 }
536
537 pub async fn unregister_stream(&self, stream_id: u32) {
539 let mut streams = self.resumable_streams.write().await;
540 streams.remove(&stream_id);
541 }
542
543 pub fn state_receiver(&self) -> tokio::sync::watch::Receiver<ConnectionState> {
545 self.state_tx.subscribe()
546 }
547
548 pub fn state(&self) -> ConnectionState {
550 self.state_tx.borrow().clone()
551 }
552
553 pub async fn active_stream_ids(&self) -> Vec<u32> {
555 let streams = self.resumable_streams.read().await;
556 streams.keys().copied().collect()
557 }
558
559 #[tracing::instrument(skip_all)]
563 pub async fn stop_all_streams(&self) {
564 use titan_api_types::ws::v1::StopStreamRequest;
565
566 let stream_ids = self.active_stream_ids().await;
567
568 if stream_ids.is_empty() {
569 return;
570 }
571
572 tracing::info!("Stopping {} active streams", stream_ids.len());
573
574 for stream_id in stream_ids {
575 let _ = self
577 .send_request(RequestData::StopStream(StopStreamRequest { id: stream_id }))
578 .await;
579 }
580
581 let mut streams = self.resumable_streams.write().await;
583 streams.clear();
584 }
585
586 #[tracing::instrument(skip_all)]
588 pub async fn shutdown(&self) {
589 self.stop_all_streams().await;
591
592 let _ = self.state_tx.send(ConnectionState::Disconnected {
594 reason: "Client shutdown".to_string(),
595 });
596
597 }
600}
601
602fn calculate_backoff(attempt: u32, max_delay_ms: u64) -> u64 {
604 let base_delay = INITIAL_BACKOFF_MS * 2u64.saturating_pow(attempt.saturating_sub(1));
605 base_delay.min(max_delay_ms)
606}