1use base64::{prelude::BASE64_STANDARD, Engine};
2
3use crate::common::{
4 frontend::{FrontendRequest, FrontendRequestMethod, OAuthValidationError},
5 syntax::{ValidateSyntax, CLIENT_ID_SYNTAX},
6 util::NoneIfEmpty,
7};
8
9#[derive(Debug)]
11pub struct TokenRequest {
12 pub client_id: String,
14 pub client_secret: Option<String>,
16 pub grant_type: RequestedGrantType,
18 pub redirect_uri: Option<String>,
20 pub scope: Option<Vec<String>>,
22}
23
24#[derive(Debug)]
26pub enum RequestedGrantType {
27 ClientCredentials,
29 AuthorizationCode {
31 code: String,
33 code_verifier: String,
35 },
36 RefreshToken {
38 refresh_token: String,
40 },
41}
42
43impl TryFrom<&dyn FrontendRequest> for TokenRequest {
44 type Error = OAuthValidationError;
45
46 fn try_from(request: &dyn FrontendRequest) -> Result<Self, Self::Error> {
47 if !matches!(request.request_method(), FrontendRequestMethod::POST) {
48 return Err(OAuthValidationError::InvalidRequestMethod {
49 expected: FrontendRequestMethod::POST,
50 actual: request.request_method(),
51 });
52 }
53
54 let body_param = |key| request.body_param(key).none_if_empty();
56
57 let header_credentials = get_credentials_from_header(request)?;
58
59 let (mut client_id, mut client_secret) = header_credentials.unzip();
60 if let Some(body_client_id) = body_param("client_id") {
61 if let Some(header_client_id) = client_id {
62 if body_client_id != header_client_id {
63 return Err(OAuthValidationError::MismatchedClientCredentials);
64 }
65 }
66
67 client_id = Some(body_client_id);
68 };
69 if let Some(body_client_secret) = body_param("client_secret") {
70 if client_secret.is_none() {
71 client_secret = Some(body_client_secret);
72 }
73 };
74
75 let Some(client_id) = client_id else {
76 return Err(OAuthValidationError::MissingRequiredParameter("client_id"));
77 };
78 client_id.validate_syntax("client_id", &CLIENT_ID_SYNTAX)?;
79
80 let Some(grant_type_str) = body_param("grant_type") else {
81 return Err(OAuthValidationError::MissingRequiredParameter("grant_type"));
82 };
83
84 let grant_type = match grant_type_str.as_str() {
85 "client_credentials" => RequestedGrantType::ClientCredentials,
86 "authorization_code" => {
87 let code = match body_param("code") {
88 Some(code) => code,
89 None => return Err(OAuthValidationError::MissingRequiredParameter("code")),
90 };
91 let code_verifier = match body_param("code_verifier") {
92 Some(code_verifier) => code_verifier,
93 None => {
94 return Err(OAuthValidationError::MissingRequiredParameter("code_verifier"))
95 }
96 };
97 RequestedGrantType::AuthorizationCode { code, code_verifier }
98 }
99 "refresh_token" => {
100 let refresh_token = match body_param("refresh_token") {
101 Some(refresh_token) => refresh_token,
102 None => {
103 return Err(OAuthValidationError::MissingRequiredParameter("refresh_token"))
104 }
105 };
106 RequestedGrantType::RefreshToken { refresh_token }
107 }
108 _ => {
109 return Err(OAuthValidationError::InvalidGrantType {
110 requested: grant_type_str.to_string(),
111 });
112 }
113 };
114
115 let scope = body_param("scope").map(|s| s.split(" ").map(str::to_string).collect());
116
117 Ok(TokenRequest {
118 client_id,
119 client_secret,
120 grant_type,
121 scope,
122 redirect_uri: body_param("redirect_uri"),
123 })
124 }
125}
126
127fn get_credentials_from_header(
128 request: &dyn FrontendRequest,
129) -> Result<Option<(String, String)>, OAuthValidationError> {
130 static MASKED: &str = "<masked>";
131
132 if let Some(authorization_header) = request.header_param("authorization") {
133 let parts: Vec<&str> = authorization_header.split_whitespace().collect();
135 if parts.len() != 2 {
136 return Err(OAuthValidationError::InvalidParameterValue(
137 "authorization",
138 MASKED.to_string(),
139 ));
140 }
141 if parts[0].to_lowercase() != "Basic" {
142 return Err(OAuthValidationError::InvalidParameterValue(
143 "authorization",
144 MASKED.to_string(),
145 ));
146 }
147 let decoded = BASE64_STANDARD.decode(parts[1]).map_err(|_| {
148 OAuthValidationError::InvalidParameterValue("authorization", MASKED.to_string())
149 })?;
150 let decoded_str = std::str::from_utf8(&decoded).map_err(|_| {
151 OAuthValidationError::InvalidParameterValue("authorization", MASKED.to_string())
152 })?;
153 let parts: Vec<&str> = decoded_str.split(':').collect();
154 if parts.len() != 2 {
155 return Err(OAuthValidationError::InvalidParameterValue(
156 "authorization",
157 MASKED.to_string(),
158 ));
159 }
160 return Ok(Some((parts[0].to_string(), parts[1].to_string())));
161 }
162
163 Ok(None)
164}