1use std::cmp;
30use std::convert::TryFrom;
31use std::ops::Deref;
32use std::sync::Arc;
33
34use bitcoin::{FeeRate, Network};
35use log::warn;
36use tokio::sync::RwLock;
37use tonic::metadata::AsciiMetadataValue;
38use tonic::metadata::errors::InvalidMetadataValue;
39use tonic::service::interceptor::{InterceptedService, Interceptor};
40
41use ark::ArkInfo;
42
43use crate::{
44 mailbox, protos, ArkServiceClient, ConvertError, RequestExt,
45 MAX_PROTOCOL_VERSION, MIN_PROTOCOL_VERSION,
46};
47
48
49#[cfg(all(feature = "tonic-native", feature = "tonic-web"))]
50compile_error!("features `tonic-native` and `tonic-web` are mutually exclusive");
51
52#[cfg(all(feature = "socks5-proxy", not(feature = "tonic-native")))]
53compile_error!("the `socks5-proxy` feature is only usable in conjunction with `tonic-native`");
54
55
56pub const ACCESS_TOKEN_HEADER: &str = "ark-access-token";
58pub const NO_TRANSPORT_BACKEND_MESSAGE: &str =
60 "no Ark RPC transport backend compiled in this build; enable `bark-server-rpc/tonic-native` or `bark-server-rpc/tonic-web`";
61
62
63#[cfg(feature = "tonic-native")]
64mod transport {
65 use std::str::FromStr;
66 use std::time::Duration;
67
68 use http::Uri;
69 use log::info;
70 use tonic::transport::{Channel, Endpoint};
71
72 use super::CreateEndpointError;
73
74 pub type Transport = Channel;
75
76 pub async fn connect(address: &str) -> Result<Transport, CreateEndpointError> {
82 Ok(create_endpoint(address)?.connect().await?)
83 }
84
85 #[cfg(feature = "socks5-proxy")]
87 pub async fn connect_with_proxy(
88 address: &str,
89 proxy: &str,
90 ) -> Result<Transport, CreateEndpointError> {
91 use hyper_socks2::SocksConnector;
92 use hyper_util::client::legacy::connect::HttpConnector;
93
94 let endpoint = create_endpoint(address)?;
95 let proxy_uri = proxy.parse::<Uri>().map_err(CreateEndpointError::InvalidProxyUri)?;
96 let connector = {
97 let mut http = HttpConnector::new();
100 http.enforce_http(false);
101 SocksConnector {
102 proxy_addr: proxy_uri,
103 auth: None,
104 connector: http,
105 }
106 };
107 info!("Connecting to Ark server via SOCKS5 proxy {}...", proxy);
108 Ok(endpoint.connect_with_connector(connector).await?)
109 }
110
111 fn create_endpoint(address: &str) -> Result<Endpoint, CreateEndpointError> {
115 let uri = Uri::from_str(address)?;
116
117 let scheme = uri.scheme_str().unwrap_or("");
118 if scheme != "http" && scheme != "https" {
119 return Err(CreateEndpointError::InvalidScheme(scheme.to_string()));
120 }
121
122 #[cfg_attr(not(any(feature = "tls-native-roots", feature = "tls-webpki-roots")), allow(unused_mut))]
123 let mut endpoint = Channel::builder(uri.clone())
124 .http2_keep_alive_interval(Duration::from_secs(20))
125 .keep_alive_timeout(Duration::from_secs(600))
126 .keep_alive_while_idle(true)
127 .timeout(Duration::from_secs(600));
128
129 #[cfg(any(feature = "tls-native-roots", feature = "tls-webpki-roots"))]
130 if scheme == "https" {
131 use tonic::transport::ClientTlsConfig;
132
133 info!("Connecting to Ark server at {} using TLS...", address);
134 let uri_auth = uri.clone().into_parts().authority
135 .ok_or(CreateEndpointError::MissingAuthority)?;
136 let domain = uri_auth.host();
137
138 let tls_config = ClientTlsConfig::new()
139 .with_enabled_roots()
140 .domain_name(domain);
141 endpoint = endpoint.tls_config(tls_config).map_err(CreateEndpointError::Transport)?;
142 return Ok(endpoint);
143 }
144 #[cfg(not(any(feature = "tls-native-roots", feature = "tls-webpki-roots")))]
145 if scheme == "https" {
146 return Err(CreateEndpointError::InvalidScheme(
147 "Missing TLS roots, https is unsupported".to_owned(),
148 ));
149 }
150 info!("Connecting to Ark server at {} without TLS...", address);
151 Ok(endpoint)
152 }
153}
154
155#[cfg(feature = "tonic-web")]
156mod transport {
157 use super::CreateEndpointError;
158 use tonic_web_wasm_client::Client as WasmClient;
159
160 pub type Transport = WasmClient;
161
162 pub async fn connect(address: &str) -> Result<Transport, CreateEndpointError> {
163 Ok(tonic_web_wasm_client::Client::new(address.to_string()))
164 }
165}
166
167#[cfg(not(any(feature = "tonic-native", feature = "tonic-web")))]
171mod transport {
172 use std::convert::Infallible;
173 use std::future::{ready, Ready};
174 use std::task::{Context, Poll};
175
176 use http::{Request, Response};
177 use tonic::Status;
178 use tonic::body::Body;
179 use tonic::codegen::Service;
180
181 use super::NO_TRANSPORT_BACKEND_MESSAGE;
182
183 pub async fn connect(_address: &str) -> Result<Transport, crate::client::CreateEndpointError> {
184 Err(crate::client::CreateEndpointError::NoTransportBackend)
185 }
186
187 #[derive(Debug, Clone, Default)]
188 pub struct Transport;
189
190 impl Service<Request<Body>> for Transport {
191 type Response = Response<Body>;
192 type Error = Infallible;
193 type Future = Ready<Result<Self::Response, Self::Error>>;
194
195 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
196 Poll::Ready(Ok(()))
197 }
198
199 fn call(&mut self, _req: Request<Body>) -> Self::Future {
200 let status = Status::failed_precondition(NO_TRANSPORT_BACKEND_MESSAGE);
201 ready(Ok(status.into_http::<Body>()))
202 }
203 }
204}
205
206
207#[derive(Debug, thiserror::Error)]
208#[error("failed to create gRPC endpoint: {msg}")]
209pub enum CreateEndpointError {
210 #[error("{NO_TRANSPORT_BACKEND_MESSAGE}")]
211 NoTransportBackend,
212 #[error("failed to parse Ark server as a URI")]
213 InvalidUri(#[from] http::uri::InvalidUri),
214 #[error("Ark server scheme must be either http or https. Found: {0}")]
215 InvalidScheme(String),
216 #[error("Ark server URI is missing an authority part")]
217 MissingAuthority,
218 #[cfg(feature = "tonic-native")]
219 #[error(transparent)]
220 Transport(#[from] tonic::transport::Error),
221 #[cfg(feature = "socks5-proxy")]
222 #[error("invalid SOCKS5 proxy URI: {0:#}")]
223 InvalidProxyUri(http::uri::InvalidUri),
224}
225
226#[derive(Debug, thiserror::Error)]
227#[error("failed to connect to Ark server: {msg}")]
228pub enum ConnectError {
229 #[error("missing info '{0}' to connect")]
230 MissingInfo(&'static str),
231 #[error("invalid access token: {0}")]
232 InvalidAccessToken(#[from] #[source] InvalidMetadataValue),
233 #[error(transparent)]
234 CreateEndpoint(#[from] CreateEndpointError),
235 #[error("handshake request failed: {0}")]
236 Handshake(tonic::Status),
237 #[error("version mismatch. Client max is: {client_max}, server min is: {server_min}")]
238 ProtocolVersionMismatchClientTooOld { client_max: u64, server_min: u64 },
239 #[error("version mismatch. Client min is: {client_min}, server max is: {server_max}")]
240 ProtocolVersionMismatchServerTooOld { client_min: u64, server_max: u64 },
241 #[error("error getting ark info: {0}")]
242 GetArkInfo(tonic::Status),
243 #[error("invalid ark info from ark server: {0}")]
244 InvalidArkInfo(#[from] ConvertError),
245 #[error("network mismatch. Expected: {expected}, Got: {got}")]
246 NetworkMismatch { expected: Network, got: Network },
247 #[error("error getting offboard fee rate: {0}")]
248 GetOffboardFeeRate(tonic::Status),
249 #[error("tokio channel error: {0}")]
250 Tokio(#[from] tokio::sync::oneshot::error::RecvError),
251}
252
253#[derive(Clone)]
259#[deprecated(since = "0.1.3", note = "should not be used directly")]
260pub struct ProtocolVersionInterceptor {
261 pver: u64,
262}
263
264#[allow(deprecated)]
265impl tonic::service::Interceptor for ProtocolVersionInterceptor {
266 fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
267 #[allow(deprecated)]
268 req.set_pver(self.pver);
269 Ok(req)
270 }
271}
272
273#[derive(Clone)]
278pub struct ArkServiceInterceptor {
279 pver: Option<u64>,
280 access_token: Option<AsciiMetadataValue>,
281}
282
283impl tonic::service::Interceptor for ArkServiceInterceptor {
284 fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
285 if let Some(pver) = self.pver {
286 req.set_pver(pver);
287 }
288 if let Some(ref access_token) = self.access_token {
289 req.metadata_mut().insert(ACCESS_TOKEN_HEADER, access_token.clone());
290 }
291 Ok(req)
292 }
293}
294
295pub struct ArkInfoHandle {
299 pub info: ArkInfo,
300 pub waiter: Option<tokio::sync::oneshot::Receiver<Result<ArkInfo, ConnectError>>>,
301}
302
303impl Deref for ArkInfoHandle {
304 type Target = ArkInfo;
305
306 fn deref(&self) -> &Self::Target {
307 &self.info
308 }
309}
310
311pub struct ServerInfo {
312 pub pver: u64,
316 pub info: ArkInfo,
318}
319
320impl ServerInfo {
321 pub fn new(pver: u64, info: ArkInfo) -> Self {
322 Self { pver, info }
323 }
324}
325
326#[derive(Default)]
327pub struct ServerConnectionBuilder {
328 address: Option<String>,
329 network: Option<Network>,
330 #[cfg(feature = "socks5-proxy")]
331 proxy: Option<String>,
332 access_token: Option<String>,
333}
334
335impl ServerConnectionBuilder {
336 pub fn address(mut self, address: impl Into<String>) -> Self {
337 self.address = Some(address.into());
338 self
339 }
340
341 pub fn network(mut self, network: Network) -> Self {
342 self.network = Some(network);
343 self
344 }
345
346 #[cfg(feature = "socks5-proxy")]
347 pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
348 self.proxy = Some(proxy.into());
349 self
350 }
351
352 pub fn access_token(mut self, access_token: impl Into<String>) -> Self {
353 self.access_token = Some(access_token.into());
354 self
355 }
356
357 pub async fn connect(self) -> Result<ServerConnection, ConnectError> {
358 ServerConnection::inner_connect(self).await
359 }
360}
361
362#[derive(Clone)]
369pub struct ServerConnection {
370 info: Arc<RwLock<ServerInfo>>,
371 pub client: ArkServiceClient<InterceptedService<transport::Transport, ArkServiceInterceptor>>,
373 pub mailbox_client: mailbox::MailboxServiceClient<InterceptedService<transport::Transport, ArkServiceInterceptor>>,
375}
376
377impl ServerConnection {
378 fn handshake_req() -> protos::HandshakeRequest {
379 protos::HandshakeRequest {
380 bark_version: Some(env!("CARGO_PKG_VERSION").into()),
381 }
382 }
383
384 pub fn builder() -> ServerConnectionBuilder {
402 ServerConnectionBuilder::default()
403 }
404
405 async fn inner_connect(builder: ServerConnectionBuilder) -> Result<ServerConnection, ConnectError> {
407 let address = builder.address.ok_or(ConnectError::MissingInfo("address"))?;
408 let network = builder.network.ok_or(ConnectError::MissingInfo("network"))?;
409
410 #[cfg(feature = "socks5-proxy")]
411 let transport = if let Some(proxy) = builder.proxy {
412 transport::connect_with_proxy(&address, &proxy).await?
413 } else {
414 transport::connect(&address).await?
415 };
416 #[cfg(not(feature = "socks5-proxy"))]
417 let transport = transport::connect(&address).await?;
418
419 let mut interceptor = ArkServiceInterceptor {
420 pver: None,
421 access_token: builder.access_token.map(|t| t.try_into()).transpose()?,
422 };
423
424 let mut handshake_client = ArkServiceClient::with_interceptor(transport.clone(), interceptor.clone());
425 let handshake = handshake_client.handshake(Self::handshake_req()).await
426 .map_err(ConnectError::Handshake)?.into_inner();
427
428 let pver = check_handshake(handshake)?;
429 interceptor.pver = Some(pver);
430
431 let mut client = ArkServiceClient::with_interceptor(transport.clone(), interceptor.clone())
432 .max_decoding_message_size(64 * 1024 * 1024); let info = client.ark_info(network).await?;
435
436 let mailbox_client = mailbox::MailboxServiceClient::with_interceptor(transport, interceptor)
437 .max_decoding_message_size(64 * 1024 * 1024); let info = Arc::new(RwLock::new(ServerInfo::new(pver, info)));
440 Ok(ServerConnection {
441 info,
442 client,
443 mailbox_client,
444 })
445 }
446
447 #[deprecated(since = "0.1.3", note = "use builder() instead")]
448 pub async fn connect(
449 address: &str,
450 network: Network,
451 ) -> Result<ServerConnection, ConnectError> {
452 Self::builder().address(address).network(network).connect().await
453 }
454
455 #[cfg(feature = "socks5-proxy")]
456 #[deprecated(since = "0.1.3", note = "use builder() instead")]
457 pub async fn connect_via_proxy(
458 address: &str,
459 network: Network,
460 proxy: &str,
461 ) -> Result<ServerConnection, ConnectError> {
462 Self::builder().address(address).network(network).proxy(proxy).connect().await
463 }
464
465 pub async fn check_connection(&self) -> Result<(), ConnectError> {
467 let mut client = self.client.clone();
468 let handshake = client.handshake(Self::handshake_req()).await
469 .map_err(ConnectError::Handshake)?.into_inner();
470 check_handshake(handshake)?;
471 Ok(())
472 }
473
474 pub async fn ark_info(&self) -> ArkInfo {
476 self.info.read().await.info.clone()
477 }
478
479 pub async fn offboard_feerate(&self) -> Result<FeeRate, ConnectError> {
481 let resp = self.client.clone()
482 .get_offboard_fee_rate(protos::Empty {}).await
483 .map_err(ConnectError::GetOffboardFeeRate)?
484 .into_inner();
485 Ok(FeeRate::from_sat_per_kwu(resp.sat_vkb / 4))
486 }
487}
488trait ArkServiceClientExt {
489 async fn ark_info(&mut self, network: Network) -> Result<ArkInfo, ConnectError>;
490}
491
492impl<I: Interceptor> ArkServiceClientExt for ArkServiceClient<InterceptedService<transport::Transport, I>> {
493 async fn ark_info(&mut self, network: Network) -> Result<ArkInfo, ConnectError> {
494 let res = self.get_ark_info(protos::Empty {}).await
495 .map_err(ConnectError::GetArkInfo)?;
496 let info = ArkInfo::try_from(res.into_inner())
497 .map_err(ConnectError::InvalidArkInfo)?;
498 if network != info.network {
499 return Err(ConnectError::NetworkMismatch { expected: network, got: info.network });
500 }
501
502 Ok(info)
503 }
504}
505
506fn check_handshake(handshake: protos::HandshakeResponse) -> Result<u64, ConnectError> {
507 if let Some(ref msg) = handshake.psa {
508 warn!("Message from Ark server: \"{}\"", msg);
509 }
510
511 if MAX_PROTOCOL_VERSION < handshake.min_protocol_version {
512 return Err(ConnectError::ProtocolVersionMismatchClientTooOld {
513 client_max: MAX_PROTOCOL_VERSION, server_min: handshake.min_protocol_version
514 });
515 }
516 if MIN_PROTOCOL_VERSION > handshake.max_protocol_version {
517 return Err(ConnectError::ProtocolVersionMismatchServerTooOld {
518 client_min: MIN_PROTOCOL_VERSION, server_max: handshake.max_protocol_version
519 });
520 }
521
522 let pver = cmp::min(MAX_PROTOCOL_VERSION, handshake.max_protocol_version);
523 assert!((MIN_PROTOCOL_VERSION..=MAX_PROTOCOL_VERSION).contains(&pver));
524 assert!((handshake.min_protocol_version..=handshake.max_protocol_version).contains(&pver));
525
526 Ok(pver)
527}
528
529#[cfg(test)]
530mod tests {
531 use super::{CreateEndpointError, NO_TRANSPORT_BACKEND_MESSAGE};
532
533 #[test]
534 fn no_transport_backend_error_mentions_feature_selection() {
535 let err = CreateEndpointError::NoTransportBackend;
536 assert_eq!(err.to_string(), NO_TRANSPORT_BACKEND_MESSAGE);
537 assert!(err.to_string().contains("bark-server-rpc/tonic-native"));
538 assert!(err.to_string().contains("bark-server-rpc/tonic-web"));
539 }
540}