1use std::borrow::Cow;
2use std::marker::PhantomData;
3use std::str::from_utf8;
4
5use base64::Engine;
6use base64::engine::general_purpose::STANDARD;
7
8use crate::code_grant::refresh::{refresh, Error, Endpoint as RefreshEndpoint, Request};
9use crate::primitives::{registrar::Registrar, issuer::Issuer};
10use super::{
11 Endpoint, InnerTemplate, OAuthError, QueryParameter, WebRequest, WebResponse,
12 is_authorization_method,
13};
14
15pub struct RefreshFlow<E, R>
17where
18 E: Endpoint<R>,
19 R: WebRequest,
20{
21 endpoint: WrappedRefresh<E, R>,
22}
23
24struct WrappedRefresh<E: Endpoint<R>, R: WebRequest> {
25 inner: E,
26 r_type: PhantomData<R>,
27}
28
29struct WrappedRequest<'a, R: WebRequest + 'a> {
30 request: PhantomData<R>,
32
33 body: Cow<'a, dyn QueryParameter + 'static>,
35
36 authorization: Option<Authorization>,
38
39 error: Option<InitError<R::Error>>,
41}
42
43enum InitError<E> {
44 Malformed,
45 Internal(E),
46}
47
48struct Authorization(String, Vec<u8>);
49
50impl<E, R> RefreshFlow<E, R>
51where
52 E: Endpoint<R>,
53 R: WebRequest,
54{
55 pub fn prepare(mut endpoint: E) -> Result<Self, E::Error> {
69 if endpoint.registrar().is_none() {
70 return Err(endpoint.error(OAuthError::PrimitiveError));
71 }
72
73 if endpoint.issuer_mut().is_none() {
74 return Err(endpoint.error(OAuthError::PrimitiveError));
75 }
76
77 Ok(RefreshFlow {
78 endpoint: WrappedRefresh {
79 inner: endpoint,
80 r_type: PhantomData,
81 },
82 })
83 }
84
85 pub fn execute(&mut self, mut request: R) -> Result<R::Response, E::Error> {
92 let refreshed = refresh(&mut self.endpoint, &WrappedRequest::new(&mut request));
93
94 let token = match refreshed {
95 Err(error) => return token_error(&mut self.endpoint.inner, &mut request, error),
96 Ok(token) => token,
97 };
98
99 let mut response = self
100 .endpoint
101 .inner
102 .response(&mut request, InnerTemplate::Ok.into())?;
103 response
104 .body_json(&token.to_json())
105 .map_err(|err| self.endpoint.inner.web_error(err))?;
106 Ok(response)
107 }
108}
109
110fn token_error<E: Endpoint<R>, R: WebRequest>(
111 endpoint: &mut E, request: &mut R, error: Error,
112) -> Result<R::Response, E::Error> {
113 Ok(match error {
114 Error::Invalid(mut json) => {
115 let mut response = endpoint.response(
116 request,
117 InnerTemplate::BadRequest {
118 access_token_error: Some(json.description()),
119 }
120 .into(),
121 )?;
122 response.client_error().map_err(|err| endpoint.web_error(err))?;
123 response
124 .body_json(&json.to_json())
125 .map_err(|err| endpoint.web_error(err))?;
126 response
127 }
128 Error::Unauthorized(mut json, scheme) => {
129 let mut response = endpoint.response(
130 request,
131 InnerTemplate::Unauthorized {
132 error: None,
133 access_token_error: Some(json.description()),
134 }
135 .into(),
136 )?;
137 response
138 .unauthorized(&scheme)
139 .map_err(|err| endpoint.web_error(err))?;
140 response
141 .body_json(&json.to_json())
142 .map_err(|err| endpoint.web_error(err))?;
143 response
144 }
145 Error::Primitive => {
146 return Err(endpoint.error(OAuthError::PrimitiveError));
148 }
149 })
150}
151
152impl<'a, R: WebRequest + 'a> WrappedRequest<'a, R> {
153 pub fn new(request: &'a mut R) -> Self {
154 Self::new_or_fail(request).unwrap_or_else(Self::from_err)
155 }
156
157 fn new_or_fail(request: &'a mut R) -> Result<Self, InitError<R::Error>> {
158 let authorization = match request.authheader() {
160 Err(err) => return Err(InitError::Internal(err)),
161 Ok(Some(header)) => Self::parse_header(header).map(Some)?,
162 Ok(None) => None,
163 };
164
165 Ok(WrappedRequest {
166 request: PhantomData,
167 body: request.urlbody().map_err(InitError::Internal)?,
168 authorization,
169 error: None,
170 })
171 }
172
173 fn from_err(err: InitError<R::Error>) -> Self {
174 WrappedRequest {
175 request: PhantomData,
176 body: Cow::Owned(Default::default()),
177 authorization: None,
178 error: Some(err),
179 }
180 }
181
182 fn parse_header(header: Cow<str>) -> Result<Authorization, InitError<R::Error>> {
183 let authorization = {
184 let auth_data = match is_authorization_method(&header, "Basic ") {
185 None => return Err(InitError::Malformed),
186 Some(data) => data,
187 };
188
189 let combined = match STANDARD.decode(auth_data) {
190 Err(_) => return Err(InitError::Malformed),
191 Ok(vec) => vec,
192 };
193
194 let mut split = combined.splitn(2, |&c| c == b':');
195 let client_bin = match split.next() {
196 None => return Err(InitError::Malformed),
197 Some(client) => client,
198 };
199 let passwd = match split.next() {
200 None => return Err(InitError::Malformed),
201 Some(passwd64) => passwd64,
202 };
203
204 let client = match from_utf8(client_bin) {
205 Err(_) => return Err(InitError::Malformed),
206 Ok(client) => client,
207 };
208
209 Authorization(client.to_string(), passwd.to_vec())
210 };
211
212 Ok(authorization)
213 }
214}
215
216impl<E: Endpoint<R>, R: WebRequest> RefreshEndpoint for WrappedRefresh<E, R> {
217 fn registrar(&self) -> &dyn Registrar {
218 self.inner.registrar().unwrap()
219 }
220
221 fn issuer(&mut self) -> &mut dyn Issuer {
222 self.inner.issuer_mut().unwrap()
223 }
224}
225
226impl<'a, R: WebRequest> Request for WrappedRequest<'a, R> {
227 fn valid(&self) -> bool {
228 self.error.is_none()
229 }
230
231 fn refresh_token(&self) -> Option<Cow<str>> {
232 self.body.unique_value("refresh_token")
233 }
234
235 fn authorization(&self) -> Option<(Cow<str>, Cow<[u8]>)> {
236 self.authorization
237 .as_ref()
238 .map(|auth| (auth.0.as_str().into(), auth.1.as_slice().into()))
239 }
240
241 fn scope(&self) -> Option<Cow<str>> {
242 self.body.unique_value("scope")
243 }
244
245 fn grant_type(&self) -> Option<Cow<str>> {
246 self.body.unique_value("grant_type")
247 }
248
249 fn extension(&self, key: &str) -> Option<Cow<str>> {
250 self.body.unique_value(key)
251 }
252}