1use std::time::Duration;
5
6use backoff::ExponentialBackoff;
7use http::{uri::InvalidUri, Uri};
8use hyper_socks2::{Auth, SocksConnector};
9use hyper_util::client::legacy::connect::HttpConnector;
10use tonic::{
11 body::Body,
12 client::GrpcService,
13 transport::{Channel, ClientTlsConfig, Endpoint},
14};
15use tower::{Layer, ServiceBuilder};
16use url::Url;
17
18use qcs_api_client_common::{
19 backoff::{self, default_backoff},
20 configuration::{tokens::TokenRefresher, ClientConfiguration, LoadError, TokenError},
21};
22
23#[cfg(feature = "tracing")]
24use qcs_api_client_common::tracing_configuration::TracingConfiguration;
25
26use rigetti_hyper_proxy::{Intercept, Proxy, ProxyConnector};
27
28#[cfg(feature = "tracing")]
29use super::trace::{build_trace_layer, CustomTraceLayer, CustomTraceService};
30use super::{Error, RefreshLayer, RefreshService, RetryLayer, RetryService};
31
32#[derive(Debug, thiserror::Error)]
34#[non_exhaustive]
35pub enum ChannelError {
36 #[error("Failed to parse URI: {0}")]
38 InvalidUri(#[from] InvalidUri),
39 #[error("Failed to parse URL: {0}")]
41 InvalidUrl(#[from] url::ParseError),
42 #[error("Protocol is missing or not supported: {0:?}")]
44 UnsupportedProtocol(Option<String>),
45 #[error("HTTP proxy ssl verification failed: {0}")]
47 SslFailure(#[from] std::io::Error),
48 #[error("Cannot set separate https and http proxies if one of them is socks5")]
50 Mismatch {
51 https_proxy: Uri,
53 http_proxy: Uri,
55 },
56}
57
58pub trait IntoService<C: GrpcService<Body>> {
60 type Service: GrpcService<Body>;
62
63 fn into_service(self, channel: C) -> Self::Service;
65}
66
67impl<C> IntoService<C> for ()
68where
69 C: GrpcService<Body>,
70{
71 type Service = C;
72 fn into_service(self, channel: C) -> Self::Service {
73 channel
74 }
75}
76
77#[derive(Clone, Debug)]
79pub struct RefreshOptions<O, R>
80where
81 R: TokenRefresher + Clone + Send + Sync,
82{
83 layer: RefreshLayer<R>,
84 other: O,
85}
86
87impl<T> From<T> for RefreshOptions<(), T>
88where
89 T: TokenRefresher + Clone + Send + Sync,
90{
91 fn from(refresher: T) -> Self {
92 Self {
93 layer: RefreshLayer::with_refresher(refresher),
94 other: (),
95 }
96 }
97}
98
99impl<C, T, O> IntoService<C> for RefreshOptions<O, T>
100where
101 C: GrpcService<Body>,
102 O: IntoService<C>,
103 O::Service: GrpcService<Body>,
104 RefreshService<O::Service, T>: GrpcService<Body>,
105 T: TokenRefresher + Clone + Send + Sync + 'static,
106{
107 type Service = RefreshService<O::Service, T>;
108 fn into_service(self, channel: C) -> Self::Service {
109 let service = self.other.into_service(channel);
110 self.layer.layer(service)
111 }
112}
113
114#[derive(Clone, Debug)]
116pub struct RetryOptions<O = ()> {
117 layer: RetryLayer,
118 other: O,
119}
120
121impl From<ExponentialBackoff> for RetryOptions<()> {
122 fn from(backoff: ExponentialBackoff) -> Self {
123 Self {
124 layer: RetryLayer { backoff },
125 other: (),
126 }
127 }
128}
129
130impl<C, O> IntoService<C> for RetryOptions<O>
131where
132 C: GrpcService<Body>,
133 O: IntoService<C>,
134 O::Service: GrpcService<Body>,
135 RetryService<O::Service>: GrpcService<Body>,
136{
137 type Service = RetryService<O::Service>;
138 fn into_service(self, channel: C) -> Self::Service {
139 let service = self.other.into_service(channel);
140 self.layer.layer(service)
141 }
142}
143
144#[derive(Clone, Debug)]
146pub struct ChannelBuilder<O = ()> {
147 endpoint: Endpoint,
148 #[cfg(feature = "tracing")]
149 trace_layer: CustomTraceLayer,
150 options: O,
151}
152
153impl From<Endpoint> for ChannelBuilder<()> {
154 fn from(endpoint: Endpoint) -> Self {
155 #[cfg(feature = "tracing")]
156 {
157 let base_url = endpoint.uri().to_string();
158 Self {
159 endpoint,
160 trace_layer: build_trace_layer(base_url, None),
161 options: (),
162 }
163 }
164
165 #[cfg(not(feature = "tracing"))]
166 return Self {
167 endpoint,
168 options: (),
169 };
170 }
171}
172
173impl ChannelBuilder<()> {
174 pub fn from_uri(uri: Uri) -> Self {
176 #[cfg(feature = "tracing")]
177 {
178 let base_url = uri.to_string();
179 Self {
180 endpoint: get_endpoint(uri),
181 trace_layer: build_trace_layer(base_url, None),
182 options: (),
183 }
184 }
185
186 #[cfg(not(feature = "tracing"))]
187 return Self {
188 endpoint: get_endpoint(uri),
189 options: (),
190 };
191 }
192}
193
194#[cfg(feature = "tracing")]
195type TargetService = CustomTraceService;
196#[cfg(not(feature = "tracing"))]
197type TargetService = Channel;
198
199impl<O> ChannelBuilder<O>
200where
201 O: IntoService<TargetService>,
202{
203 #[must_use]
205 pub fn with_timeout(mut self, timeout: Duration) -> Self {
206 self.endpoint = self.endpoint.timeout(timeout);
207 self
208 }
209
210 pub fn with_refresh_layer<T>(
212 self,
213 layer: RefreshLayer<T>,
214 ) -> ChannelBuilder<RefreshOptions<O, T>>
215 where
216 T: TokenRefresher + Clone + Send + Sync,
217 {
218 #[cfg(feature = "tracing")]
219 return ChannelBuilder {
220 endpoint: self.endpoint,
221 trace_layer: self.trace_layer,
222 options: RefreshOptions {
223 layer,
224 other: self.options,
225 },
226 };
227 #[cfg(not(feature = "tracing"))]
228 return ChannelBuilder {
229 endpoint: self.endpoint,
230 options: RefreshOptions {
231 layer,
232 other: self.options,
233 },
234 };
235 }
236
237 pub fn with_token_refresher<T>(self, refresher: T) -> ChannelBuilder<RefreshOptions<O, T>>
239 where
240 T: TokenRefresher + Clone + Send + Sync,
241 {
242 self.with_refresh_layer(RefreshLayer::with_refresher(refresher))
243 }
244
245 pub fn with_qcs_config(
247 self,
248 config: ClientConfiguration,
249 ) -> ChannelBuilder<RefreshOptions<O, ClientConfiguration>> {
250 #[cfg(feature = "tracing")]
251 {
252 let base_url = self.endpoint.uri().to_string();
253 let trace_layer = build_trace_layer(base_url, config.tracing_configuration());
254 let mut builder = self.with_token_refresher(config);
255 builder.trace_layer = trace_layer;
256 builder
257 }
258 #[cfg(not(feature = "tracing"))]
259 {
260 self.with_token_refresher(config)
261 }
262 }
263
264 pub fn with_qcs_profile(
270 self,
271 profile: Option<String>,
272 ) -> Result<ChannelBuilder<RefreshOptions<O, ClientConfiguration>>, LoadError> {
273 let config = match profile {
274 Some(profile) => ClientConfiguration::load_profile(profile)?,
275 None => ClientConfiguration::load_default()?,
276 };
277
278 Ok(self.with_qcs_config(config))
279 }
280
281 pub fn with_retry_layer(self, layer: RetryLayer) -> ChannelBuilder<RetryOptions<O>> {
283 #[cfg(feature = "tracing")]
284 return ChannelBuilder {
285 endpoint: self.endpoint,
286 trace_layer: self.trace_layer,
287 options: RetryOptions {
288 layer,
289 other: self.options,
290 },
291 };
292 #[cfg(not(feature = "tracing"))]
293 return ChannelBuilder {
294 endpoint: self.endpoint,
295 options: RetryOptions {
296 layer,
297 other: self.options,
298 },
299 };
300 }
301
302 pub fn with_retry_backoff(
304 self,
305 backoff: ExponentialBackoff,
306 ) -> ChannelBuilder<RetryOptions<O>> {
307 self.with_retry_layer(RetryLayer { backoff })
308 }
309
310 pub fn with_default_retry(self) -> ChannelBuilder<RetryOptions<O>> {
312 self.with_retry_backoff(default_backoff())
313 }
314
315 #[allow(clippy::result_large_err)]
321 pub fn build(self) -> Result<O::Service, ChannelError> {
322 let channel = get_channel_with_endpoint(&self.endpoint)?;
323 #[cfg(feature = "tracing")]
324 {
325 let traced_channel = self.trace_layer.layer(channel);
326 Ok(self.options.into_service(traced_channel))
327 }
328
329 #[cfg(not(feature = "tracing"))]
330 Ok(self.options.into_service(channel))
331 }
332}
333
334#[allow(clippy::result_large_err)]
342pub fn parse_uri(s: &str) -> Result<Uri, Error<TokenError>> {
343 s.parse().map_err(Error::from)
344}
345
346#[allow(clippy::missing_panics_doc)]
348pub fn get_endpoint(uri: Uri) -> Endpoint {
349 Channel::builder(uri)
350 .user_agent(concat!(
351 "QCS gRPC Client (Rust)/",
352 env!("CARGO_PKG_VERSION")
353 ))
354 .expect("user agent string should be valid")
355 .http2_adaptive_window(true)
356 .tls_config(ClientTlsConfig::new().with_enabled_roots())
357 .expect("tls setup should succeed")
358}
359
360pub fn get_endpoint_with_timeout(uri: Uri, timeout: Option<Duration>) -> Endpoint {
362 if let Some(duration) = timeout {
363 get_endpoint(uri).timeout(duration)
364 } else {
365 get_endpoint(uri)
366 }
367}
368
369fn get_env_uri(key: &str) -> Result<Option<Uri>, InvalidUri> {
372 std::env::var(key)
373 .or_else(|_| std::env::var(key.to_lowercase()))
374 .ok()
375 .map(Uri::try_from)
376 .transpose()
377}
378
379fn get_uri_socks_auth(uri: &Uri) -> Result<Option<Auth>, url::ParseError> {
381 let full_url = uri.to_string().parse::<Url>()?;
382 let user = full_url.username();
383 let auth = if user.is_empty() {
384 None
385 } else {
386 let pass = full_url.password().unwrap_or_default();
387 Some(Auth::new(user, pass))
388 };
389 Ok(auth)
390}
391
392#[allow(clippy::result_large_err)]
408pub fn get_channel(uri: Uri) -> Result<Channel, ChannelError> {
409 let endpoint = get_endpoint(uri);
410 get_channel_with_endpoint(&endpoint)
411}
412
413#[allow(clippy::result_large_err)]
430pub fn get_channel_with_timeout(
431 uri: Uri,
432 timeout: Option<Duration>,
433) -> Result<Channel, ChannelError> {
434 let endpoint = get_endpoint_with_timeout(uri, timeout);
435 get_channel_with_endpoint(&endpoint)
436}
437
438#[allow(
457 clippy::similar_names,
458 reason = "http(s)_proxy are similar but precise in this case"
459)]
460#[allow(clippy::result_large_err)]
461pub fn get_channel_with_endpoint(endpoint: &Endpoint) -> Result<Channel, ChannelError> {
462 let https_proxy = get_env_uri("HTTPS_PROXY")?;
463 let http_proxy = get_env_uri("HTTP_PROXY")?;
464
465 let mut connector = HttpConnector::new();
466 connector.enforce_http(false);
467
468 let connect_to = |uri: http::Uri, intercept: Intercept| {
469 let connector = connector.clone();
470 match uri.scheme_str() {
471 Some("socks5") => {
472 let socks_connector = SocksConnector {
473 auth: get_uri_socks_auth(&uri)?,
474 proxy_addr: uri,
475 connector,
476 };
477 Ok(endpoint.connect_with_connector_lazy(socks_connector))
478 }
479 Some("https" | "http") => {
480 let is_http = uri.scheme() == Some(&http::uri::Scheme::HTTP);
481 let proxy = Proxy::new(intercept, uri);
482 let mut proxy_connector = ProxyConnector::from_proxy(connector, proxy)?;
483 if is_http {
484 proxy_connector.set_tls(None);
485 }
486 Ok(endpoint.connect_with_connector_lazy(proxy_connector))
487 }
488 scheme => Err(ChannelError::UnsupportedProtocol(scheme.map(String::from))),
489 }
490 };
491
492 let channel = match (https_proxy, http_proxy) {
493 (None, None) => endpoint.connect_lazy(),
495
496 (Some(https_proxy), None) => connect_to(https_proxy, Intercept::Https)?,
498 (None, Some(http_proxy)) => connect_to(http_proxy, Intercept::Http)?,
499
500 (Some(https_proxy), Some(http_proxy)) => {
503 if https_proxy == http_proxy {
504 connect_to(https_proxy, Intercept::All)?
505 } else {
506 let accepted = [https_proxy.scheme_str(), http_proxy.scheme_str()]
507 .into_iter()
508 .all(|scheme| matches!(scheme, Some("https" | "http")));
509 if accepted {
510 let mut proxy_connector = ProxyConnector::new(connector)?;
511 proxy_connector.extend_proxies(vec![
512 Proxy::new(Intercept::Https, https_proxy),
513 Proxy::new(Intercept::Http, http_proxy),
514 ]);
515 endpoint.connect_with_connector_lazy(proxy_connector)
516 } else {
517 return Err(ChannelError::Mismatch {
518 https_proxy,
519 http_proxy,
520 });
521 }
522 }
523 }
524 };
525
526 Ok(channel)
527}
528
529#[allow(clippy::result_large_err)]
535pub fn get_wrapped_channel(
536 uri: Uri,
537) -> Result<RefreshService<Channel, ClientConfiguration>, Error<TokenError>> {
538 wrap_channel(get_channel(uri)?)
539}
540
541#[must_use]
543pub fn wrap_channel_with<C>(
544 channel: C,
545 config: ClientConfiguration,
546) -> RefreshService<C, ClientConfiguration>
547where
548 C: GrpcService<Body>,
549{
550 ServiceBuilder::new()
551 .layer(RefreshLayer::with_config(config))
552 .service(channel)
553}
554
555pub fn wrap_channel_with_token_refresher<C, T>(
559 channel: C,
560 token_refresher: T,
561) -> RefreshService<C, T>
562where
563 C: GrpcService<Body>,
564 T: TokenRefresher + Clone + Send + Sync,
565{
566 ServiceBuilder::new()
567 .layer(RefreshLayer::with_refresher(token_refresher))
568 .service(channel)
569}
570
571#[allow(clippy::result_large_err)]
577pub fn wrap_channel<C>(
578 channel: C,
579) -> Result<RefreshService<C, ClientConfiguration>, Error<TokenError>>
580where
581 C: GrpcService<Body>,
582{
583 Ok(wrap_channel_with(channel, {
584 ClientConfiguration::load_default()?
585 }))
586}
587
588#[allow(clippy::result_large_err)]
594pub fn wrap_channel_with_profile<C>(
595 channel: C,
596 profile: String,
597) -> Result<RefreshService<C, ClientConfiguration>, Error<TokenError>>
598where
599 C: GrpcService<Body>,
600{
601 Ok(wrap_channel_with(
602 channel,
603 ClientConfiguration::load_profile(profile)?,
604 ))
605}
606
607pub fn wrap_channel_with_retry<C>(channel: C) -> RetryService<C>
609where
610 C: GrpcService<Body>,
611{
612 ServiceBuilder::new()
613 .layer(RetryLayer::default())
614 .service(channel)
615}
616
617#[cfg(feature = "tracing")]
618pub fn wrap_channel_with_tracing(
620 channel: Channel,
621 base_url: String,
622 configuration: TracingConfiguration,
623) -> CustomTraceService {
624 ServiceBuilder::new()
625 .layer(build_trace_layer(base_url, Some(&configuration)))
626 .service(channel)
627}