oauth2_client/authorization_code_grant/
authorization_endpoint.rs

1use core::convert::Infallible;
2
3use http_api_client_endpoint::{Body, Endpoint, Request, Response};
4use oauth2_core::{
5    access_token_response::GENERAL_ERROR_BODY_KEY_ERROR,
6    authorization_code_grant::{
7        authorization_request::{Query as REQ_Query, METHOD as REQ_METHOD},
8        authorization_response::{
9            ErrorQuery as RES_ErrorQuery, SuccessfulQuery as RES_SuccessfulQuery,
10        },
11    },
12    http::Error as HttpError,
13    serde::Serialize,
14    types::{CodeChallenge, CodeChallengeMethod, Nonce, Scope, State},
15};
16use serde_json::{Map, Value};
17use serde_qs::Error as SerdeQsError;
18
19use crate::ProviderExtAuthorizationCodeGrant;
20
21//
22//
23//
24#[derive(Clone)]
25pub struct AuthorizationEndpoint<'a, SCOPE>
26where
27    SCOPE: Scope,
28{
29    provider: &'a dyn ProviderExtAuthorizationCodeGrant<Scope = SCOPE>,
30    scopes: Option<Vec<SCOPE>>,
31    pub state: Option<State>,
32    pub code_challenge: Option<(CodeChallenge, CodeChallengeMethod)>,
33    pub nonce: Option<Nonce>,
34}
35impl<'a, SCOPE> AuthorizationEndpoint<'a, SCOPE>
36where
37    SCOPE: Scope,
38{
39    pub fn new(
40        provider: &'a dyn ProviderExtAuthorizationCodeGrant<Scope = SCOPE>,
41        scopes: impl Into<Option<Vec<SCOPE>>>,
42    ) -> Self {
43        Self {
44            provider,
45            scopes: scopes.into(),
46            state: None,
47            code_challenge: None,
48            nonce: None,
49        }
50    }
51
52    pub fn configure<F>(mut self, mut f: F) -> Self
53    where
54        F: FnMut(&mut Self),
55    {
56        f(&mut self);
57        self
58    }
59
60    pub fn set_state(&mut self, state: State) {
61        self.state = Some(state);
62    }
63
64    pub fn set_code_challenge(
65        &mut self,
66        code_challenge: CodeChallenge,
67        code_challenge_method: CodeChallengeMethod,
68    ) {
69        self.code_challenge = Some((code_challenge, code_challenge_method));
70    }
71
72    pub fn set_nonce(&mut self, nonce: Nonce) {
73        self.nonce = Some(nonce);
74    }
75}
76
77impl<'a, SCOPE> Endpoint for AuthorizationEndpoint<'a, SCOPE>
78where
79    SCOPE: Scope + Serialize,
80{
81    type RenderRequestError = AuthorizationEndpointError;
82
83    type ParseResponseOutput = ();
84    type ParseResponseError = Infallible;
85
86    fn render_request(&self) -> Result<Request<Body>, Self::RenderRequestError> {
87        let mut query = REQ_Query::new(
88            self.provider
89                .client_id()
90                .cloned()
91                .ok_or(AuthorizationEndpointError::ClientIdMissing)?,
92            self.provider.redirect_uri().map(|x| x.to_string()),
93            self.scopes.to_owned().map(Into::into),
94            self.state.to_owned(),
95        );
96        if let Some((code_challenge, code_challenge_method)) = &self.code_challenge {
97            query.code_challenge = Some(code_challenge.to_owned());
98            query.code_challenge_method = Some(code_challenge_method.to_owned());
99        }
100        query.nonce = self.nonce.to_owned();
101
102        if let Some(extra) = self.provider.authorization_request_query_extra() {
103            query.set_extra(extra);
104        }
105
106        let query_str = if let Some(query_str_ret) = self
107            .provider
108            .authorization_request_query_serializing(&query)
109        {
110            query_str_ret
111                .map_err(|err| AuthorizationEndpointError::CustomSerRequestQueryFailed(err))?
112        } else {
113            serde_qs::to_string(&query)
114                .map_err(AuthorizationEndpointError::SerRequestQueryFailed)?
115        };
116
117        let mut url = self.provider.authorization_endpoint_url().to_owned();
118        url.set_query(Some(query_str.as_str()));
119
120        //
121        self.provider.authorization_request_url_modifying(&mut url);
122
123        //
124        let request = Request::builder()
125            .method(REQ_METHOD)
126            .uri(url.as_str())
127            .body(vec![])
128            .map_err(AuthorizationEndpointError::MakeRequestFailed)?;
129
130        Ok(request)
131    }
132
133    fn parse_response(
134        &self,
135        _response: Response<Body>,
136    ) -> Result<Self::ParseResponseOutput, Self::ParseResponseError> {
137        unreachable!()
138    }
139}
140
141#[derive(thiserror::Error, Debug)]
142pub enum AuthorizationEndpointError {
143    #[error("ClientIdMissing")]
144    ClientIdMissing,
145    //
146    #[error("CustomSerRequestQueryFailed {0}")]
147    CustomSerRequestQueryFailed(Box<dyn std::error::Error + Send + Sync>),
148    //
149    #[error("SerRequestQueryFailed {0}")]
150    SerRequestQueryFailed(SerdeQsError),
151    #[error("MakeRequestFailed {0}")]
152    MakeRequestFailed(HttpError),
153}
154
155//
156//
157//
158pub fn parse_redirect_uri_query(
159    query_str: impl AsRef<str>,
160) -> Result<Result<RES_SuccessfulQuery, RES_ErrorQuery>, ParseRedirectUriQueryError> {
161    let map = serde_qs::from_str::<Map<String, Value>>(query_str.as_ref())?;
162    if !map.contains_key(GENERAL_ERROR_BODY_KEY_ERROR) {
163        let query = serde_qs::from_str::<RES_SuccessfulQuery>(query_str.as_ref())?;
164
165        return Ok(Ok(query));
166    }
167
168    let query = serde_qs::from_str::<RES_ErrorQuery>(query_str.as_ref())?;
169
170    Ok(Err(query))
171}
172
173pub type ParseRedirectUriQueryError = SerdeQsError;