1use crate::auth::Auth;
15use crate::client::Statistics;
16use crate::connection::Connection;
17use crate::connection::State;
18#[cfg(feature = "websockets")]
19use crate::connection::WebSocketAdapter;
20use crate::options::CallbackArg1;
21use crate::tls;
22use crate::AuthError;
23use crate::ClientError;
24use crate::ClientOp;
25use crate::ConnectError;
26use crate::ConnectErrorKind;
27use crate::ConnectInfo;
28use crate::Event;
29use crate::Protocol;
30use crate::ServerAddr;
31use crate::ServerError;
32use crate::ServerInfo;
33use crate::ServerOp;
34use crate::SocketAddr;
35use crate::ToServerAddrs;
36use crate::LANG;
37use crate::VERSION;
38#[cfg(feature = "nkeys")]
39use base64::engine::general_purpose::URL_SAFE_NO_PAD;
40#[cfg(feature = "nkeys")]
41use base64::engine::Engine;
42use rand::rng;
43use rand::seq::SliceRandom;
44use std::cmp;
45use std::fmt;
46use std::io;
47use std::path::PathBuf;
48use std::sync::atomic::AtomicUsize;
49use std::sync::atomic::Ordering;
50use std::sync::Arc;
51use std::time::Duration;
52use tokio::net::{TcpSocket, TcpStream};
53use tokio::time::sleep;
54use tokio_rustls::rustls;
55
56#[derive(Debug, Clone)]
77pub struct Server {
78 pub addr: ServerAddr,
80 pub failed_attempts: usize,
83 pub did_connect: bool,
85 pub is_discovered: bool,
88 pub last_error: Option<String>,
90}
91
92impl fmt::Display for Server {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 write!(f, "{}", self.addr.as_url_str())
95 }
96}
97
98#[derive(Debug, Clone)]
104pub struct ReconnectToServer {
105 pub addr: ServerAddr,
108 pub delay: Option<Duration>,
111}
112
113pub(crate) type ReconnectToServerCallback =
114 CallbackArg1<(Vec<Server>, ServerInfo), Option<ReconnectToServer>>;
115
116impl Server {
117 fn new(addr: ServerAddr) -> Self {
118 Server {
119 addr,
120 failed_attempts: 0,
121 did_connect: false,
122 is_discovered: false,
123 last_error: None,
124 }
125 }
126
127 fn new_discovered(addr: ServerAddr) -> Self {
128 Server {
129 addr,
130 failed_attempts: 0,
131 did_connect: false,
132 is_discovered: true,
133 last_error: None,
134 }
135 }
136}
137
138pub(crate) struct ConnectorOptions {
139 pub(crate) tls_required: bool,
140 pub(crate) certificates: Vec<PathBuf>,
141 pub(crate) client_cert: Option<PathBuf>,
142 pub(crate) client_key: Option<PathBuf>,
143 pub(crate) tls_client_config: Option<rustls::ClientConfig>,
144 pub(crate) tls_first: bool,
145 pub(crate) auth: Auth,
146 pub(crate) no_echo: bool,
147 pub(crate) connection_timeout: Duration,
148 pub(crate) name: Option<String>,
149 pub(crate) ignore_discovered_servers: bool,
150 pub(crate) retain_servers_order: bool,
151 pub(crate) read_buffer_capacity: u16,
152 pub(crate) reconnect_delay_callback: Arc<dyn Fn(usize) -> Duration + Send + Sync + 'static>,
153 pub(crate) auth_callback: Option<CallbackArg1<Vec<u8>, Result<Auth, AuthError>>>,
154 pub(crate) max_reconnects: Option<usize>,
155 pub(crate) local_address: Option<SocketAddr>,
156 pub(crate) reconnect_to_server_callback: Option<ReconnectToServerCallback>,
157}
158
159pub(crate) struct Connector {
161 servers: Vec<Server>,
163 options: ConnectorOptions,
164 pub(crate) connect_stats: Arc<Statistics>,
165 attempts: usize,
166 pub(crate) events_tx: tokio::sync::mpsc::Sender<Event>,
167 pub(crate) state_tx: tokio::sync::watch::Sender<State>,
168 pub(crate) max_payload: Arc<AtomicUsize>,
169 last_info: ServerInfo,
171}
172
173pub(crate) fn reconnect_delay_callback_default(attempts: usize) -> Duration {
174 if attempts <= 1 {
175 Duration::from_millis(0)
176 } else {
177 let exp: u32 = (attempts - 1).try_into().unwrap_or(u32::MAX);
178 let max = Duration::from_secs(4);
179 cmp::min(Duration::from_millis(2_u64.saturating_pow(exp)), max)
180 }
181}
182
183impl Connector {
184 pub(crate) fn new<A: ToServerAddrs>(
185 addrs: A,
186 options: ConnectorOptions,
187 events_tx: tokio::sync::mpsc::Sender<Event>,
188 state_tx: tokio::sync::watch::Sender<State>,
189 max_payload: Arc<AtomicUsize>,
190 connect_stats: Arc<Statistics>,
191 ) -> Result<Connector, io::Error> {
192 let servers = addrs.to_server_addrs()?.map(Server::new).collect();
193
194 Ok(Connector {
195 attempts: 0,
196 servers,
197 options,
198 events_tx,
199 state_tx,
200 max_payload,
201 connect_stats,
202 last_info: ServerInfo::default(),
203 })
204 }
205
206 pub(crate) fn set_server_pool(&mut self, addrs: Vec<ServerAddr>) -> Result<(), String> {
212 if addrs.is_empty() {
213 return Err("server pool cannot be empty".to_string());
214 }
215
216 let has_ws = addrs.iter().any(|a| a.is_websocket());
218 let has_non_ws = addrs.iter().any(|a| !a.is_websocket());
219 if has_ws && has_non_ws {
220 return Err("cannot mix websocket and non-websocket URLs in server pool".to_string());
221 }
222
223 let new_servers = addrs
224 .into_iter()
225 .map(|addr| {
226 if let Some(existing) = self.servers.iter().find(|s| s.addr == addr) {
227 Server {
228 addr,
229 failed_attempts: existing.failed_attempts,
230 did_connect: existing.did_connect,
231 is_discovered: false,
232 last_error: existing.last_error.clone(),
233 }
234 } else {
235 Server::new(addr)
236 }
237 })
238 .collect();
239 self.servers = new_servers;
240 self.attempts = 0;
241 Ok(())
242 }
243
244 pub(crate) fn server_pool(&self) -> Vec<Server> {
246 self.servers.to_vec()
247 }
248
249 pub(crate) async fn connect(&mut self) -> Result<(ServerInfo, Connection), ConnectError> {
250 loop {
251 match self.try_connect().await {
252 Ok(inner) => {
253 return Ok(inner);
254 }
255 Err(error) => match error.kind() {
256 ConnectErrorKind::MaxReconnects => {
257 return Err(ConnectError::with_source(
258 crate::ConnectErrorKind::MaxReconnects,
259 error,
260 ))
261 }
262 other => {
263 self.events_tx
264 .try_send(Event::ClientError(ClientError::Other(other.to_string())))
265 .ok();
266 }
267 },
268 }
269 }
270 }
271
272 pub(crate) async fn try_connect(&mut self) -> Result<(ServerInfo, Connection), ConnectError> {
273 tracing::debug!(attempt = %self.attempts, "connecting to server");
274 let mut error = None;
275
276 if let Some(ref callback) = self.options.reconnect_to_server_callback {
278 let pool_snapshot = self.servers.to_vec();
279 let info_snapshot = self.last_info.clone();
280 let selection = callback.call((pool_snapshot, info_snapshot)).await;
281
282 if let Some(target) = selection {
283 if self.servers.iter().any(|s| s.addr == target.addr) {
285 self.attempts += 1;
286 if let Some(max_reconnects) = self.options.max_reconnects {
287 if self.attempts > max_reconnects {
288 self.events_tx
289 .try_send(Event::ClientError(ClientError::MaxReconnects))
290 .ok();
291 return Err(ConnectError::new(crate::ConnectErrorKind::MaxReconnects));
292 }
293 }
294
295 let delay = match target.delay {
298 Some(d) => d,
299 None => (self.options.reconnect_delay_callback)(self.attempts),
300 };
301 if !delay.is_zero() {
302 sleep(delay).await;
303 }
304
305 match self.try_connect_to_server(&target.addr).await {
306 Ok(result) => return Ok(result),
307 Err(inner) => match inner.kind() {
308 ConnectErrorKind::AuthorizationViolation
309 | ConnectErrorKind::Authentication => return Err(inner),
310 _ => {
311 tracing::debug!(
312 server = ?target.addr,
313 error = %inner,
314 "callback-selected server connection failed"
315 );
316 error.replace(inner);
317 }
318 },
319 }
320 return Err(error.unwrap());
321 } else {
322 tracing::warn!(
323 server = ?target.addr,
324 "reconnect callback returned server not in pool, using default selection"
325 );
326 self.events_tx
327 .try_send(Event::ClientError(ClientError::ServerNotInPool))
328 .ok();
329 }
331 }
332 }
334
335 let mut servers = self.servers.clone();
337 if !self.options.retain_servers_order {
338 servers.shuffle(&mut rng());
339 servers.sort_by_key(|a| a.failed_attempts);
341 }
342
343 for entry in servers {
344 self.attempts += 1;
345 if let Some(max_reconnects) = self.options.max_reconnects {
346 if self.attempts > max_reconnects {
347 tracing::error!(
348 attempts = %self.attempts,
349 max_reconnects = %max_reconnects,
350 "max reconnection attempts reached"
351 );
352 self.events_tx
353 .try_send(Event::ClientError(ClientError::MaxReconnects))
354 .ok();
355 return Err(ConnectError::new(crate::ConnectErrorKind::MaxReconnects));
356 }
357 }
358
359 let duration = (self.options.reconnect_delay_callback)(self.attempts);
360 tracing::debug!(
361 attempt = %self.attempts,
362 server = ?entry.addr,
363 delay_ms = %duration.as_millis(),
364 "attempting connection"
365 );
366
367 sleep(duration).await;
368
369 match self.try_connect_to_server(&entry.addr).await {
370 Ok(result) => return Ok(result),
371 Err(inner) => match inner.kind() {
372 ConnectErrorKind::AuthorizationViolation | ConnectErrorKind::Authentication => {
373 return Err(inner);
374 }
375 _ => {
376 tracing::debug!(
377 server = ?entry.addr,
378 error = %inner,
379 "connection attempt failed"
380 );
381 error.replace(inner);
382 }
383 },
384 }
385 }
386
387 Err(error.unwrap_or_else(|| {
388 ConnectError::with_source(
389 crate::ConnectErrorKind::Io,
390 "all connection attempts failed",
391 )
392 }))
393 }
394
395 async fn try_connect_to_server(
398 &mut self,
399 server_addr: &ServerAddr,
400 ) -> Result<(ServerInfo, Connection), ConnectError> {
401 let socket_addrs = server_addr
402 .socket_addrs()
403 .await
404 .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Dns, err))?;
405
406 let mut last_err = None;
407 for socket_addr in socket_addrs {
408 match tokio::time::timeout(
409 self.options.connection_timeout,
410 self.try_connect_to(
411 &socket_addr,
412 server_addr.tls_required(),
413 server_addr.clone(),
414 ),
415 )
416 .await
417 {
418 Ok(Ok((server_info, connection))) => {
419 tracing::info!(
420 server = %server_info.port,
421 max_payload = %server_info.max_payload,
422 "connected successfully"
423 );
424 self.attempts = 0;
425 self.connect_stats.connects.add(1, Ordering::Relaxed);
426 self.events_tx.try_send(Event::Connected).ok();
427 self.state_tx.send(State::Connected).ok();
428 self.max_payload.store(
429 server_info.max_payload,
430 std::sync::atomic::Ordering::Relaxed,
431 );
432 self.last_info = server_info.clone();
433
434 if let Some(entry) = self.servers.iter_mut().find(|s| s.addr == *server_addr) {
436 entry.did_connect = true;
437 entry.failed_attempts = 0;
438 entry.last_error = None;
439 }
440
441 return Ok((server_info, connection));
442 }
443
444 Ok(Err(inner)) => {
445 if let Some(entry) = self.servers.iter_mut().find(|s| s.addr == *server_addr) {
447 entry.failed_attempts += 1;
448 entry.last_error = Some(inner.to_string());
449 }
450 last_err = Some(inner);
451 }
452
453 Err(_) => {
454 tracing::debug!(
455 server = ?server_addr,
456 "connection handshake timed out"
457 );
458 if let Some(entry) = self.servers.iter_mut().find(|s| s.addr == *server_addr) {
459 entry.failed_attempts += 1;
460 entry.last_error = Some("timed out".to_string());
461 }
462 last_err = Some(ConnectError::new(crate::ConnectErrorKind::TimedOut));
463 }
464 }
465 }
466
467 Err(last_err.unwrap_or_else(|| {
468 ConnectError::with_source(crate::ConnectErrorKind::Dns, "no addresses resolved")
469 }))
470 }
471
472 pub(crate) async fn try_connect_to(
473 &mut self,
474 socket_addr: &SocketAddr,
475 tls_required: bool,
476 server_addr: ServerAddr,
477 ) -> Result<(ServerInfo, Connection), ConnectError> {
478 tracing::debug!(
479 socket_addr = %socket_addr,
480 tls_required = %tls_required,
481 "establishing connection"
482 );
483 let mut connection = match server_addr.scheme() {
484 #[cfg(feature = "websockets")]
485 "ws" => {
486 let ws = tokio_websockets::client::Builder::new()
487 .uri(server_addr.as_url_str())
488 .map_err(|err| {
489 ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err)
490 })?
491 .connect()
492 .await
493 .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Io, err))?;
494
495 let con = WebSocketAdapter::new(ws.0);
496 Connection::new(Box::new(con), 0, self.connect_stats.clone())
497 }
498 #[cfg(feature = "websockets")]
499 "wss" => {
500 let tls_config =
501 Arc::new(tls::config_tls(&self.options).await.map_err(|err| {
502 ConnectError::with_source(crate::ConnectErrorKind::Tls, err)
503 })?);
504 let tls_connector = tokio_rustls::TlsConnector::from(tls_config);
505 let ws = tokio_websockets::client::Builder::new()
506 .connector(&tokio_websockets::Connector::Rustls(tls_connector))
507 .uri(server_addr.as_url_str())
508 .map_err(|err| {
509 ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err)
510 })?
511 .connect()
512 .await
513 .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Io, err))?;
514 let con = WebSocketAdapter::new(ws.0);
515 Connection::new(Box::new(con), 0, self.connect_stats.clone())
516 }
517 _ => {
518 let tcp_stream = if let Some(local_addr) = self.options.local_address {
519 let socket = if local_addr.is_ipv4() {
520 TcpSocket::new_v4()?
521 } else {
522 TcpSocket::new_v6()?
523 };
524 socket.bind(local_addr).map_err(|err| {
525 ConnectError::with_source(crate::ConnectErrorKind::Io, err)
526 })?;
527 socket.connect(*socket_addr).await?
528 } else {
529 TcpStream::connect(socket_addr).await?
530 };
531 tcp_stream.set_nodelay(true)?;
532
533 Connection::new(
534 Box::new(tcp_stream),
535 self.options.read_buffer_capacity.into(),
536 self.connect_stats.clone(),
537 )
538 }
539 };
540
541 let tls_connection = |connection: Connection| async {
542 tracing::debug!("upgrading connection to TLS");
543 let tls_config = Arc::new(
544 tls::config_tls(&self.options)
545 .await
546 .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?,
547 );
548 let tls_connector = tokio_rustls::TlsConnector::from(tls_config);
549
550 let domain = crate::rustls::pki_types::ServerName::try_from(server_addr.host())
551 .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?;
552
553 let tls_stream = tls_connector
554 .connect(domain.to_owned(), connection.stream)
555 .await?;
556
557 Ok::<Connection, ConnectError>(Connection::new(
558 Box::new(tls_stream),
559 0,
560 self.connect_stats.clone(),
561 ))
562 };
563
564 if self.options.tls_first && !server_addr.is_websocket() {
568 connection = tls_connection(connection).await?;
569 }
570
571 let op = connection.read_op().await?;
572 let info = match op {
573 Some(ServerOp::Info(info)) => {
574 tracing::debug!(
575 server_id = %info.server_id,
576 version = %info.version,
577 "received server info"
578 );
579 info
580 }
581 Some(op) => {
582 tracing::error!(received_op = ?op, "expected INFO, got different operation");
583 return Err(ConnectError::with_source(
584 crate::ConnectErrorKind::Io,
585 format!("expected INFO, got {op:?}"),
586 ));
587 }
588 None => {
589 tracing::error!("expected INFO, got nothing");
590 return Err(ConnectError::with_source(
591 crate::ConnectErrorKind::Io,
592 "expected INFO, got nothing",
593 ));
594 }
595 };
596
597 if !self.options.tls_first
599 && !server_addr.is_websocket()
600 && (self.options.tls_required || info.tls_required || tls_required)
601 {
602 connection = tls_connection(connection).await?;
603 };
604
605 if !self.options.ignore_discovered_servers {
607 for url in &info.connect_urls {
608 let discovered_addr = url.parse::<ServerAddr>().map_err(|err| {
609 ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err)
610 })?;
611 if !self.servers.iter().any(|s| s.addr == discovered_addr) {
612 tracing::debug!(
613 discovered_url = %url,
614 "adding discovered server"
615 );
616 self.servers.push(Server::new_discovered(discovered_addr));
617 }
618 }
619 }
620
621 let tls_required = self.options.tls_required || server_addr.tls_required();
623 let mut connect_info = ConnectInfo {
624 tls_required,
625 name: self.options.name.clone(),
626 pedantic: false,
627 verbose: false,
628 lang: LANG.to_string(),
629 version: VERSION.to_string(),
630 protocol: Protocol::Dynamic,
631 user: self.options.auth.username.to_owned(),
632 pass: self.options.auth.password.to_owned(),
633 auth_token: self.options.auth.token.to_owned(),
634 user_jwt: None,
635 nkey: None,
636 signature: None,
637 echo: !self.options.no_echo,
638 headers: true,
639 no_responders: true,
640 };
641
642 #[cfg(feature = "nkeys")]
643 if let Some(nkey) = self.options.auth.nkey.as_ref() {
644 match nkeys::KeyPair::from_seed(nkey.as_str()) {
645 Ok(key_pair) => {
646 let nonce = info.nonce.clone();
647 match key_pair.sign(nonce.as_bytes()) {
648 Ok(signed) => {
649 connect_info.nkey = Some(key_pair.public_key());
650 connect_info.signature = Some(URL_SAFE_NO_PAD.encode(signed));
651 }
652 Err(_) => {
653 tracing::error!("failed to sign nonce with nkey");
654 return Err(ConnectError::new(crate::ConnectErrorKind::Authentication));
655 }
656 };
657 }
658 Err(_) => {
659 tracing::error!("failed to create key pair from nkey seed");
660 return Err(ConnectError::new(crate::ConnectErrorKind::Authentication));
661 }
662 }
663 }
664
665 #[cfg(feature = "nkeys")]
666 if let Some(jwt) = self.options.auth.jwt.as_ref() {
667 if let Some(sign_fn) = self.options.auth.signature_callback.as_ref() {
668 match sign_fn.call(info.nonce.clone()).await {
669 Ok(sig) => {
670 connect_info.user_jwt = Some(jwt.clone());
671 connect_info.signature = Some(sig);
672 }
673 Err(_) => {
674 tracing::error!("failed to sign nonce with JWT callback");
675 return Err(ConnectError::new(crate::ConnectErrorKind::Authentication));
676 }
677 }
678 }
679 }
680
681 if let Some(callback) = self.options.auth_callback.as_ref() {
682 let auth: crate::Auth = callback
683 .call(info.nonce.as_bytes().to_vec())
684 .await
685 .map_err(|err| {
686 tracing::error!(error = %err, "auth callback failed");
687 ConnectError::with_source(crate::ConnectErrorKind::Authentication, err)
688 })?;
689 connect_info.user = auth.username;
690 connect_info.pass = auth.password;
691 connect_info.user_jwt = auth.jwt;
692 #[cfg(feature = "nkeys")]
693 {
694 connect_info.signature = auth
695 .signature
696 .map(|signature| URL_SAFE_NO_PAD.encode(signature));
697 }
698 #[cfg(not(feature = "nkeys"))]
699 {
700 if auth.signature.is_some() {
701 tracing::error!("signature authentication requires 'nkeys' feature");
702 return Err(ConnectError::new(crate::ConnectErrorKind::Authentication));
703 }
704 connect_info.signature = None;
705 }
706 connect_info.auth_token = auth.token;
707 connect_info.nkey = auth.nkey;
708 }
709
710 connection
712 .easy_write_and_flush([ClientOp::Connect(connect_info), ClientOp::Ping].iter())
713 .await?;
714
715 match connection.read_op().await? {
716 Some(ServerOp::Error(err)) => match err {
717 ServerError::AuthorizationViolation => {
718 tracing::error!(error = %err, "authorization violation");
719 Err(ConnectError::with_source(
720 crate::ConnectErrorKind::AuthorizationViolation,
721 err,
722 ))
723 }
724 err => {
725 tracing::error!(error = %err, "server error during connection");
726 Err(ConnectError::with_source(crate::ConnectErrorKind::Io, err))
727 }
728 },
729 Some(_) => Ok((*info, connection)),
730 None => {
731 tracing::error!("connection closed unexpectedly");
732 Err(ConnectError::with_source(
733 crate::ConnectErrorKind::Io,
734 "broken pipe",
735 ))
736 }
737 }
738 }
739}
740
741#[cfg(test)]
742mod tests {
743 use super::*;
744
745 #[test]
746 fn reconnect_delay_callback_duration() {
747 let duration = reconnect_delay_callback_default(0);
748 assert_eq!(duration.as_millis(), 0);
749
750 let duration = reconnect_delay_callback_default(1);
751 assert_eq!(duration.as_millis(), 0);
752
753 let duration = reconnect_delay_callback_default(4);
754 assert_eq!(duration.as_millis(), 8);
755
756 let duration = reconnect_delay_callback_default(12);
757 assert_eq!(duration.as_millis(), 2048);
758
759 let duration = reconnect_delay_callback_default(13);
760 assert_eq!(duration.as_millis(), 4000);
761
762 let duration = reconnect_delay_callback_default(50);
764 assert_eq!(duration.as_millis(), 4000);
765 }
766}