qcs_api_client_grpc/tonic/
refresh.rs1use std::{
2 future::{poll_fn, Future},
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use http::StatusCode;
8use tonic::{
9 body::BoxBody,
10 client::GrpcService,
11 codegen::http::{Request, Response},
12};
13use tower::Layer;
14
15use qcs_api_client_common::configuration::{ClientConfiguration, TokenError, TokenRefresher};
16
17use super::error::Error;
18
19#[derive(Clone, Debug)]
23pub struct RefreshService<S: GrpcService<BoxBody>, T: TokenRefresher> {
24 service: S,
25 token_refresher: T,
26}
27
28#[derive(Clone, Debug)]
30pub struct RefreshLayer<T: TokenRefresher> {
31 token_refresher: T,
32}
33
34impl<T: TokenRefresher> RefreshLayer<T> {
35 pub const fn with_refresher(token_refresher: T) -> Self {
37 Self { token_refresher }
38 }
39}
40
41impl RefreshLayer<ClientConfiguration> {
42 pub fn new() -> Result<Self, Error<TokenError>> {
48 let config = ClientConfiguration::load_default()?;
49 Ok(Self::with_config(config))
50 }
51
52 pub fn with_profile(profile: String) -> Result<Self, Error<TokenError>> {
58 let config = ClientConfiguration::load_profile(profile)?;
59 Ok(Self::with_config(config))
60 }
61
62 #[must_use]
64 pub const fn with_config(config: ClientConfiguration) -> Self {
65 Self::with_refresher(config)
66 }
67}
68
69impl<S, T> Layer<S> for RefreshLayer<T>
70where
71 S: GrpcService<BoxBody>,
72 T: TokenRefresher + Clone,
73{
74 type Service = RefreshService<S, T>;
75
76 fn layer(&self, inner: S) -> Self::Service {
77 RefreshService {
78 token_refresher: self.token_refresher.clone(),
79 service: inner,
80 }
81 }
82}
83
84impl<S, T> GrpcService<BoxBody> for RefreshService<S, T>
85where
86 S: GrpcService<BoxBody> + Clone + Send + 'static,
87 <S as GrpcService<BoxBody>>::Future: Send,
88 <S as GrpcService<BoxBody>>::ResponseBody: Send,
89 T: TokenRefresher + Clone + Send + 'static,
90 T::Error: std::error::Error + Sync,
91 Error<T::Error>: From<S::Error>,
92 <T as TokenRefresher>::Error: Send,
93{
94 type ResponseBody = <S as GrpcService<BoxBody>>::ResponseBody;
95 type Error = Error<T::Error>;
96 type Future =
97 Pin<Box<dyn Future<Output = Result<Response<Self::ResponseBody>, Self::Error>> + Send>>;
98
99 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100 self.service.poll_ready(cx).map_err(Error::from)
101 }
102
103 fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
104 let service = self.service.clone();
105 let service = std::mem::replace(&mut self.service, service);
111 let token_refresher = self.token_refresher.clone();
112 super::common::pin_future_with_otel_context_if_available(service_call(
113 req,
114 token_refresher,
115 service,
116 ))
117 }
118}
119
120async fn service_call<C, T>(
121 req: Request<BoxBody>,
122 token_refresher: T,
123 mut channel: C,
124) -> Result<Response<<C as GrpcService<BoxBody>>::ResponseBody>, Error<T::Error>>
125where
126 C: GrpcService<BoxBody> + Send,
127 <C as GrpcService<BoxBody>>::ResponseBody: Send,
128 <C as GrpcService<BoxBody>>::Future: Send,
129 T: TokenRefresher + Send,
130 T::Error: std::error::Error,
131 Error<T::Error>: From<C::Error>,
132{
133 let token = token_refresher
134 .validated_access_token()
135 .await
136 .map_err(Error::Refresh)?;
137 let (req, retry_req) = super::build_duplicate_request(req).await?;
138 let resp = make_request(&mut channel, req, token).await?;
139
140 let grpc_authnz_failure = matches!(
141 super::common::get_status_code_from_headers(resp.headers()).ok(),
142 Some(tonic::Code::Unauthenticated) | Some(tonic::Code::PermissionDenied)
143 );
144 let http_authnz_failure =
145 resp.status() == StatusCode::UNAUTHORIZED || resp.status() == StatusCode::FORBIDDEN;
146
147 if grpc_authnz_failure || http_authnz_failure {
148 #[cfg(feature = "tracing")]
149 {
150 tracing::info!("refreshing token after receiving unauthorized or forbidden status",);
151 }
152
153 let token = token_refresher
155 .validated_access_token()
156 .await
157 .map_err(Error::Refresh)?;
158
159 #[cfg(feature = "tracing")]
160 {
161 tracing::info!("token refreshed");
162 }
163 poll_fn(|cx| channel.poll_ready(cx))
166 .await
167 .map_err(super::error::Error::from)?;
168 make_request(&mut channel, retry_req, token).await
169 } else {
170 Ok(resp)
171 }
172}
173
174async fn make_request<C, E: std::error::Error>(
175 service: &mut C,
176 mut request: Request<BoxBody>,
177 token: String,
178) -> Result<Response<<C as GrpcService<BoxBody>>::ResponseBody>, Error<E>>
179where
180 C: GrpcService<BoxBody> + Send,
181 <C as GrpcService<BoxBody>>::ResponseBody: Send,
182 <C as GrpcService<BoxBody>>::Future: Send,
183 Error<E>: From<C::Error>,
184{
185 let header_val = format!("Bearer {token}")
186 .try_into()
187 .map_err(Error::InvalidAccessToken)?;
188 request.headers_mut().insert("authorization", header_val);
189 service.call(request).await.map_err(Error::from)
190}