raos/token/
request.rs

1use base64::{prelude::BASE64_STANDARD, Engine};
2
3use crate::common::{
4    frontend::{FrontendRequest, FrontendRequestMethod, OAuthValidationError},
5    syntax::{ValidateSyntax, CLIENT_ID_SYNTAX},
6    util::NoneIfEmpty,
7};
8
9/// A parsed request to exchange an authorization code, refresh code or client credentials for an access token.
10#[derive(Debug)]
11pub struct TokenRequest {
12    /// The client ID.
13    pub client_id: String,
14    /// The client secret.
15    pub client_secret: Option<String>,
16    /// The type of grant requested by the client.
17    pub grant_type: RequestedGrantType,
18    /// The redirect_uri that is repeated in the token request, for compatibility with OAuth 2.0.
19    pub redirect_uri: Option<String>,
20    /// The requested scope, used when refreshing a token using the refresh token grant type.
21    pub scope: Option<Vec<String>>,
22}
23
24/// The type of grant requested by the client.
25#[derive(Debug)]
26pub enum RequestedGrantType {
27    /// The client is requesting an access token using client credentials.
28    ClientCredentials,
29    /// The client is requesting an access token using an authorization code.
30    AuthorizationCode {
31        /// The authorization code.
32        code: String,
33        /// The code verifier used to answer the PKCE challenge.
34        code_verifier: String,
35    },
36    /// The client is requesting an access token using a refresh token.
37    RefreshToken {
38        /// The refresh token.
39        refresh_token: String,
40    },
41}
42
43impl TryFrom<&dyn FrontendRequest> for TokenRequest {
44    type Error = OAuthValidationError;
45
46    fn try_from(request: &dyn FrontendRequest) -> Result<Self, Self::Error> {
47        if !matches!(request.request_method(), FrontendRequestMethod::POST) {
48            return Err(OAuthValidationError::InvalidRequestMethod {
49                expected: FrontendRequestMethod::POST,
50                actual: request.request_method(),
51            });
52        }
53
54        // We should treat empty values as if they were omitted from the request
55        let body_param = |key| request.body_param(key).none_if_empty();
56
57        let header_credentials = get_credentials_from_header(request)?;
58
59        let (mut client_id, mut client_secret) = header_credentials.unzip();
60        if let Some(body_client_id) = body_param("client_id") {
61            if let Some(header_client_id) = client_id {
62                if body_client_id != header_client_id {
63                    return Err(OAuthValidationError::MismatchedClientCredentials);
64                }
65            }
66
67            client_id = Some(body_client_id);
68        };
69        if let Some(body_client_secret) = body_param("client_secret") {
70            if client_secret.is_none() {
71                client_secret = Some(body_client_secret);
72            }
73        };
74
75        let Some(client_id) = client_id else {
76            return Err(OAuthValidationError::MissingRequiredParameter("client_id"));
77        };
78        client_id.validate_syntax("client_id", &CLIENT_ID_SYNTAX)?;
79
80        let Some(grant_type_str) = body_param("grant_type") else {
81            return Err(OAuthValidationError::MissingRequiredParameter("grant_type"));
82        };
83
84        let grant_type = match grant_type_str.as_str() {
85            "client_credentials" => RequestedGrantType::ClientCredentials,
86            "authorization_code" => {
87                let code = match body_param("code") {
88                    Some(code) => code,
89                    None => return Err(OAuthValidationError::MissingRequiredParameter("code")),
90                };
91                let code_verifier = match body_param("code_verifier") {
92                    Some(code_verifier) => code_verifier,
93                    None => {
94                        return Err(OAuthValidationError::MissingRequiredParameter("code_verifier"))
95                    }
96                };
97                RequestedGrantType::AuthorizationCode { code, code_verifier }
98            }
99            "refresh_token" => {
100                let refresh_token = match body_param("refresh_token") {
101                    Some(refresh_token) => refresh_token,
102                    None => {
103                        return Err(OAuthValidationError::MissingRequiredParameter("refresh_token"))
104                    }
105                };
106                RequestedGrantType::RefreshToken { refresh_token }
107            }
108            _ => {
109                return Err(OAuthValidationError::InvalidGrantType {
110                    requested: grant_type_str.to_string(),
111                });
112            }
113        };
114
115        let scope = body_param("scope").map(|s| s.split(" ").map(str::to_string).collect());
116
117        Ok(TokenRequest {
118            client_id,
119            client_secret,
120            grant_type,
121            scope,
122            redirect_uri: body_param("redirect_uri"),
123        })
124    }
125}
126
127fn get_credentials_from_header(
128    request: &dyn FrontendRequest,
129) -> Result<Option<(String, String)>, OAuthValidationError> {
130    static MASKED: &str = "<masked>";
131
132    if let Some(authorization_header) = request.header_param("authorization") {
133        // Parse the authorization header
134        let parts: Vec<&str> = authorization_header.split_whitespace().collect();
135        if parts.len() != 2 {
136            return Err(OAuthValidationError::InvalidParameterValue(
137                "authorization",
138                MASKED.to_string(),
139            ));
140        }
141        if parts[0].to_lowercase() != "Basic" {
142            return Err(OAuthValidationError::InvalidParameterValue(
143                "authorization",
144                MASKED.to_string(),
145            ));
146        }
147        let decoded = BASE64_STANDARD.decode(parts[1]).map_err(|_| {
148            OAuthValidationError::InvalidParameterValue("authorization", MASKED.to_string())
149        })?;
150        let decoded_str = std::str::from_utf8(&decoded).map_err(|_| {
151            OAuthValidationError::InvalidParameterValue("authorization", MASKED.to_string())
152        })?;
153        let parts: Vec<&str> = decoded_str.split(':').collect();
154        if parts.len() != 2 {
155            return Err(OAuthValidationError::InvalidParameterValue(
156                "authorization",
157                MASKED.to_string(),
158            ));
159        }
160        return Ok(Some((parts[0].to_string(), parts[1].to_string())));
161    }
162
163    Ok(None)
164}