oxide_auth_async/endpoint/
refresh.rs1use std::{borrow::Cow, marker::PhantomData, str::from_utf8};
2
3use base64::{engine::general_purpose::STANDARD, Engine};
4use oxide_auth::{
5 code_grant::refresh::{Error, Request},
6 endpoint::{WebRequest, WebResponse, OAuthError, QueryParameter, Template, NormalizedParameter},
7};
8
9use super::Endpoint;
10use crate::{
11 code_grant::refresh::{refresh, Endpoint as RefreshEndpoint},
12 primitives::{Issuer, Registrar},
13};
14
15pub struct RefreshFlow<E, R>
17where
18 E: Endpoint<R>,
19 R: WebRequest,
20{
21 endpoint: WrappedRefresh<E, R>,
22}
23
24struct WrappedRefresh<E, R>
25where
26 E: Endpoint<R>,
27 R: WebRequest,
28{
29 inner: E,
30 r_type: PhantomData<R>,
31}
32
33struct WrappedRequest<R: WebRequest> {
34 body: NormalizedParameter,
36
37 authorization: Option<Authorization>,
39
40 error: Option<Option<R::Error>>,
42}
43
44struct Authorization(String, Vec<u8>);
45
46impl<E, R> RefreshFlow<E, R>
47where
48 E: Endpoint<R> + Send + Sync,
49 R: WebRequest + Send + Sync,
50 <R as WebRequest>::Error: Send + Sync,
51{
52 pub fn prepare(mut endpoint: E) -> Result<Self, E::Error> {
66 if endpoint.registrar().is_none() {
67 return Err(endpoint.error(OAuthError::PrimitiveError));
68 }
69
70 if endpoint.issuer_mut().is_none() {
71 return Err(endpoint.error(OAuthError::PrimitiveError));
72 }
73
74 Ok(RefreshFlow {
75 endpoint: WrappedRefresh {
76 inner: endpoint,
77 r_type: PhantomData,
78 },
79 })
80 }
81
82 pub async fn execute(&mut self, mut request: R) -> Result<R::Response, E::Error> {
83 let refreshed = refresh(&mut self.endpoint, &WrappedRequest::new(&mut request)).await;
84
85 let token = match refreshed {
86 Err(error) => return token_error(&mut self.endpoint.inner, &mut request, error),
87 Ok(token) => token,
88 };
89
90 let mut response = self.endpoint.inner.response(&mut request, Template::new_ok())?;
91 response
92 .body_json(&token.to_json())
93 .map_err(|err| self.endpoint.inner.web_error(err))?;
94 Ok(response)
95 }
96}
97
98fn token_error<E, R>(endpoint: &mut E, request: &mut R, error: Error) -> Result<R::Response, E::Error>
99where
100 E: Endpoint<R>,
101 R: WebRequest,
102{
103 Ok(match error {
104 Error::Invalid(mut json) => {
105 let mut response =
106 endpoint.response(request, Template::new_bad(Some(json.description())))?;
107 response.client_error().map_err(|err| endpoint.web_error(err))?;
108 response
109 .body_json(&json.to_json())
110 .map_err(|err| endpoint.web_error(err))?;
111 response
112 }
113 Error::Unauthorized(mut json, scheme) => {
114 let mut response = endpoint.response(
115 request,
116 Template::new_unauthorized(None, Some(json.description())),
117 )?;
118 response
119 .unauthorized(&scheme)
120 .map_err(|err| endpoint.web_error(err))?;
121 response
122 .body_json(&json.to_json())
123 .map_err(|err| endpoint.web_error(err))?;
124 response
125 }
126 Error::Primitive => {
127 return Err(endpoint.error(OAuthError::PrimitiveError));
129 }
130 })
131}
132
133impl<'a, R: WebRequest> WrappedRequest<R> {
134 pub fn new(request: &'a mut R) -> Self {
135 Self::new_or_fail(request).unwrap_or_else(Self::from_err)
136 }
137
138 fn new_or_fail(request: &'a mut R) -> Result<Self, Option<R::Error>> {
139 let authorization = match request.authheader() {
141 Err(err) => return Err(Some(err)),
142 Ok(Some(header)) => Self::parse_header(header).map(Some)?,
143 Ok(None) => None,
144 };
145
146 Ok(WrappedRequest {
147 body: request.urlbody()?.into_owned(),
148 authorization,
149 error: None,
150 })
151 }
152
153 fn from_err(err: Option<R::Error>) -> Self {
154 WrappedRequest {
155 body: Default::default(),
156 authorization: None,
157 error: Some(err),
158 }
159 }
160
161 fn parse_header(header: Cow<str>) -> Result<Authorization, Option<R::Error>> {
162 let authorization = {
163 if !header.starts_with("Basic ") {
164 return Err(None);
165 }
166
167 let combined = match STANDARD.decode(&header[6..]) {
168 Err(_) => return Err(None),
169 Ok(vec) => vec,
170 };
171
172 let mut split = combined.splitn(2, |&c| c == b':');
173 let client_bin = match split.next() {
174 None => return Err(None),
175 Some(client) => client,
176 };
177 let passwd = match split.next() {
178 None => return Err(None),
179 Some(passwd64) => passwd64,
180 };
181
182 let client = match from_utf8(client_bin) {
183 Err(_) => return Err(None),
184 Ok(client) => client,
185 };
186
187 Authorization(client.to_string(), passwd.to_vec())
188 };
189
190 Ok(authorization)
191 }
192}
193
194impl<E, R> RefreshEndpoint for WrappedRefresh<E, R>
195where
196 E: Endpoint<R>,
197 R: WebRequest,
198{
199 fn registrar(&self) -> &(dyn Registrar + Sync) {
200 self.inner.registrar().unwrap()
201 }
202
203 fn issuer(&mut self) -> &mut (dyn Issuer + Send) {
204 self.inner.issuer_mut().unwrap()
205 }
206}
207
208impl<R: WebRequest> Request for WrappedRequest<R> {
209 fn valid(&self) -> bool {
210 self.error.is_none()
211 }
212
213 fn refresh_token(&self) -> Option<Cow<str>> {
214 self.body.unique_value("refresh_token")
215 }
216
217 fn authorization(&self) -> Option<(Cow<str>, Cow<[u8]>)> {
218 self.authorization
219 .as_ref()
220 .map(|auth| (auth.0.as_str().into(), auth.1.as_slice().into()))
221 }
222
223 fn scope(&self) -> Option<Cow<str>> {
224 self.body.unique_value("scope")
225 }
226
227 fn grant_type(&self) -> Option<Cow<str>> {
228 self.body.unique_value("grant_type")
229 }
230
231 fn extension(&self, key: &str) -> Option<Cow<str>> {
232 self.body.unique_value(key)
233 }
234}