qcs_api_client_grpc/tonic/
refresh.rs

1use std::{
2    future::{poll_fn, Future},
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use http::StatusCode;
8use tonic::{
9    body::BoxBody,
10    client::GrpcService,
11    codegen::http::{Request, Response},
12};
13use tower::Layer;
14
15use qcs_api_client_common::configuration::{ClientConfiguration, TokenError, TokenRefresher};
16
17use super::error::Error;
18
19/// The [`GrpcService`] that wraps the gRPC client in order to provide QCS authentication.
20///
21/// See also: [`RefreshLayer`].
22#[derive(Clone, Debug)]
23pub struct RefreshService<S: GrpcService<BoxBody>, T: TokenRefresher> {
24    service: S,
25    token_refresher: T,
26}
27
28/// The [`Layer`] used to apply QCS authentication to requests.
29#[derive(Clone, Debug)]
30pub struct RefreshLayer<T: TokenRefresher> {
31    token_refresher: T,
32}
33
34impl<T: TokenRefresher> RefreshLayer<T> {
35    /// Create a new [`RefreshLayer`] with the given [`TokenRefresher`]
36    pub const fn with_refresher(token_refresher: T) -> Self {
37        Self { token_refresher }
38    }
39}
40
41impl RefreshLayer<ClientConfiguration> {
42    /// Create a new [`RefreshLayer`].
43    ///
44    /// # Errors
45    ///
46    /// Will fail with error if loading the [`ClientConfiguration`] fails.
47    pub fn new() -> Result<Self, Error<TokenError>> {
48        let config = ClientConfiguration::load_default()?;
49        Ok(Self::with_config(config))
50    }
51
52    /// Create a new [`RefreshLayer`] using the given QCS configuration profile.
53    ///
54    /// # Errors
55    ///
56    /// Will fail if loading the [`ClientConfiguration`] fails.
57    pub fn with_profile(profile: String) -> Result<Self, Error<TokenError>> {
58        let config = ClientConfiguration::load_profile(profile)?;
59        Ok(Self::with_config(config))
60    }
61
62    /// Create a [`RefreshLayer`] from an existing [`ClientConfiguration`].
63    #[must_use]
64    pub const fn with_config(config: ClientConfiguration) -> Self {
65        Self::with_refresher(config)
66    }
67}
68
69impl<S, T> Layer<S> for RefreshLayer<T>
70where
71    S: GrpcService<BoxBody>,
72    T: TokenRefresher + Clone,
73{
74    type Service = RefreshService<S, T>;
75
76    fn layer(&self, inner: S) -> Self::Service {
77        RefreshService {
78            token_refresher: self.token_refresher.clone(),
79            service: inner,
80        }
81    }
82}
83
84impl<S, T> GrpcService<BoxBody> for RefreshService<S, T>
85where
86    S: GrpcService<BoxBody> + Clone + Send + 'static,
87    <S as GrpcService<BoxBody>>::Future: Send,
88    <S as GrpcService<BoxBody>>::ResponseBody: Send,
89    T: TokenRefresher + Clone + Send + 'static,
90    T::Error: std::error::Error + Sync,
91    Error<T::Error>: From<S::Error>,
92    <T as TokenRefresher>::Error: Send,
93{
94    type ResponseBody = <S as GrpcService<BoxBody>>::ResponseBody;
95    type Error = Error<T::Error>;
96    type Future =
97        Pin<Box<dyn Future<Output = Result<Response<Self::ResponseBody>, Self::Error>> + Send>>;
98
99    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100        self.service.poll_ready(cx).map_err(Error::from)
101    }
102
103    fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
104        let service = self.service.clone();
105        // It is necessary to replace self.service with the above clone
106        // because the cloned version may not be "ready".
107        //
108        // See this github issue for more context:
109        // https://github.com/tower-rs/tower/issues/547
110        let service = std::mem::replace(&mut self.service, service);
111        let token_refresher = self.token_refresher.clone();
112        super::common::pin_future_with_otel_context_if_available(service_call(
113            req,
114            token_refresher,
115            service,
116        ))
117    }
118}
119
120async fn service_call<C, T>(
121    req: Request<BoxBody>,
122    token_refresher: T,
123    mut channel: C,
124) -> Result<Response<<C as GrpcService<BoxBody>>::ResponseBody>, Error<T::Error>>
125where
126    C: GrpcService<BoxBody> + Send,
127    <C as GrpcService<BoxBody>>::ResponseBody: Send,
128    <C as GrpcService<BoxBody>>::Future: Send,
129    T: TokenRefresher + Send,
130    T::Error: std::error::Error,
131    Error<T::Error>: From<C::Error>,
132{
133    let token = token_refresher
134        .validated_access_token()
135        .await
136        .map_err(Error::Refresh)?;
137    let (req, retry_req) = super::build_duplicate_request(req).await?;
138    let resp = make_request(&mut channel, req, token).await?;
139
140    let grpc_authnz_failure = matches!(
141        super::common::get_status_code_from_headers(resp.headers()).ok(),
142        Some(tonic::Code::Unauthenticated) | Some(tonic::Code::PermissionDenied)
143    );
144    let http_authnz_failure =
145        resp.status() == StatusCode::UNAUTHORIZED || resp.status() == StatusCode::FORBIDDEN;
146
147    if grpc_authnz_failure || http_authnz_failure {
148        #[cfg(feature = "tracing")]
149        {
150            tracing::info!("refreshing token after receiving unauthorized or forbidden status",);
151        }
152
153        // Refresh token and try again
154        let token = token_refresher
155            .validated_access_token()
156            .await
157            .map_err(Error::Refresh)?;
158
159        #[cfg(feature = "tracing")]
160        {
161            tracing::info!("token refreshed");
162        }
163        // Ensure that the service is ready before trying to use it.
164        // Failure to do this *will* cause a panic.
165        poll_fn(|cx| channel.poll_ready(cx))
166            .await
167            .map_err(super::error::Error::from)?;
168        make_request(&mut channel, retry_req, token).await
169    } else {
170        Ok(resp)
171    }
172}
173
174async fn make_request<C, E: std::error::Error>(
175    service: &mut C,
176    mut request: Request<BoxBody>,
177    token: String,
178) -> Result<Response<<C as GrpcService<BoxBody>>::ResponseBody>, Error<E>>
179where
180    C: GrpcService<BoxBody> + Send,
181    <C as GrpcService<BoxBody>>::ResponseBody: Send,
182    <C as GrpcService<BoxBody>>::Future: Send,
183    Error<E>: From<C::Error>,
184{
185    let header_val = format!("Bearer {token}")
186        .try_into()
187        .map_err(Error::InvalidAccessToken)?;
188    request.headers_mut().insert("authorization", header_val);
189    service.call(request).await.map_err(Error::from)
190}