oauth2_client/jwt_authorization_grant/
access_token_endpoint.rs

1use http_api_client_endpoint::{Body, Endpoint, Request, Response};
2use oauth2_core::{
3    access_token_request::{
4        Body as REQ_Body, BodyWithJwtAuthorizationGrant, CONTENT_TYPE as REQ_CONTENT_TYPE,
5        METHOD as REQ_METHOD,
6    },
7    access_token_response::{CONTENT_TYPE as RES_CONTENT_TYPE, GENERAL_ERROR_BODY_KEY_ERROR},
8    http::{
9        header::{ACCEPT, CONTENT_TYPE},
10        Error as HttpError,
11    },
12    jwt_authorization_grant::access_token_response::{
13        ErrorBody as RES_ErrorBody, SuccessfulBody as RES_SuccessfulBody,
14    },
15    serde::{de::DeserializeOwned, Serialize},
16    types::Scope,
17};
18use serde_json::{Error as SerdeJsonError, Map, Value};
19use serde_urlencoded::ser::Error as SerdeUrlencodedSerError;
20
21use crate::ProviderExtJwtAuthorizationGrant;
22
23//
24//
25//
26#[derive(Clone)]
27pub struct AccessTokenEndpoint<'a, SCOPE>
28where
29    SCOPE: Scope,
30{
31    provider: &'a (dyn ProviderExtJwtAuthorizationGrant<Scope = SCOPE> + Send + Sync),
32    scopes: Option<Vec<SCOPE>>,
33}
34impl<'a, SCOPE> AccessTokenEndpoint<'a, SCOPE>
35where
36    SCOPE: Scope,
37{
38    pub fn new(
39        provider: &'a (dyn ProviderExtJwtAuthorizationGrant<Scope = SCOPE> + Send + Sync),
40        scopes: impl Into<Option<Vec<SCOPE>>>,
41    ) -> Self {
42        Self {
43            provider,
44            scopes: scopes.into(),
45        }
46    }
47}
48
49impl<'a, SCOPE> Endpoint for AccessTokenEndpoint<'a, SCOPE>
50where
51    SCOPE: Scope + Serialize + DeserializeOwned,
52{
53    type RenderRequestError = AccessTokenEndpointError;
54
55    type ParseResponseOutput = Result<RES_SuccessfulBody<SCOPE>, RES_ErrorBody>;
56    type ParseResponseError = AccessTokenEndpointError;
57
58    fn render_request(&self) -> Result<Request<Body>, Self::RenderRequestError> {
59        //
60        let mut url = self.provider.token_endpoint_url().to_owned();
61
62        //
63        self.provider.access_token_request_url_modifying(&mut url);
64
65        //
66        let mut body = BodyWithJwtAuthorizationGrant::new(
67            self.provider.assertion().to_owned(),
68            self.scopes.to_owned().map(Into::into),
69            self.provider.client_id().map(Into::into),
70        );
71        if let Some(extra_ret) = self.provider.access_token_request_body_extra(&body) {
72            match extra_ret {
73                Ok(extra) => {
74                    body.set_extra(extra);
75                }
76                Err(err) => {
77                    return Err(AccessTokenEndpointError::MakeRequestBodyExtraFailed(err));
78                }
79            }
80        }
81
82        //
83        let body = REQ_Body::<SCOPE>::JwtAuthorizationGrant(body);
84
85        let body_str = serde_urlencoded::to_string(body)
86            .map_err(AccessTokenEndpointError::SerRequestBodyFailed)?;
87
88        //
89        let mut request = Request::builder()
90            .method(REQ_METHOD)
91            .uri(url.as_str())
92            .header(CONTENT_TYPE, REQ_CONTENT_TYPE.to_string())
93            .header(ACCEPT, RES_CONTENT_TYPE.to_string())
94            .body(body_str.as_bytes().to_vec())
95            .map_err(AccessTokenEndpointError::MakeRequestFailed)?;
96
97        //
98        self.provider.access_token_request_modifying(&mut request);
99
100        Ok(request)
101    }
102
103    fn parse_response(
104        &self,
105        response: Response<Body>,
106    ) -> Result<Self::ParseResponseOutput, Self::ParseResponseError> {
107        //
108        if response.status().is_success() {
109            let map = serde_json::from_slice::<Map<String, Value>>(response.body())
110                .map_err(AccessTokenEndpointError::DeResponseBodyFailed)?;
111            if !map.contains_key(GENERAL_ERROR_BODY_KEY_ERROR) {
112                let body = serde_json::from_slice::<RES_SuccessfulBody<SCOPE>>(response.body())
113                    .map_err(AccessTokenEndpointError::DeResponseBodyFailed)?;
114
115                return Ok(Ok(body));
116            }
117        }
118
119        let body = serde_json::from_slice::<RES_ErrorBody>(response.body())
120            .map_err(AccessTokenEndpointError::DeResponseBodyFailed)?;
121        Ok(Err(body))
122    }
123}
124
125#[derive(thiserror::Error, Debug)]
126pub enum AccessTokenEndpointError {
127    #[error("MakeRequestBodyExtraFailed {0}")]
128    MakeRequestBodyExtraFailed(Box<dyn std::error::Error + Send + Sync>),
129    //
130    #[error("SerRequestBodyFailed {0}")]
131    SerRequestBodyFailed(SerdeUrlencodedSerError),
132    #[error("MakeRequestFailed {0}")]
133    MakeRequestFailed(HttpError),
134    //
135    #[error("DeResponseBodyFailed {0}")]
136    DeResponseBodyFailed(SerdeJsonError),
137}