oxide_auth_async/endpoint/
refresh.rs

1use std::{borrow::Cow, marker::PhantomData, str::from_utf8};
2
3use base64::{engine::general_purpose::STANDARD, Engine};
4use oxide_auth::{
5    code_grant::refresh::{Error, Request},
6    endpoint::{WebRequest, WebResponse, OAuthError, QueryParameter, Template, NormalizedParameter},
7};
8
9use super::Endpoint;
10use crate::{
11    code_grant::refresh::{refresh, Endpoint as RefreshEndpoint},
12    primitives::{Issuer, Registrar},
13};
14
15/// Takes requests from clients to refresh their access tokens.
16pub struct RefreshFlow<E, R>
17where
18    E: Endpoint<R>,
19    R: WebRequest,
20{
21    endpoint: WrappedRefresh<E, R>,
22}
23
24struct WrappedRefresh<E, R>
25where
26    E: Endpoint<R>,
27    R: WebRequest,
28{
29    inner: E,
30    r_type: PhantomData<R>,
31}
32
33struct WrappedRequest<R: WebRequest> {
34    /// The query in the body.
35    body: NormalizedParameter,
36
37    /// The authorization token.
38    authorization: Option<Authorization>,
39
40    /// An error if one occurred.
41    error: Option<Option<R::Error>>,
42}
43
44struct Authorization(String, Vec<u8>);
45
46impl<E, R> RefreshFlow<E, R>
47where
48    E: Endpoint<R> + Send + Sync,
49    R: WebRequest + Send + Sync,
50    <R as WebRequest>::Error: Send + Sync,
51{
52    /// Wrap the endpoint if it supports handling refresh requests.
53    ///
54    /// Also binds the endpoint to the particular `WebRequest` type through the type system. The
55    /// endpoint needs to provide (return `Some`):
56    ///
57    /// * a `Registrar` from `registrar`
58    /// * an `Issuer` from `issuer_mut`
59    ///
60    /// ## Panics
61    ///
62    /// Indirectly `execute` may panic when this flow is instantiated with an inconsistent
63    /// endpoint, for details see the documentation of `Endpoint` and `execute`. For
64    /// consistent endpoints, the panic is instead caught as an error here.
65    pub fn prepare(mut endpoint: E) -> Result<Self, E::Error> {
66        if endpoint.registrar().is_none() {
67            return Err(endpoint.error(OAuthError::PrimitiveError));
68        }
69
70        if endpoint.issuer_mut().is_none() {
71            return Err(endpoint.error(OAuthError::PrimitiveError));
72        }
73
74        Ok(RefreshFlow {
75            endpoint: WrappedRefresh {
76                inner: endpoint,
77                r_type: PhantomData,
78            },
79        })
80    }
81
82    pub async fn execute(&mut self, mut request: R) -> Result<R::Response, E::Error> {
83        let refreshed = refresh(&mut self.endpoint, &WrappedRequest::new(&mut request)).await;
84
85        let token = match refreshed {
86            Err(error) => return token_error(&mut self.endpoint.inner, &mut request, error),
87            Ok(token) => token,
88        };
89
90        let mut response = self.endpoint.inner.response(&mut request, Template::new_ok())?;
91        response
92            .body_json(&token.to_json())
93            .map_err(|err| self.endpoint.inner.web_error(err))?;
94        Ok(response)
95    }
96}
97
98fn token_error<E, R>(endpoint: &mut E, request: &mut R, error: Error) -> Result<R::Response, E::Error>
99where
100    E: Endpoint<R>,
101    R: WebRequest,
102{
103    Ok(match error {
104        Error::Invalid(mut json) => {
105            let mut response =
106                endpoint.response(request, Template::new_bad(Some(json.description())))?;
107            response.client_error().map_err(|err| endpoint.web_error(err))?;
108            response
109                .body_json(&json.to_json())
110                .map_err(|err| endpoint.web_error(err))?;
111            response
112        }
113        Error::Unauthorized(mut json, scheme) => {
114            let mut response = endpoint.response(
115                request,
116                Template::new_unauthorized(None, Some(json.description())),
117            )?;
118            response
119                .unauthorized(&scheme)
120                .map_err(|err| endpoint.web_error(err))?;
121            response
122                .body_json(&json.to_json())
123                .map_err(|err| endpoint.web_error(err))?;
124            response
125        }
126        Error::Primitive => {
127            // FIXME: give the context for restoration.
128            return Err(endpoint.error(OAuthError::PrimitiveError));
129        }
130    })
131}
132
133impl<'a, R: WebRequest> WrappedRequest<R> {
134    pub fn new(request: &'a mut R) -> Self {
135        Self::new_or_fail(request).unwrap_or_else(Self::from_err)
136    }
137
138    fn new_or_fail(request: &'a mut R) -> Result<Self, Option<R::Error>> {
139        // If there is a header, it must parse correctly.
140        let authorization = match request.authheader() {
141            Err(err) => return Err(Some(err)),
142            Ok(Some(header)) => Self::parse_header(header).map(Some)?,
143            Ok(None) => None,
144        };
145
146        Ok(WrappedRequest {
147            body: request.urlbody()?.into_owned(),
148            authorization,
149            error: None,
150        })
151    }
152
153    fn from_err(err: Option<R::Error>) -> Self {
154        WrappedRequest {
155            body: Default::default(),
156            authorization: None,
157            error: Some(err),
158        }
159    }
160
161    fn parse_header(header: Cow<str>) -> Result<Authorization, Option<R::Error>> {
162        let authorization = {
163            if !header.starts_with("Basic ") {
164                return Err(None);
165            }
166
167            let combined = match STANDARD.decode(&header[6..]) {
168                Err(_) => return Err(None),
169                Ok(vec) => vec,
170            };
171
172            let mut split = combined.splitn(2, |&c| c == b':');
173            let client_bin = match split.next() {
174                None => return Err(None),
175                Some(client) => client,
176            };
177            let passwd = match split.next() {
178                None => return Err(None),
179                Some(passwd64) => passwd64,
180            };
181
182            let client = match from_utf8(client_bin) {
183                Err(_) => return Err(None),
184                Ok(client) => client,
185            };
186
187            Authorization(client.to_string(), passwd.to_vec())
188        };
189
190        Ok(authorization)
191    }
192}
193
194impl<E, R> RefreshEndpoint for WrappedRefresh<E, R>
195where
196    E: Endpoint<R>,
197    R: WebRequest,
198{
199    fn registrar(&self) -> &(dyn Registrar + Sync) {
200        self.inner.registrar().unwrap()
201    }
202
203    fn issuer(&mut self) -> &mut (dyn Issuer + Send) {
204        self.inner.issuer_mut().unwrap()
205    }
206}
207
208impl<R: WebRequest> Request for WrappedRequest<R> {
209    fn valid(&self) -> bool {
210        self.error.is_none()
211    }
212
213    fn refresh_token(&self) -> Option<Cow<str>> {
214        self.body.unique_value("refresh_token")
215    }
216
217    fn authorization(&self) -> Option<(Cow<str>, Cow<[u8]>)> {
218        self.authorization
219            .as_ref()
220            .map(|auth| (auth.0.as_str().into(), auth.1.as_slice().into()))
221    }
222
223    fn scope(&self) -> Option<Cow<str>> {
224        self.body.unique_value("scope")
225    }
226
227    fn grant_type(&self) -> Option<Cow<str>> {
228        self.body.unique_value("grant_type")
229    }
230
231    fn extension(&self, key: &str) -> Option<Cow<str>> {
232        self.body.unique_value(key)
233    }
234}