1use std::collections::HashMap;
4use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use titan_api_codec::codec::ws::v1::ClientCodec;
10use titan_api_codec::codec::Codec;
11use titan_api_types::ws::v1::{
12 ClientRequest, RequestData, ResponseSuccess, ServerMessage, StreamData, SwapQuoteRequest,
13};
14use tokio::net::TcpStream;
15use tokio::sync::{mpsc, oneshot, RwLock};
16use tokio_tungstenite::tungstenite::Message;
17use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
18use tokio_util::sync::CancellationToken;
19
20use crate::config::TitanConfig;
21use crate::error::TitanClientError;
22use crate::state::ConnectionState;
23
24type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
25type ResponseResult = Result<ResponseSuccess, TitanClientError>;
26type PendingRequestsMap = Arc<RwLock<HashMap<u32, oneshot::Sender<ResponseResult>>>>;
27type OnEndCallback = Arc<dyn Fn() + Send + Sync>;
28
29pub const INITIAL_BACKOFF_MS: u64 = 100;
31
32pub const DEFAULT_PING_INTERVAL_MS: u64 = 25_000;
34
35pub const DEFAULT_PONG_TIMEOUT_MS: u64 = 10_000;
37
38#[derive(Clone)]
40pub struct ResumableStream {
41 pub request: SwapQuoteRequest,
43 pub sender: mpsc::Sender<StreamData>,
45 pub on_end: Option<OnEndCallback>,
47 pub effective_id: Option<Arc<AtomicU32>>,
49 pub stopped: Arc<AtomicBool>,
51}
52
53type ResumableStreamsMap = Arc<RwLock<HashMap<u32, ResumableStream>>>;
54
55pub struct PendingRequest {
57 pub request: ClientRequest,
58 pub response_tx: oneshot::Sender<ResponseResult>,
59}
60
61pub struct Connection {
63 #[expect(dead_code)]
64 config: TitanConfig,
65 request_id: AtomicU32,
66 sender: mpsc::Sender<PendingRequest>,
67 shutdown: CancellationToken,
68 state_tx: tokio::sync::watch::Sender<ConnectionState>,
69 #[expect(dead_code)]
70 pending_requests: PendingRequestsMap,
71 resumable_streams: ResumableStreamsMap,
72}
73
74struct RunSingleConnectionArgs<'a> {
75 ws_stream: &'a mut WsStream,
76 request_rx: &'a mut mpsc::Receiver<PendingRequest>,
77 pending_requests: &'a PendingRequestsMap,
78 resumable_streams: &'a ResumableStreamsMap,
79 state_tx: &'a tokio::sync::watch::Sender<ConnectionState>,
80 request_id_counter: &'a mut u32,
81 config: &'a TitanConfig,
82 shutdown: &'a CancellationToken,
83}
84
85impl Connection {
86 #[tracing::instrument(skip_all)]
90 pub async fn connect(config: TitanConfig) -> Result<Self, TitanClientError> {
91 let (state_tx, _state_rx) = tokio::sync::watch::channel(ConnectionState::Disconnected {
92 reason: "Connecting...".to_string(),
93 });
94
95 let pending_requests: PendingRequestsMap = Arc::new(RwLock::new(HashMap::new()));
96 let resumable_streams: ResumableStreamsMap = Arc::new(RwLock::new(HashMap::new()));
97 let shutdown = CancellationToken::new();
98
99 let ws_stream = Self::establish_connection(&config).await?;
101
102 let (sender, receiver) = mpsc::channel::<PendingRequest>(32);
104
105 let pending_clone = pending_requests.clone();
107 let streams_clone = resumable_streams.clone();
108 let state_tx_clone = state_tx.clone();
109 let config_clone = config.clone();
110
111 tokio::spawn(Self::run_connection_loop_with_reconnect(
112 ws_stream,
113 receiver,
114 pending_clone,
115 streams_clone,
116 state_tx_clone,
117 config_clone,
118 shutdown.clone(),
119 ));
120
121 state_tx.send_replace(ConnectionState::Connected);
122
123 Ok(Self {
124 config,
125 request_id: AtomicU32::new(1),
126 sender,
127 shutdown,
128 state_tx,
129 pending_requests,
130 resumable_streams,
131 })
132 }
133
134 async fn establish_connection(config: &TitanConfig) -> Result<WsStream, TitanClientError> {
136 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
137 use tokio_tungstenite::Connector;
138
139 let url = if config.url.contains("/ws") || config.url.ends_with('/') {
140 format!("{}?auth={}", config.url, config.token)
141 } else {
142 format!("{}/?auth={}", config.url, config.token)
143 };
144
145 let mut request = url.into_client_request().map_err(|e| {
146 TitanClientError::Unexpected(anyhow::anyhow!("Failed to build request: {e}"))
147 })?;
148
149 request.headers_mut().insert(
150 "Sec-WebSocket-Protocol",
151 titan_api_types::ws::v1::WEBSOCKET_SUBPROTO_BASE
152 .parse()
153 .map_err(|e| {
154 TitanClientError::Unexpected(anyhow::anyhow!(
155 "Sec-WebSocket-Protocol fail: {e}"
156 ))
157 })?,
158 );
159
160 let tls_config = if config.danger_accept_invalid_certs {
161 crate::tls::build_dangerous_tls_config()
162 } else {
163 crate::tls::build_default_tls_config()
164 }
165 .map_err(|e| TitanClientError::Unexpected(anyhow::anyhow!("TLS config failed: {e}")))?;
166 let connector = Connector::Rustls(Arc::new(tls_config));
167 let (ws_stream, _response) =
168 tokio_tungstenite::connect_async_tls_with_config(request, None, false, Some(connector))
169 .await
170 .map_err(TitanClientError::WebSocket)?;
171 Ok(ws_stream)
172 }
173
174 async fn run_connection_loop_with_reconnect(
176 initial_ws_stream: WsStream,
177 mut request_rx: mpsc::Receiver<PendingRequest>,
178 pending_requests: PendingRequestsMap,
179 resumable_streams: ResumableStreamsMap,
180 state_tx: tokio::sync::watch::Sender<ConnectionState>,
181 config: TitanConfig,
182 shutdown: CancellationToken,
183 ) {
184 let mut ws_stream = initial_ws_stream;
185 let mut reconnect_attempt: u32 = 0;
186 let mut request_id_counter: u32 = 1;
187
188 loop {
189 let disconnect_reason = Self::run_single_connection(RunSingleConnectionArgs {
191 ws_stream: &mut ws_stream,
192 request_rx: &mut request_rx,
193 pending_requests: &pending_requests,
194 resumable_streams: &resumable_streams,
195 state_tx: &state_tx,
196 request_id_counter: &mut request_id_counter,
197 config: &config,
198 shutdown: &shutdown,
199 })
200 .await;
201
202 Self::fail_pending_requests(&pending_requests, &disconnect_reason).await;
204
205 if shutdown.is_cancelled() {
206 break;
207 }
208
209 if request_rx.is_closed() {
211 tracing::info!("Request channel closed, shutting down connection");
212 break;
213 }
214
215 reconnect_attempt += 1;
217
218 if let Some(max) = config.max_reconnect_attempts {
220 if reconnect_attempt > max {
221 tracing::error!("Max reconnect attempts ({}) reached, giving up", max);
222 let _ = state_tx.send(ConnectionState::Disconnected {
223 reason: format!(
224 "Max reconnect attempts reached. Last error: {}",
225 disconnect_reason
226 ),
227 });
228 break;
229 }
230 }
231
232 let backoff_ms = calculate_backoff(reconnect_attempt, config.max_reconnect_delay_ms);
234
235 tracing::debug!(
236 attempt = reconnect_attempt,
237 backoff_ms,
238 "Reconnecting after disconnection: {}",
239 disconnect_reason
240 );
241
242 let _ = state_tx.send(ConnectionState::Reconnecting {
243 attempt: reconnect_attempt,
244 });
245
246 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
248
249 match Self::establish_connection(&config).await {
251 Ok(new_stream) => {
252 ws_stream = new_stream;
253 reconnect_attempt = 0;
254 let _ = state_tx.send(ConnectionState::Connected);
255 tracing::debug!("Reconnected successfully");
256
257 Self::resume_streams(
259 &mut ws_stream,
260 &pending_requests,
261 &resumable_streams,
262 &mut request_id_counter,
263 )
264 .await;
265 }
266 Err(e) => {
267 tracing::warn!("Reconnection failed: {}", e);
268 }
269 }
270 }
271
272 Self::cleanup_pending_requests(&pending_requests).await;
274 Self::cleanup_resumable_streams(&resumable_streams).await;
275 }
276
277 async fn resume_streams(
279 ws_stream: &mut WsStream,
280 pending_requests: &PendingRequestsMap,
281 resumable_streams: &ResumableStreamsMap,
282 request_id_counter: &mut u32,
283 ) {
284 let streams_to_resume: Vec<(u32, ResumableStream)> = {
285 let streams = resumable_streams.read().await;
286 streams.iter().map(|(k, v)| (*k, v.clone())).collect()
287 };
288
289 if streams_to_resume.is_empty() {
290 return;
291 }
292
293 tracing::info!(
294 "Resuming {} streams after reconnection",
295 streams_to_resume.len()
296 );
297
298 let codec = ClientCodec::Uncompressed;
299 let mut encoder = codec.encoder();
300 let mut decoder = codec.decoder();
301
302 for (old_stream_id, resumable) in streams_to_resume {
303 if resumable.stopped.load(Ordering::SeqCst) || resumable.sender.is_closed() {
304 let mut streams = resumable_streams.write().await;
305 if let Some(stream) = streams.remove(&old_stream_id) {
306 if let Some(ref on_end) = stream.on_end {
307 on_end();
308 }
309 }
310 continue;
311 }
312
313 let request_id = *request_id_counter;
314 *request_id_counter += 1;
315
316 let request = ClientRequest {
317 id: request_id,
318 data: RequestData::NewSwapQuoteStream(resumable.request.clone()),
319 };
320
321 let encoded = match encoder.encode_mut(&request) {
323 Ok(data) => data.to_vec(),
324 Err(e) => {
325 tracing::error!("Failed to encode stream resume request: {}", e);
326 let mut streams = resumable_streams.write().await;
327 if let Some(stream) = streams.remove(&old_stream_id) {
328 if let Some(ref on_end) = stream.on_end {
329 on_end();
330 }
331 }
332 continue;
333 }
334 };
335
336 if let Err(e) = ws_stream.send(Message::Binary(encoded.into())).await {
337 tracing::error!("Failed to send stream resume request: {}", e);
338 let mut streams = resumable_streams.write().await;
339 if let Some(stream) = streams.remove(&old_stream_id) {
340 if let Some(ref on_end) = stream.on_end {
341 on_end();
342 }
343 }
344 continue;
345 }
346
347 loop {
349 match ws_stream.next().await {
350 Some(Ok(Message::Binary(data))) => match decoder.decode_mut(data) {
351 Ok(ServerMessage::Response(response)) => {
352 if response.request_id != request_id {
353 Self::handle_server_message(
354 ServerMessage::Response(response),
355 pending_requests,
356 resumable_streams,
357 )
358 .await;
359 continue;
360 }
361
362 if let Some(stream_info) = response.stream {
363 let new_stream_id = stream_info.id;
364
365 let mut streams = resumable_streams.write().await;
367 if let Some(stream) = streams.remove(&old_stream_id) {
368 if stream.stopped.load(Ordering::SeqCst)
369 || stream.sender.is_closed()
370 {
371 if let Some(ref on_end) = stream.on_end {
372 on_end();
373 }
374 } else {
375 if let Some(ref effective_id) = stream.effective_id {
377 effective_id.store(new_stream_id, Ordering::SeqCst);
378 }
379 streams.insert(new_stream_id, stream);
380 tracing::info!(
381 old_id = old_stream_id,
382 new_id = new_stream_id,
383 "Stream resumed with new ID"
384 );
385 }
386 }
387 } else {
388 tracing::error!(
389 "Stream resume response missing stream info for {}",
390 old_stream_id
391 );
392 let mut streams = resumable_streams.write().await;
393 if let Some(stream) = streams.remove(&old_stream_id) {
394 if let Some(ref on_end) = stream.on_end {
395 on_end();
396 }
397 }
398 }
399 break;
400 }
401 Ok(ServerMessage::Error(error)) => {
402 if error.request_id != request_id {
403 Self::handle_server_message(
404 ServerMessage::Error(error),
405 pending_requests,
406 resumable_streams,
407 )
408 .await;
409 continue;
410 }
411
412 tracing::error!(
413 "Failed to resume stream {}: {}",
414 old_stream_id,
415 error.message
416 );
417 let mut streams = resumable_streams.write().await;
419 if let Some(stream) = streams.remove(&old_stream_id) {
420 if let Some(ref on_end) = stream.on_end {
421 on_end();
422 }
423 }
424 break;
425 }
426 Ok(other) => {
427 Self::handle_server_message(other, pending_requests, resumable_streams)
428 .await;
429 }
430 Err(e) => {
431 tracing::error!("Failed to decode stream resume response: {}", e);
432 }
433 },
434 Some(Ok(Message::Ping(data))) => {
435 let _ = ws_stream.send(Message::Pong(data)).await;
436 }
437 Some(Ok(Message::Pong(_))) => {}
438 Some(Ok(Message::Close(frame))) => {
439 let reason = frame.map_or_else(
440 || "Server closed connection".to_string(),
441 |f| f.reason.to_string(),
442 );
443 tracing::warn!("WebSocket closed during stream resumption: {reason}");
444 break;
445 }
446 Some(Ok(_)) => {}
447 Some(Err(e)) => {
448 tracing::error!("WebSocket error during stream resumption: {}", e);
449 break;
450 }
451 None => {
452 tracing::error!("Connection closed during stream resumption");
453 break;
454 }
455 }
456 }
457 }
458 }
459
460 async fn run_single_connection(args: RunSingleConnectionArgs<'_>) -> String {
462 let RunSingleConnectionArgs {
463 ws_stream,
464 request_rx,
465 pending_requests,
466 resumable_streams,
467 state_tx,
468 request_id_counter,
469 config,
470 shutdown,
471 } = args;
472 let codec = ClientCodec::Uncompressed;
473 let mut encoder = codec.encoder();
474 let mut decoder = codec.decoder();
475
476 let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
477
478 let ping_interval_ms = if config.ping_interval_ms > 0 {
479 config.ping_interval_ms
480 } else {
481 DEFAULT_PING_INTERVAL_MS
482 };
483 let mut ping_timer = tokio::time::interval(Duration::from_millis(ping_interval_ms));
484 ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
485
486 let pong_timeout = Duration::from_millis(config.pong_timeout_ms);
487 let mut last_ping = tokio::time::Instant::now();
488 let mut awaiting_pong = false;
489
490 loop {
491 tokio::select! {
492 () = shutdown.cancelled() => {
493 return "Client shutdown".to_string();
494 }
495 maybe_req = request_rx.recv() => {
496 let Some(pending_req) = maybe_req else {
497 return "Request channel closed".to_string();
498 };
499
500 let request_id = pending_req.request.id;
501 *request_id_counter = request_id.max(*request_id_counter) + 1;
502
503 {
504 let mut pending_map = pending_requests.write().await;
505 pending_map.insert(request_id, pending_req.response_tx);
506 }
507
508 match encoder.encode_mut(&pending_req.request) {
509 Ok(data) => {
510 if let Err(e) = ws_sink.send(Message::Binary(data.to_vec().into())).await {
511 tracing::error!("Failed to send WebSocket message: {e}");
512 let mut pending_map = pending_requests.write().await;
513 if let Some(tx) = pending_map.remove(&request_id) {
514 let _ = tx.send(Err(TitanClientError::ConnectionClosed {
515 reason: format!("Send failed: {e}"),
516 }));
517 }
518 }
519 }
520 Err(e) => {
521 tracing::error!("Failed to encode request: {e}");
522 let mut pending_map = pending_requests.write().await;
523 if let Some(tx) = pending_map.remove(&request_id) {
524 let _ = tx.send(Err(TitanClientError::Unexpected(anyhow::anyhow!(
525 "Encode failed: {e}"
526 ))));
527 }
528 }
529 }
530 }
531
532 Some(msg_result) = ws_stream_rx.next() => {
533 match msg_result {
534 Ok(Message::Binary(data)) => {
535 match decoder.decode_mut(data) {
536 Ok(server_msg) => {
537 Self::handle_server_message(
538 server_msg,
539 pending_requests,
540 resumable_streams,
541 ).await;
542 }
543 Err(e) => {
544 tracing::error!("Failed to decode server message: {e}");
545 }
546 }
547 }
548 Ok(Message::Close(frame)) => {
549 let reason = frame.map_or_else(|| "Server closed connection".to_string(), |f| f.reason.to_string());
550 tracing::warn!("WebSocket closed: {reason}");
551 let _ = state_tx.send(ConnectionState::Disconnected {
552 reason: reason.clone(),
553 });
554 return reason;
555 }
556 Ok(Message::Ping(data)) => {
557 let _ = ws_sink.send(Message::Pong(data)).await;
558 }
559 Ok(Message::Pong(_)) => {
560 awaiting_pong = false;
561 tracing::trace!("Received pong from server");
562 }
563 Ok(_) => {}
564 Err(e) => {
565 let reason = format!("WebSocket error: {e}");
566 let error_str = e.to_string();
567 if error_str.contains("Connection reset without closing handshake") {
568 tracing::debug!("{reason}");
569 } else {
570 tracing::error!("{reason}");
571 }
572 let _ = state_tx.send(ConnectionState::Disconnected {
573 reason: reason.clone(),
574 });
575 return reason;
576 }
577 }
578 }
579
580 _ = ping_timer.tick() => {
581 if config.pong_timeout_ms > 0 && awaiting_pong && last_ping.elapsed() > pong_timeout {
582 let reason = "Pong timeout".to_string();
583 let timeout_ms = config.pong_timeout_ms;
584 tracing::debug!("No pong received within {timeout_ms}ms, triggering reconnect");
585 let _ = state_tx.send(ConnectionState::Disconnected {
586 reason: reason.clone(),
587 });
588 return reason;
589 }
590
591 if let Err(e) = ws_sink.send(Message::Ping(vec![].into())).await {
592 let reason = format!("Failed to send ping: {e}");
593 tracing::warn!("{reason}");
594 let _ = state_tx.send(ConnectionState::Disconnected {
595 reason: reason.clone(),
596 });
597 return reason;
598 }
599 awaiting_pong = true;
600 last_ping = tokio::time::Instant::now();
601 tracing::trace!("Sent keepalive ping");
602 }
603
604 else => {
605 return "Channel closed".to_string();
606 }
607 }
608 }
609 }
610
611 async fn handle_server_message(
613 msg: ServerMessage,
614 pending_requests: &PendingRequestsMap,
615 resumable_streams: &ResumableStreamsMap,
616 ) {
617 match msg {
618 ServerMessage::Response(response) => {
619 let mut pending = pending_requests.write().await;
620 if let Some(tx) = pending.remove(&response.request_id) {
621 let _ = tx.send(Ok(response));
622 }
623 }
624 ServerMessage::Error(error) => {
625 let mut pending = pending_requests.write().await;
626 if let Some(tx) = pending.remove(&error.request_id) {
627 let _ = tx.send(Err(TitanClientError::ServerError {
628 code: error.code,
629 message: error.message,
630 }));
631 }
632 }
633 ServerMessage::StreamData(data) => {
634 let streams = resumable_streams.read().await;
635 if let Some(stream) = streams.get(&data.id) {
636 let _ = stream.sender.send(data).await;
637 }
638 }
639 ServerMessage::StreamEnd(end) => {
640 let mut streams = resumable_streams.write().await;
641 if let Some(stream) = streams.remove(&end.id) {
642 if let Some(ref on_end) = stream.on_end {
643 on_end();
644 }
645 }
646 }
647 ServerMessage::Other(_) => {
648 tracing::warn!("Received unknown server message type");
649 }
650 }
651 }
652
653 async fn cleanup_pending_requests(pending_requests: &PendingRequestsMap) {
655 Self::fail_pending_requests(pending_requests, "Connection closed").await;
656 }
657
658 async fn fail_pending_requests(pending_requests: &PendingRequestsMap, reason: &str) {
659 let mut pending_map = pending_requests.write().await;
660 for (_request_id, tx) in pending_map.drain() {
661 let _ = tx.send(Err(TitanClientError::ConnectionClosed {
662 reason: reason.to_string(),
663 }));
664 }
665 }
666
667 async fn cleanup_resumable_streams(resumable_streams: &ResumableStreamsMap) {
669 let mut streams = resumable_streams.write().await;
670 for (_id, stream) in streams.drain() {
671 if let Some(ref on_end) = stream.on_end {
672 on_end();
673 }
674 }
675 }
676
677 #[tracing::instrument(skip_all)]
679 pub async fn send_request(
680 &self,
681 data: RequestData,
682 ) -> Result<ResponseSuccess, TitanClientError> {
683 if self.shutdown.is_cancelled() {
684 return Err(TitanClientError::ConnectionClosed {
685 reason: "Client shutdown".to_string(),
686 });
687 }
688
689 let request_id = self.request_id.fetch_add(1, Ordering::SeqCst);
690 let request = ClientRequest {
691 id: request_id,
692 data,
693 };
694
695 let (response_tx, response_rx) = oneshot::channel();
696
697 self.sender
698 .send(PendingRequest {
699 request,
700 response_tx,
701 })
702 .await
703 .map_err(|_| TitanClientError::ConnectionClosed {
704 reason: "Connection closed".to_string(),
705 })?;
706
707 let response = response_rx
708 .await
709 .map_err(|_| TitanClientError::ConnectionClosed {
710 reason: "Response channel closed".to_string(),
711 })?;
712
713 response
714 }
715
716 pub async fn register_stream(
718 &self,
719 stream_id: u32,
720 request: SwapQuoteRequest,
721 sender: mpsc::Sender<StreamData>,
722 on_end: Option<OnEndCallback>,
723 effective_id: Option<Arc<AtomicU32>>,
724 stopped: Arc<AtomicBool>,
725 ) {
726 let mut streams = self.resumable_streams.write().await;
727 streams.insert(
728 stream_id,
729 ResumableStream {
730 request,
731 sender,
732 on_end,
733 effective_id,
734 stopped,
735 },
736 );
737 }
738
739 pub async fn unregister_stream(&self, stream_id: u32) {
741 let mut streams = self.resumable_streams.write().await;
742 streams.remove(&stream_id);
743 }
744
745 pub fn state_receiver(&self) -> tokio::sync::watch::Receiver<ConnectionState> {
747 self.state_tx.subscribe()
748 }
749
750 pub fn state(&self) -> ConnectionState {
752 self.state_tx.borrow().clone()
753 }
754
755 pub async fn active_stream_ids(&self) -> Vec<u32> {
757 let streams = self.resumable_streams.read().await;
758 streams.keys().copied().collect()
759 }
760
761 #[tracing::instrument(skip_all)]
765 pub async fn stop_all_streams(&self) {
766 use titan_api_types::ws::v1::StopStreamRequest;
767
768 let stream_ids = self.active_stream_ids().await;
769
770 if stream_ids.is_empty() {
771 return;
772 }
773
774 tracing::info!("Stopping {} active streams", stream_ids.len());
775
776 for stream_id in stream_ids {
777 let _ = self
779 .send_request(RequestData::StopStream(StopStreamRequest { id: stream_id }))
780 .await;
781 }
782
783 let mut streams = self.resumable_streams.write().await;
785 for (_id, stream) in streams.drain() {
786 if let Some(ref on_end) = stream.on_end {
787 on_end();
788 }
789 }
790 }
791
792 #[tracing::instrument(skip_all)]
794 pub async fn shutdown(&self) {
795 self.stop_all_streams().await;
797
798 self.shutdown.cancel();
799
800 let _ = self.state_tx.send(ConnectionState::Disconnected {
802 reason: "Client shutdown".to_string(),
803 });
804
805 }
808}
809
810fn calculate_backoff(attempt: u32, max_delay_ms: u64) -> u64 {
812 let base_delay = INITIAL_BACKOFF_MS * 2u64.saturating_pow(attempt.saturating_sub(1));
813 base_delay.min(max_delay_ms)
814}