1use std::cmp;
30use std::convert::TryFrom;
31use std::ops::Deref;
32use std::sync::Arc;
33
34use bitcoin::Network;
35use log::warn;
36use tokio::sync::RwLock;
37use tonic::service::interceptor::InterceptedService;
38
39use ark::ArkInfo;
40
41use crate::{mailbox, protos, ArkServiceClient, ConvertError, RequestExt};
42
43#[cfg(all(feature = "tonic-native", feature = "tonic-web"))]
44compile_error!("features `tonic-native` and `tonic-web` are mutually exclusive");
45
46#[cfg(not(any(feature = "tonic-native", feature = "tonic-web")))]
47compile_error!("either `tonic-native` or `tonic-web` feature must be enabled");
48
49#[cfg(all(feature = "tonic-web", feature = "socks5-proxy"))]
50compile_error!("`tonic-web` does not support the `socks5-proxy` feature");
51
52#[cfg(feature = "tonic-native")]
53mod transport {
54 use std::str::FromStr;
55 use std::time::Duration;
56
57 use http::Uri;
58 use log::info;
59 use tonic::transport::{Channel, Endpoint};
60
61 use super::CreateEndpointError;
62
63 pub type Transport = Channel;
64
65 pub async fn connect(address: &str) -> Result<Transport, CreateEndpointError> {
71 Ok(create_endpoint(address)?.connect().await?)
72 }
73
74 #[cfg(feature = "socks5-proxy")]
76 pub async fn connect_with_proxy(
77 address: &str,
78 proxy: &str,
79 ) -> Result<Transport, CreateEndpointError> {
80 use hyper_socks2::SocksConnector;
81 use hyper_util::client::legacy::connect::HttpConnector;
82
83 let endpoint = create_endpoint(address)?;
84 let proxy_uri = proxy.parse::<Uri>().map_err(CreateEndpointError::InvalidProxyUri)?;
85 let connector = {
86 let mut http = HttpConnector::new();
89 http.enforce_http(false);
90 SocksConnector {
91 proxy_addr: proxy_uri,
92 auth: None,
93 connector: http,
94 }
95 };
96 info!("Connecting to Ark server via SOCKS5 proxy {}...", proxy);
97 Ok(endpoint.connect_with_connector(connector).await?)
98 }
99
100 fn create_endpoint(address: &str) -> Result<Endpoint, CreateEndpointError> {
104 let uri = Uri::from_str(address)?;
105
106 let scheme = uri.scheme_str().unwrap_or("");
107 if scheme != "http" && scheme != "https" {
108 return Err(CreateEndpointError::InvalidScheme(scheme.to_string()));
109 }
110
111 #[warn(unused_mut)]
113 let mut endpoint = Channel::builder(uri.clone())
114 .keep_alive_timeout(Duration::from_secs(600))
115 .timeout(Duration::from_secs(600));
116
117 #[cfg(any(feature = "tls-native-roots", feature = "tls-webpki-roots"))]
118 if scheme == "https" {
119 use tonic::transport::ClientTlsConfig;
120
121 info!("Connecting to Ark server at {} using TLS...", address);
122 let uri_auth = uri.clone().into_parts().authority
123 .ok_or(CreateEndpointError::MissingAuthority)?;
124 let domain = uri_auth.host();
125
126 let tls_config = ClientTlsConfig::new()
127 .with_enabled_roots()
128 .domain_name(domain);
129 endpoint = endpoint.tls_config(tls_config).map_err(CreateEndpointError::Transport)?;
130 return Ok(endpoint);
131 }
132 #[cfg(not(any(feature = "tls-native-roots", feature = "tls-webpki-roots")))]
133 if scheme == "https" {
134 return Err(CreateEndpointError::InvalidScheme(
135 "Missing TLS roots, https is unsupported".to_owned(),
136 ));
137 }
138 info!("Connecting to Ark server at {} without TLS...", address);
139 Ok(endpoint)
140 }
141}
142
143#[cfg(feature = "tonic-web")]
144mod transport {
145 use super::CreateEndpointError;
146 use tonic_web_wasm_client::Client as WasmClient;
147
148 pub type Transport = WasmClient;
149
150 pub async fn connect(address: &str) -> Result<Transport, CreateEndpointError> {
151 Ok(tonic_web_wasm_client::Client::new(address.to_string()))
152 }
153}
154
155pub const MIN_PROTOCOL_VERSION: u64 = 1;
159
160pub const MAX_PROTOCOL_VERSION: u64 = 1;
164
165pub const ARK_INFO_TTL_SECS: u64 = 10 * 60;
169
170#[derive(Debug, thiserror::Error)]
171#[error("failed to create gRPC endpoint: {msg}")]
172pub enum CreateEndpointError {
173 #[error("failed to parse Ark server as a URI")]
174 InvalidUri(#[from] http::uri::InvalidUri),
175 #[error("Ark server scheme must be either http or https. Found: {0}")]
176 InvalidScheme(String),
177 #[error("Ark server URI is missing an authority part")]
178 MissingAuthority,
179 #[cfg(feature = "tonic-native")]
180 #[error(transparent)]
181 Transport(#[from] tonic::transport::Error),
182 #[cfg(feature = "socks5-proxy")]
183 #[error("invalid SOCKS5 proxy URI: {0:#}")]
184 InvalidProxyUri(http::uri::InvalidUri),
185}
186
187#[derive(Debug, thiserror::Error)]
188#[error("failed to connect to Ark server: {msg}")]
189pub enum ConnectError {
190 #[error(transparent)]
191 CreateEndpoint(#[from] CreateEndpointError),
192 #[error("handshake request failed: {0}")]
193 Handshake(tonic::Status),
194 #[error("version mismatch. Client max is: {client_max}, server min is: {server_min}")]
195 ProtocolVersionMismatchClientTooOld { client_max: u64, server_min: u64 },
196 #[error("version mismatch. Client min is: {client_min}, server max is: {server_max}")]
197 ProtocolVersionMismatchServerTooOld { client_min: u64, server_max: u64 },
198 #[error("error getting ark info: {0}")]
199 GetArkInfo(tonic::Status),
200 #[error("invalid ark info from ark server: {0}")]
201 InvalidArkInfo(#[from] ConvertError),
202 #[error("network mismatch. Expected: {expected}, Got: {got}")]
203 NetworkMismatch { expected: Network, got: Network },
204 #[error("tokio channel error: {0}")]
205 Tokio(#[from] tokio::sync::oneshot::error::RecvError),
206}
207
208#[derive(Clone)]
214pub struct ProtocolVersionInterceptor {
215 pver: u64,
216}
217
218impl tonic::service::Interceptor for ProtocolVersionInterceptor {
219 fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
220 req.set_pver(self.pver);
221 Ok(req)
222 }
223}
224
225pub struct ArkInfoHandle {
229 pub info: ArkInfo,
230 pub waiter: Option<tokio::sync::oneshot::Receiver<Result<ArkInfo, ConnectError>>>,
231}
232
233impl Deref for ArkInfoHandle {
234 type Target = ArkInfo;
235
236 fn deref(&self) -> &Self::Target {
237 &self.info
238 }
239}
240
241pub struct ServerInfo {
242 pub pver: u64,
246 pub info: ArkInfo,
248 pub refresh_at_secs: u64,
250}
251
252impl ServerInfo {
253 fn ttl() -> u64 {
255 ark::time::timestamp_secs() + ARK_INFO_TTL_SECS
256 }
257
258 pub fn new(pver: u64, info: ArkInfo) -> Self {
259 Self { pver, info, refresh_at_secs: Self::ttl() }
260 }
261
262 pub fn update(&mut self, info: ArkInfo) {
263 self.info = info;
264 self.refresh_at_secs = Self::ttl();
265 }
266
267 pub fn is_outdated(&self) -> bool {
269 ark::time::timestamp_secs() > self.refresh_at_secs
270 }
271}
272
273#[derive(Clone)]
280pub struct ServerConnection {
281 info: Arc<RwLock<ServerInfo>>,
282 pub client: ArkServiceClient<InterceptedService<transport::Transport, ProtocolVersionInterceptor>>,
284 pub mailbox_client: mailbox::MailboxServiceClient<InterceptedService<transport::Transport, ProtocolVersionInterceptor>>,
286}
287
288impl ServerConnection {
289 fn handshake_req() -> protos::HandshakeRequest {
290 protos::HandshakeRequest {
291 bark_version: Some(env!("CARGO_PKG_VERSION").into()),
292 }
293 }
294
295 pub async fn connect(
313 address: &str,
314 network: Network,
315 ) -> Result<ServerConnection, ConnectError> {
316 let transport = transport::connect(address).await?;
317 Self::connect_inner(transport, network).await
318 }
319
320 #[cfg(feature = "socks5-proxy")]
322 pub async fn connect_via_proxy(
323 address: &str,
324 network: Network,
325 proxy: &str,
326 ) -> Result<ServerConnection, ConnectError> {
327 let transport = transport::connect_with_proxy(address, proxy).await?;
328 Self::connect_inner(transport, network).await
329 }
330
331 async fn connect_inner(
332 transport: transport::Transport,
333 network: Network,
334 ) -> Result<ServerConnection, ConnectError> {
335 let mut handshake_client = ArkServiceClient::new(transport.clone());
336 let handshake = handshake_client.handshake(Self::handshake_req()).await
337 .map_err(ConnectError::Handshake)?.into_inner();
338
339 let pver = check_handshake(handshake)?;
340
341 let interceptor = ProtocolVersionInterceptor { pver };
342 let mut client = ArkServiceClient::with_interceptor(transport.clone(), interceptor.clone())
343 .max_decoding_message_size(64 * 1024 * 1024); let info = client.ark_info(network).await?;
346
347 let mailbox_client = mailbox::MailboxServiceClient::with_interceptor(transport, interceptor)
348 .max_decoding_message_size(64 * 1024 * 1024); let info = Arc::new(RwLock::new(ServerInfo::new(pver, info)));
351 Ok(ServerConnection {
352 info,
353 client,
354 mailbox_client,
355 })
356 }
357
358 pub async fn check_connection(&self) -> Result<(), ConnectError> {
360 let mut client = self.client.clone();
361 let handshake = client.handshake(Self::handshake_req()).await
362 .map_err(ConnectError::Handshake)?.into_inner();
363 check_handshake(handshake)?;
364 Ok(())
365 }
366
367 pub async fn ark_info(&self) -> Result<ArkInfo, ConnectError> {
375 let mut current = self.info.write().await;
376
377 let new_info = self.client.clone().ark_info(current.info.network).await?;
378 if current.is_outdated() {
379 current.update(new_info.clone());
380 return Ok(new_info);
381 }
382
383 Ok(current.info.clone())
384 }
385}
386trait ArkServiceClientExt {
387 async fn ark_info(&mut self, network: Network) -> Result<ArkInfo, ConnectError>;
388}
389
390impl ArkServiceClientExt for ArkServiceClient<InterceptedService<transport::Transport, ProtocolVersionInterceptor>> {
391 async fn ark_info(&mut self, network: Network) -> Result<ArkInfo, ConnectError> {
392 let res = self.get_ark_info(protos::Empty {}).await
393 .map_err(ConnectError::GetArkInfo)?;
394 let info = ArkInfo::try_from(res.into_inner())
395 .map_err(ConnectError::InvalidArkInfo)?;
396 if network != info.network {
397 return Err(ConnectError::NetworkMismatch { expected: network, got: info.network });
398 }
399
400 Ok(info)
401 }
402}
403
404fn check_handshake(handshake: protos::HandshakeResponse) -> Result<u64, ConnectError> {
405 if let Some(ref msg) = handshake.psa {
406 warn!("Message from Ark server: \"{}\"", msg);
407 }
408
409 if MAX_PROTOCOL_VERSION < handshake.min_protocol_version {
410 return Err(ConnectError::ProtocolVersionMismatchClientTooOld {
411 client_max: MAX_PROTOCOL_VERSION, server_min: handshake.min_protocol_version
412 });
413 }
414 if MIN_PROTOCOL_VERSION > handshake.max_protocol_version {
415 return Err(ConnectError::ProtocolVersionMismatchServerTooOld {
416 client_min: MIN_PROTOCOL_VERSION, server_max: handshake.max_protocol_version
417 });
418 }
419
420 let pver = cmp::min(MAX_PROTOCOL_VERSION, handshake.max_protocol_version);
421 assert!((MIN_PROTOCOL_VERSION..=MAX_PROTOCOL_VERSION).contains(&pver));
422 assert!((handshake.min_protocol_version..=handshake.max_protocol_version).contains(&pver));
423
424 Ok(pver)
425}