oauth2_client/authorization_code_grant/
access_token_endpoint.rs

1use http_api_client_endpoint::{Body, Endpoint, Request, Response};
2use oauth2_core::{
3    access_token_request::{
4        Body as REQ_Body, BodyWithAuthorizationCodeGrant, CONTENT_TYPE as REQ_CONTENT_TYPE,
5        METHOD as REQ_METHOD,
6    },
7    access_token_response::{CONTENT_TYPE as RES_CONTENT_TYPE, GENERAL_ERROR_BODY_KEY_ERROR},
8    authorization_code_grant::access_token_response::{
9        ErrorBody as RES_ErrorBody, SuccessfulBody as RES_SuccessfulBody,
10    },
11    http::{
12        header::{ACCEPT, CONTENT_TYPE},
13        Error as HttpError,
14    },
15    serde::{de::DeserializeOwned, Serialize},
16    types::{Code, CodeVerifier, Scope},
17};
18use serde_json::{Error as SerdeJsonError, Map, Value};
19use serde_urlencoded::ser::Error as SerdeUrlencodedSerError;
20
21use crate::ProviderExtAuthorizationCodeGrant;
22
23//
24//
25//
26#[derive(Clone)]
27pub struct AccessTokenEndpoint<'a, SCOPE>
28where
29    SCOPE: Scope,
30{
31    provider: &'a (dyn ProviderExtAuthorizationCodeGrant<Scope = SCOPE> + Send + Sync),
32    code: Code,
33    pub code_verifier: Option<CodeVerifier>,
34}
35impl<'a, SCOPE> AccessTokenEndpoint<'a, SCOPE>
36where
37    SCOPE: Scope,
38{
39    pub fn new(
40        provider: &'a (dyn ProviderExtAuthorizationCodeGrant<Scope = SCOPE> + Send + Sync),
41        code: Code,
42    ) -> Self {
43        Self {
44            provider,
45            code,
46            code_verifier: None,
47        }
48    }
49
50    pub fn configure<F>(mut self, mut f: F) -> Self
51    where
52        F: FnMut(&mut Self),
53    {
54        f(&mut self);
55        self
56    }
57
58    pub fn set_code_verifier(&mut self, code_verifier: CodeVerifier) {
59        self.code_verifier = Some(code_verifier);
60    }
61}
62
63impl<'a, SCOPE> Endpoint for AccessTokenEndpoint<'a, SCOPE>
64where
65    SCOPE: Scope + Serialize + DeserializeOwned,
66{
67    type RenderRequestError = AccessTokenEndpointError;
68
69    type ParseResponseOutput = Result<RES_SuccessfulBody<SCOPE>, RES_ErrorBody>;
70    type ParseResponseError = AccessTokenEndpointError;
71
72    fn render_request(&self) -> Result<Request<Body>, Self::RenderRequestError> {
73        let mut body = BodyWithAuthorizationCodeGrant::new(
74            self.code.to_owned(),
75            self.provider.redirect_uri().map(|x| x.to_string()),
76            self.provider.client_id().cloned(),
77            self.provider.client_secret().cloned(),
78        );
79
80        if let Some(code_verifier) = &self.code_verifier {
81            body.code_verifier = Some(code_verifier.to_owned());
82        }
83
84        if let Some(extra) = self.provider.access_token_request_body_extra(&body) {
85            body.set_extra(extra);
86        }
87
88        if let Some(request_ret) = self.provider.access_token_request_rendering(&body) {
89            let request = request_ret
90                .map_err(|err| AccessTokenEndpointError::CustomRenderingRequestFailed(err))?;
91
92            return Ok(request);
93        }
94
95        //
96        let body = REQ_Body::<SCOPE>::AuthorizationCodeGrant(body);
97
98        let body_str = serde_urlencoded::to_string(body)
99            .map_err(AccessTokenEndpointError::SerRequestBodyFailed)?;
100
101        let request = Request::builder()
102            .method(REQ_METHOD)
103            .uri(self.provider.token_endpoint_url().as_str())
104            .header(CONTENT_TYPE, REQ_CONTENT_TYPE.to_string())
105            .header(ACCEPT, RES_CONTENT_TYPE.to_string())
106            .body(body_str.as_bytes().to_vec())
107            .map_err(AccessTokenEndpointError::MakeRequestFailed)?;
108
109        Ok(request)
110    }
111
112    fn parse_response(
113        &self,
114        response: Response<Body>,
115    ) -> Result<Self::ParseResponseOutput, Self::ParseResponseError> {
116        if let Some(body_ret_ret) = self.provider.access_token_response_parsing(&response) {
117            let body_ret = body_ret_ret
118                .map_err(|err| AccessTokenEndpointError::CustomParsingResponseFailed(err))?;
119
120            return Ok(body_ret);
121        }
122
123        //
124        if response.status().is_success() {
125            let map = serde_json::from_slice::<Map<String, Value>>(response.body())
126                .map_err(AccessTokenEndpointError::DeResponseBodyFailed)?;
127            if !map.contains_key(GENERAL_ERROR_BODY_KEY_ERROR) {
128                let body = serde_json::from_slice::<RES_SuccessfulBody<SCOPE>>(response.body())
129                    .map_err(AccessTokenEndpointError::DeResponseBodyFailed)?;
130
131                return Ok(Ok(body));
132            }
133        }
134
135        let body = serde_json::from_slice::<RES_ErrorBody>(response.body())
136            .map_err(AccessTokenEndpointError::DeResponseBodyFailed)?;
137        Ok(Err(body))
138    }
139}
140
141#[derive(thiserror::Error, Debug)]
142pub enum AccessTokenEndpointError {
143    #[error("CustomRenderingRequestFailed {0}")]
144    CustomRenderingRequestFailed(Box<dyn std::error::Error + Send + Sync>),
145    //
146    #[error("SerRequestBodyFailed {0}")]
147    SerRequestBodyFailed(SerdeUrlencodedSerError),
148    #[error("MakeRequestFailed {0}")]
149    MakeRequestFailed(HttpError),
150    //
151    #[error("CustomParsingResponseFailed {0}")]
152    CustomParsingResponseFailed(Box<dyn std::error::Error + Send + Sync>),
153    //
154    #[error("DeResponseBodyFailed {0}")]
155    DeResponseBodyFailed(SerdeJsonError),
156}