1use std::time::Duration;
5
6use backoff::ExponentialBackoff;
7use hyper_socks2::{Auth, SocksConnector};
8use hyper_util::client::legacy::connect::HttpConnector;
9use qcs_dependencies_client::http::{Uri, uri::InvalidUri};
10use qcs_dependencies_client::tonic::{
11 body::Body,
12 client::GrpcService,
13 transport::{Channel, ClientTlsConfig, Endpoint},
14};
15use qcs_dependencies_client::tower::{Layer, ServiceBuilder};
16use url::Url;
17
18use qcs_api_client_common::{
19 backoff::{self, default_backoff},
20 configuration::{ClientConfiguration, LoadError, TokenError, tokens::TokenRefresher},
21};
22
23#[cfg(feature = "tracing")]
24use qcs_api_client_common::tracing_configuration::TracingConfiguration;
25
26use rigetti_hyper_proxy::{Intercept, Proxy, ProxyConnector};
27
28#[cfg(feature = "tracing")]
29use super::trace::{CustomTraceLayer, CustomTraceService, build_trace_layer};
30use super::{Error, RefreshLayer, RefreshService, RetryLayer, RetryService};
31
32#[derive(Debug, thiserror::Error)]
34#[non_exhaustive]
35pub enum ChannelError {
36 #[error("Failed to parse URI: {0}")]
38 InvalidUri(#[from] InvalidUri),
39 #[error("Failed to parse URL: {0}")]
41 InvalidUrl(#[from] url::ParseError),
42 #[error("Protocol is missing or not supported: {0:?}")]
44 UnsupportedProtocol(Option<String>),
45 #[error("HTTP proxy ssl verification failed: {0}")]
47 SslFailure(#[from] std::io::Error),
48 #[error("Cannot set separate https and http proxies if one of them is socks5")]
50 Mismatch {
51 https_proxy: Uri,
53 http_proxy: Uri,
55 },
56}
57
58pub trait IntoService<C: GrpcService<Body>> {
60 type Service: GrpcService<Body>;
62
63 fn into_service(self, channel: C) -> Self::Service;
65}
66
67impl<C> IntoService<C> for ()
68where
69 C: GrpcService<Body>,
70{
71 type Service = C;
72 fn into_service(self, channel: C) -> Self::Service {
73 channel
74 }
75}
76
77#[derive(Clone, Debug)]
79pub struct RefreshOptions<O, R>
80where
81 R: TokenRefresher + Clone + Send + Sync,
82{
83 layer: RefreshLayer<R>,
84 other: O,
85}
86
87impl<T> From<T> for RefreshOptions<(), T>
88where
89 T: TokenRefresher + Clone + Send + Sync,
90{
91 fn from(refresher: T) -> Self {
92 Self {
93 layer: RefreshLayer::with_refresher(refresher),
94 other: (),
95 }
96 }
97}
98
99impl<C, T, O> IntoService<C> for RefreshOptions<O, T>
100where
101 C: GrpcService<Body>,
102 O: IntoService<C>,
103 O::Service: GrpcService<Body>,
104 RefreshService<O::Service, T>: GrpcService<Body>,
105 T: TokenRefresher + Clone + Send + Sync + 'static,
106{
107 type Service = RefreshService<O::Service, T>;
108 fn into_service(self, channel: C) -> Self::Service {
109 let service = self.other.into_service(channel);
110 self.layer.layer(service)
111 }
112}
113
114#[derive(Clone, Debug)]
116pub struct RetryOptions<O = ()> {
117 layer: RetryLayer,
118 other: O,
119}
120
121impl From<ExponentialBackoff> for RetryOptions<()> {
122 fn from(backoff: ExponentialBackoff) -> Self {
123 Self {
124 layer: RetryLayer { backoff },
125 other: (),
126 }
127 }
128}
129
130impl<C, O> IntoService<C> for RetryOptions<O>
131where
132 C: GrpcService<Body>,
133 O: IntoService<C>,
134 O::Service: GrpcService<Body>,
135 RetryService<O::Service>: GrpcService<Body>,
136{
137 type Service = RetryService<O::Service>;
138 fn into_service(self, channel: C) -> Self::Service {
139 let service = self.other.into_service(channel);
140 self.layer.layer(service)
141 }
142}
143
144#[derive(Clone, Debug)]
146pub struct ChannelBuilder<O = ()> {
147 endpoint: Endpoint,
148 #[cfg(feature = "tracing")]
149 trace_layer: CustomTraceLayer,
150 options: O,
151}
152
153impl From<Endpoint> for ChannelBuilder<()> {
154 fn from(endpoint: Endpoint) -> Self {
155 #[cfg(feature = "tracing")]
156 {
157 let base_url = endpoint.uri().to_string();
158 Self {
159 endpoint,
160 trace_layer: build_trace_layer(base_url, None),
161 options: (),
162 }
163 }
164
165 #[cfg(not(feature = "tracing"))]
166 return Self {
167 endpoint,
168 options: (),
169 };
170 }
171}
172
173impl ChannelBuilder<()> {
174 pub fn from_uri(uri: Uri) -> Self {
176 #[cfg(feature = "tracing")]
177 {
178 let base_url = uri.to_string();
179 Self {
180 endpoint: get_endpoint(uri),
181 trace_layer: build_trace_layer(base_url, None),
182 options: (),
183 }
184 }
185
186 #[cfg(not(feature = "tracing"))]
187 return Self {
188 endpoint: get_endpoint(uri),
189 options: (),
190 };
191 }
192}
193
194#[cfg(feature = "tracing")]
195type TargetService = CustomTraceService;
196#[cfg(not(feature = "tracing"))]
197type TargetService = Channel;
198
199impl<O> ChannelBuilder<O>
200where
201 O: IntoService<TargetService>,
202{
203 #[must_use]
205 pub fn with_timeout(mut self, timeout: Duration) -> Self {
206 self.endpoint = self.endpoint.timeout(timeout);
207 self
208 }
209
210 pub fn with_refresh_layer<T>(
212 self,
213 layer: RefreshLayer<T>,
214 ) -> ChannelBuilder<RefreshOptions<O, T>>
215 where
216 T: TokenRefresher + Clone + Send + Sync,
217 {
218 #[cfg(feature = "tracing")]
219 return ChannelBuilder {
220 endpoint: self.endpoint,
221 trace_layer: self.trace_layer,
222 options: RefreshOptions {
223 layer,
224 other: self.options,
225 },
226 };
227 #[cfg(not(feature = "tracing"))]
228 return ChannelBuilder {
229 endpoint: self.endpoint,
230 options: RefreshOptions {
231 layer,
232 other: self.options,
233 },
234 };
235 }
236
237 pub fn with_token_refresher<T>(self, refresher: T) -> ChannelBuilder<RefreshOptions<O, T>>
239 where
240 T: TokenRefresher + Clone + Send + Sync,
241 {
242 self.with_refresh_layer(RefreshLayer::with_refresher(refresher))
243 }
244
245 pub fn with_qcs_config(
247 self,
248 config: ClientConfiguration,
249 ) -> ChannelBuilder<RefreshOptions<O, ClientConfiguration>> {
250 #[cfg(feature = "tracing")]
251 {
252 let base_url = self.endpoint.uri().to_string();
253 let trace_layer = build_trace_layer(base_url, config.tracing_configuration());
254 let mut builder = self.with_token_refresher(config);
255 builder.trace_layer = trace_layer;
256 builder
257 }
258 #[cfg(not(feature = "tracing"))]
259 {
260 self.with_token_refresher(config)
261 }
262 }
263
264 pub fn with_qcs_profile(
270 self,
271 profile: Option<String>,
272 ) -> Result<ChannelBuilder<RefreshOptions<O, ClientConfiguration>>, LoadError> {
273 let config = match profile {
274 Some(profile) => ClientConfiguration::load_profile(profile)?,
275 None => ClientConfiguration::load_default()?,
276 };
277
278 Ok(self.with_qcs_config(config))
279 }
280
281 pub fn with_retry_layer(self, layer: RetryLayer) -> ChannelBuilder<RetryOptions<O>> {
283 #[cfg(feature = "tracing")]
284 return ChannelBuilder {
285 endpoint: self.endpoint,
286 trace_layer: self.trace_layer,
287 options: RetryOptions {
288 layer,
289 other: self.options,
290 },
291 };
292 #[cfg(not(feature = "tracing"))]
293 return ChannelBuilder {
294 endpoint: self.endpoint,
295 options: RetryOptions {
296 layer,
297 other: self.options,
298 },
299 };
300 }
301
302 pub fn with_retry_backoff(
304 self,
305 backoff: ExponentialBackoff,
306 ) -> ChannelBuilder<RetryOptions<O>> {
307 self.with_retry_layer(RetryLayer { backoff })
308 }
309
310 pub fn with_default_retry(self) -> ChannelBuilder<RetryOptions<O>> {
312 self.with_retry_backoff(default_backoff())
313 }
314
315 #[allow(clippy::result_large_err)]
321 pub fn build(self) -> Result<O::Service, ChannelError> {
322 let channel = get_channel_with_endpoint(&self.endpoint)?;
323 #[cfg(feature = "tracing")]
324 {
325 let traced_channel = self.trace_layer.layer(channel);
326 Ok(self.options.into_service(traced_channel))
327 }
328
329 #[cfg(not(feature = "tracing"))]
330 Ok(self.options.into_service(channel))
331 }
332}
333
334#[allow(clippy::result_large_err)]
342pub fn parse_uri(s: &str) -> Result<Uri, Error<TokenError>> {
343 s.parse().map_err(Error::from)
344}
345
346#[allow(clippy::missing_panics_doc)]
348pub fn get_endpoint(uri: Uri) -> Endpoint {
349 Channel::builder(uri)
350 .user_agent(concat!(
351 "QCS gRPC Client (Rust)/",
352 env!("CARGO_PKG_VERSION")
353 ))
354 .expect("user agent string should be valid")
355 .http2_adaptive_window(true)
356 .tls_config(ClientTlsConfig::new().with_enabled_roots())
357 .expect("tls setup should succeed")
358}
359
360pub fn get_endpoint_with_timeout(uri: Uri, timeout: Option<Duration>) -> Endpoint {
362 if let Some(duration) = timeout {
363 get_endpoint(uri).timeout(duration)
364 } else {
365 get_endpoint(uri)
366 }
367}
368
369fn get_env_uri(key: &str) -> Result<Option<Uri>, InvalidUri> {
372 std::env::var(key)
373 .or_else(|_| std::env::var(key.to_lowercase()))
374 .ok()
375 .map(Uri::try_from)
376 .transpose()
377}
378
379fn get_uri_socks_auth(uri: &Uri) -> Result<Option<Auth>, url::ParseError> {
381 let full_url = uri.to_string().parse::<Url>()?;
382 let user = full_url.username();
383 let auth = if user.is_empty() {
384 None
385 } else {
386 let pass = full_url.password().unwrap_or_default();
387 Some(Auth::new(user, pass))
388 };
389 Ok(auth)
390}
391
392#[allow(clippy::result_large_err)]
408pub fn get_channel(uri: Uri) -> Result<Channel, ChannelError> {
409 let endpoint = get_endpoint(uri);
410 get_channel_with_endpoint(&endpoint)
411}
412
413#[allow(clippy::result_large_err)]
430pub fn get_channel_with_timeout(
431 uri: Uri,
432 timeout: Option<Duration>,
433) -> Result<Channel, ChannelError> {
434 let endpoint = get_endpoint_with_timeout(uri, timeout);
435 get_channel_with_endpoint(&endpoint)
436}
437
438#[allow(
457 clippy::similar_names,
458 reason = "http(s)_proxy are similar but precise in this case"
459)]
460#[allow(clippy::result_large_err)]
461pub fn get_channel_with_endpoint(endpoint: &Endpoint) -> Result<Channel, ChannelError> {
462 let https_proxy = get_env_uri("HTTPS_PROXY")?;
463 let http_proxy = get_env_uri("HTTP_PROXY")?;
464
465 let mut connector = HttpConnector::new();
466 connector.enforce_http(false);
467
468 let connect_to = |uri: qcs_dependencies_client::http::Uri, intercept: Intercept| {
469 let connector = connector.clone();
470 match uri.scheme_str() {
471 Some("socks5") => {
472 let socks_connector = SocksConnector {
473 auth: get_uri_socks_auth(&uri)?,
474 proxy_addr: uri,
475 connector,
476 };
477 Ok(endpoint.connect_with_connector_lazy(socks_connector))
478 }
479 Some("https" | "http") => {
480 let is_http =
481 uri.scheme() == Some(&qcs_dependencies_client::http::uri::Scheme::HTTP);
482 let proxy = Proxy::new(intercept, uri);
483 let mut proxy_connector = ProxyConnector::from_proxy(connector, proxy)?;
484 if is_http {
485 proxy_connector.set_tls(None);
486 }
487 Ok(endpoint.connect_with_connector_lazy(proxy_connector))
488 }
489 scheme => Err(ChannelError::UnsupportedProtocol(scheme.map(String::from))),
490 }
491 };
492
493 let channel = match (https_proxy, http_proxy) {
494 (None, None) => endpoint.connect_lazy(),
496
497 (Some(https_proxy), None) => connect_to(https_proxy, Intercept::Https)?,
499 (None, Some(http_proxy)) => connect_to(http_proxy, Intercept::Http)?,
500
501 (Some(https_proxy), Some(http_proxy)) => {
504 if https_proxy == http_proxy {
505 connect_to(https_proxy, Intercept::All)?
506 } else {
507 let accepted = [https_proxy.scheme_str(), http_proxy.scheme_str()]
508 .into_iter()
509 .all(|scheme| matches!(scheme, Some("https" | "http")));
510 if accepted {
511 let mut proxy_connector = ProxyConnector::new(connector)?;
512 proxy_connector.extend_proxies(vec![
513 Proxy::new(Intercept::Https, https_proxy),
514 Proxy::new(Intercept::Http, http_proxy),
515 ]);
516 endpoint.connect_with_connector_lazy(proxy_connector)
517 } else {
518 return Err(ChannelError::Mismatch {
519 https_proxy,
520 http_proxy,
521 });
522 }
523 }
524 }
525 };
526
527 Ok(channel)
528}
529
530#[allow(clippy::result_large_err)]
536pub fn get_wrapped_channel(
537 uri: Uri,
538) -> Result<RefreshService<Channel, ClientConfiguration>, Error<TokenError>> {
539 wrap_channel(get_channel(uri)?)
540}
541
542#[must_use]
544pub fn wrap_channel_with<C>(
545 channel: C,
546 config: ClientConfiguration,
547) -> RefreshService<C, ClientConfiguration>
548where
549 C: GrpcService<Body>,
550{
551 ServiceBuilder::new()
552 .layer(RefreshLayer::with_config(config))
553 .service(channel)
554}
555
556pub fn wrap_channel_with_token_refresher<C, T>(
560 channel: C,
561 token_refresher: T,
562) -> RefreshService<C, T>
563where
564 C: GrpcService<Body>,
565 T: TokenRefresher + Clone + Send + Sync,
566{
567 ServiceBuilder::new()
568 .layer(RefreshLayer::with_refresher(token_refresher))
569 .service(channel)
570}
571
572#[allow(clippy::result_large_err)]
578pub fn wrap_channel<C>(
579 channel: C,
580) -> Result<RefreshService<C, ClientConfiguration>, Error<TokenError>>
581where
582 C: GrpcService<Body>,
583{
584 Ok(wrap_channel_with(channel, {
585 ClientConfiguration::load_default()?
586 }))
587}
588
589#[allow(clippy::result_large_err)]
595pub fn wrap_channel_with_profile<C>(
596 channel: C,
597 profile: String,
598) -> Result<RefreshService<C, ClientConfiguration>, Error<TokenError>>
599where
600 C: GrpcService<Body>,
601{
602 Ok(wrap_channel_with(
603 channel,
604 ClientConfiguration::load_profile(profile)?,
605 ))
606}
607
608pub fn wrap_channel_with_retry<C>(channel: C) -> RetryService<C>
610where
611 C: GrpcService<Body>,
612{
613 ServiceBuilder::new()
614 .layer(RetryLayer::default())
615 .service(channel)
616}
617
618#[cfg(feature = "tracing")]
619pub fn wrap_channel_with_tracing(
621 channel: Channel,
622 base_url: String,
623 configuration: TracingConfiguration,
624) -> CustomTraceService {
625 ServiceBuilder::new()
626 .layer(build_trace_layer(base_url, Some(&configuration)))
627 .service(channel)
628}