1use std::cmp;
30use std::convert::TryFrom;
31use std::ops::Deref;
32use std::sync::Arc;
33
34use bitcoin::Network;
35use log::warn;
36use tokio::sync::RwLock;
37use tonic::metadata::AsciiMetadataValue;
38use tonic::metadata::errors::InvalidMetadataValue;
39use tonic::service::interceptor::{InterceptedService, Interceptor};
40
41use ark::ArkInfo;
42
43use crate::{mailbox, protos, ArkServiceClient, ConvertError, RequestExt};
44
45
46#[cfg(all(feature = "tonic-native", feature = "tonic-web"))]
47compile_error!("features `tonic-native` and `tonic-web` are mutually exclusive");
48
49#[cfg(not(any(feature = "tonic-native", feature = "tonic-web")))]
50compile_error!("either `tonic-native` or `tonic-web` feature must be enabled");
51
52#[cfg(all(feature = "tonic-web", feature = "socks5-proxy"))]
53compile_error!("`tonic-web` does not support the `socks5-proxy` feature");
54
55
56pub const ACCESS_TOKEN_HEADER: &str = "ark-access-token";
58
59
60#[cfg(feature = "tonic-native")]
61mod transport {
62 use std::str::FromStr;
63 use std::time::Duration;
64
65 use http::Uri;
66 use log::info;
67 use tonic::transport::{Channel, Endpoint};
68
69 use super::CreateEndpointError;
70
71 pub type Transport = Channel;
72
73 pub async fn connect(address: &str) -> Result<Transport, CreateEndpointError> {
79 Ok(create_endpoint(address)?.connect().await?)
80 }
81
82 #[cfg(feature = "socks5-proxy")]
84 pub async fn connect_with_proxy(
85 address: &str,
86 proxy: &str,
87 ) -> Result<Transport, CreateEndpointError> {
88 use hyper_socks2::SocksConnector;
89 use hyper_util::client::legacy::connect::HttpConnector;
90
91 let endpoint = create_endpoint(address)?;
92 let proxy_uri = proxy.parse::<Uri>().map_err(CreateEndpointError::InvalidProxyUri)?;
93 let connector = {
94 let mut http = HttpConnector::new();
97 http.enforce_http(false);
98 SocksConnector {
99 proxy_addr: proxy_uri,
100 auth: None,
101 connector: http,
102 }
103 };
104 info!("Connecting to Ark server via SOCKS5 proxy {}...", proxy);
105 Ok(endpoint.connect_with_connector(connector).await?)
106 }
107
108 fn create_endpoint(address: &str) -> Result<Endpoint, CreateEndpointError> {
112 let uri = Uri::from_str(address)?;
113
114 let scheme = uri.scheme_str().unwrap_or("");
115 if scheme != "http" && scheme != "https" {
116 return Err(CreateEndpointError::InvalidScheme(scheme.to_string()));
117 }
118
119 #[warn(unused_mut)]
121 let mut endpoint = Channel::builder(uri.clone())
122 .keep_alive_timeout(Duration::from_secs(600))
123 .timeout(Duration::from_secs(600));
124
125 #[cfg(any(feature = "tls-native-roots", feature = "tls-webpki-roots"))]
126 if scheme == "https" {
127 use tonic::transport::ClientTlsConfig;
128
129 info!("Connecting to Ark server at {} using TLS...", address);
130 let uri_auth = uri.clone().into_parts().authority
131 .ok_or(CreateEndpointError::MissingAuthority)?;
132 let domain = uri_auth.host();
133
134 let tls_config = ClientTlsConfig::new()
135 .with_enabled_roots()
136 .domain_name(domain);
137 endpoint = endpoint.tls_config(tls_config).map_err(CreateEndpointError::Transport)?;
138 return Ok(endpoint);
139 }
140 #[cfg(not(any(feature = "tls-native-roots", feature = "tls-webpki-roots")))]
141 if scheme == "https" {
142 return Err(CreateEndpointError::InvalidScheme(
143 "Missing TLS roots, https is unsupported".to_owned(),
144 ));
145 }
146 info!("Connecting to Ark server at {} without TLS...", address);
147 Ok(endpoint)
148 }
149}
150
151#[cfg(feature = "tonic-web")]
152mod transport {
153 use super::CreateEndpointError;
154 use tonic_web_wasm_client::Client as WasmClient;
155
156 pub type Transport = WasmClient;
157
158 pub async fn connect(address: &str) -> Result<Transport, CreateEndpointError> {
159 Ok(tonic_web_wasm_client::Client::new(address.to_string()))
160 }
161}
162
163pub const MIN_PROTOCOL_VERSION: u64 = 1;
167
168pub const MAX_PROTOCOL_VERSION: u64 = 1;
172
173pub const ARK_INFO_TTL_SECS: u64 = 10 * 60;
177
178#[derive(Debug, thiserror::Error)]
179#[error("failed to create gRPC endpoint: {msg}")]
180pub enum CreateEndpointError {
181 #[error("failed to parse Ark server as a URI")]
182 InvalidUri(#[from] http::uri::InvalidUri),
183 #[error("Ark server scheme must be either http or https. Found: {0}")]
184 InvalidScheme(String),
185 #[error("Ark server URI is missing an authority part")]
186 MissingAuthority,
187 #[cfg(feature = "tonic-native")]
188 #[error(transparent)]
189 Transport(#[from] tonic::transport::Error),
190 #[cfg(feature = "socks5-proxy")]
191 #[error("invalid SOCKS5 proxy URI: {0:#}")]
192 InvalidProxyUri(http::uri::InvalidUri),
193}
194
195#[derive(Debug, thiserror::Error)]
196#[error("failed to connect to Ark server: {msg}")]
197pub enum ConnectError {
198 #[error("missing info '{0}' to connect")]
199 MissingInfo(&'static str),
200 #[error("invalid access token: {0}")]
201 InvalidAccessToken(#[from] #[source] InvalidMetadataValue),
202 #[error(transparent)]
203 CreateEndpoint(#[from] CreateEndpointError),
204 #[error("handshake request failed: {0}")]
205 Handshake(tonic::Status),
206 #[error("version mismatch. Client max is: {client_max}, server min is: {server_min}")]
207 ProtocolVersionMismatchClientTooOld { client_max: u64, server_min: u64 },
208 #[error("version mismatch. Client min is: {client_min}, server max is: {server_max}")]
209 ProtocolVersionMismatchServerTooOld { client_min: u64, server_max: u64 },
210 #[error("error getting ark info: {0}")]
211 GetArkInfo(tonic::Status),
212 #[error("invalid ark info from ark server: {0}")]
213 InvalidArkInfo(#[from] ConvertError),
214 #[error("network mismatch. Expected: {expected}, Got: {got}")]
215 NetworkMismatch { expected: Network, got: Network },
216 #[error("tokio channel error: {0}")]
217 Tokio(#[from] tokio::sync::oneshot::error::RecvError),
218}
219
220#[derive(Clone)]
226#[deprecated(since = "0.1.3", note = "should not be used directly")]
227pub struct ProtocolVersionInterceptor {
228 pver: u64,
229}
230
231#[allow(deprecated)]
232impl tonic::service::Interceptor for ProtocolVersionInterceptor {
233 fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
234 #[allow(deprecated)]
235 req.set_pver(self.pver);
236 Ok(req)
237 }
238}
239
240#[derive(Clone)]
245pub struct ArkServiceInterceptor {
246 pver: Option<u64>,
247 access_token: Option<AsciiMetadataValue>,
248}
249
250impl tonic::service::Interceptor for ArkServiceInterceptor {
251 fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
252 if let Some(pver) = self.pver {
253 req.set_pver(pver);
254 }
255 if let Some(ref access_token) = self.access_token {
256 req.metadata_mut().insert(ACCESS_TOKEN_HEADER, access_token.clone());
257 }
258 Ok(req)
259 }
260}
261
262pub struct ArkInfoHandle {
266 pub info: ArkInfo,
267 pub waiter: Option<tokio::sync::oneshot::Receiver<Result<ArkInfo, ConnectError>>>,
268}
269
270impl Deref for ArkInfoHandle {
271 type Target = ArkInfo;
272
273 fn deref(&self) -> &Self::Target {
274 &self.info
275 }
276}
277
278pub struct ServerInfo {
279 pub pver: u64,
283 pub info: ArkInfo,
285 pub refresh_at_secs: u64,
287}
288
289impl ServerInfo {
290 fn ttl() -> u64 {
292 ark::time::timestamp_secs() + ARK_INFO_TTL_SECS
293 }
294
295 pub fn new(pver: u64, info: ArkInfo) -> Self {
296 Self { pver, info, refresh_at_secs: Self::ttl() }
297 }
298
299 pub fn update(&mut self, info: ArkInfo) {
300 self.info = info;
301 self.refresh_at_secs = Self::ttl();
302 }
303
304 pub fn is_outdated(&self) -> bool {
306 ark::time::timestamp_secs() > self.refresh_at_secs
307 }
308}
309
310#[derive(Default)]
311pub struct ServerConnectionBuilder {
312 address: Option<String>,
313 network: Option<Network>,
314 #[cfg(feature = "socks5-proxy")]
315 proxy: Option<String>,
316 access_token: Option<String>,
317}
318
319impl ServerConnectionBuilder {
320 pub fn address(mut self, address: impl Into<String>) -> Self {
321 self.address = Some(address.into());
322 self
323 }
324
325 pub fn network(mut self, network: Network) -> Self {
326 self.network = Some(network);
327 self
328 }
329
330 #[cfg(feature = "socks5-proxy")]
331 pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
332 self.proxy = Some(proxy.into());
333 self
334 }
335
336 pub fn access_token(mut self, access_token: impl Into<String>) -> Self {
337 self.access_token = Some(access_token.into());
338 self
339 }
340
341 pub async fn connect(self) -> Result<ServerConnection, ConnectError> {
342 ServerConnection::inner_connect(self).await
343 }
344}
345
346#[derive(Clone)]
353pub struct ServerConnection {
354 info: Arc<RwLock<ServerInfo>>,
355 pub client: ArkServiceClient<InterceptedService<transport::Transport, ArkServiceInterceptor>>,
357 pub mailbox_client: mailbox::MailboxServiceClient<InterceptedService<transport::Transport, ArkServiceInterceptor>>,
359}
360
361impl ServerConnection {
362 fn handshake_req() -> protos::HandshakeRequest {
363 protos::HandshakeRequest {
364 bark_version: Some(env!("CARGO_PKG_VERSION").into()),
365 }
366 }
367
368 pub fn builder() -> ServerConnectionBuilder {
386 ServerConnectionBuilder::default()
387 }
388
389 async fn inner_connect(builder: ServerConnectionBuilder) -> Result<ServerConnection, ConnectError> {
391 let address = builder.address.ok_or(ConnectError::MissingInfo("address"))?;
392 let network = builder.network.ok_or(ConnectError::MissingInfo("network"))?;
393
394 #[cfg(feature = "socks5-proxy")]
395 let transport = if let Some(proxy) = builder.proxy {
396 transport::connect_with_proxy(&address, &proxy).await?
397 } else {
398 transport::connect(&address).await?
399 };
400 #[cfg(not(feature = "socks5-proxy"))]
401 let transport = transport::connect(&address).await?;
402
403 let mut interceptor = ArkServiceInterceptor {
404 pver: None,
405 access_token: builder.access_token.map(|t| t.try_into()).transpose()?,
406 };
407
408 let mut handshake_client = ArkServiceClient::with_interceptor(transport.clone(), interceptor.clone());
409 let handshake = handshake_client.handshake(Self::handshake_req()).await
410 .map_err(ConnectError::Handshake)?.into_inner();
411
412 let pver = check_handshake(handshake)?;
413 interceptor.pver = Some(pver);
414
415 let mut client = ArkServiceClient::with_interceptor(transport.clone(), interceptor.clone())
416 .max_decoding_message_size(64 * 1024 * 1024); let info = client.ark_info(network).await?;
419
420 let mailbox_client = mailbox::MailboxServiceClient::with_interceptor(transport, interceptor)
421 .max_decoding_message_size(64 * 1024 * 1024); let info = Arc::new(RwLock::new(ServerInfo::new(pver, info)));
424 Ok(ServerConnection {
425 info,
426 client,
427 mailbox_client,
428 })
429 }
430
431 #[deprecated(since = "0.1.3", note = "use builder() instead")]
432 pub async fn connect(
433 address: &str,
434 network: Network,
435 ) -> Result<ServerConnection, ConnectError> {
436 Self::builder().address(address).network(network).connect().await
437 }
438
439 #[cfg(feature = "socks5-proxy")]
440 #[deprecated(since = "0.1.3", note = "use builder() instead")]
441 pub async fn connect_via_proxy(
442 address: &str,
443 network: Network,
444 proxy: &str,
445 ) -> Result<ServerConnection, ConnectError> {
446 Self::builder().address(address).network(network).proxy(proxy).connect().await
447 }
448
449 pub async fn check_connection(&self) -> Result<(), ConnectError> {
451 let mut client = self.client.clone();
452 let handshake = client.handshake(Self::handshake_req()).await
453 .map_err(ConnectError::Handshake)?.into_inner();
454 check_handshake(handshake)?;
455 Ok(())
456 }
457
458 pub async fn ark_info(&self) -> Result<ArkInfo, ConnectError> {
466 let mut current = self.info.write().await;
467
468 let new_info = self.client.clone().ark_info(current.info.network).await?;
469 if current.is_outdated() {
470 current.update(new_info.clone());
471 return Ok(new_info);
472 }
473
474 Ok(current.info.clone())
475 }
476}
477trait ArkServiceClientExt {
478 async fn ark_info(&mut self, network: Network) -> Result<ArkInfo, ConnectError>;
479}
480
481impl<I: Interceptor> ArkServiceClientExt for ArkServiceClient<InterceptedService<transport::Transport, I>> {
482 async fn ark_info(&mut self, network: Network) -> Result<ArkInfo, ConnectError> {
483 let res = self.get_ark_info(protos::Empty {}).await
484 .map_err(ConnectError::GetArkInfo)?;
485 let info = ArkInfo::try_from(res.into_inner())
486 .map_err(ConnectError::InvalidArkInfo)?;
487 if network != info.network {
488 return Err(ConnectError::NetworkMismatch { expected: network, got: info.network });
489 }
490
491 Ok(info)
492 }
493}
494
495fn check_handshake(handshake: protos::HandshakeResponse) -> Result<u64, ConnectError> {
496 if let Some(ref msg) = handshake.psa {
497 warn!("Message from Ark server: \"{}\"", msg);
498 }
499
500 if MAX_PROTOCOL_VERSION < handshake.min_protocol_version {
501 return Err(ConnectError::ProtocolVersionMismatchClientTooOld {
502 client_max: MAX_PROTOCOL_VERSION, server_min: handshake.min_protocol_version
503 });
504 }
505 if MIN_PROTOCOL_VERSION > handshake.max_protocol_version {
506 return Err(ConnectError::ProtocolVersionMismatchServerTooOld {
507 client_min: MIN_PROTOCOL_VERSION, server_max: handshake.max_protocol_version
508 });
509 }
510
511 let pver = cmp::min(MAX_PROTOCOL_VERSION, handshake.max_protocol_version);
512 assert!((MIN_PROTOCOL_VERSION..=MAX_PROTOCOL_VERSION).contains(&pver));
513 assert!((handshake.min_protocol_version..=handshake.max_protocol_version).contains(&pver));
514
515 Ok(pver)
516}