use std::cmp;
use std::convert::TryFrom;
use std::ops::Deref;
use std::sync::Arc;
use bitcoin::{FeeRate, Network};
use log::warn;
use tokio::sync::RwLock;
use tonic::metadata::AsciiMetadataValue;
use tonic::metadata::errors::InvalidMetadataValue;
use tonic::service::interceptor::{InterceptedService, Interceptor};
use ark::ArkInfo;
use crate::{mailbox, protos, ArkServiceClient, ConvertError, RequestExt};
#[cfg(all(feature = "tonic-native", feature = "tonic-web"))]
compile_error!("features `tonic-native` and `tonic-web` are mutually exclusive");
#[cfg(all(feature = "socks5-proxy", not(feature = "tonic-native")))]
compile_error!("the `socks5-proxy` feature is only usable in conjunction with `tonic-native`");
pub const ACCESS_TOKEN_HEADER: &str = "ark-access-token";
pub const NO_TRANSPORT_BACKEND_MESSAGE: &str =
"no Ark RPC transport backend compiled in this build; enable `bark-server-rpc/tonic-native` or `bark-server-rpc/tonic-web`";
#[cfg(feature = "tonic-native")]
mod transport {
use std::str::FromStr;
use std::time::Duration;
use http::Uri;
use log::info;
use tonic::transport::{Channel, Endpoint};
use super::CreateEndpointError;
pub type Transport = Channel;
pub async fn connect(address: &str) -> Result<Transport, CreateEndpointError> {
Ok(create_endpoint(address)?.connect().await?)
}
#[cfg(feature = "socks5-proxy")]
pub async fn connect_with_proxy(
address: &str,
proxy: &str,
) -> Result<Transport, CreateEndpointError> {
use hyper_socks2::SocksConnector;
use hyper_util::client::legacy::connect::HttpConnector;
let endpoint = create_endpoint(address)?;
let proxy_uri = proxy.parse::<Uri>().map_err(CreateEndpointError::InvalidProxyUri)?;
let connector = {
let mut http = HttpConnector::new();
http.enforce_http(false);
SocksConnector {
proxy_addr: proxy_uri,
auth: None,
connector: http,
}
};
info!("Connecting to Ark server via SOCKS5 proxy {}...", proxy);
Ok(endpoint.connect_with_connector(connector).await?)
}
fn create_endpoint(address: &str) -> Result<Endpoint, CreateEndpointError> {
let uri = Uri::from_str(address)?;
let scheme = uri.scheme_str().unwrap_or("");
if scheme != "http" && scheme != "https" {
return Err(CreateEndpointError::InvalidScheme(scheme.to_string()));
}
#[cfg_attr(not(any(feature = "tls-native-roots", feature = "tls-webpki-roots")), allow(unused_mut))]
let mut endpoint = Channel::builder(uri.clone())
.http2_keep_alive_interval(Duration::from_secs(20))
.keep_alive_timeout(Duration::from_secs(600))
.keep_alive_while_idle(true)
.timeout(Duration::from_secs(600));
#[cfg(any(feature = "tls-native-roots", feature = "tls-webpki-roots"))]
if scheme == "https" {
use tonic::transport::ClientTlsConfig;
info!("Connecting to Ark server at {} using TLS...", address);
let uri_auth = uri.clone().into_parts().authority
.ok_or(CreateEndpointError::MissingAuthority)?;
let domain = uri_auth.host();
let tls_config = ClientTlsConfig::new()
.with_enabled_roots()
.domain_name(domain);
endpoint = endpoint.tls_config(tls_config).map_err(CreateEndpointError::Transport)?;
return Ok(endpoint);
}
#[cfg(not(any(feature = "tls-native-roots", feature = "tls-webpki-roots")))]
if scheme == "https" {
return Err(CreateEndpointError::InvalidScheme(
"Missing TLS roots, https is unsupported".to_owned(),
));
}
info!("Connecting to Ark server at {} without TLS...", address);
Ok(endpoint)
}
}
#[cfg(feature = "tonic-web")]
mod transport {
use super::CreateEndpointError;
use tonic_web_wasm_client::Client as WasmClient;
pub type Transport = WasmClient;
pub async fn connect(address: &str) -> Result<Transport, CreateEndpointError> {
Ok(tonic_web_wasm_client::Client::new(address.to_string()))
}
}
#[cfg(not(any(feature = "tonic-native", feature = "tonic-web")))]
mod transport {
use std::convert::Infallible;
use std::future::{ready, Ready};
use std::task::{Context, Poll};
use http::{Request, Response};
use tonic::Status;
use tonic::body::Body;
use tonic::codegen::Service;
use super::NO_TRANSPORT_BACKEND_MESSAGE;
pub async fn connect(_address: &str) -> Result<Transport, crate::client::CreateEndpointError> {
Err(crate::client::CreateEndpointError::NoTransportBackend)
}
#[derive(Debug, Clone, Default)]
pub struct Transport;
impl Service<Request<Body>> for Transport {
type Response = Response<Body>;
type Error = Infallible;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Body>) -> Self::Future {
let status = Status::failed_precondition(NO_TRANSPORT_BACKEND_MESSAGE);
ready(Ok(status.into_http::<Body>()))
}
}
}
pub const MIN_PROTOCOL_VERSION: u64 = 1;
pub const MAX_PROTOCOL_VERSION: u64 = 1;
#[derive(Debug, thiserror::Error)]
#[error("failed to create gRPC endpoint: {msg}")]
pub enum CreateEndpointError {
#[error("{NO_TRANSPORT_BACKEND_MESSAGE}")]
NoTransportBackend,
#[error("failed to parse Ark server as a URI")]
InvalidUri(#[from] http::uri::InvalidUri),
#[error("Ark server scheme must be either http or https. Found: {0}")]
InvalidScheme(String),
#[error("Ark server URI is missing an authority part")]
MissingAuthority,
#[cfg(feature = "tonic-native")]
#[error(transparent)]
Transport(#[from] tonic::transport::Error),
#[cfg(feature = "socks5-proxy")]
#[error("invalid SOCKS5 proxy URI: {0:#}")]
InvalidProxyUri(http::uri::InvalidUri),
}
#[derive(Debug, thiserror::Error)]
#[error("failed to connect to Ark server: {msg}")]
pub enum ConnectError {
#[error("missing info '{0}' to connect")]
MissingInfo(&'static str),
#[error("invalid access token: {0}")]
InvalidAccessToken(#[from] #[source] InvalidMetadataValue),
#[error(transparent)]
CreateEndpoint(#[from] CreateEndpointError),
#[error("handshake request failed: {0}")]
Handshake(tonic::Status),
#[error("version mismatch. Client max is: {client_max}, server min is: {server_min}")]
ProtocolVersionMismatchClientTooOld { client_max: u64, server_min: u64 },
#[error("version mismatch. Client min is: {client_min}, server max is: {server_max}")]
ProtocolVersionMismatchServerTooOld { client_min: u64, server_max: u64 },
#[error("error getting ark info: {0}")]
GetArkInfo(tonic::Status),
#[error("invalid ark info from ark server: {0}")]
InvalidArkInfo(#[from] ConvertError),
#[error("network mismatch. Expected: {expected}, Got: {got}")]
NetworkMismatch { expected: Network, got: Network },
#[error("error getting offboard fee rate: {0}")]
GetOffboardFeeRate(tonic::Status),
#[error("tokio channel error: {0}")]
Tokio(#[from] tokio::sync::oneshot::error::RecvError),
}
#[derive(Clone)]
#[deprecated(since = "0.1.3", note = "should not be used directly")]
pub struct ProtocolVersionInterceptor {
pver: u64,
}
#[allow(deprecated)]
impl tonic::service::Interceptor for ProtocolVersionInterceptor {
fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
#[allow(deprecated)]
req.set_pver(self.pver);
Ok(req)
}
}
#[derive(Clone)]
pub struct ArkServiceInterceptor {
pver: Option<u64>,
access_token: Option<AsciiMetadataValue>,
}
impl tonic::service::Interceptor for ArkServiceInterceptor {
fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
if let Some(pver) = self.pver {
req.set_pver(pver);
}
if let Some(ref access_token) = self.access_token {
req.metadata_mut().insert(ACCESS_TOKEN_HEADER, access_token.clone());
}
Ok(req)
}
}
pub struct ArkInfoHandle {
pub info: ArkInfo,
pub waiter: Option<tokio::sync::oneshot::Receiver<Result<ArkInfo, ConnectError>>>,
}
impl Deref for ArkInfoHandle {
type Target = ArkInfo;
fn deref(&self) -> &Self::Target {
&self.info
}
}
pub struct ServerInfo {
pub pver: u64,
pub info: ArkInfo,
}
impl ServerInfo {
pub fn new(pver: u64, info: ArkInfo) -> Self {
Self { pver, info }
}
}
#[derive(Default)]
pub struct ServerConnectionBuilder {
address: Option<String>,
network: Option<Network>,
#[cfg(feature = "socks5-proxy")]
proxy: Option<String>,
access_token: Option<String>,
}
impl ServerConnectionBuilder {
pub fn address(mut self, address: impl Into<String>) -> Self {
self.address = Some(address.into());
self
}
pub fn network(mut self, network: Network) -> Self {
self.network = Some(network);
self
}
#[cfg(feature = "socks5-proxy")]
pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
self.proxy = Some(proxy.into());
self
}
pub fn access_token(mut self, access_token: impl Into<String>) -> Self {
self.access_token = Some(access_token.into());
self
}
pub async fn connect(self) -> Result<ServerConnection, ConnectError> {
ServerConnection::inner_connect(self).await
}
}
#[derive(Clone)]
pub struct ServerConnection {
info: Arc<RwLock<ServerInfo>>,
pub client: ArkServiceClient<InterceptedService<transport::Transport, ArkServiceInterceptor>>,
pub mailbox_client: mailbox::MailboxServiceClient<InterceptedService<transport::Transport, ArkServiceInterceptor>>,
}
impl ServerConnection {
fn handshake_req() -> protos::HandshakeRequest {
protos::HandshakeRequest {
bark_version: Some(env!("CARGO_PKG_VERSION").into()),
}
}
pub fn builder() -> ServerConnectionBuilder {
ServerConnectionBuilder::default()
}
async fn inner_connect(builder: ServerConnectionBuilder) -> Result<ServerConnection, ConnectError> {
let address = builder.address.ok_or(ConnectError::MissingInfo("address"))?;
let network = builder.network.ok_or(ConnectError::MissingInfo("network"))?;
#[cfg(feature = "socks5-proxy")]
let transport = if let Some(proxy) = builder.proxy {
transport::connect_with_proxy(&address, &proxy).await?
} else {
transport::connect(&address).await?
};
#[cfg(not(feature = "socks5-proxy"))]
let transport = transport::connect(&address).await?;
let mut interceptor = ArkServiceInterceptor {
pver: None,
access_token: builder.access_token.map(|t| t.try_into()).transpose()?,
};
let mut handshake_client = ArkServiceClient::with_interceptor(transport.clone(), interceptor.clone());
let handshake = handshake_client.handshake(Self::handshake_req()).await
.map_err(ConnectError::Handshake)?.into_inner();
let pver = check_handshake(handshake)?;
interceptor.pver = Some(pver);
let mut client = ArkServiceClient::with_interceptor(transport.clone(), interceptor.clone())
.max_decoding_message_size(64 * 1024 * 1024);
let info = client.ark_info(network).await?;
let mailbox_client = mailbox::MailboxServiceClient::with_interceptor(transport, interceptor)
.max_decoding_message_size(64 * 1024 * 1024);
let info = Arc::new(RwLock::new(ServerInfo::new(pver, info)));
Ok(ServerConnection {
info,
client,
mailbox_client,
})
}
#[deprecated(since = "0.1.3", note = "use builder() instead")]
pub async fn connect(
address: &str,
network: Network,
) -> Result<ServerConnection, ConnectError> {
Self::builder().address(address).network(network).connect().await
}
#[cfg(feature = "socks5-proxy")]
#[deprecated(since = "0.1.3", note = "use builder() instead")]
pub async fn connect_via_proxy(
address: &str,
network: Network,
proxy: &str,
) -> Result<ServerConnection, ConnectError> {
Self::builder().address(address).network(network).proxy(proxy).connect().await
}
pub async fn check_connection(&self) -> Result<(), ConnectError> {
let mut client = self.client.clone();
let handshake = client.handshake(Self::handshake_req()).await
.map_err(ConnectError::Handshake)?.into_inner();
check_handshake(handshake)?;
Ok(())
}
pub async fn ark_info(&self) -> ArkInfo {
self.info.read().await.info.clone()
}
pub async fn offboard_feerate(&self) -> Result<FeeRate, ConnectError> {
let resp = self.client.clone()
.get_offboard_fee_rate(protos::Empty {}).await
.map_err(ConnectError::GetOffboardFeeRate)?
.into_inner();
Ok(FeeRate::from_sat_per_kwu(resp.sat_vkb / 4))
}
}
trait ArkServiceClientExt {
async fn ark_info(&mut self, network: Network) -> Result<ArkInfo, ConnectError>;
}
impl<I: Interceptor> ArkServiceClientExt for ArkServiceClient<InterceptedService<transport::Transport, I>> {
async fn ark_info(&mut self, network: Network) -> Result<ArkInfo, ConnectError> {
let res = self.get_ark_info(protos::Empty {}).await
.map_err(ConnectError::GetArkInfo)?;
let info = ArkInfo::try_from(res.into_inner())
.map_err(ConnectError::InvalidArkInfo)?;
if network != info.network {
return Err(ConnectError::NetworkMismatch { expected: network, got: info.network });
}
Ok(info)
}
}
fn check_handshake(handshake: protos::HandshakeResponse) -> Result<u64, ConnectError> {
if let Some(ref msg) = handshake.psa {
warn!("Message from Ark server: \"{}\"", msg);
}
if MAX_PROTOCOL_VERSION < handshake.min_protocol_version {
return Err(ConnectError::ProtocolVersionMismatchClientTooOld {
client_max: MAX_PROTOCOL_VERSION, server_min: handshake.min_protocol_version
});
}
if MIN_PROTOCOL_VERSION > handshake.max_protocol_version {
return Err(ConnectError::ProtocolVersionMismatchServerTooOld {
client_min: MIN_PROTOCOL_VERSION, server_max: handshake.max_protocol_version
});
}
let pver = cmp::min(MAX_PROTOCOL_VERSION, handshake.max_protocol_version);
assert!((MIN_PROTOCOL_VERSION..=MAX_PROTOCOL_VERSION).contains(&pver));
assert!((handshake.min_protocol_version..=handshake.max_protocol_version).contains(&pver));
Ok(pver)
}
#[cfg(test)]
mod tests {
use super::{CreateEndpointError, NO_TRANSPORT_BACKEND_MESSAGE};
#[test]
fn no_transport_backend_error_mentions_feature_selection() {
let err = CreateEndpointError::NoTransportBackend;
assert_eq!(err.to_string(), NO_TRANSPORT_BACKEND_MESSAGE);
assert!(err.to_string().contains("bark-server-rpc/tonic-native"));
assert!(err.to_string().contains("bark-server-rpc/tonic-web"));
}
}