1use std::borrow::Cow;
2use std::str::from_utf8;
3use std::marker::PhantomData;
4
5use base64::Engine;
6use base64::engine::general_purpose::STANDARD;
7
8use crate::code_grant::accesstoken::{
9 access_token, Error as TokenError, Extension, Endpoint as TokenEndpoint, Request as TokenRequest,
10 Authorization as TokenAuthorization,
11};
12use crate::primitives::{authorizer::Authorizer, registrar::Registrar, issuer::Issuer};
13use super::{
14 Endpoint, InnerTemplate, OAuthError, QueryParameter, WebRequest, WebResponse,
15 is_authorization_method,
16};
17
18pub struct AccessTokenFlow<E, R>
31where
32 E: Endpoint<R>,
33 R: WebRequest,
34{
35 endpoint: WrappedToken<E, R>,
36 allow_credentials_in_body: bool,
37}
38
39struct WrappedToken<E: Endpoint<R>, R: WebRequest> {
40 inner: E,
41 extension_fallback: (),
42 r_type: PhantomData<R>,
43}
44
45struct WrappedRequest<'a, R: WebRequest + 'a> {
46 request: PhantomData<R>,
48
49 body: Cow<'a, dyn QueryParameter + 'static>,
51
52 authorization: Option<Authorization>,
54
55 error: Option<FailParse<R::Error>>,
57
58 allow_credentials_in_body: bool,
60}
61
62#[derive(Debug)]
63struct Invalid;
64
65enum FailParse<E> {
66 Invalid,
67 Err(E),
68}
69
70#[derive(Debug, PartialEq, Eq)]
71struct Authorization(String, Option<Vec<u8>>);
72
73impl<E, R> AccessTokenFlow<E, R>
74where
75 E: Endpoint<R>,
76 R: WebRequest,
77{
78 pub fn prepare(mut endpoint: E) -> Result<Self, E::Error> {
89 if endpoint.registrar().is_none() {
90 return Err(endpoint.error(OAuthError::PrimitiveError));
91 }
92
93 if endpoint.authorizer_mut().is_none() {
94 return Err(endpoint.error(OAuthError::PrimitiveError));
95 }
96
97 if endpoint.issuer_mut().is_none() {
98 return Err(endpoint.error(OAuthError::PrimitiveError));
99 }
100
101 Ok(AccessTokenFlow {
102 endpoint: WrappedToken {
103 inner: endpoint,
104 extension_fallback: (),
105 r_type: PhantomData,
106 },
107 allow_credentials_in_body: false,
108 })
109 }
110
111 pub fn allow_credentials_in_body(&mut self, allow: bool) {
119 self.allow_credentials_in_body = allow;
120 }
121
122 pub fn execute(&mut self, mut request: R) -> Result<R::Response, E::Error> {
129 let issued = access_token(
130 &mut self.endpoint,
131 &WrappedRequest::new(&mut request, self.allow_credentials_in_body),
132 );
133
134 let token = match issued {
135 Err(error) => return token_error(&mut self.endpoint.inner, &mut request, error),
136 Ok(token) => token,
137 };
138
139 let mut response = self
140 .endpoint
141 .inner
142 .response(&mut request, InnerTemplate::Ok.into())?;
143 response
144 .body_json(&token.to_json())
145 .map_err(|err| self.endpoint.inner.web_error(err))?;
146 Ok(response)
147 }
148}
149
150fn token_error<E: Endpoint<R>, R: WebRequest>(
151 endpoint: &mut E, request: &mut R, error: TokenError,
152) -> Result<R::Response, E::Error> {
153 Ok(match error {
154 TokenError::Invalid(mut json) => {
155 let mut response = endpoint.response(
156 request,
157 InnerTemplate::BadRequest {
158 access_token_error: Some(json.description()),
159 }
160 .into(),
161 )?;
162 response.client_error().map_err(|err| endpoint.web_error(err))?;
163 response
164 .body_json(&json.to_json())
165 .map_err(|err| endpoint.web_error(err))?;
166 response
167 }
168 TokenError::Unauthorized(mut json, scheme) => {
169 let mut response = endpoint.response(
170 request,
171 InnerTemplate::Unauthorized {
172 error: None,
173 access_token_error: Some(json.description()),
174 }
175 .into(),
176 )?;
177 response
178 .unauthorized(&scheme)
179 .map_err(|err| endpoint.web_error(err))?;
180 response
181 .body_json(&json.to_json())
182 .map_err(|err| endpoint.web_error(err))?;
183 response
184 }
185 TokenError::Primitive(_) => {
186 return Err(endpoint.error(OAuthError::PrimitiveError));
188 }
189 })
190}
191
192impl<E: Endpoint<R>, R: WebRequest> TokenEndpoint for WrappedToken<E, R> {
193 fn registrar(&self) -> &dyn Registrar {
194 self.inner.registrar().unwrap()
195 }
196
197 fn authorizer(&mut self) -> &mut dyn Authorizer {
198 self.inner.authorizer_mut().unwrap()
199 }
200
201 fn issuer(&mut self) -> &mut dyn Issuer {
202 self.inner.issuer_mut().unwrap()
203 }
204
205 fn extension(&mut self) -> &mut dyn Extension {
206 self.inner
207 .extension()
208 .and_then(super::Extension::access_token)
209 .unwrap_or(&mut self.extension_fallback)
210 }
211}
212
213impl<'a, R: WebRequest + 'a> WrappedRequest<'a, R> {
214 pub fn new(request: &'a mut R, credentials: bool) -> Self {
215 Self::new_or_fail(request, credentials).unwrap_or_else(Self::from_err)
216 }
217
218 fn new_or_fail(request: &'a mut R, credentials: bool) -> Result<Self, FailParse<R::Error>> {
219 let authorization = match request.authheader() {
221 Err(err) => return Err(FailParse::Err(err)),
222 Ok(Some(header)) => Self::parse_header(header).map(Some)?,
223 Ok(None) => None,
224 };
225
226 Ok(WrappedRequest {
227 request: PhantomData,
228 body: request.urlbody().map_err(FailParse::Err)?,
229 authorization,
230 error: None,
231 allow_credentials_in_body: credentials,
232 })
233 }
234
235 fn from_err(err: FailParse<R::Error>) -> Self {
236 WrappedRequest {
237 request: PhantomData,
238 body: Cow::Owned(Default::default()),
239 authorization: None,
240 error: Some(err),
241 allow_credentials_in_body: false,
242 }
243 }
244
245 fn parse_header(header: Cow<str>) -> Result<Authorization, Invalid> {
246 let authorization = {
247 let auth_data = match is_authorization_method(&header, "Basic ") {
248 None => return Err(Invalid),
249 Some(data) => data,
250 };
251
252 let combined = match STANDARD.decode(auth_data) {
253 Err(_) => return Err(Invalid),
254 Ok(vec) => vec,
255 };
256
257 let mut split = combined.splitn(2, |&c| c == b':');
258 let client_bin = match split.next() {
259 None => return Err(Invalid),
260 Some(client) => client,
261 };
262 let passwd = match split.next() {
263 None => return Err(Invalid),
264 Some([]) => None,
265 Some(passwd64) => Some(passwd64),
266 };
267
268 let client = match from_utf8(client_bin) {
269 Err(_) => return Err(Invalid),
270 Ok(client) => client,
271 };
272
273 Authorization(client.to_string(), passwd.map(|passwd| passwd.to_vec()))
274 };
275
276 Ok(authorization)
277 }
278}
279
280impl<'a, R: WebRequest> TokenRequest for WrappedRequest<'a, R> {
281 fn valid(&self) -> bool {
282 self.error.is_none()
283 }
284
285 fn code(&self) -> Option<Cow<str>> {
286 self.body.unique_value("code")
287 }
288
289 fn authorization(&self) -> TokenAuthorization {
290 match &self.authorization {
291 None => TokenAuthorization::None,
292 Some(Authorization(username, None)) => TokenAuthorization::Username(username.into()),
293 Some(Authorization(username, Some(password))) => {
294 TokenAuthorization::UsernamePassword(username.into(), password.into())
295 }
296 }
297 }
298
299 fn client_id(&self) -> Option<Cow<str>> {
300 self.body.unique_value("client_id")
301 }
302
303 fn redirect_uri(&self) -> Option<Cow<str>> {
304 self.body.unique_value("redirect_uri")
305 }
306
307 fn grant_type(&self) -> Option<Cow<str>> {
308 self.body.unique_value("grant_type")
309 }
310
311 fn extension(&self, key: &str) -> Option<Cow<str>> {
312 self.body.unique_value(key)
313 }
314
315 fn allow_credentials_in_body(&self) -> bool {
316 self.allow_credentials_in_body
317 }
318}
319
320impl<E> From<Invalid> for FailParse<E> {
321 fn from(_: Invalid) -> Self {
322 FailParse::Invalid
323 }
324}
325
326#[cfg(test)]
327mod test {
328 use super::*;
329 use crate::endpoint::accesstoken::WrappedRequest;
330 use crate::frontends::simple::request::Request;
331
332 #[test]
333 fn test_client_id_only() {
334 let result = WrappedRequest::<Request>::parse_header("Basic Zm9vOg==".into());
335 assert!(result.is_ok());
336 let result = result.unwrap();
337 assert_eq!(result, Authorization("foo".into(), None));
338 }
339
340 #[test]
341 fn test_client_id_and_secret() {
342 let result = WrappedRequest::<Request>::parse_header("Basic Zm9vOmJhcg==".into());
343 assert!(result.is_ok());
344 let result = result.unwrap();
345 assert_eq!(result, Authorization("foo".into(), Some("bar".into())));
346 }
347}