trz_gateway_client/
client.rs1use std::sync::Arc;
4use std::sync::Mutex;
5use std::sync::atomic::AtomicBool;
6use std::sync::atomic::Ordering::SeqCst;
7use std::time::Instant;
8
9use connect::ConnectError;
10use futures::FutureExt;
11use futures::future::Shared;
12use nameth::NamedEnumValues as _;
13use nameth::nameth;
14use tokio::sync::oneshot;
15use tracing::Instrument;
16use tracing::info;
17use tracing::info_span;
18use tracing::warn;
19use trz_gateway_common::declare_identifier;
20use trz_gateway_common::handle::ServerHandle;
21use trz_gateway_common::id::ClientId;
22use trz_gateway_common::id::ClientName;
23use trz_gateway_common::retry_strategy::RetryStrategy;
24use trz_gateway_common::security_configuration::certificate::CertificateConfig;
25use trz_gateway_common::security_configuration::certificate::tls_server::ToTlsServer as _;
26use trz_gateway_common::security_configuration::certificate::tls_server::ToTlsServerError;
27use trz_gateway_common::security_configuration::custom_server_certificate_verifier::ChainOnlyServerCertificateVerifier;
28use trz_gateway_common::security_configuration::trusted_store::TrustedStoreConfig;
29use trz_gateway_common::security_configuration::trusted_store::tls_client::ToTlsClient as _;
30use trz_gateway_common::security_configuration::trusted_store::tls_client::ToTlsClientError;
31use uuid::Uuid;
32
33use self::config::SniOverrideError;
34use self::config::url;
35use self::service::ClientService;
36use crate::tunnel_config::TunnelConfig;
37
38pub mod certificate;
39pub mod config;
40pub mod connect;
41mod connection;
42mod health;
43pub mod service;
44
45pub struct Client {
51 pub client_name: ClientName,
53
54 uri: String,
56
57 sni_override: Option<String>,
59
60 tls_client: tokio_tungstenite::Connector,
64
65 tls_server: tokio_rustls::TlsAcceptor,
69
70 client_service: Arc<dyn ClientService>,
72
73 retry_strategy: RetryStrategy,
75
76 current_auth_code: Arc<Mutex<AuthCode>>,
80}
81
82declare_identifier!(AuthCode);
83
84impl Client {
85 pub fn new<C: TunnelConfig>(config: C) -> Result<Arc<Self>, NewClientError<C>> {
87 let tls_client = config
88 .gateway_pki()
89 .to_tls_client(ChainOnlyServerCertificateVerifier)?;
90 let tls_server = config.client_certificate().to_tls_server()?;
91 let client_name = config.client_name();
92 let tunnel_path = format!("/remote/tunnel/{client_name}");
93 Ok(Arc::new(Client {
94 client_name,
95 uri: url(&config, &tunnel_path)?.to_string(),
96 sni_override: config.sni_override().map(ToOwned::to_owned),
97 tls_client: tokio_tungstenite::Connector::Rustls(tls_client.into()),
98 tls_server: tokio_rustls::TlsAcceptor::from(tls_server),
99 client_service: Arc::new(config.client_service()),
100 retry_strategy: config.retry_strategy(),
101 current_auth_code: config.current_auth_code(),
102 }))
103 }
104
105 pub async fn run(self: &Arc<Self>) -> Result<ServerHandle<()>, ConnectError> {
107 let this = self.clone();
108 let client_name = &this.client_name;
109 let span = info_span!("Run", %client_name);
110 async move {
111 let client_id = ClientId::from(Uuid::new_v4().to_string());
112 info!(%client_id, "Allocated new client id");
113 let (shutdown_rx, terminated_tx, handle) = ServerHandle::new("Client");
114 let (serving_tx, serving_rx) = oneshot::channel();
115 let task = run_impl(this, client_id, serving_tx, shutdown_rx, terminated_tx);
116 tokio::spawn(task.in_current_span());
117 let _ = serving_rx.await;
118 Ok(handle)
119 }
120 .instrument(span)
121 .await
122 }
123}
124
125async fn run_impl(
126 this: Arc<Client>,
127 client_id: ClientId,
128
129 serving_tx: oneshot::Sender<()>,
131
132 shutdown_rx: impl Future<Output = ()> + Send + 'static,
134
135 terminated_tx: oneshot::Sender<()>,
137) {
138 scopeguard::defer! { let _ = terminated_tx.send(()); };
139 let retry_strategy0 = this.retry_strategy.clone();
140 let mut retry_strategy = retry_strategy0.clone();
141 let shutdown_rx = shutdown_rx.shared();
142
143 let is_shutdown = is_shutdown(shutdown_rx.clone());
144
145 let mut serving_tx: Option<oneshot::Sender<()>> = Some(serving_tx);
146 loop {
147 let start = Instant::now();
148 let result = this
149 .connect(
150 client_id.clone(),
151 shutdown_rx.clone(),
152 retry_strategy.peek() / 2,
153 &mut serving_tx,
154 )
155 .await;
156 if is_shutdown.load(SeqCst) {
157 return;
158 }
159 let uptime = Instant::now() - start;
160 if uptime < retry_strategy0.max_delay() {
161 match result {
162 Ok(()) => {
163 info! { "Connection closed, retrying in {}...", humantime::format_duration(retry_strategy.peek()) }
164 }
165 Err(error) => {
166 warn! { %error, "Connection failed, retrying in {}...", humantime::format_duration(retry_strategy.peek()) }
167 }
168 }
169 if let futures::future::Either::Right(((), _retry_strategy_wait)) =
170 futures::future::select(Box::pin(retry_strategy.wait()), shutdown_rx.clone()).await
171 {
172 return;
173 }
174 } else {
175 retry_strategy = retry_strategy0.clone();
176 }
177 }
178}
179
180fn is_shutdown(shutdown_rx: Shared<impl Future<Output = ()> + Send + 'static>) -> Arc<AtomicBool> {
181 let is_shutdown = Arc::new(AtomicBool::new(false));
182 tokio::spawn({
183 let is_shutdown = is_shutdown.clone();
184 async move {
185 let _ = shutdown_rx.await;
186 is_shutdown.store(true, SeqCst);
187 }
188 });
189 return is_shutdown;
190}
191
192#[nameth]
193#[derive(thiserror::Error, Debug)]
194pub enum NewClientError<C: TunnelConfig> {
195 #[error("[{n}] {0}", n = self.name())]
196 SniOverride(#[from] SniOverrideError),
197
198 #[error("[{n}] {0}", n = self.name())]
199 ToTlsClient(#[from] ToTlsClientError<<C::GatewayPki as TrustedStoreConfig>::Error>),
200
201 #[error("[{n}] {0}", n = self.name())]
202 ToTlsServer(#[from] ToTlsServerError<<C::ClientCertificate as CertificateConfig>::Error>),
203}