use std::time::Duration;
use backoff::ExponentialBackoff;
use http::{uri::InvalidUri, Uri};
use hyper_proxy2::{Intercept, Proxy, ProxyConnector};
use hyper_socks2::{Auth, SocksConnector};
use hyper_util::client::legacy::connect::HttpConnector;
use tonic::{
body::BoxBody,
client::GrpcService,
transport::{Channel, ClientTlsConfig, Endpoint},
};
use tower::{Layer, ServiceBuilder};
use url::Url;
use qcs_api_client_common::{
backoff::{self, default_backoff},
configuration::{ClientConfiguration, LoadError, TokenError, TokenRefresher},
};
#[cfg(feature = "tracing")]
use qcs_api_client_common::tracing_configuration::TracingConfiguration;
#[cfg(feature = "tracing")]
use super::trace::{build_trace_layer, CustomTraceLayer, CustomTraceService};
use super::{Error, RefreshLayer, RefreshService, RetryLayer, RetryService};
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum ChannelError {
#[error("Failed to parse URI: {0}")]
InvalidUri(#[from] InvalidUri),
#[error("Failed to parse URL: {0}")]
InvalidUrl(#[from] url::ParseError),
#[error("Protocol is missing or not supported: {0:?}")]
UnsupportedProtocol(Option<String>),
#[error("HTTP proxy ssl verification failed: {0}")]
SslFailure(#[from] std::io::Error),
#[error("Cannot set separate https and http proxies if one of them is socks5")]
Mismatch {
https_proxy: Uri,
http_proxy: Uri,
},
}
pub trait IntoService<C: GrpcService<BoxBody>> {
type Service: GrpcService<BoxBody>;
fn into_service(self, channel: C) -> Self::Service;
}
impl<C> IntoService<C> for ()
where
C: GrpcService<BoxBody>,
{
type Service = C;
fn into_service(self, channel: C) -> Self::Service {
channel
}
}
#[derive(Clone, Debug)]
pub struct RefreshOptions<O, R>
where
R: TokenRefresher + Clone + Send + Sync,
{
layer: RefreshLayer<R>,
other: O,
}
impl<T> From<T> for RefreshOptions<(), T>
where
T: TokenRefresher + Clone + Send + Sync,
{
fn from(refresher: T) -> Self {
Self {
layer: RefreshLayer::with_refresher(refresher),
other: (),
}
}
}
impl<C, T, O> IntoService<C> for RefreshOptions<O, T>
where
C: GrpcService<BoxBody>,
O: IntoService<C>,
O::Service: GrpcService<BoxBody>,
RefreshService<O::Service, T>: GrpcService<BoxBody>,
T: TokenRefresher + Clone + Send + Sync + 'static,
{
type Service = RefreshService<O::Service, T>;
fn into_service(self, channel: C) -> Self::Service {
let service = self.other.into_service(channel);
self.layer.layer(service)
}
}
#[derive(Clone, Debug)]
pub struct RetryOptions<O = ()> {
layer: RetryLayer,
other: O,
}
impl From<ExponentialBackoff> for RetryOptions<()> {
fn from(backoff: ExponentialBackoff) -> Self {
Self {
layer: RetryLayer { backoff },
other: (),
}
}
}
impl<C, O> IntoService<C> for RetryOptions<O>
where
C: GrpcService<BoxBody>,
O: IntoService<C>,
O::Service: GrpcService<BoxBody>,
RetryService<O::Service>: GrpcService<BoxBody>,
{
type Service = RetryService<O::Service>;
fn into_service(self, channel: C) -> Self::Service {
let service = self.other.into_service(channel);
self.layer.layer(service)
}
}
#[derive(Clone, Debug)]
pub struct ChannelBuilder<O = ()> {
endpoint: Endpoint,
#[cfg(feature = "tracing")]
trace_layer: CustomTraceLayer,
options: O,
}
impl From<Endpoint> for ChannelBuilder<()> {
fn from(endpoint: Endpoint) -> Self {
#[cfg(feature = "tracing")]
{
let base_url = endpoint.uri().to_string();
Self {
endpoint,
trace_layer: build_trace_layer(base_url, None),
options: (),
}
}
#[cfg(not(feature = "tracing"))]
return Self {
endpoint,
options: (),
};
}
}
impl ChannelBuilder<()> {
pub fn from_uri(uri: Uri) -> Self {
#[cfg(feature = "tracing")]
{
let base_url = uri.to_string();
Self {
endpoint: get_endpoint(uri),
trace_layer: build_trace_layer(base_url, None),
options: (),
}
}
#[cfg(not(feature = "tracing"))]
return Self {
endpoint: get_endpoint(uri),
options: (),
};
}
}
#[cfg(feature = "tracing")]
type TargetService = CustomTraceService;
#[cfg(not(feature = "tracing"))]
type TargetService = Channel;
impl<O> ChannelBuilder<O>
where
O: IntoService<TargetService>,
{
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.endpoint = self.endpoint.timeout(timeout);
self
}
pub fn with_refresh_layer<T>(
self,
layer: RefreshLayer<T>,
) -> ChannelBuilder<RefreshOptions<O, T>>
where
T: TokenRefresher + Clone + Send + Sync,
{
#[cfg(feature = "tracing")]
return ChannelBuilder {
endpoint: self.endpoint,
trace_layer: self.trace_layer,
options: RefreshOptions {
layer,
other: self.options,
},
};
#[cfg(not(feature = "tracing"))]
return ChannelBuilder {
endpoint: self.endpoint,
options: RefreshOptions {
layer,
other: self.options,
},
};
}
pub fn with_token_refresher<T>(self, refresher: T) -> ChannelBuilder<RefreshOptions<O, T>>
where
T: TokenRefresher + Clone + Send + Sync,
{
self.with_refresh_layer(RefreshLayer::with_refresher(refresher))
}
pub fn with_qcs_config(
self,
config: ClientConfiguration,
) -> ChannelBuilder<RefreshOptions<O, ClientConfiguration>> {
#[cfg(feature = "tracing")]
{
let base_url = self.endpoint.uri().to_string();
let trace_layer = build_trace_layer(base_url, config.tracing_configuration());
let mut builder = self.with_token_refresher(config);
builder.trace_layer = trace_layer;
builder
}
#[cfg(not(feature = "tracing"))]
{
self.with_token_refresher(config)
}
}
pub fn with_qcs_profile(
self,
profile: Option<String>,
) -> Result<ChannelBuilder<RefreshOptions<O, ClientConfiguration>>, LoadError> {
let config = match profile {
Some(profile) => ClientConfiguration::load_profile(profile)?,
None => ClientConfiguration::load_default()?,
};
Ok(self.with_qcs_config(config))
}
pub fn with_retry_layer(self, layer: RetryLayer) -> ChannelBuilder<RetryOptions<O>> {
#[cfg(feature = "tracing")]
return ChannelBuilder {
endpoint: self.endpoint,
trace_layer: self.trace_layer,
options: RetryOptions {
layer,
other: self.options,
},
};
#[cfg(not(feature = "tracing"))]
return ChannelBuilder {
endpoint: self.endpoint,
options: RetryOptions {
layer,
other: self.options,
},
};
}
pub fn with_retry_backoff(
self,
backoff: ExponentialBackoff,
) -> ChannelBuilder<RetryOptions<O>> {
self.with_retry_layer(RetryLayer { backoff })
}
pub fn with_default_retry(self) -> ChannelBuilder<RetryOptions<O>> {
self.with_retry_backoff(default_backoff())
}
pub fn build(self) -> Result<O::Service, ChannelError> {
let channel = get_channel_with_endpoint(&self.endpoint)?;
#[cfg(feature = "tracing")]
{
let traced_channel = self.trace_layer.layer(channel);
Ok(self.options.into_service(traced_channel))
}
#[cfg(not(feature = "tracing"))]
Ok(self.options.into_service(channel))
}
}
pub fn parse_uri(s: &str) -> Result<Uri, Error<TokenError>> {
s.parse().map_err(Error::from)
}
#[allow(clippy::missing_panics_doc)]
pub fn get_endpoint(uri: Uri) -> Endpoint {
Channel::builder(uri)
.user_agent(concat!(
"QCS gRPC Client (Rust)/",
env!("CARGO_PKG_VERSION")
))
.expect("user agent string should be valid")
.tls_config(ClientTlsConfig::new().with_enabled_roots())
.expect("tls setup should succeed")
}
pub fn get_endpoint_with_timeout(uri: Uri, timeout: Option<Duration>) -> Endpoint {
if let Some(duration) = timeout {
get_endpoint(uri).timeout(duration)
} else {
get_endpoint(uri)
}
}
fn get_env_uri(key: &str) -> Result<Option<Uri>, InvalidUri> {
std::env::var(key)
.or_else(|_| std::env::var(key.to_lowercase()))
.ok()
.map(Uri::try_from)
.transpose()
}
fn get_uri_socks_auth(uri: &Uri) -> Result<Option<Auth>, url::ParseError> {
let full_url = uri.to_string().parse::<Url>()?;
let user = full_url.username();
let auth = if user.is_empty() {
None
} else {
let pass = full_url.password().unwrap_or_default();
Some(Auth::new(user, pass))
};
Ok(auth)
}
pub fn get_channel(uri: Uri) -> Result<Channel, ChannelError> {
let endpoint = get_endpoint(uri);
get_channel_with_endpoint(&endpoint)
}
pub fn get_channel_with_timeout(
uri: Uri,
timeout: Option<Duration>,
) -> Result<Channel, ChannelError> {
let endpoint = get_endpoint_with_timeout(uri, timeout);
get_channel_with_endpoint(&endpoint)
}
#[allow(clippy::similar_names)] pub fn get_channel_with_endpoint(endpoint: &Endpoint) -> Result<Channel, ChannelError> {
let https_proxy = get_env_uri("HTTPS_PROXY")?;
let http_proxy = get_env_uri("HTTP_PROXY")?;
let mut connector = HttpConnector::new();
connector.enforce_http(false);
let connect_to = |uri: http::Uri, intercept: Intercept| {
let connector = connector.clone();
match uri.scheme_str() {
Some("socks5") => {
let socks_connector = SocksConnector {
auth: get_uri_socks_auth(&uri)?,
proxy_addr: uri,
connector,
};
Ok(endpoint.connect_with_connector_lazy(socks_connector))
}
Some("https" | "http") => {
let is_http = uri.scheme() == Some(&http::uri::Scheme::HTTP);
let proxy = Proxy::new(intercept, uri);
let mut proxy_connector = ProxyConnector::from_proxy(connector, proxy)?;
if is_http {
proxy_connector.set_tls(None);
}
Ok(endpoint.connect_with_connector_lazy(proxy_connector))
}
scheme => Err(ChannelError::UnsupportedProtocol(scheme.map(String::from))),
}
};
let channel = match (https_proxy, http_proxy) {
(None, None) => endpoint.connect_lazy(),
(Some(https_proxy), None) => connect_to(https_proxy, Intercept::Https)?,
(None, Some(http_proxy)) => connect_to(http_proxy, Intercept::Http)?,
(Some(https_proxy), Some(http_proxy)) => {
if https_proxy == http_proxy {
connect_to(https_proxy, Intercept::All)?
} else {
let accepted = [https_proxy.scheme_str(), http_proxy.scheme_str()]
.into_iter()
.all(|scheme| matches!(scheme, Some("https" | "http")));
if accepted {
let mut proxy_connector = ProxyConnector::new(connector)?;
proxy_connector.extend_proxies(vec![
Proxy::new(Intercept::Https, https_proxy),
Proxy::new(Intercept::Http, http_proxy),
]);
endpoint.connect_with_connector_lazy(proxy_connector)
} else {
return Err(ChannelError::Mismatch {
https_proxy,
http_proxy,
});
}
}
}
};
Ok(channel)
}
pub fn get_wrapped_channel(
uri: Uri,
) -> Result<RefreshService<Channel, ClientConfiguration>, Error<TokenError>> {
wrap_channel(get_channel(uri)?)
}
#[must_use]
pub fn wrap_channel_with<C>(
channel: C,
config: ClientConfiguration,
) -> RefreshService<C, ClientConfiguration>
where
C: GrpcService<BoxBody>,
{
ServiceBuilder::new()
.layer(RefreshLayer::with_config(config))
.service(channel)
}
pub fn wrap_channel_with_token_refresher<C, T>(
channel: C,
token_refresher: T,
) -> RefreshService<C, T>
where
C: GrpcService<BoxBody>,
T: TokenRefresher + Clone + Send + Sync,
{
ServiceBuilder::new()
.layer(RefreshLayer::with_refresher(token_refresher))
.service(channel)
}
pub fn wrap_channel<C>(
channel: C,
) -> Result<RefreshService<C, ClientConfiguration>, Error<TokenError>>
where
C: GrpcService<BoxBody>,
{
Ok(wrap_channel_with(channel, {
ClientConfiguration::load_default()?
}))
}
pub fn wrap_channel_with_profile<C>(
channel: C,
profile: String,
) -> Result<RefreshService<C, ClientConfiguration>, Error<TokenError>>
where
C: GrpcService<BoxBody>,
{
Ok(wrap_channel_with(
channel,
ClientConfiguration::load_profile(profile)?,
))
}
pub fn wrap_channel_with_retry<C>(channel: C) -> RetryService<C>
where
C: GrpcService<BoxBody>,
{
ServiceBuilder::new()
.layer(RetryLayer::default())
.service(channel)
}
#[cfg(feature = "tracing")]
pub fn wrap_channel_with_tracing(
channel: Channel,
base_url: String,
configuration: TracingConfiguration,
) -> CustomTraceService {
ServiceBuilder::new()
.layer(build_trace_layer(base_url, Some(&configuration)))
.service(channel)
}