oauth2_client/authorization_code_grant/
access_token_endpoint.rs1use http_api_client_endpoint::{Body, Endpoint, Request, Response};
2use oauth2_core::{
3 access_token_request::{
4 Body as REQ_Body, BodyWithAuthorizationCodeGrant, CONTENT_TYPE as REQ_CONTENT_TYPE,
5 METHOD as REQ_METHOD,
6 },
7 access_token_response::{CONTENT_TYPE as RES_CONTENT_TYPE, GENERAL_ERROR_BODY_KEY_ERROR},
8 authorization_code_grant::access_token_response::{
9 ErrorBody as RES_ErrorBody, SuccessfulBody as RES_SuccessfulBody,
10 },
11 http::{
12 header::{ACCEPT, CONTENT_TYPE},
13 Error as HttpError,
14 },
15 serde::{de::DeserializeOwned, Serialize},
16 types::{Code, CodeVerifier, Scope},
17};
18use serde_json::{Error as SerdeJsonError, Map, Value};
19use serde_urlencoded::ser::Error as SerdeUrlencodedSerError;
20
21use crate::ProviderExtAuthorizationCodeGrant;
22
23#[derive(Clone)]
27pub struct AccessTokenEndpoint<'a, SCOPE>
28where
29 SCOPE: Scope,
30{
31 provider: &'a (dyn ProviderExtAuthorizationCodeGrant<Scope = SCOPE> + Send + Sync),
32 code: Code,
33 pub code_verifier: Option<CodeVerifier>,
34}
35impl<'a, SCOPE> AccessTokenEndpoint<'a, SCOPE>
36where
37 SCOPE: Scope,
38{
39 pub fn new(
40 provider: &'a (dyn ProviderExtAuthorizationCodeGrant<Scope = SCOPE> + Send + Sync),
41 code: Code,
42 ) -> Self {
43 Self {
44 provider,
45 code,
46 code_verifier: None,
47 }
48 }
49
50 pub fn configure<F>(mut self, mut f: F) -> Self
51 where
52 F: FnMut(&mut Self),
53 {
54 f(&mut self);
55 self
56 }
57
58 pub fn set_code_verifier(&mut self, code_verifier: CodeVerifier) {
59 self.code_verifier = Some(code_verifier);
60 }
61}
62
63impl<'a, SCOPE> Endpoint for AccessTokenEndpoint<'a, SCOPE>
64where
65 SCOPE: Scope + Serialize + DeserializeOwned,
66{
67 type RenderRequestError = AccessTokenEndpointError;
68
69 type ParseResponseOutput = Result<RES_SuccessfulBody<SCOPE>, RES_ErrorBody>;
70 type ParseResponseError = AccessTokenEndpointError;
71
72 fn render_request(&self) -> Result<Request<Body>, Self::RenderRequestError> {
73 let mut body = BodyWithAuthorizationCodeGrant::new(
74 self.code.to_owned(),
75 self.provider.redirect_uri().map(|x| x.to_string()),
76 self.provider.client_id().cloned(),
77 self.provider.client_secret().cloned(),
78 );
79
80 if let Some(code_verifier) = &self.code_verifier {
81 body.code_verifier = Some(code_verifier.to_owned());
82 }
83
84 if let Some(extra) = self.provider.access_token_request_body_extra(&body) {
85 body.set_extra(extra);
86 }
87
88 if let Some(request_ret) = self.provider.access_token_request_rendering(&body) {
89 let request = request_ret
90 .map_err(|err| AccessTokenEndpointError::CustomRenderingRequestFailed(err))?;
91
92 return Ok(request);
93 }
94
95 let body = REQ_Body::<SCOPE>::AuthorizationCodeGrant(body);
97
98 let body_str = serde_urlencoded::to_string(body)
99 .map_err(AccessTokenEndpointError::SerRequestBodyFailed)?;
100
101 let request = Request::builder()
102 .method(REQ_METHOD)
103 .uri(self.provider.token_endpoint_url().as_str())
104 .header(CONTENT_TYPE, REQ_CONTENT_TYPE.to_string())
105 .header(ACCEPT, RES_CONTENT_TYPE.to_string())
106 .body(body_str.as_bytes().to_vec())
107 .map_err(AccessTokenEndpointError::MakeRequestFailed)?;
108
109 Ok(request)
110 }
111
112 fn parse_response(
113 &self,
114 response: Response<Body>,
115 ) -> Result<Self::ParseResponseOutput, Self::ParseResponseError> {
116 if let Some(body_ret_ret) = self.provider.access_token_response_parsing(&response) {
117 let body_ret = body_ret_ret
118 .map_err(|err| AccessTokenEndpointError::CustomParsingResponseFailed(err))?;
119
120 return Ok(body_ret);
121 }
122
123 if response.status().is_success() {
125 let map = serde_json::from_slice::<Map<String, Value>>(response.body())
126 .map_err(AccessTokenEndpointError::DeResponseBodyFailed)?;
127 if !map.contains_key(GENERAL_ERROR_BODY_KEY_ERROR) {
128 let body = serde_json::from_slice::<RES_SuccessfulBody<SCOPE>>(response.body())
129 .map_err(AccessTokenEndpointError::DeResponseBodyFailed)?;
130
131 return Ok(Ok(body));
132 }
133 }
134
135 let body = serde_json::from_slice::<RES_ErrorBody>(response.body())
136 .map_err(AccessTokenEndpointError::DeResponseBodyFailed)?;
137 Ok(Err(body))
138 }
139}
140
141#[derive(thiserror::Error, Debug)]
142pub enum AccessTokenEndpointError {
143 #[error("CustomRenderingRequestFailed {0}")]
144 CustomRenderingRequestFailed(Box<dyn std::error::Error + Send + Sync>),
145 #[error("SerRequestBodyFailed {0}")]
147 SerRequestBodyFailed(SerdeUrlencodedSerError),
148 #[error("MakeRequestFailed {0}")]
149 MakeRequestFailed(HttpError),
150 #[error("CustomParsingResponseFailed {0}")]
152 CustomParsingResponseFailed(Box<dyn std::error::Error + Send + Sync>),
153 #[error("DeResponseBodyFailed {0}")]
155 DeResponseBodyFailed(SerdeJsonError),
156}