qcs_api_client_grpc/tonic/
refresh.rs

1use std::{
2    future::{poll_fn, Future},
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use super::Body;
8use http::StatusCode;
9use tonic::{
10    client::GrpcService,
11    codegen::http::{Request, Response},
12};
13use tower::Layer;
14
15use qcs_api_client_common::configuration::{
16    secrets::SecretAccessToken, tokens::TokenRefresher, ClientConfiguration, TokenError,
17};
18
19use super::error::Error;
20
21/// The [`GrpcService`] that wraps the gRPC client in order to provide QCS authentication.
22///
23/// See also: [`RefreshLayer`].
24#[derive(Clone, Debug)]
25pub struct RefreshService<S: GrpcService<Body>, T: TokenRefresher> {
26    service: S,
27    token_refresher: T,
28}
29
30/// The [`Layer`] used to apply QCS authentication to requests.
31#[derive(Clone, Debug)]
32pub struct RefreshLayer<T: TokenRefresher> {
33    token_refresher: T,
34}
35
36impl<T: TokenRefresher> RefreshLayer<T> {
37    /// Create a new [`RefreshLayer`] with the given [`TokenRefresher`]
38    pub const fn with_refresher(token_refresher: T) -> Self {
39        Self { token_refresher }
40    }
41}
42
43impl RefreshLayer<ClientConfiguration> {
44    /// Create a new [`RefreshLayer`].
45    ///
46    /// # Errors
47    ///
48    /// Will fail with error if loading the [`ClientConfiguration`] fails.
49    #[allow(clippy::result_large_err)]
50    pub fn new() -> Result<Self, Error<TokenError>> {
51        let config = ClientConfiguration::load_default()?;
52        Ok(Self::with_config(config))
53    }
54
55    /// Create a new [`RefreshLayer`] using the given QCS configuration profile.
56    ///
57    /// # Errors
58    ///
59    /// Will fail if loading the [`ClientConfiguration`] fails.
60    #[allow(clippy::result_large_err)]
61    pub fn with_profile(profile: String) -> Result<Self, Error<TokenError>> {
62        let config = ClientConfiguration::load_profile(profile)?;
63        Ok(Self::with_config(config))
64    }
65
66    /// Create a [`RefreshLayer`] from an existing [`ClientConfiguration`].
67    #[must_use]
68    pub const fn with_config(config: ClientConfiguration) -> Self {
69        Self::with_refresher(config)
70    }
71}
72
73impl<S, T> Layer<S> for RefreshLayer<T>
74where
75    S: GrpcService<Body>,
76    T: TokenRefresher + Clone,
77{
78    type Service = RefreshService<S, T>;
79
80    fn layer(&self, inner: S) -> Self::Service {
81        RefreshService {
82            token_refresher: self.token_refresher.clone(),
83            service: inner,
84        }
85    }
86}
87
88impl<S, T> GrpcService<Body> for RefreshService<S, T>
89where
90    S: GrpcService<Body> + Clone + Send + 'static,
91    <S as GrpcService<Body>>::Future: Send,
92    <S as GrpcService<Body>>::ResponseBody: Send,
93    T: TokenRefresher + Clone + Send + 'static,
94    T::Error: std::error::Error + Sync,
95    Error<T::Error>: From<S::Error>,
96    <T as TokenRefresher>::Error: Send,
97{
98    type ResponseBody = <S as GrpcService<Body>>::ResponseBody;
99    type Error = Error<T::Error>;
100    type Future =
101        Pin<Box<dyn Future<Output = Result<Response<Self::ResponseBody>, Self::Error>> + Send>>;
102
103    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
104        self.service.poll_ready(cx).map_err(Error::from)
105    }
106
107    fn call(&mut self, req: Request<Body>) -> Self::Future {
108        let service = self.service.clone();
109        // It is necessary to replace self.service with the above clone
110        // because the cloned version may not be "ready".
111        //
112        // See this github issue for more context:
113        // https://github.com/tower-rs/tower/issues/547
114        let service = std::mem::replace(&mut self.service, service);
115        let token_refresher = self.token_refresher.clone();
116        super::common::pin_future_with_otel_context_if_available(service_call(
117            req,
118            token_refresher,
119            service,
120        ))
121    }
122}
123
124async fn service_call<C, T>(
125    req: Request<Body>,
126    token_refresher: T,
127    mut channel: C,
128) -> Result<Response<<C as GrpcService<Body>>::ResponseBody>, Error<T::Error>>
129where
130    C: GrpcService<Body> + Send,
131    <C as GrpcService<Body>>::ResponseBody: Send,
132    <C as GrpcService<Body>>::Future: Send,
133    T: TokenRefresher + Send,
134    T::Error: std::error::Error,
135    Error<T::Error>: From<C::Error>,
136{
137    let token = token_refresher
138        .validated_access_token()
139        .await
140        .map_err(Error::Refresh)?;
141    let (req, retry_req) = super::build_duplicate_request(req).await?;
142    let resp = make_request(&mut channel, req, token).await?;
143
144    let grpc_authnz_failure = matches!(
145        super::common::get_status_code_from_headers(resp.headers()).ok(),
146        Some(tonic::Code::Unauthenticated) | Some(tonic::Code::PermissionDenied)
147    );
148    let http_authnz_failure =
149        resp.status() == StatusCode::UNAUTHORIZED || resp.status() == StatusCode::FORBIDDEN;
150
151    if grpc_authnz_failure || http_authnz_failure {
152        #[cfg(feature = "tracing")]
153        {
154            tracing::info!("refreshing token after receiving unauthorized or forbidden status",);
155        }
156
157        // Refresh token and try again
158        let token = token_refresher
159            .validated_access_token()
160            .await
161            .map_err(Error::Refresh)?;
162
163        #[cfg(feature = "tracing")]
164        {
165            tracing::info!("token refreshed");
166        }
167        // Ensure that the service is ready before trying to use it.
168        // Failure to do this *will* cause a panic.
169        poll_fn(|cx| -> Poll<Result<(), _>> { channel.poll_ready(cx) })
170            .await
171            .map_err(super::error::Error::from)?;
172        make_request(&mut channel, retry_req, token).await
173    } else {
174        Ok(resp)
175    }
176}
177
178async fn make_request<C, E: std::error::Error>(
179    service: &mut C,
180    mut request: Request<Body>,
181    access_token: SecretAccessToken,
182) -> Result<Response<<C as GrpcService<Body>>::ResponseBody>, Error<E>>
183where
184    C: GrpcService<Body> + Send,
185    <C as GrpcService<Body>>::ResponseBody: Send,
186    <C as GrpcService<Body>>::Future: Send,
187    Error<E>: From<C::Error>,
188{
189    let header_val = format!("Bearer {}", access_token.secret())
190        .try_into()
191        .map_err(Error::InvalidAccessToken)?;
192    request.headers_mut().insert("authorization", header_val);
193    service.call(request).await.map_err(Error::from)
194}