1use std::cmp;
30use std::convert::TryFrom;
31use std::ops::Deref;
32use std::sync::Arc;
33
34use bitcoin::Network;
35use log::warn;
36use tokio::sync::RwLock;
37use tonic::metadata::AsciiMetadataValue;
38use tonic::metadata::errors::InvalidMetadataValue;
39use tonic::service::interceptor::{InterceptedService, Interceptor};
40
41use ark::ArkInfo;
42
43use crate::{mailbox, protos, ArkServiceClient, ConvertError, RequestExt};
44
45
46#[cfg(all(feature = "tonic-native", feature = "tonic-web"))]
47compile_error!("features `tonic-native` and `tonic-web` are mutually exclusive");
48
49#[cfg(all(feature = "socks5-proxy", not(feature = "tonic-native")))]
50compile_error!("the `socks5-proxy` feature is only usable in conjunction with `tonic-native`");
51
52
53pub const ACCESS_TOKEN_HEADER: &str = "ark-access-token";
55pub const NO_TRANSPORT_BACKEND_MESSAGE: &str =
57 "no Ark RPC transport backend compiled in this build; enable `bark-server-rpc/tonic-native` or `bark-server-rpc/tonic-web`";
58
59
60#[cfg(feature = "tonic-native")]
61mod transport {
62 use std::str::FromStr;
63 use std::time::Duration;
64
65 use http::Uri;
66 use log::info;
67 use tonic::transport::{Channel, Endpoint};
68
69 use super::CreateEndpointError;
70
71 pub type Transport = Channel;
72
73 pub async fn connect(address: &str) -> Result<Transport, CreateEndpointError> {
79 Ok(create_endpoint(address)?.connect().await?)
80 }
81
82 #[cfg(feature = "socks5-proxy")]
84 pub async fn connect_with_proxy(
85 address: &str,
86 proxy: &str,
87 ) -> Result<Transport, CreateEndpointError> {
88 use hyper_socks2::SocksConnector;
89 use hyper_util::client::legacy::connect::HttpConnector;
90
91 let endpoint = create_endpoint(address)?;
92 let proxy_uri = proxy.parse::<Uri>().map_err(CreateEndpointError::InvalidProxyUri)?;
93 let connector = {
94 let mut http = HttpConnector::new();
97 http.enforce_http(false);
98 SocksConnector {
99 proxy_addr: proxy_uri,
100 auth: None,
101 connector: http,
102 }
103 };
104 info!("Connecting to Ark server via SOCKS5 proxy {}...", proxy);
105 Ok(endpoint.connect_with_connector(connector).await?)
106 }
107
108 fn create_endpoint(address: &str) -> Result<Endpoint, CreateEndpointError> {
112 let uri = Uri::from_str(address)?;
113
114 let scheme = uri.scheme_str().unwrap_or("");
115 if scheme != "http" && scheme != "https" {
116 return Err(CreateEndpointError::InvalidScheme(scheme.to_string()));
117 }
118
119 #[cfg_attr(not(any(feature = "tls-native-roots", feature = "tls-webpki-roots")), allow(unused_mut))]
120 let mut endpoint = Channel::builder(uri.clone())
121 .http2_keep_alive_interval(Duration::from_secs(20))
122 .keep_alive_timeout(Duration::from_secs(600))
123 .keep_alive_while_idle(true)
124 .timeout(Duration::from_secs(600));
125
126 #[cfg(any(feature = "tls-native-roots", feature = "tls-webpki-roots"))]
127 if scheme == "https" {
128 use tonic::transport::ClientTlsConfig;
129
130 info!("Connecting to Ark server at {} using TLS...", address);
131 let uri_auth = uri.clone().into_parts().authority
132 .ok_or(CreateEndpointError::MissingAuthority)?;
133 let domain = uri_auth.host();
134
135 let tls_config = ClientTlsConfig::new()
136 .with_enabled_roots()
137 .domain_name(domain);
138 endpoint = endpoint.tls_config(tls_config).map_err(CreateEndpointError::Transport)?;
139 return Ok(endpoint);
140 }
141 #[cfg(not(any(feature = "tls-native-roots", feature = "tls-webpki-roots")))]
142 if scheme == "https" {
143 return Err(CreateEndpointError::InvalidScheme(
144 "Missing TLS roots, https is unsupported".to_owned(),
145 ));
146 }
147 info!("Connecting to Ark server at {} without TLS...", address);
148 Ok(endpoint)
149 }
150}
151
152#[cfg(feature = "tonic-web")]
153mod transport {
154 use super::CreateEndpointError;
155 use tonic_web_wasm_client::Client as WasmClient;
156
157 pub type Transport = WasmClient;
158
159 pub async fn connect(address: &str) -> Result<Transport, CreateEndpointError> {
160 Ok(tonic_web_wasm_client::Client::new(address.to_string()))
161 }
162}
163
164#[cfg(not(any(feature = "tonic-native", feature = "tonic-web")))]
168mod transport {
169 use std::convert::Infallible;
170 use std::future::{ready, Ready};
171 use std::task::{Context, Poll};
172
173 use http::{Request, Response};
174 use tonic::Status;
175 use tonic::body::Body;
176 use tonic::codegen::Service;
177
178 use super::NO_TRANSPORT_BACKEND_MESSAGE;
179
180 pub async fn connect(_address: &str) -> Result<Transport, crate::client::CreateEndpointError> {
181 Err(crate::client::CreateEndpointError::NoTransportBackend)
182 }
183
184 #[derive(Debug, Clone, Default)]
185 pub struct Transport;
186
187 impl Service<Request<Body>> for Transport {
188 type Response = Response<Body>;
189 type Error = Infallible;
190 type Future = Ready<Result<Self::Response, Self::Error>>;
191
192 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
193 Poll::Ready(Ok(()))
194 }
195
196 fn call(&mut self, _req: Request<Body>) -> Self::Future {
197 let status = Status::failed_precondition(NO_TRANSPORT_BACKEND_MESSAGE);
198 ready(Ok(status.into_http::<Body>()))
199 }
200 }
201}
202
203pub const MIN_PROTOCOL_VERSION: u64 = 1;
207
208pub const MAX_PROTOCOL_VERSION: u64 = 1;
212
213pub const ARK_INFO_TTL_SECS: u64 = 10 * 60;
217
218#[derive(Debug, thiserror::Error)]
219#[error("failed to create gRPC endpoint: {msg}")]
220pub enum CreateEndpointError {
221 #[error("{NO_TRANSPORT_BACKEND_MESSAGE}")]
222 NoTransportBackend,
223 #[error("failed to parse Ark server as a URI")]
224 InvalidUri(#[from] http::uri::InvalidUri),
225 #[error("Ark server scheme must be either http or https. Found: {0}")]
226 InvalidScheme(String),
227 #[error("Ark server URI is missing an authority part")]
228 MissingAuthority,
229 #[cfg(feature = "tonic-native")]
230 #[error(transparent)]
231 Transport(#[from] tonic::transport::Error),
232 #[cfg(feature = "socks5-proxy")]
233 #[error("invalid SOCKS5 proxy URI: {0:#}")]
234 InvalidProxyUri(http::uri::InvalidUri),
235}
236
237#[derive(Debug, thiserror::Error)]
238#[error("failed to connect to Ark server: {msg}")]
239pub enum ConnectError {
240 #[error("missing info '{0}' to connect")]
241 MissingInfo(&'static str),
242 #[error("invalid access token: {0}")]
243 InvalidAccessToken(#[from] #[source] InvalidMetadataValue),
244 #[error(transparent)]
245 CreateEndpoint(#[from] CreateEndpointError),
246 #[error("handshake request failed: {0}")]
247 Handshake(tonic::Status),
248 #[error("version mismatch. Client max is: {client_max}, server min is: {server_min}")]
249 ProtocolVersionMismatchClientTooOld { client_max: u64, server_min: u64 },
250 #[error("version mismatch. Client min is: {client_min}, server max is: {server_max}")]
251 ProtocolVersionMismatchServerTooOld { client_min: u64, server_max: u64 },
252 #[error("error getting ark info: {0}")]
253 GetArkInfo(tonic::Status),
254 #[error("invalid ark info from ark server: {0}")]
255 InvalidArkInfo(#[from] ConvertError),
256 #[error("network mismatch. Expected: {expected}, Got: {got}")]
257 NetworkMismatch { expected: Network, got: Network },
258 #[error("tokio channel error: {0}")]
259 Tokio(#[from] tokio::sync::oneshot::error::RecvError),
260}
261
262#[derive(Clone)]
268#[deprecated(since = "0.1.3", note = "should not be used directly")]
269pub struct ProtocolVersionInterceptor {
270 pver: u64,
271}
272
273#[allow(deprecated)]
274impl tonic::service::Interceptor for ProtocolVersionInterceptor {
275 fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
276 #[allow(deprecated)]
277 req.set_pver(self.pver);
278 Ok(req)
279 }
280}
281
282#[derive(Clone)]
287pub struct ArkServiceInterceptor {
288 pver: Option<u64>,
289 access_token: Option<AsciiMetadataValue>,
290}
291
292impl tonic::service::Interceptor for ArkServiceInterceptor {
293 fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
294 if let Some(pver) = self.pver {
295 req.set_pver(pver);
296 }
297 if let Some(ref access_token) = self.access_token {
298 req.metadata_mut().insert(ACCESS_TOKEN_HEADER, access_token.clone());
299 }
300 Ok(req)
301 }
302}
303
304pub struct ArkInfoHandle {
308 pub info: ArkInfo,
309 pub waiter: Option<tokio::sync::oneshot::Receiver<Result<ArkInfo, ConnectError>>>,
310}
311
312impl Deref for ArkInfoHandle {
313 type Target = ArkInfo;
314
315 fn deref(&self) -> &Self::Target {
316 &self.info
317 }
318}
319
320pub struct ServerInfo {
321 pub pver: u64,
325 pub info: ArkInfo,
327 pub refresh_at_secs: u64,
329}
330
331impl ServerInfo {
332 fn ttl() -> u64 {
334 ark::time::timestamp_secs() + ARK_INFO_TTL_SECS
335 }
336
337 pub fn new(pver: u64, info: ArkInfo) -> Self {
338 Self { pver, info, refresh_at_secs: Self::ttl() }
339 }
340
341 pub fn update(&mut self, info: ArkInfo) {
342 self.info = info;
343 self.refresh_at_secs = Self::ttl();
344 }
345
346 pub fn is_outdated(&self) -> bool {
348 ark::time::timestamp_secs() > self.refresh_at_secs
349 }
350}
351
352#[derive(Default)]
353pub struct ServerConnectionBuilder {
354 address: Option<String>,
355 network: Option<Network>,
356 #[cfg(feature = "socks5-proxy")]
357 proxy: Option<String>,
358 access_token: Option<String>,
359}
360
361impl ServerConnectionBuilder {
362 pub fn address(mut self, address: impl Into<String>) -> Self {
363 self.address = Some(address.into());
364 self
365 }
366
367 pub fn network(mut self, network: Network) -> Self {
368 self.network = Some(network);
369 self
370 }
371
372 #[cfg(feature = "socks5-proxy")]
373 pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
374 self.proxy = Some(proxy.into());
375 self
376 }
377
378 pub fn access_token(mut self, access_token: impl Into<String>) -> Self {
379 self.access_token = Some(access_token.into());
380 self
381 }
382
383 pub async fn connect(self) -> Result<ServerConnection, ConnectError> {
384 ServerConnection::inner_connect(self).await
385 }
386}
387
388#[derive(Clone)]
395pub struct ServerConnection {
396 info: Arc<RwLock<ServerInfo>>,
397 pub client: ArkServiceClient<InterceptedService<transport::Transport, ArkServiceInterceptor>>,
399 pub mailbox_client: mailbox::MailboxServiceClient<InterceptedService<transport::Transport, ArkServiceInterceptor>>,
401}
402
403impl ServerConnection {
404 fn handshake_req() -> protos::HandshakeRequest {
405 protos::HandshakeRequest {
406 bark_version: Some(env!("CARGO_PKG_VERSION").into()),
407 }
408 }
409
410 pub fn builder() -> ServerConnectionBuilder {
428 ServerConnectionBuilder::default()
429 }
430
431 async fn inner_connect(builder: ServerConnectionBuilder) -> Result<ServerConnection, ConnectError> {
433 let address = builder.address.ok_or(ConnectError::MissingInfo("address"))?;
434 let network = builder.network.ok_or(ConnectError::MissingInfo("network"))?;
435
436 #[cfg(feature = "socks5-proxy")]
437 let transport = if let Some(proxy) = builder.proxy {
438 transport::connect_with_proxy(&address, &proxy).await?
439 } else {
440 transport::connect(&address).await?
441 };
442 #[cfg(not(feature = "socks5-proxy"))]
443 let transport = transport::connect(&address).await?;
444
445 let mut interceptor = ArkServiceInterceptor {
446 pver: None,
447 access_token: builder.access_token.map(|t| t.try_into()).transpose()?,
448 };
449
450 let mut handshake_client = ArkServiceClient::with_interceptor(transport.clone(), interceptor.clone());
451 let handshake = handshake_client.handshake(Self::handshake_req()).await
452 .map_err(ConnectError::Handshake)?.into_inner();
453
454 let pver = check_handshake(handshake)?;
455 interceptor.pver = Some(pver);
456
457 let mut client = ArkServiceClient::with_interceptor(transport.clone(), interceptor.clone())
458 .max_decoding_message_size(64 * 1024 * 1024); let info = client.ark_info(network).await?;
461
462 let mailbox_client = mailbox::MailboxServiceClient::with_interceptor(transport, interceptor)
463 .max_decoding_message_size(64 * 1024 * 1024); let info = Arc::new(RwLock::new(ServerInfo::new(pver, info)));
466 Ok(ServerConnection {
467 info,
468 client,
469 mailbox_client,
470 })
471 }
472
473 #[deprecated(since = "0.1.3", note = "use builder() instead")]
474 pub async fn connect(
475 address: &str,
476 network: Network,
477 ) -> Result<ServerConnection, ConnectError> {
478 Self::builder().address(address).network(network).connect().await
479 }
480
481 #[cfg(feature = "socks5-proxy")]
482 #[deprecated(since = "0.1.3", note = "use builder() instead")]
483 pub async fn connect_via_proxy(
484 address: &str,
485 network: Network,
486 proxy: &str,
487 ) -> Result<ServerConnection, ConnectError> {
488 Self::builder().address(address).network(network).proxy(proxy).connect().await
489 }
490
491 pub async fn check_connection(&self) -> Result<(), ConnectError> {
493 let mut client = self.client.clone();
494 let handshake = client.handshake(Self::handshake_req()).await
495 .map_err(ConnectError::Handshake)?.into_inner();
496 check_handshake(handshake)?;
497 Ok(())
498 }
499
500 pub async fn ark_info(&self) -> Result<ArkInfo, ConnectError> {
508 let mut current = self.info.write().await;
509
510 let new_info = self.client.clone().ark_info(current.info.network).await?;
511 if current.is_outdated() {
512 current.update(new_info.clone());
513 return Ok(new_info);
514 }
515
516 Ok(current.info.clone())
517 }
518}
519trait ArkServiceClientExt {
520 async fn ark_info(&mut self, network: Network) -> Result<ArkInfo, ConnectError>;
521}
522
523impl<I: Interceptor> ArkServiceClientExt for ArkServiceClient<InterceptedService<transport::Transport, I>> {
524 async fn ark_info(&mut self, network: Network) -> Result<ArkInfo, ConnectError> {
525 let res = self.get_ark_info(protos::Empty {}).await
526 .map_err(ConnectError::GetArkInfo)?;
527 let info = ArkInfo::try_from(res.into_inner())
528 .map_err(ConnectError::InvalidArkInfo)?;
529 if network != info.network {
530 return Err(ConnectError::NetworkMismatch { expected: network, got: info.network });
531 }
532
533 Ok(info)
534 }
535}
536
537fn check_handshake(handshake: protos::HandshakeResponse) -> Result<u64, ConnectError> {
538 if let Some(ref msg) = handshake.psa {
539 warn!("Message from Ark server: \"{}\"", msg);
540 }
541
542 if MAX_PROTOCOL_VERSION < handshake.min_protocol_version {
543 return Err(ConnectError::ProtocolVersionMismatchClientTooOld {
544 client_max: MAX_PROTOCOL_VERSION, server_min: handshake.min_protocol_version
545 });
546 }
547 if MIN_PROTOCOL_VERSION > handshake.max_protocol_version {
548 return Err(ConnectError::ProtocolVersionMismatchServerTooOld {
549 client_min: MIN_PROTOCOL_VERSION, server_max: handshake.max_protocol_version
550 });
551 }
552
553 let pver = cmp::min(MAX_PROTOCOL_VERSION, handshake.max_protocol_version);
554 assert!((MIN_PROTOCOL_VERSION..=MAX_PROTOCOL_VERSION).contains(&pver));
555 assert!((handshake.min_protocol_version..=handshake.max_protocol_version).contains(&pver));
556
557 Ok(pver)
558}
559
560#[cfg(test)]
561mod tests {
562 use super::{CreateEndpointError, NO_TRANSPORT_BACKEND_MESSAGE};
563
564 #[test]
565 fn no_transport_backend_error_mentions_feature_selection() {
566 let err = CreateEndpointError::NoTransportBackend;
567 assert_eq!(err.to_string(), NO_TRANSPORT_BACKEND_MESSAGE);
568 assert!(err.to_string().contains("bark-server-rpc/tonic-native"));
569 assert!(err.to_string().contains("bark-server-rpc/tonic-web"));
570 }
571}